This commit is contained in:
parent
8ba21251bf
commit
91bab4acbd
|
@ -144,7 +144,9 @@ def main():
|
||||||
|
|
||||||
print("📦 Generating training and validation datasets...")
|
print("📦 Generating training and validation datasets...")
|
||||||
print(f" ➤ Slicing each recording into {dataset_cfg.num_slices} snippets")
|
print(f" ➤ Slicing each recording into {dataset_cfg.num_slices} snippets")
|
||||||
print(f" ➤ Train/Val split: {int(dataset_cfg.train_split * 100)}% / {int((1 - dataset_cfg.train_split) * 100)}%")
|
print(
|
||||||
|
f" ➤ Train/Val split: {int(dataset_cfg.train_split * 100)}% / {int((1 - dataset_cfg.train_split) * 100)}%"
|
||||||
|
)
|
||||||
print(f" ➤ Output directory: data/dataset\n")
|
print(f" ➤ Output directory: data/dataset\n")
|
||||||
|
|
||||||
train_path, val_path = generate_datasets(dataset_cfg)
|
train_path, val_path = generate_datasets(dataset_cfg)
|
||||||
|
@ -160,6 +162,5 @@ def main():
|
||||||
print(f" 🔸 Validation samples saved to: {val_path} ({num_val} samples)")
|
print(f" 🔸 Validation samples saved to: {val_path} ({num_val} samples)")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -3,12 +3,15 @@ from collections import defaultdict
|
||||||
from typing import List, Tuple, Dict
|
from typing import List, Tuple, Dict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def split(
|
def split(
|
||||||
dataset: List[Tuple[np.ndarray, Dict[str, any]]],
|
dataset: List[Tuple[np.ndarray, Dict[str, any]]],
|
||||||
train_frac: float,
|
train_frac: float,
|
||||||
seed: int,
|
seed: int,
|
||||||
label_key: str = "modulation"
|
label_key: str = "modulation",
|
||||||
) -> Tuple[List[Tuple[np.ndarray, Dict[str, any]]], List[Tuple[np.ndarray, Dict[str, any]]]]:
|
) -> Tuple[
|
||||||
|
List[Tuple[np.ndarray, Dict[str, any]]], List[Tuple[np.ndarray, Dict[str, any]]]
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Splits a dataset of modulated IQ signal recordings into training and validation subsets.
|
Splits a dataset of modulated IQ signal recordings into training and validation subsets.
|
||||||
|
|
||||||
|
@ -61,8 +64,7 @@ def split(
|
||||||
|
|
||||||
|
|
||||||
def split_recording(
|
def split_recording(
|
||||||
recording_list: List[Tuple[np.ndarray, Dict[str, any]]],
|
recording_list: List[Tuple[np.ndarray, Dict[str, any]]], num_snippets: int
|
||||||
num_snippets: int
|
|
||||||
) -> List[Tuple[np.ndarray, Dict[str, any]]]:
|
) -> List[Tuple[np.ndarray, Dict[str, any]]]:
|
||||||
"""
|
"""
|
||||||
Splits each full recording into a specified number of smaller snippets.
|
Splits each full recording into a specified number of smaller snippets.
|
||||||
|
|
|
@ -7,11 +7,7 @@ from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier
|
||||||
from helpers.app_settings import get_app_settings
|
from helpers.app_settings import get_app_settings
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
||||||
def convert_to_onnx(
|
|
||||||
ckpt_path: str,
|
|
||||||
fp16: bool=False
|
|
||||||
) -> None :
|
|
||||||
"""
|
"""
|
||||||
Convert a PyTorch model to ONNX format.
|
Convert a PyTorch model to ONNX format.
|
||||||
|
|
||||||
|
@ -37,9 +33,7 @@ def convert_to_onnx(
|
||||||
)
|
)
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
||||||
ckpt_path, weights_only=True, map_location=device
|
|
||||||
)
|
|
||||||
model.load_state_dict(checkpoint["state_dict"])
|
model.load_state_dict(checkpoint["state_dict"])
|
||||||
|
|
||||||
if fp16:
|
if fp16:
|
||||||
|
|
|
@ -5,7 +5,10 @@ import os
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
def profile_onnx_model(path_to_onnx: str, num_runs: int = 100, warmup_runs: int = 5) -> None:
|
|
||||||
|
def profile_onnx_model(
|
||||||
|
path_to_onnx: str, num_runs: int = 100, warmup_runs: int = 5
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Profiles an ONNX model by running inference multiple times and collecting performance data.
|
Profiles an ONNX model by running inference multiple times and collecting performance data.
|
||||||
|
|
||||||
|
@ -58,7 +61,9 @@ def profile_onnx_model(path_to_onnx: str, num_runs: int = 100, warmup_runs: int
|
||||||
times.append(t1 - t0)
|
times.append(t1 - t0)
|
||||||
|
|
||||||
avg_time = sum(times) / len(times)
|
avg_time = sum(times) / len(times)
|
||||||
print(f"[Timing] Avg inference time (excluding {warmup_runs} warm-ups): {avg_time:.6f} sec")
|
print(
|
||||||
|
f"[Timing] Avg inference time (excluding {warmup_runs} warm-ups): {avg_time:.6f} sec"
|
||||||
|
)
|
||||||
|
|
||||||
# End profiling & parse JSON
|
# End profiling & parse JSON
|
||||||
profile_file = session.end_profiling()
|
profile_file = session.end_profiling()
|
||||||
|
@ -71,7 +76,9 @@ def profile_onnx_model(path_to_onnx: str, num_runs: int = 100, warmup_runs: int
|
||||||
print(f"[Profile] Number of nodes executed: {len(nodes)}")
|
print(f"[Profile] Number of nodes executed: {len(nodes)}")
|
||||||
if nodes:
|
if nodes:
|
||||||
top = max(nodes, key=lambda x: x.get("dur", 0))
|
top = max(nodes, key=lambda x: x.get("dur", 0))
|
||||||
print(f"[Profile] Most expensive op: {top['name']} — {top['dur'] / 1e6:.3f} ms")
|
print(
|
||||||
|
f"[Profile] Most expensive op: {top['name']} — {top['dur'] / 1e6:.3f} ms"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[Warning] Failed to parse profiling JSON: {e}")
|
print(f"[Warning] Failed to parse profiling JSON: {e}")
|
||||||
|
|
||||||
|
|
|
@ -1,35 +1,33 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import h5py
|
|
||||||
from sklearn.metrics import classification_report
|
from sklearn.metrics import classification_report
|
||||||
|
import matplotlib
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier
|
from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier
|
||||||
from helpers.app_settings import get_app_settings
|
from helpers.app_settings import get_app_settings
|
||||||
from cm_plotter import plot_confusion_matrix
|
from cm_plotter import plot_confusion_matrix
|
||||||
|
from scripts.training.modulation_dataset import ModulationH5Dataset
|
||||||
|
|
||||||
|
|
||||||
def load_validation_data(h5_path:str ="data/datasets/val.h5"):
|
def load_validation_data():
|
||||||
"""
|
val_dataset = ModulationH5Dataset(
|
||||||
Loads validation data from an HDF5 file.
|
"data/dataset/val.h5",
|
||||||
|
label_name="modulation",
|
||||||
|
data_key="validation_data"
|
||||||
|
)
|
||||||
|
|
||||||
|
X = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
|
||||||
|
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
|
||||||
|
class_names = list(val_dataset.label_encoder.classes_)
|
||||||
|
|
||||||
Returns:
|
|
||||||
X_val: np.ndarray of shape (N, C, L)
|
|
||||||
y_val: np.ndarray of shape (N,)
|
|
||||||
class_names: list of class names
|
|
||||||
"""
|
|
||||||
with h5py.File(h5_path, "r") as f:
|
|
||||||
X = f["X"][:] # shape: (N, C, L)
|
|
||||||
y = f["y"][:] # shape: (N,)
|
|
||||||
if "class_names" in f:
|
|
||||||
class_names = [s.decode("utf-8") for s in f["class_names"][:]]
|
|
||||||
else:
|
|
||||||
class_names = [str(i) for i in np.unique(y)]
|
|
||||||
return X, y, class_names
|
return X, y, class_names
|
||||||
|
|
||||||
|
|
||||||
def build_model_from_ckpt(ckpt_path: str, in_channels: int, num_classes: int) -> torch.nn.Module:
|
def build_model_from_ckpt(
|
||||||
|
ckpt_path: str, in_channels: int, num_classes: int
|
||||||
|
) -> torch.nn.Module:
|
||||||
"""
|
"""
|
||||||
Build and return a PyTorch model loaded from a checkpoint.
|
Build and return a PyTorch model loaded from a checkpoint.
|
||||||
"""
|
"""
|
||||||
|
@ -37,13 +35,11 @@ def build_model_from_ckpt(ckpt_path: str, in_channels: int, num_classes: int) ->
|
||||||
model=mobilenetv3(
|
model=mobilenetv3(
|
||||||
model_size="mobilenetv3_small_050",
|
model_size="mobilenetv3_small_050",
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
in_chans=in_channels
|
in_chans=in_channels,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
||||||
ckpt_path, weights_only=True, map_location=device
|
|
||||||
)
|
|
||||||
model.load_state_dict(checkpoint["state_dict"])
|
model.load_state_dict(checkpoint["state_dict"])
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
@ -54,13 +50,16 @@ def evaluate_checkpoint(ckpt_path: str):
|
||||||
Loads the model from checkpoint and evaluates it on a validation set.
|
Loads the model from checkpoint and evaluates it on a validation set.
|
||||||
Prints classification metrics and plots a confusion matrix.
|
Prints classification metrics and plots a confusion matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Load validation data
|
# Load validation data
|
||||||
X_val, y_true, class_names = load_validation_data()
|
X_val, y_true, class_names = load_validation_data()
|
||||||
num_classes = len(class_names)
|
num_classes = len(class_names)
|
||||||
in_channels = X_val.shape[1]
|
in_channels = X_val.shape[1]
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model = build_model_from_ckpt(ckpt_path, in_channels=in_channels, num_classes=num_classes)
|
model = build_model_from_ckpt(
|
||||||
|
ckpt_path, in_channels=in_channels, num_classes=num_classes
|
||||||
|
)
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
y_pred = []
|
y_pred = []
|
||||||
|
@ -73,7 +72,7 @@ def evaluate_checkpoint(ckpt_path: str):
|
||||||
|
|
||||||
# Print classification report
|
# Print classification report
|
||||||
print("\nClassification Report:")
|
print("\nClassification Report:")
|
||||||
print(classification_report(y_true, y_pred, target_names=class_names))
|
print(classification_report(y_true, y_pred, target_names=class_names, zero_division=0))
|
||||||
|
|
||||||
# Plot confusion matrix
|
# Plot confusion matrix
|
||||||
plot_confusion_matrix(
|
plot_confusion_matrix(
|
||||||
|
@ -81,7 +80,7 @@ def evaluate_checkpoint(ckpt_path: str):
|
||||||
y_pred=np.array(y_pred),
|
y_pred=np.array(y_pred),
|
||||||
classes=class_names,
|
classes=class_names,
|
||||||
normalize=True,
|
normalize=True,
|
||||||
title="Normalized Confusion Matrix"
|
title="Normalized Confusion Matrix",
|
||||||
)
|
)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
|
@ -131,7 +131,7 @@ def train_model():
|
||||||
trainer = L.Trainer(
|
trainer = L.Trainer(
|
||||||
max_epochs=epochs,
|
max_epochs=epochs,
|
||||||
callbacks=[checkpoint_callback, CustomProgressBar()],
|
callbacks=[checkpoint_callback, CustomProgressBar()],
|
||||||
accelerator="gpu",
|
accelerator="cpu",
|
||||||
devices=1,
|
devices=1,
|
||||||
benchmark=True,
|
benchmark=True,
|
||||||
precision="16-mixed",
|
precision="16-mixed",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user