diff --git a/scripts/model_builder/plot_data.py b/scripts/model_builder/plot_data.py index 8acc1ff..d5bb542 100644 --- a/scripts/model_builder/plot_data.py +++ b/scripts/model_builder/plot_data.py @@ -13,132 +13,162 @@ from helpers.app_settings import get_app_settings def load_validation_data(): - val_dataset = ModulationH5Dataset( - "data/dataset/val.h5", label_name="modulation", data_key="validation_data" - ) + val_dataset = ModulationH5Dataset( + "data/dataset/val.h5", label_name="modulation", data_key="validation_data" + ) - X = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L) - y = np.array([y.item() for _, y in val_dataset]) # shape: (N,) - class_names = list(val_dataset.label_encoder.classes_) - return X, y, class_names + x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L) + y = np.array([y.item() for _, y in val_dataset]) # shape: (N,) + class_names = list(val_dataset.label_encoder.classes_) + + + return x, y, class_names def build_model_from_ckpt( - ckpt_path: str, in_channels: int, num_classes: int + ckpt_path: str, in_channels: int, num_classes: int ) -> torch.nn.Module: - """ - Build and return a PyTorch model loaded from a checkpoint. - """ - model = RFClassifier( - model=mobilenetv3( - model_size="mobilenetv3_small_050", - num_classes=num_classes, - in_chans=in_channels, - ) - ) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device) - model.load_state_dict(checkpoint["state_dict"]) - model.eval() - return model + """ + Build and return a PyTorch model loaded from a checkpoint. + """ + model = RFClassifier( + model=mobilenetv3( + model_size="mobilenetv3_small_050", + num_classes=num_classes, + in_chans=in_channels, + ) + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device) + model.load_state_dict(checkpoint["state_dict"]) + model.eval() + return model + + def evaluate_checkpoint(ckpt_path: str): - """ - Loads the model from checkpoint and evaluates it on a validation set. - Prints classification metrics and plots a confusion matrix. - """ - - # Load validation data - X_val, y_true, class_names = load_validation_data() - num_classes = len(class_names) - in_channels = X_val.shape[1] - - # Load model - model = build_model_from_ckpt( - ckpt_path, in_channels=in_channels, num_classes=num_classes - ) - - # Inference - y_pred = [] - with torch.no_grad(): - for x in X_val: - x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32) - logits = model(x_tensor) - pred = torch.argmax(logits, dim=1).item() - y_pred.append(pred) - - # Print classification report - print("\nClassification Report:") - print( - classification_report(y_true, y_pred, target_names=class_names, zero_division=0) - ) - - plot_confusion_matrix_with_counts( - y_true=np.array(y_true), - y_pred=np.array(y_pred), - classes=class_names, - normalize=True, - title="Normalized Confusion Matrix", - ) + """ + Loads the model from checkpoint and evaluates it on a validation set. + Prints classification metrics and plots a confusion matrix. + """ -def plot_confusion_matrix_with_counts( - y_true: np.ndarray, - y_pred: np.ndarray, - classes: list[str], - normalize: bool = True, - title: str = "Confusion Matrix (counts and normalized)", + # Load validation data + X_val, y_true, class_names = load_validation_data() + num_classes = len(class_names) + in_channels = X_val.shape[1] + + + # Load model + model = build_model_from_ckpt( + ckpt_path, in_channels=in_channels, num_classes=num_classes + ) + + + # Inference + y_pred = [] + with torch.no_grad(): + for x in X_val: + x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32) + logits = model(x_tensor) + pred = torch.argmax(logits, dim=1).item() + y_pred.append(pred) + + + # Print classification report + print("\nClassification Report:") + print( + classification_report(y_true, y_pred, target_names=class_names, zero_division=0) + ) + + + + print_confusion_matrix( + y_true=np.array(y_true), + y_pred=np.array(y_pred), + classes=class_names, + normalize=True, + title="Normalized Confusion Matrix", + ) + + + + +def print_confusion_matrix( + y_true: np.ndarray, + y_pred: np.ndarray, + classes: list[str], + normalize: bool = True, + title: str = "Confusion Matrix (counts and normalized)", ) -> None: + """ + Plot a confusion matrix showing both raw counts and (optionally) normalized values. + + + Args: + y_true: true labels (integers 0..C-1) + y_pred: predicted labels (same shape as y_true) + classes: list of class‐name strings in index order + normalize: if True, each row is normalized to sum=1 + title: title for the plot + """ + # 1) build raw CM + c = len(classes) + cm = np.zeros((c, c), dtype=int) + for t, p in zip(y_true, y_pred): + cm[t, p] += 1 + + + # 2) normalize if requested + if normalize: + with np.errstate(divide="ignore", invalid="ignore"): + cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None] + cm_norm = np.nan_to_num(cm_norm) + print_confusion_matrix_helper(cm_norm, classes) + else: + print_confusion_matrix_helper(cm, classes) + + + + +import numpy as np + +def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2): """ - Plot a confusion matrix showing both raw counts and (optionally) normalized values. + Pretty prints a confusion matrix with x/y labels. - Args: - y_true: true labels (integers 0..C-1) - y_pred: predicted labels (same shape as y_true) - classes: list of class‐name strings in index order - normalize: if True, each row is normalized to sum=1 - title: title for the plot + Parameters: + - matrix: square 2D numpy array + - labels: list of class labels (default: range(num_classes)) + - normalize: whether to normalize rows to sum to 1 + - digits: number of decimal places to show for normalized values """ - # 1) build raw CM - C = len(classes) - cm = np.zeros((C, C), dtype=int) - for t, p in zip(y_true, y_pred): - cm[t, p] += 1 - - # 2) normalize if requested - if normalize: - with np.errstate(divide="ignore", invalid="ignore"): - cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None] - cm_norm = np.nan_to_num(cm_norm) - else: - cm_norm = cm - - # 3) plot - fig, ax = plt.subplots(figsize=(8, 8)) - im = ax.imshow(cm_norm, interpolation="nearest") - ax.set_title(title) - ax.set_xlabel("Predicted label") - ax.set_ylabel("True label") - ax.set_xticks(np.arange(C)) - ax.set_yticks(np.arange(C)) - ax.set_xticklabels(classes, rotation=45, ha="right") - ax.set_yticklabels(classes) - - # 4) annotate - for i in range(C): - for j in range(C): - count = cm[i, j] - val = cm_norm[i, j] - txt = f"{count}\n{val:.2f}" - ax.text(j, i, txt, ha="center", va="center") - - fig.colorbar(im, ax=ax, label="Normalized value" if normalize else "Count") - plt.tight_layout() - plt.show() - + matrix = np.array(matrix) + num_classes = matrix.shape[0] + labels = classes or list(range(num_classes)) + + # Header + print(" " * 9 + "Ground Truth →") + header = "Pred ↓ | " + " ".join([f"{str(label):>6}" for label in labels]) + print(header) + print("-" * len(header)) + + # Rows + for i in range(num_classes): + row_vals = matrix[i] + if normalize: + row_sum = row_vals.sum() + row_vals = row_vals / row_sum if row_sum != 0 else row_vals + row_str = " ".join([f"{val:>6.{digits}f}" for val in row_vals]) + else: + row_str = " ".join([f"{int(val):>6}" for val in row_vals]) + print(f"{str(labels[i]):>7} | {row_str}") + if __name__ == "__main__": - settings = get_app_settings() - evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt")) + settings = get_app_settings() + evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt")) + +