forked from qoherent/modrec-workflow
93 lines
2.8 KiB
Python
93 lines
2.8 KiB
Python
import os
|
|
import torch
|
|
import numpy as np
|
|
import h5py
|
|
from sklearn.metrics import classification_report
|
|
from matplotlib import pyplot as plt
|
|
|
|
from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier
|
|
from helpers.app_settings import get_app_settings
|
|
from cm_plotter import plot_confusion_matrix
|
|
|
|
|
|
def load_validation_data(h5_path:str ="data/datasets/val.h5"):
|
|
"""
|
|
Loads validation data from an HDF5 file.
|
|
|
|
Returns:
|
|
X_val: np.ndarray of shape (N, C, L)
|
|
y_val: np.ndarray of shape (N,)
|
|
class_names: list of class names
|
|
"""
|
|
with h5py.File(h5_path, "r") as f:
|
|
X = f["X"][:] # shape: (N, C, L)
|
|
y = f["y"][:] # shape: (N,)
|
|
if "class_names" in f:
|
|
class_names = [s.decode("utf-8") for s in f["class_names"][:]]
|
|
else:
|
|
class_names = [str(i) for i in np.unique(y)]
|
|
return X, y, class_names
|
|
|
|
|
|
def build_model_from_ckpt(ckpt_path: str, in_channels: int, num_classes: int) -> torch.nn.Module:
|
|
"""
|
|
Build and return a PyTorch model loaded from a checkpoint.
|
|
"""
|
|
model = RFClassifier(
|
|
model=mobilenetv3(
|
|
model_size="mobilenetv3_small_050",
|
|
num_classes=num_classes,
|
|
in_chans=in_channels
|
|
)
|
|
)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
checkpoint = torch.load(
|
|
ckpt_path, weights_only=True, map_location=device
|
|
)
|
|
model.load_state_dict(checkpoint["state_dict"])
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def evaluate_checkpoint(ckpt_path: str):
|
|
"""
|
|
Loads the model from checkpoint and evaluates it on a validation set.
|
|
Prints classification metrics and plots a confusion matrix.
|
|
"""
|
|
# Load validation data
|
|
X_val, y_true, class_names = load_validation_data()
|
|
num_classes = len(class_names)
|
|
in_channels = X_val.shape[1]
|
|
|
|
# Load model
|
|
model = build_model_from_ckpt(ckpt_path, in_channels=in_channels, num_classes=num_classes)
|
|
|
|
# Inference
|
|
y_pred = []
|
|
with torch.no_grad():
|
|
for x in X_val:
|
|
x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32)
|
|
logits = model(x_tensor)
|
|
pred = torch.argmax(logits, dim=1).item()
|
|
y_pred.append(pred)
|
|
|
|
# Print classification report
|
|
print("\nClassification Report:")
|
|
print(classification_report(y_true, y_pred, target_names=class_names))
|
|
|
|
# Plot confusion matrix
|
|
plot_confusion_matrix(
|
|
y_true=np.array(y_true),
|
|
y_pred=np.array(y_pred),
|
|
classes=class_names,
|
|
normalize=True,
|
|
title="Normalized Confusion Matrix"
|
|
)
|
|
plt.show()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
settings = get_app_settings()
|
|
ckpt_path = os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
|
|
evaluate_checkpoint(ckpt_path)
|