modrec-workflow/scripts/model_builder/plot_data.py

161 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
import torch
from matplotlib import pyplot as plt
from mobilenetv3 import RFClassifier, mobilenetv3
from modulation_dataset import ModulationH5Dataset
from sklearn.metrics import classification_report
from helpers.app_settings import get_app_settings
os.environ["NNPACK"] = "0"
def load_validation_data():
val_dataset = ModulationH5Dataset(
"data/dataset/val.h5", label_name="modulation", data_key="validation_data"
)
x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
class_names = list(val_dataset.label_encoder.classes_)
return x, y, class_names
def build_model_from_ckpt(
ckpt_path: str, in_channels: int, num_classes: int
) -> torch.nn.Module:
"""
Build and return a PyTorch model loaded from a checkpoint.
"""
model = RFClassifier(
model=mobilenetv3(
model_size="mobilenetv3_small_050",
num_classes=num_classes,
in_chans=in_channels,
)
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
return model
def evaluate_checkpoint(ckpt_path: str):
"""
Loads the model from checkpoint and evaluates it on a validation set.
Prints classification metrics and plots a confusion matrix.
"""
# Load validation data
X_val, y_true, class_names = load_validation_data()
num_classes = len(class_names)
in_channels = X_val.shape[1]
# Load model
model = build_model_from_ckpt(
ckpt_path, in_channels=in_channels, num_classes=num_classes
)
# Inference
y_pred = []
with torch.no_grad():
for x in X_val:
x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32)
logits = model(x_tensor)
pred = torch.argmax(logits, dim=1).item()
y_pred.append(pred)
# Print classification report
print("\nClassification Report:")
print(
classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
)
print_confusion_matrix(
y_true=np.array(y_true),
y_pred=np.array(y_pred),
classes=class_names,
normalize=True,
title="Normalized Confusion Matrix",
)
def print_confusion_matrix(
y_true: np.ndarray,
y_pred: np.ndarray,
classes: list[str],
normalize: bool = True,
title: str = "Confusion Matrix (counts and normalized)",
) -> None:
"""
Plot a confusion matrix showing both raw counts and (optionally) normalized values.
Args:
y_true: true labels (integers 0..C-1)
y_pred: predicted labels (same shape as y_true)
classes: list of classname strings in index order
normalize: if True, each row is normalized to sum=1
title: title for the plot
"""
# 1) build raw CM
c = len(classes)
cm = np.zeros((c, c), dtype=int)
for t, p in zip(y_true, y_pred):
cm[t, p] += 1
# 2) normalize if requested
if normalize:
with np.errstate(divide="ignore", invalid="ignore"):
cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None]
cm_norm = np.nan_to_num(cm_norm)
print_confusion_matrix_helper(cm_norm, classes)
else:
print_confusion_matrix_helper(cm, classes)
import numpy as np
def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2):
"""
Pretty prints a confusion matrix with x/y labels.
Parameters:
- matrix: square 2D numpy array
- labels: list of class labels (default: range(num_classes))
- normalize: whether to normalize rows to sum to 1
- digits: number of decimal places to show for normalized values
"""
matrix = np.array(matrix)
num_classes = matrix.shape[0]
labels = classes or list(range(num_classes))
# Header
print(" " * 9 + "Ground Truth →")
header = "Pred ↓ | " + " ".join([f"{str(label):>6}" for label in labels])
print(header)
print("-" * len(header))
# Rows
for i in range(num_classes):
row_vals = matrix[i]
if normalize:
row_sum = row_vals.sum()
row_vals = row_vals / row_sum if row_sum != 0 else row_vals
row_str = " ".join([f"{val:>6.{digits}f}" for val in row_vals])
else:
row_str = " ".join([f"{int(val):>6}" for val in row_vals])
print(f"{str(labels[i]):>7} | {row_str}")
if __name__ == "__main__":
settings = get_app_settings()
evaluate_checkpoint(
os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
)