fixed plot.py
Some checks failed
RIA Hub Workflow Demo / ria-demo (push) Failing after 2m25s

This commit is contained in:
Liyu Xiao 2025-06-18 13:44:29 -04:00
parent 8ba21251bf
commit 91bab4acbd
7 changed files with 52 additions and 49 deletions

View File

@ -15,7 +15,7 @@ def generate_modulated_signals(output_dir: str) -> None:
adds AWGN noise, and saves the resulting samples as .npy files to the given output directory. adds AWGN noise, and saves the resulting samples as .npy files to the given output directory.
The function uses modulation parameters defined in app.yaml and supports modulation types like The function uses modulation parameters defined in app.yaml and supports modulation types like
PSK and QAM through configurable constellation settings. The generated recordings are tagged PSK and QAM through configurable constellation settings. The generated recordings are tagged
with metadata such as modulation type, SNR, roll-off factor (beta), and samples-per-symbol (sps). with metadata such as modulation type, SNR, roll-off factor (beta), and samples-per-symbol (sps).
Parameters: Parameters:

View File

@ -144,7 +144,9 @@ def main():
print("📦 Generating training and validation datasets...") print("📦 Generating training and validation datasets...")
print(f" ➤ Slicing each recording into {dataset_cfg.num_slices} snippets") print(f" ➤ Slicing each recording into {dataset_cfg.num_slices} snippets")
print(f" ➤ Train/Val split: {int(dataset_cfg.train_split * 100)}% / {int((1 - dataset_cfg.train_split) * 100)}%") print(
f" ➤ Train/Val split: {int(dataset_cfg.train_split * 100)}% / {int((1 - dataset_cfg.train_split) * 100)}%"
)
print(f" ➤ Output directory: data/dataset\n") print(f" ➤ Output directory: data/dataset\n")
train_path, val_path = generate_datasets(dataset_cfg) train_path, val_path = generate_datasets(dataset_cfg)
@ -160,6 +162,5 @@ def main():
print(f" 🔸 Validation samples saved to: {val_path} ({num_val} samples)") print(f" 🔸 Validation samples saved to: {val_path} ({num_val} samples)")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -3,12 +3,15 @@ from collections import defaultdict
from typing import List, Tuple, Dict from typing import List, Tuple, Dict
import numpy as np import numpy as np
def split( def split(
dataset: List[Tuple[np.ndarray, Dict[str, any]]], dataset: List[Tuple[np.ndarray, Dict[str, any]]],
train_frac: float, train_frac: float,
seed: int, seed: int,
label_key: str = "modulation" label_key: str = "modulation",
) -> Tuple[List[Tuple[np.ndarray, Dict[str, any]]], List[Tuple[np.ndarray, Dict[str, any]]]]: ) -> Tuple[
List[Tuple[np.ndarray, Dict[str, any]]], List[Tuple[np.ndarray, Dict[str, any]]]
]:
""" """
Splits a dataset of modulated IQ signal recordings into training and validation subsets. Splits a dataset of modulated IQ signal recordings into training and validation subsets.
@ -61,9 +64,8 @@ def split(
def split_recording( def split_recording(
recording_list: List[Tuple[np.ndarray, Dict[str, any]]], recording_list: List[Tuple[np.ndarray, Dict[str, any]]], num_snippets: int
num_snippets: int ) -> List[Tuple[np.ndarray, Dict[str, any]]]:
) -> List[Tuple[np.ndarray, Dict[str, any]]]:
""" """
Splits each full recording into a specified number of smaller snippets. Splits each full recording into a specified number of smaller snippets.
@ -75,13 +77,13 @@ def split_recording(
array into `num_snippets` contiguous chunks of shape (2, N // num_snippets). array into `num_snippets` contiguous chunks of shape (2, N // num_snippets).
Parameters: Parameters:
recording_list (List[Tuple[np.ndarray, dict]]): recording_list (List[Tuple[np.ndarray, dict]]):
List of (data, metadata) tuples to be split. List of (data, metadata) tuples to be split.
num_snippets (int): num_snippets (int):
Number of equal-length segments to divide each recording into. Number of equal-length segments to divide each recording into.
Returns: Returns:
List[Tuple[np.ndarray, dict]]: List[Tuple[np.ndarray, dict]]:
A flat list containing all resulting (snippet, metadata) pairs. A flat list containing all resulting (snippet, metadata) pairs.
""" """
snippet_list = [] snippet_list = []

View File

@ -7,11 +7,7 @@ from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier
from helpers.app_settings import get_app_settings from helpers.app_settings import get_app_settings
def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
def convert_to_onnx(
ckpt_path: str,
fp16: bool=False
) -> None :
""" """
Convert a PyTorch model to ONNX format. Convert a PyTorch model to ONNX format.
@ -35,11 +31,9 @@ def convert_to_onnx(
in_chans=in_channels, in_chans=in_channels,
) )
) )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load( checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
ckpt_path, weights_only=True, map_location=device
)
model.load_state_dict(checkpoint["state_dict"]) model.load_state_dict(checkpoint["state_dict"])
if fp16: if fp16:

View File

