formatting
This commit is contained in:
parent
6d531ae5f3
commit
1c7ddef5cb
|
@ -17,12 +17,10 @@ def load_validation_data():
|
|||
"data/dataset/val.h5", label_name="modulation", data_key="validation_data"
|
||||
)
|
||||
|
||||
|
||||
x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
|
||||
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
|
||||
class_names = list(val_dataset.label_encoder.classes_)
|
||||
|
||||
|
||||
return x, y, class_names
|
||||
|
||||
|
||||
|
@ -46,27 +44,22 @@ def build_model_from_ckpt(
|
|||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
def evaluate_checkpoint(ckpt_path: str):
|
||||
"""
|
||||
Loads the model from checkpoint and evaluates it on a validation set.
|
||||
Prints classification metrics and plots a confusion matrix.
|
||||
"""
|
||||
|
||||
|
||||
# Load validation data
|
||||
X_val, y_true, class_names = load_validation_data()
|
||||
num_classes = len(class_names)
|
||||
in_channels = X_val.shape[1]
|
||||
|
||||
|
||||
# Load model
|
||||
model = build_model_from_ckpt(
|
||||
ckpt_path, in_channels=in_channels, num_classes=num_classes
|
||||
)
|
||||
|
||||
|
||||
# Inference
|
||||
y_pred = []
|
||||
with torch.no_grad():
|
||||
|
@ -76,15 +69,12 @@ def evaluate_checkpoint(ckpt_path: str):
|
|||
pred = torch.argmax(logits, dim=1).item()
|
||||
y_pred.append(pred)
|
||||
|
||||
|
||||
# Print classification report
|
||||
print("\nClassification Report:")
|
||||
print(
|
||||
classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
|
||||
)
|
||||
|
||||
|
||||
|
||||
print_confusion_matrix(
|
||||
y_true=np.array(y_true),
|
||||
y_pred=np.array(y_pred),
|
||||
|
@ -94,8 +84,6 @@ def evaluate_checkpoint(ckpt_path: str):
|
|||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def print_confusion_matrix(
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
|
@ -120,7 +108,6 @@ def print_confusion_matrix(
|
|||
for t, p in zip(y_true, y_pred):
|
||||
cm[t, p] += 1
|
||||
|
||||
|
||||
# 2) normalize if requested
|
||||
if normalize:
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
|
@ -131,10 +118,9 @@ def print_confusion_matrix(
|
|||
print_confusion_matrix_helper(cm, classes)
|
||||
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2):
|
||||
"""
|
||||
Pretty prints a confusion matrix with x/y labels.
|
||||
|
@ -169,6 +155,6 @@ def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=
|
|||
|
||||
if __name__ == "__main__":
|
||||
settings = get_app_settings()
|
||||
evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt"))
|
||||
|
||||
|
||||
evaluate_checkpoint(
|
||||
os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user