import os import numpy as np import torch from matplotlib import pyplot as plt from mobilenetv3 import RFClassifier, mobilenetv3 from modulation_dataset import ModulationH5Dataset from sklearn.metrics import classification_report from helpers.app_settings import get_app_settings os.environ["NNPACK"] = "0" def load_validation_data(): val_dataset = ModulationH5Dataset( "data/dataset/val.h5", label_name="modulation", data_key="validation_data" ) x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L) y = np.array([y.item() for _, y in val_dataset]) # shape: (N,) class_names = list(val_dataset.label_encoder.classes_) return x, y, class_names def build_model_from_ckpt( ckpt_path: str, in_channels: int, num_classes: int ) -> torch.nn.Module: """ Build and return a PyTorch model loaded from a checkpoint. """ model = RFClassifier( model=mobilenetv3( model_size="mobilenetv3_small_050", num_classes=num_classes, in_chans=in_channels, ) ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device) model.load_state_dict(checkpoint["state_dict"]) model.eval() return model def evaluate_checkpoint(ckpt_path: str): """ Loads the model from checkpoint and evaluates it on a validation set. Prints classification metrics and plots a confusion matrix. """ # Load validation data X_val, y_true, class_names = load_validation_data() num_classes = len(class_names) in_channels = X_val.shape[1] # Load model model = build_model_from_ckpt( ckpt_path, in_channels=in_channels, num_classes=num_classes ) # Inference y_pred = [] with torch.no_grad(): for x in X_val: x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32) logits = model(x_tensor) pred = torch.argmax(logits, dim=1).item() y_pred.append(pred) # Print classification report print("\nClassification Report:") print( classification_report(y_true, y_pred, target_names=class_names, zero_division=0) ) print_confusion_matrix( y_true=np.array(y_true), y_pred=np.array(y_pred), classes=class_names, normalize=True, title="Normalized Confusion Matrix", ) def print_confusion_matrix( y_true: np.ndarray, y_pred: np.ndarray, classes: list[str], normalize: bool = True, title: str = "Confusion Matrix (counts and normalized)", ) -> None: """ Plot a confusion matrix showing both raw counts and (optionally) normalized values. Args: y_true: true labels (integers 0..C-1) y_pred: predicted labels (same shape as y_true) classes: list of class‐name strings in index order normalize: if True, each row is normalized to sum=1 title: title for the plot """ # 1) build raw CM c = len(classes) cm = np.zeros((c, c), dtype=int) for t, p in zip(y_true, y_pred): cm[t, p] += 1 # 2) normalize if requested if normalize: with np.errstate(divide="ignore", invalid="ignore"): cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None] cm_norm = np.nan_to_num(cm_norm) print_confusion_matrix_helper(cm_norm, classes) else: print_confusion_matrix_helper(cm, classes) import numpy as np def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2): """ Pretty prints a confusion matrix with x/y labels. Parameters: - matrix: square 2D numpy array - labels: list of class labels (default: range(num_classes)) - normalize: whether to normalize rows to sum to 1 - digits: number of decimal places to show for normalized values """ matrix = np.array(matrix) num_classes = matrix.shape[0] labels = classes or list(range(num_classes)) # Header print(" " * 9 + "Ground Truth →") header = "Pred ↓ | " + " ".join([f"{str(label):>6}" for label in labels]) print(header) print("-" * len(header)) # Rows for i in range(num_classes): row_vals = matrix[i] if normalize: row_sum = row_vals.sum() row_vals = row_vals / row_sum if row_sum != 0 else row_vals row_str = " ".join([f"{val:>6.{digits}f}" for val in row_vals]) else: row_str = " ".join([f"{int(val):>6}" for val in row_vals]) print(f"{str(labels[i]):>7} | {row_str}") if __name__ == "__main__": settings = get_app_settings() evaluate_checkpoint( os.path.join("checkpoint_files", "inference_recognition_model.ckpt") )