@ -5,7 +5,10 @@ import os
import time import time
import json import json
def profile_onnx_model(path_to_onnx: str, num_runs: int = 100, warmup_runs: int = 5) -> None:
def profile_onnx_model(
path_to_onnx: str, num_runs: int = 100, warmup_runs: int = 5
) -> None:
""" """
Profiles an ONNX model by running inference multiple times and collecting performance data. Profiles an ONNX model by running inference multiple times and collecting performance data.
@ -58,7 +61,9 @@ def profile_onnx_model(path_to_onnx: str, num_runs: int = 100, warmup_runs: int
times.append(t1 - t0) times.append(t1 - t0)
avg_time = sum(times) / len(times) avg_time = sum(times) / len(times)
print(f"[Timing] Avg inference time (excluding {warmup_runs} warm-ups): {avg_time:.6f} sec") print(
f"[Timing] Avg inference time (excluding {warmup_runs} warm-ups): {avg_time:.6f} sec"
)
# End profiling & parse JSON # End profiling & parse JSON
profile_file = session.end_profiling() profile_file = session.end_profiling()
@ -71,7 +76,9 @@ def profile_onnx_model(path_to_onnx: str, num_runs: int = 100, warmup_runs: int
print(f"[Profile] Number of nodes executed: {len(nodes)}") print(f"[Profile] Number of nodes executed: {len(nodes)}")
if nodes: if nodes:
top = max(nodes, key=lambda x: x.get("dur", 0)) top = max(nodes, key=lambda x: x.get("dur", 0))
print(f"[Profile] Most expensive op: {top['name']}{top['dur'] / 1e6:.3f} ms") print(
f"[Profile] Most expensive op: {top['name']}{top['dur'] / 1e6:.3f} ms"
)
except Exception as e: except Exception as e:
print(f"[Warning] Failed to parse profiling JSON: {e}") print(f"[Warning] Failed to parse profiling JSON: {e}")

View File

@ -1,35 +1,33 @@
import os import os
import torch import torch
import numpy as np import numpy as np
import h5py
from sklearn.metrics import classification_report from sklearn.metrics import classification_report
import matplotlib
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier
from helpers.app_settings import get_app_settings from helpers.app_settings import get_app_settings
from cm_plotter import plot_confusion_matrix from cm_plotter import plot_confusion_matrix
from scripts.training.modulation_dataset import ModulationH5Dataset
def load_validation_data(h5_path:str ="data/datasets/val.h5"): def load_validation_data():
""" val_dataset = ModulationH5Dataset(
Loads validation data from an HDF5 file. "data/dataset/val.h5",
label_name="modulation",
data_key="validation_data"
)
X = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
class_names = list(val_dataset.label_encoder.classes_)
Returns:
X_val: np.ndarray of shape (N, C, L)
y_val: np.ndarray of shape (N,)
class_names: list of class names
"""
with h5py.File(h5_path, "r") as f:
X = f["X"][:] # shape: (N, C, L)
y = f["y"][:] # shape: (N,)
if "class_names" in f:
class_names = [s.decode("utf-8") for s in f["class_names"][:]]
else:
class_names = [str(i) for i in np.unique(y)]
return X, y, class_names return X, y, class_names
def build_model_from_ckpt(ckpt_path: str, in_channels: int, num_classes: int) -> torch.nn.Module: def build_model_from_ckpt(
ckpt_path: str, in_channels: int, num_classes: int
) -> torch.nn.Module:
""" """
Build and return a PyTorch model loaded from a checkpoint. Build and return a PyTorch model loaded from a checkpoint.
""" """
@ -37,13 +35,11 @@ def build_model_from_ckpt(ckpt_path: str, in_channels: int, num_classes: int) ->
model=mobilenetv3( model=mobilenetv3(
model_size="mobilenetv3_small_050", model_size="mobilenetv3_small_050",
num_classes=num_classes, num_classes=num_classes,
in_chans=in_channels in_chans=in_channels,
) )
) )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load( checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
ckpt_path, weights_only=True, map_location=device
)
model.load_state_dict(checkpoint["state_dict"]) model.load_state_dict(checkpoint["state_dict"])
model.eval() model.eval()
return model return model
@ -54,13 +50,16 @@ def evaluate_checkpoint(ckpt_path: str):
Loads the model from checkpoint and evaluates it on a validation set. Loads the model from checkpoint and evaluates it on a validation set.
Prints classification metrics and plots a confusion matrix. Prints classification metrics and plots a confusion matrix.
""" """
# Load validation data # Load validation data
X_val, y_true, class_names = load_validation_data() X_val, y_true, class_names = load_validation_data()
num_classes = len(class_names) num_classes = len(class_names)
in_channels = X_val.shape[1] in_channels = X_val.shape[1]
# Load model # Load model
model = build_model_from_ckpt(ckpt_path, in_channels=in_channels, num_classes=num_classes) model = build_model_from_ckpt(
ckpt_path, in_channels=in_channels, num_classes=num_classes
)
# Inference # Inference
y_pred = [] y_pred = []
@ -73,7 +72,7 @@ def evaluate_checkpoint(ckpt_path: str):
# Print classification report # Print classification report
print("\nClassification Report:") print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=class_names)) print(classification_report(y_true, y_pred, target_names=class_names, zero_division=0))
# Plot confusion matrix # Plot confusion matrix
plot_confusion_matrix( plot_confusion_matrix(
@ -81,7 +80,7 @@ def evaluate_checkpoint(ckpt_path: str):
y_pred=np.array(y_pred), y_pred=np.array(y_pred),
classes=class_names, classes=class_names,
normalize=True, normalize=True,
title="Normalized Confusion Matrix" title="Normalized Confusion Matrix",
) )
plt.show() plt.show()

View File

@ -131,7 +131,7 @@ def train_model():
trainer = L.Trainer( trainer = L.Trainer(
max_epochs=epochs, max_epochs=epochs,
callbacks=[checkpoint_callback, CustomProgressBar()], callbacks=[checkpoint_callback, CustomProgressBar()],
accelerator="gpu", accelerator="cpu",
devices=1, devices=1,
benchmark=True, benchmark=True,
precision="16-mixed", precision="16-mixed",