modrec-workflow/data/training/cm_plotter.py
2025-05-22 14:11:18 -04:00

67 lines
1.9 KiB
Python

import numpy as np
from typing import Optional
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix
def plot_confusion_matrix(
y_true: np.array,
y_pred: np.array,
classes: list,
normalize: bool = True,
title: Optional[str] = None,
text: bool = True,
rotate_x_text: int = 90,
figsize: tuple = (16, 9),
cmap: plt.cm = plt.cm.Blues,
):
"""Function to help plot confusion matrices
https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
"""
if not title:
if normalize:
title = "Normalized confusion matrix"
else:
title = "Confusion matrix, without normalization"
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
if normalize:
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation="none", cmap=cmap)
ax.figure.colorbar(im, ax=ax)
ax.set(
xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=classes,
yticklabels=classes,
title=title,
ylabel="True label",
xlabel="Predicted label",
)
ax.set_xticklabels(classes, rotation=rotate_x_text)
ax.figure.set_size_inches(figsize)
# Loop over data dimensions and create text annotations.
fmt = ".2f" if normalize else "d"
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
if text:
ax.text(
j,
i,
format(cm[i, j], fmt),
ha="center",
va="center",
color="white" if cm[i, j] > thresh else "black",
)
if len(classes) == 2:
plt.axis([-0.5, 1.5, 1.5, -0.5])
fig.tight_layout()
return ax