All checks were successful
Modulation Recognition Demo / ria-demo (push) Successful in 2m52s
Documentation and formatting updates: - Updates to project README. - Adding project health files (`LICENSE` and `SECURITY.md`) - A few minor formatting changes throughout - A few typo fixes, removal of unused code, cleanup of shadowed variables, and fixed import ordering with isort. **Note:** These changes have not been tested. Co-authored-by: Michael Luciuk <michael.luciuk@gmail.com> Co-authored-by: Liyu Xiao <liyu@qoherent.ai> Reviewed-on: https://git.riahub.ai/qoherent/modrec-workflow/pulls/1 Reviewed-by: Liyux <liyux@noreply.localhost> Co-authored-by: Michael Luciuk <michael@qoherent.ai> Co-committed-by: Michael Luciuk <michael@qoherent.ai>
68 lines
1.9 KiB
Python
68 lines
1.9 KiB
Python
from typing import Optional
|
|
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
from sklearn.metrics import confusion_matrix
|
|
|
|
|
|
def plot_confusion_matrix(
|
|
y_true: np.array,
|
|
y_pred: np.array,
|
|
classes: list,
|
|
normalize: bool = True,
|
|
title: Optional[str] = None,
|
|
text: bool = True,
|
|
rotate_x_text: int = 90,
|
|
figsize: tuple = (16, 9),
|
|
cmap: plt.cm = plt.cm.Blues,
|
|
):
|
|
"""Function to help plot confusion matrices
|
|
|
|
https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
|
|
"""
|
|
if not title:
|
|
if normalize:
|
|
title = "Normalized confusion matrix"
|
|
else:
|
|
title = "Confusion matrix, without normalization"
|
|
|
|
# Compute confusion matrix
|
|
cm = confusion_matrix(y_true, y_pred)
|
|
if normalize:
|
|
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
|
|
|
|
fig, ax = plt.subplots()
|
|
im = ax.imshow(cm, interpolation="none", cmap=cmap)
|
|
ax.figure.colorbar(im, ax=ax)
|
|
ax.set(
|
|
xticks=np.arange(cm.shape[1]),
|
|
yticks=np.arange(cm.shape[0]),
|
|
xticklabels=classes,
|
|
yticklabels=classes,
|
|
title=title,
|
|
ylabel="True label",
|
|
xlabel="Predicted label",
|
|
)
|
|
ax.set_xticklabels(classes, rotation=rotate_x_text)
|
|
ax.figure.set_size_inches(figsize)
|
|
|
|
# Loop over data dimensions and create text annotations.
|
|
fmt = ".2f" if normalize else "d"
|
|
thresh = cm.max() / 2.0
|
|
for i in range(cm.shape[0]):
|
|
for j in range(cm.shape[1]):
|
|
if text:
|
|
ax.text(
|
|
j,
|
|
i,
|
|
format(cm[i, j], fmt),
|
|
ha="center",
|
|
va="center",
|
|
color="white" if cm[i, j] > thresh else "black",
|
|
)
|
|
if len(classes) == 2:
|
|
plt.axis([-0.5, 1.5, 1.5, -0.5])
|
|
fig.tight_layout()
|
|
|
|
return ax
|