fixed plot.py

This commit is contained in:
Liyu Xiao 2025-06-18 13:44:29 -04:00
parent 8ba21251bf
commit 91bab4acbd
7 changed files with 52 additions and 49 deletions

View File

@ -144,7 +144,9 @@ def main():
print("📦 Generating training and validation datasets...") print("📦 Generating training and validation datasets...")
print(f" ➤ Slicing each recording into {dataset_cfg.num_slices} snippets") print(f" ➤ Slicing each recording into {dataset_cfg.num_slices} snippets")
print(f" ➤ Train/Val split: {int(dataset_cfg.train_split * 100)}% / {int((1 - dataset_cfg.train_split) * 100)}%") print(
f" ➤ Train/Val split: {int(dataset_cfg.train_split * 100)}% / {int((1 - dataset_cfg.train_split) * 100)}%"
)
print(f" ➤ Output directory: data/dataset\n") print(f" ➤ Output directory: data/dataset\n")
train_path, val_path = generate_datasets(dataset_cfg) train_path, val_path = generate_datasets(dataset_cfg)
@ -160,6 +162,5 @@ def main():
print(f" 🔸 Validation samples saved to: {val_path} ({num_val} samples)") print(f" 🔸 Validation samples saved to: {val_path} ({num_val} samples)")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -3,12 +3,15 @@ from collections import defaultdict
from typing import List, Tuple, Dict from typing import List, Tuple, Dict
import numpy as np import numpy as np
def split( def split(
dataset: List[Tuple[np.ndarray, Dict[str, any]]], dataset: List[Tuple[np.ndarray, Dict[str, any]]],
train_frac: float, train_frac: float,
seed: int, seed: int,
label_key: str = "modulation" label_key: str = "modulation",
) -> Tuple[List[Tuple[np.ndarray, Dict[str, any]]], List[Tuple[np.ndarray, Dict[str, any]]]]: ) -> Tuple[
List[Tuple[np.ndarray, Dict[str, any]]], List[Tuple[np.ndarray, Dict[str, any]]]
]:
""" """
Splits a dataset of modulated IQ signal recordings into training and validation subsets. Splits a dataset of modulated IQ signal recordings into training and validation subsets.
@ -61,8 +64,7 @@ def split(
def split_recording( def split_recording(
recording_list: List[Tuple[np.ndarray, Dict[str, any]]], recording_list: List[Tuple[np.ndarray, Dict[str, any]]], num_snippets: int
num_snippets: int
) -> List[Tuple[np.ndarray, Dict[str, any]]]: ) -> List[Tuple[np.ndarray, Dict[str, any]]]:
""" """
Splits each full recording into a specified number of smaller snippets. Splits each full recording into a specified number of smaller snippets.

View File

@ -7,11 +7,7 @@ from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier
from helpers.app_settings import get_app_settings from helpers.app_settings import get_app_settings
def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
def convert_to_onnx(
ckpt_path: str,
fp16: bool=False
) -> None :
""" """
Convert a PyTorch model to ONNX format. Convert a PyTorch model to ONNX format.
@ -37,9 +33,7 @@ def convert_to_onnx(
) )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load( checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
ckpt_path, weights_only=True, map_location=device
)
model.load_state_dict(checkpoint["state_dict"]) model.load_state_dict(checkpoint["state_dict"])
if fp16: if fp16:

View File

