formatting

This commit is contained in:
Liyu Xiao 2025-08-21 10:33:50 -04:00
parent 6d531ae5f3
commit 1c7ddef5cb

View File

@ -17,12 +17,10 @@ def load_validation_data():
"data/dataset/val.h5", label_name="modulation", data_key="validation_data" "data/dataset/val.h5", label_name="modulation", data_key="validation_data"
) )
x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L) x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,) y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
class_names = list(val_dataset.label_encoder.classes_) class_names = list(val_dataset.label_encoder.classes_)
return x, y, class_names return x, y, class_names
@ -46,27 +44,22 @@ def build_model_from_ckpt(
return model return model
def evaluate_checkpoint(ckpt_path: str): def evaluate_checkpoint(ckpt_path: str):
""" """
Loads the model from checkpoint and evaluates it on a validation set. Loads the model from checkpoint and evaluates it on a validation set.
Prints classification metrics and plots a confusion matrix. Prints classification metrics and plots a confusion matrix.
""" """
# Load validation data # Load validation data
X_val, y_true, class_names = load_validation_data() X_val, y_true, class_names = load_validation_data()
num_classes = len(class_names) num_classes = len(class_names)
in_channels = X_val.shape[1] in_channels = X_val.shape[1]
# Load model # Load model
model = build_model_from_ckpt( model = build_model_from_ckpt(
ckpt_path, in_channels=in_channels, num_classes=num_classes ckpt_path, in_channels=in_channels, num_classes=num_classes
) )
# Inference # Inference
y_pred = [] y_pred = []
with torch.no_grad(): with torch.no_grad():
@ -76,15 +69,12 @@ def evaluate_checkpoint(ckpt_path: str):
pred = torch.argmax(logits, dim=1).item() pred = torch.argmax(logits, dim=1).item()
y_pred.append(pred) y_pred.append(pred)
# Print classification report # Print classification report
print("\nClassification Report:") print("\nClassification Report:")
print( print(
classification_report(y_true, y_pred, target_names=class_names, zero_division=0) classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
) )
print_confusion_matrix( print_confusion_matrix(
y_true=np.array(y_true), y_true=np.array(y_true),
y_pred=np.array(y_pred), y_pred=np.array(y_pred),
@ -94,8 +84,6 @@ def evaluate_checkpoint(ckpt_path: str):
) )
def print_confusion_matrix( def print_confusion_matrix(
y_true: np.ndarray, y_true: np.ndarray,
y_pred: np.ndarray, y_pred: np.ndarray,
@ -120,7 +108,6 @@ def print_confusion_matrix(
for t, p in zip(y_true, y_pred): for t, p in zip(y_true, y_pred):
cm[t, p] += 1 cm[t, p] += 1
# 2) normalize if requested # 2) normalize if requested
if normalize: if normalize:
with np.errstate(divide="ignore", invalid="ignore"): with np.errstate(divide="ignore", invalid="ignore"):
@ -131,10 +118,9 @@ def print_confusion_matrix(
print_confusion_matrix_helper(cm, classes) print_confusion_matrix_helper(cm, classes)
import numpy as np import numpy as np
def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2): def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2):
""" """
Pretty prints a confusion matrix with x/y labels. Pretty prints a confusion matrix with x/y labels.
@ -169,6 +155,6 @@ def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=
if __name__ == "__main__": if __name__ == "__main__":
settings = get_app_settings() settings = get_app_settings()
evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt")) evaluate_checkpoint(
os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
)