added in confusion matrix for the output
All checks were successful
Modulation Recognition Demo / ria-demo (push) Successful in 2m46s
All checks were successful
Modulation Recognition Demo / ria-demo (push) Successful in 2m46s
This commit is contained in:
parent
9979d84e29
commit
6d531ae5f3
|
@ -17,11 +17,13 @@ def load_validation_data():
|
||||||
"data/dataset/val.h5", label_name="modulation", data_key="validation_data"
|
"data/dataset/val.h5", label_name="modulation", data_key="validation_data"
|
||||||
)
|
)
|
||||||
|
|
||||||
X = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
|
|
||||||
|
x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
|
||||||
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
|
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
|
||||||
class_names = list(val_dataset.label_encoder.classes_)
|
class_names = list(val_dataset.label_encoder.classes_)
|
||||||
|
|
||||||
return X, y, class_names
|
|
||||||
|
return x, y, class_names
|
||||||
|
|
||||||
|
|
||||||
def build_model_from_ckpt(
|
def build_model_from_ckpt(
|
||||||
|
@ -44,22 +46,27 @@ def build_model_from_ckpt(
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_checkpoint(ckpt_path: str):
|
def evaluate_checkpoint(ckpt_path: str):
|
||||||
"""
|
"""
|
||||||
Loads the model from checkpoint and evaluates it on a validation set.
|
Loads the model from checkpoint and evaluates it on a validation set.
|
||||||
Prints classification metrics and plots a confusion matrix.
|
Prints classification metrics and plots a confusion matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Load validation data
|
# Load validation data
|
||||||
X_val, y_true, class_names = load_validation_data()
|
X_val, y_true, class_names = load_validation_data()
|
||||||
num_classes = len(class_names)
|
num_classes = len(class_names)
|
||||||
in_channels = X_val.shape[1]
|
in_channels = X_val.shape[1]
|
||||||
|
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model = build_model_from_ckpt(
|
model = build_model_from_ckpt(
|
||||||
ckpt_path, in_channels=in_channels, num_classes=num_classes
|
ckpt_path, in_channels=in_channels, num_classes=num_classes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
y_pred = []
|
y_pred = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -69,13 +76,16 @@ def evaluate_checkpoint(ckpt_path: str):
|
||||||
pred = torch.argmax(logits, dim=1).item()
|
pred = torch.argmax(logits, dim=1).item()
|
||||||
y_pred.append(pred)
|
y_pred.append(pred)
|
||||||
|
|
||||||
|
|
||||||
# Print classification report
|
# Print classification report
|
||||||
print("\nClassification Report:")
|
print("\nClassification Report:")
|
||||||
print(
|
print(
|
||||||
classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
|
classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
plot_confusion_matrix_with_counts(
|
|
||||||
|
|
||||||
|
print_confusion_matrix(
|
||||||
y_true=np.array(y_true),
|
y_true=np.array(y_true),
|
||||||
y_pred=np.array(y_pred),
|
y_pred=np.array(y_pred),
|
||||||
classes=class_names,
|
classes=class_names,
|
||||||
|
@ -84,7 +94,9 @@ def evaluate_checkpoint(ckpt_path: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def plot_confusion_matrix_with_counts(
|
|
||||||
|
|
||||||
|
def print_confusion_matrix(
|
||||||
y_true: np.ndarray,
|
y_true: np.ndarray,
|
||||||
y_pred: np.ndarray,
|
y_pred: np.ndarray,
|
||||||
classes: list[str],
|
classes: list[str],
|
||||||
|
@ -94,6 +106,7 @@ def plot_confusion_matrix_with_counts(
|
||||||
"""
|
"""
|
||||||
Plot a confusion matrix showing both raw counts and (optionally) normalized values.
|
Plot a confusion matrix showing both raw counts and (optionally) normalized values.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y_true: true labels (integers 0..C-1)
|
y_true: true labels (integers 0..C-1)
|
||||||
y_pred: predicted labels (same shape as y_true)
|
y_pred: predicted labels (same shape as y_true)
|
||||||
|
@ -102,43 +115,60 @@ def plot_confusion_matrix_with_counts(
|
||||||
title: title for the plot
|
title: title for the plot
|
||||||
"""
|
"""
|
||||||
# 1) build raw CM
|
# 1) build raw CM
|
||||||
C = len(classes)
|
c = len(classes)
|
||||||
cm = np.zeros((C, C), dtype=int)
|
cm = np.zeros((c, c), dtype=int)
|
||||||
for t, p in zip(y_true, y_pred):
|
for t, p in zip(y_true, y_pred):
|
||||||
cm[t, p] += 1
|
cm[t, p] += 1
|
||||||
|
|
||||||
|
|
||||||
# 2) normalize if requested
|
# 2) normalize if requested
|
||||||
if normalize:
|
if normalize:
|
||||||
with np.errstate(divide="ignore", invalid="ignore"):
|
with np.errstate(divide="ignore", invalid="ignore"):
|
||||||
cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None]
|
cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None]
|
||||||
cm_norm = np.nan_to_num(cm_norm)
|
cm_norm = np.nan_to_num(cm_norm)
|
||||||
|
print_confusion_matrix_helper(cm_norm, classes)
|
||||||
else:
|
else:
|
||||||
cm_norm = cm
|
print_confusion_matrix_helper(cm, classes)
|
||||||
|
|
||||||
# 3) plot
|
|
||||||
fig, ax = plt.subplots(figsize=(8, 8))
|
|
||||||
im = ax.imshow(cm_norm, interpolation="nearest")
|
|
||||||
ax.set_title(title)
|
|
||||||
ax.set_xlabel("Predicted label")
|
|
||||||
ax.set_ylabel("True label")
|
|
||||||
ax.set_xticks(np.arange(C))
|
|
||||||
ax.set_yticks(np.arange(C))
|
|
||||||
ax.set_xticklabels(classes, rotation=45, ha="right")
|
|
||||||
ax.set_yticklabels(classes)
|
|
||||||
|
|
||||||
# 4) annotate
|
|
||||||
for i in range(C):
|
|
||||||
for j in range(C):
|
|
||||||
count = cm[i, j]
|
|
||||||
val = cm_norm[i, j]
|
|
||||||
txt = f"{count}\n{val:.2f}"
|
|
||||||
ax.text(j, i, txt, ha="center", va="center")
|
|
||||||
|
|
||||||
fig.colorbar(im, ax=ax, label="Normalized value" if normalize else "Count")
|
|
||||||
plt.tight_layout()
|
import numpy as np
|
||||||
plt.show()
|
|
||||||
|
def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2):
|
||||||
|
"""
|
||||||
|
Pretty prints a confusion matrix with x/y labels.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- matrix: square 2D numpy array
|
||||||
|
- labels: list of class labels (default: range(num_classes))
|
||||||
|
- normalize: whether to normalize rows to sum to 1
|
||||||
|
- digits: number of decimal places to show for normalized values
|
||||||
|
"""
|
||||||
|
matrix = np.array(matrix)
|
||||||
|
num_classes = matrix.shape[0]
|
||||||
|
labels = classes or list(range(num_classes))
|
||||||
|
|
||||||
|
# Header
|
||||||
|
print(" " * 9 + "Ground Truth →")
|
||||||
|
header = "Pred ↓ | " + " ".join([f"{str(label):>6}" for label in labels])
|
||||||
|
print(header)
|
||||||
|
print("-" * len(header))
|
||||||
|
|
||||||
|
# Rows
|
||||||
|
for i in range(num_classes):
|
||||||
|
row_vals = matrix[i]
|
||||||
|
if normalize:
|
||||||
|
row_sum = row_vals.sum()
|
||||||
|
row_vals = row_vals / row_sum if row_sum != 0 else row_vals
|
||||||
|
row_str = " ".join([f"{val:>6.{digits}f}" for val in row_vals])
|
||||||
|
else:
|
||||||
|
row_str = " ".join([f"{int(val):>6}" for val in row_vals])
|
||||||
|
print(f"{str(labels[i]):>7} | {row_str}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
settings = get_app_settings()
|
settings = get_app_settings()
|
||||||
evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt"))
|
evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt"))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user