@ -5,7 +5,10 @@ import os
import time import time
import json import json
def profile_onnx_model(path_to_onnx: str, num_runs: int = 100, warmup_runs: int = 5) -> None:
def profile_onnx_model(
path_to_onnx: str, num_runs: int = 100, warmup_runs: int = 5
) -> None:
""" """
Profiles an ONNX model by running inference multiple times and collecting performance data. Profiles an ONNX model by running inference multiple times and collecting performance data.
@ -58,7 +61,9 @@ def profile_onnx_model(path_to_onnx: str, num_runs: int = 100, warmup_runs: int
times.append(t1 - t0) times.append(t1 - t0)
avg_time = sum(times) / len(times) avg_time = sum(times) / len(times)
print(f"[Timing] Avg inference time (excluding {warmup_runs} warm-ups): {avg_time:.6f} sec") print(
f"[Timing] Avg inference time (excluding {warmup_runs} warm-ups): {avg_time:.6f} sec"
)
# End profiling & parse JSON # End profiling & parse JSON
profile_file = session.end_profiling() profile_file = session.end_profiling()
@ -71,7 +76,9 @@ def profile_onnx_model(path_to_onnx: str, num_runs: int = 100, warmup_runs: int
print(f"[Profile] Number of nodes executed: {len(nodes)}") print(f"[Profile] Number of nodes executed: {len(nodes)}")
if nodes: if nodes:
top = max(nodes, key=lambda x: x.get("dur", 0)) top = max(nodes, key=lambda x: x.get("dur", 0))
print(f"[Profile] Most expensive op: {top['name']}{top['dur'] / 1e6:.3f} ms") print(
f"[Profile] Most expensive op: {top['name']}{top['dur'] / 1e6:.3f} ms"
)
except Exception as e: except Exception as e:
print(f"[Warning] Failed to parse profiling JSON: {e}") print(f"[Warning] Failed to parse profiling JSON: {e}")

View File

@ -1,35 +1,33 @@
import os import os
import torch import torch
import numpy as np import numpy as np
import h5py
from sklearn.metrics import classification_report from sklearn.metrics import classification_report
import matplotlib
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier
from helpers.app_settings import get_app_settings from helpers.app_settings import get_app_settings
from cm_plotter import plot_confusion_matrix from cm_plotter import plot_confusion_matrix
from scripts.training.modulation_dataset import ModulationH5Dataset
def load_validation_data(h5_path:str ="data/datasets/val.h5"): def load_validation_data():
""" val_dataset = ModulationH5Dataset(
Loads validation data from an HDF5 file. "data/dataset/val.h5",
label_name="modulation",
data_key="validation_data"
)
X = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
class_names = list(val_dataset.label_encoder.classes_)
Returns:
X_val: np.ndarray of shape (N, C, L)
y_val: np.ndarray of shape (N,)
class_names: list of class names
"""
with h5py.File(h5_path, "r") as f:
X = f["X"][:] # shape: (N, C, L)
y = f["y"][:] # shape: (N,)
if "class_names" in f:
class_names = [s.decode("utf-8") for s in f["class_names"][:]]
else:
class_names = [str(i) for i in np.unique(y)]
return X, y, class_names return X, y, class_names
def build_model_from_ckpt(ckpt_path: str, in_channels: int, num_classes: int) -> torch.nn.Module: def build_model_from_ckpt(
ckpt_path: str, in_channels: int, num_classes: int
) -> torch.nn.Module:
""" """
Build and return a PyTorch model loaded from a checkpoint. Build and return a PyTorch model loaded from a checkpoint.
""" """
@ -37,13 +35,11 @@ def build_model_from_ckpt(ckpt_path: str, in_channels: int, num_classes: int) ->
model=mobilenetv3( model=mobilenetv3(
model_size="mobilenetv3_small_050", model_size="mobilenetv3_small_050",
num_classes=num_classes, num_classes=num_classes,
in_chans=in_channels in_chans=in_channels,
) )
) )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load( checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
ckpt_path, weights_only=True, map_location=device
)
model.load_state_dict(checkpoint["state_dict"]) model.load_state_dict(checkpoint["state_dict"])
model.eval() model.eval()
return model return model
@ -54,13 +50,16 @@ def evaluate_checkpoint(ckpt_path: str):
Loads the model from checkpoint and evaluates it on a validation set. Loads the model from checkpoint and evaluates it on a validation set.
Prints classification metrics and plots a confusion matrix. Prints classification metrics and plots a confusion matrix.
""" """
# Load validation data # Load validation data
X_val, y_true, class_names = load_validation_data() X_val, y_true, class_names = load_validation_data()
num_classes = len(class_names) num_classes = len(class_names)
in_channels = X_val.shape[1] in_channels = X_val.shape[1]
# Load model # Load model
model = build_model_from_ckpt(ckpt_path, in_channels=in_channels, num_classes=num_classes) model = build_model_from_ckpt(
ckpt_path, in_channels=in_channels, num_classes=num_classes
)
# Inference # Inference
y_pred = [] y_pred = []
@ -73,7 +72,7 @@ def evaluate_checkpoint(ckpt_path: str):
# Print classification report # Print classification report
print("\nClassification Report:") print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=class_names)) print(classification_report(y_true, y_pred, target_names=class_names, zero_division=0))
# Plot confusion matrix # Plot confusion matrix
plot_confusion_matrix( plot_confusion_matrix(
@ -81,7 +80,7 @@ def evaluate_checkpoint(ckpt_path: str):
y_pred=np.array(y_pred), y_pred=np.array(y_pred),
classes=class_names, classes=class_names,
normalize=True, normalize=True,
title="Normalized Confusion Matrix" title="Normalized Confusion Matrix",
) )
plt.show() plt.show()

View File

@ -131,7 +131,7 @@ def train_model():
trainer = L.Trainer( trainer = L.Trainer(
max_epochs=epochs, max_epochs=epochs,
callbacks=[checkpoint_callback, CustomProgressBar()], callbacks=[checkpoint_callback, CustomProgressBar()],
accelerator="gpu", accelerator="cpu",
devices=1, devices=1,
benchmark=True, benchmark=True,
precision="16-mixed", precision="16-mixed",