import os import torch import numpy as np import h5py from sklearn.metrics import classification_report from matplotlib import pyplot as plt from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier from helpers.app_settings import get_app_settings from cm_plotter import plot_confusion_matrix def load_validation_data(h5_path:str ="data/datasets/validation.h5"): """ Loads validation data from an HDF5 file. Returns: X_val: np.ndarray of shape (N, C, L) y_val: np.ndarray of shape (N,) class_names: list of class names """ with h5py.File(h5_path, "r") as f: X = f["X"][:] # shape: (N, C, L) y = f["y"][:] # shape: (N,) if "class_names" in f: class_names = [s.decode("utf-8") for s in f["class_names"][:]] else: class_names = [str(i) for i in np.unique(y)] return X, y, class_names def build_model_from_ckpt(ckpt_path: str, in_channels: int, num_classes: int) -> torch.nn.Module: """ Build and return a PyTorch model loaded from a checkpoint. """ model = RFClassifier( model=mobilenetv3( model_size="mobilenetv3_small_050", num_classes=num_classes, in_chans=in_channels ) ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load( ckpt_path, weights_only=True, map_location=device ) model.load_state_dict(checkpoint["state_dict"]) model.eval() return model def evaluate_checkpoint(ckpt_path: str): """ Loads the model from checkpoint and evaluates it on a validation set. Prints classification metrics and plots a confusion matrix. """ # Load validation data X_val, y_true, class_names = load_validation_data() num_classes = len(class_names) in_channels = X_val.shape[1] # Load model model = build_model_from_ckpt(ckpt_path, in_channels=in_channels, num_classes=num_classes) # Inference y_pred = [] with torch.no_grad(): for x in X_val: x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32) logits = model(x_tensor) pred = torch.argmax(logits, dim=1).item() y_pred.append(pred) # Print classification report print("\nClassification Report:") print(classification_report(y_true, y_pred, target_names=class_names)) # Plot confusion matrix plot_confusion_matrix( y_true=np.array(y_true), y_pred=np.array(y_pred), classes=class_names, normalize=True, title="Normalized Confusion Matrix" ) plt.show() if __name__ == "__main__": settings = get_app_settings() ckpt_path = os.path.join("checkpoint_files", "inference_recognition_model.ckpt") evaluate_checkpoint(ckpt_path)