161 lines
4.7 KiB
Python
161 lines
4.7 KiB
Python
import os
|
||
|
||
import numpy as np
|
||
import torch
|
||
from matplotlib import pyplot as plt
|
||
from mobilenetv3 import RFClassifier, mobilenetv3
|
||
from modulation_dataset import ModulationH5Dataset
|
||
from sklearn.metrics import classification_report
|
||
|
||
from helpers.app_settings import get_app_settings
|
||
|
||
os.environ["NNPACK"] = "0"
|
||
|
||
|
||
def load_validation_data():
|
||
val_dataset = ModulationH5Dataset(
|
||
"data/dataset/val.h5", label_name="modulation", data_key="validation_data"
|
||
)
|
||
|
||
x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
|
||
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
|
||
class_names = list(val_dataset.label_encoder.classes_)
|
||
|
||
return x, y, class_names
|
||
|
||
|
||
def build_model_from_ckpt(
|
||
ckpt_path: str, in_channels: int, num_classes: int
|
||
) -> torch.nn.Module:
|
||
"""
|
||
Build and return a PyTorch model loaded from a checkpoint.
|
||
"""
|
||
model = RFClassifier(
|
||
model=mobilenetv3(
|
||
model_size="mobilenetv3_small_050",
|
||
num_classes=num_classes,
|
||
in_chans=in_channels,
|
||
)
|
||
)
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
||
model.load_state_dict(checkpoint["state_dict"])
|
||
model.eval()
|
||
return model
|
||
|
||
|
||
def evaluate_checkpoint(ckpt_path: str):
|
||
"""
|
||
Loads the model from checkpoint and evaluates it on a validation set.
|
||
Prints classification metrics and plots a confusion matrix.
|
||
"""
|
||
|
||
# Load validation data
|
||
X_val, y_true, class_names = load_validation_data()
|
||
num_classes = len(class_names)
|
||
in_channels = X_val.shape[1]
|
||
|
||
# Load model
|
||
model = build_model_from_ckpt(
|
||
ckpt_path, in_channels=in_channels, num_classes=num_classes
|
||
)
|
||
|
||
# Inference
|
||
y_pred = []
|
||
with torch.no_grad():
|
||
for x in X_val:
|
||
x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32)
|
||
logits = model(x_tensor)
|
||
pred = torch.argmax(logits, dim=1).item()
|
||
y_pred.append(pred)
|
||
|
||
# Print classification report
|
||
print("\nClassification Report:")
|
||
print(
|
||
classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
|
||
)
|
||
|
||
print_confusion_matrix(
|
||
y_true=np.array(y_true),
|
||
y_pred=np.array(y_pred),
|
||
classes=class_names,
|
||
normalize=True,
|
||
title="Normalized Confusion Matrix",
|
||
)
|
||
|
||
|
||
def print_confusion_matrix(
|
||
y_true: np.ndarray,
|
||
y_pred: np.ndarray,
|
||
classes: list[str],
|
||
normalize: bool = True,
|
||
title: str = "Confusion Matrix (counts and normalized)",
|
||
) -> None:
|
||
"""
|
||
Plot a confusion matrix showing both raw counts and (optionally) normalized values.
|
||
|
||
|
||
Args:
|
||
y_true: true labels (integers 0..C-1)
|
||
y_pred: predicted labels (same shape as y_true)
|
||
classes: list of class‐name strings in index order
|
||
normalize: if True, each row is normalized to sum=1
|
||
title: title for the plot
|
||
"""
|
||
# 1) build raw CM
|
||
c = len(classes)
|
||
cm = np.zeros((c, c), dtype=int)
|
||
for t, p in zip(y_true, y_pred):
|
||
cm[t, p] += 1
|
||
|
||
# 2) normalize if requested
|
||
if normalize:
|
||
with np.errstate(divide="ignore", invalid="ignore"):
|
||
cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None]
|
||
cm_norm = np.nan_to_num(cm_norm)
|
||
print_confusion_matrix_helper(cm_norm, classes)
|
||
else:
|
||
print_confusion_matrix_helper(cm, classes)
|
||
|
||
|
||
import numpy as np
|
||
|
||
|
||
def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2):
|
||
"""
|
||
Pretty prints a confusion matrix with x/y labels.
|
||
|
||
Parameters:
|
||
- matrix: square 2D numpy array
|
||
- labels: list of class labels (default: range(num_classes))
|
||
- normalize: whether to normalize rows to sum to 1
|
||
- digits: number of decimal places to show for normalized values
|
||
"""
|
||
matrix = np.array(matrix)
|
||
num_classes = matrix.shape[0]
|
||
labels = classes or list(range(num_classes))
|
||
|
||
# Header
|
||
print(" " * 9 + "Ground Truth →")
|
||
header = "Pred ↓ | " + " ".join([f"{str(label):>6}" for label in labels])
|
||
print(header)
|
||
print("-" * len(header))
|
||
|
||
# Rows
|
||
for i in range(num_classes):
|
||
row_vals = matrix[i]
|
||
if normalize:
|
||
row_sum = row_vals.sum()
|
||
row_vals = row_vals / row_sum if row_sum != 0 else row_vals
|
||
row_str = " ".join([f"{val:>6.{digits}f}" for val in row_vals])
|
||
else:
|
||
row_str = " ".join([f"{int(val):>6}" for val in row_vals])
|
||
print(f"{str(labels[i]):>7} | {row_str}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
settings = get_app_settings()
|
||
evaluate_checkpoint(
|
||
os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
|
||
)
|