modrec-workflow/scripts/model_builder/plot_data.py

145 lines
4.2 KiB
Python
Raw Normal View History

import os
import numpy as np
import torch
from sklearn.metrics import classification_report
os.environ["NNPACK"] = "0"
from matplotlib import pyplot as plt
from mobilenetv3 import RFClassifier, mobilenetv3
from modulation_dataset import ModulationH5Dataset
from helpers.app_settings import get_app_settings
2025-06-18 13:44:29 -04:00
def load_validation_data():
val_dataset = ModulationH5Dataset(
"data/dataset/val.h5", label_name="modulation", data_key="validation_data"
2025-06-18 13:44:29 -04:00
)
X = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
2025-06-18 13:44:29 -04:00
class_names = list(val_dataset.label_encoder.classes_)
return X, y, class_names
2025-06-18 13:44:29 -04:00
def build_model_from_ckpt(
ckpt_path: str, in_channels: int, num_classes: int
) -> torch.nn.Module:
"""
Build and return a PyTorch model loaded from a checkpoint.
"""
model = RFClassifier(
model=mobilenetv3(
model_size="mobilenetv3_small_050",
num_classes=num_classes,
2025-06-18 13:44:29 -04:00
in_chans=in_channels,
)
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2025-06-18 13:44:29 -04:00
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
return model
def evaluate_checkpoint(ckpt_path: str):
"""
Loads the model from checkpoint and evaluates it on a validation set.
Prints classification metrics and plots a confusion matrix.
"""
2025-06-18 13:44:29 -04:00
# Load validation data
X_val, y_true, class_names = load_validation_data()
num_classes = len(class_names)
in_channels = X_val.shape[1]
# Load model
2025-06-18 13:44:29 -04:00
model = build_model_from_ckpt(
ckpt_path, in_channels=in_channels, num_classes=num_classes
)
# Inference
y_pred = []
with torch.no_grad():
for x in X_val:
x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32)
logits = model(x_tensor)
pred = torch.argmax(logits, dim=1).item()
y_pred.append(pred)
# Print classification report
print("\nClassification Report:")
print(
classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
)
plot_confusion_matrix_with_counts(
y_true=np.array(y_true),
y_pred=np.array(y_pred),
classes=class_names,
normalize=True,
2025-06-18 13:44:29 -04:00
title="Normalized Confusion Matrix",
)
def plot_confusion_matrix_with_counts(
y_true: np.ndarray,
y_pred: np.ndarray,
classes: list[str],
normalize: bool = True,
title: str = "Confusion Matrix (counts and normalized)",
) -> None:
"""
Plot a confusion matrix showing both raw counts and (optionally) normalized values.
Args:
y_true: true labels (integers 0..C-1)
y_pred: predicted labels (same shape as y_true)
classes: list of classname strings in index order
normalize: if True, each row is normalized to sum=1
title: title for the plot
"""
# 1) build raw CM
C = len(classes)
cm = np.zeros((C, C), dtype=int)
for t, p in zip(y_true, y_pred):
cm[t, p] += 1
# 2) normalize if requested
if normalize:
with np.errstate(divide="ignore", invalid="ignore"):
cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None]
cm_norm = np.nan_to_num(cm_norm)
else:
cm_norm = cm
# 3) plot
fig, ax = plt.subplots(figsize=(8, 8))
im = ax.imshow(cm_norm, interpolation="nearest")
ax.set_title(title)
ax.set_xlabel("Predicted label")
ax.set_ylabel("True label")
ax.set_xticks(np.arange(C))
ax.set_yticks(np.arange(C))
ax.set_xticklabels(classes, rotation=45, ha="right")
ax.set_yticklabels(classes)
# 4) annotate
for i in range(C):
for j in range(C):
count = cm[i, j]
val = cm_norm[i, j]
txt = f"{count}\n{val:.2f}"
ax.text(j, i, txt, ha="center", va="center")
fig.colorbar(im, ax=ax, label="Normalized value" if normalize else "Count")
plt.tight_layout()
plt.show()
if __name__ == "__main__":
settings = get_app_settings()
evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt"))