Compare commits

..

No commits in common. "main" and "v0.1.0" have entirely different histories.
main ... v0.1.0

11 changed files with 67 additions and 81 deletions

View File

@ -2,9 +2,11 @@ name: Modulation Recognition Demo
on: on:
push: push:
branches: [main] branches:
[main]
pull_request: pull_request:
branches: [main] branches:
[main]
jobs: jobs:
ria-demo: ria-demo:
@ -44,24 +46,22 @@ jobs:
fi fi
pip install -r requirements.txt pip install -r requirements.txt
- name: 1. Generate Recordings - name: 1. Generate Recordings
run: | run: |
mkdir -p data/recordings mkdir -p data/recordings
PYTHONPATH=. python scripts/dataset_manager/data_gen.py --output-dir data/recordings PYTHONPATH=. python scripts/dataset_manager/data_gen.py --output-dir data/recordings
- name: 📦 Compress Recordings
run: tar -czf recordings.tar.gz -C data/recordings .
- name: ⬆️ Upload recordings - name: ⬆️ Upload recordings
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
name: recordings name: recordings
path: recordings.tar.gz path: data/recordings/**
- name: 2. Build HDF5 Dataset - name: 2. Build HDF5 Dataset
run: | run: |
mkdir -p data/dataset mkdir -p data/dataset
PYTHONPATH=. python scripts/dataset_manager/produce_dataset.py PYTHONPATH=. python scripts/dataset_manager/produce_dataset.py
shell: bash shell: bash
- name: ⬆️ Upload Dataset - name: ⬆️ Upload Dataset
@ -72,16 +72,16 @@ jobs:
- name: 3. Train Model - name: 3. Train Model
env: env:
NO_NNPACK: 1 NO_NNPACK: 1
PYTORCH_NO_NNPACK: 1 PYTORCH_NO_NNPACK: 1
run: | run: |
mkdir -p checkpoint_files mkdir -p checkpoint_files
PYTHONPATH=. python scripts/model_builder/train.py 2>/dev/null PYTHONPATH=. python scripts/model_builder/train.py 2>/dev/null
- name: 4. Plot Model - name: 4. Plot Model
env: env:
NO_NNPACK: 1 NO_NNPACK: 1
PYTORCH_NO_NNPACK: 1 PYTORCH_NO_NNPACK: 1
run: | run: |
PYTHONPATH=. python scripts/model_builder/plot_data.py 2>/dev/null PYTHONPATH=. python scripts/model_builder/plot_data.py 2>/dev/null
@ -98,7 +98,7 @@ jobs:
run: | run: |
mkdir -p onnx_files mkdir -p onnx_files
MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python scripts/application_packager/convert_to_onnx.py 2>/dev/null MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python scripts/application_packager/convert_to_onnx.py 2>/dev/null
- name: ⬆️ Upload ONNX file - name: ⬆️ Upload ONNX file
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
@ -108,13 +108,13 @@ jobs:
- name: 6. Profile ONNX model - name: 6. Profile ONNX model
run: | run: |
PYTHONPATH=. python scripts/application_packager/profile_onnx.py PYTHONPATH=. python scripts/application_packager/profile_onnx.py
- name: ⬆️ Upload JSON trace - name: ⬆️ Upload JSON trace
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
name: profile-data name: profile-data
path: "**/onnxruntime_profile_*.json" path: '**/onnxruntime_profile_*.json'
- name: 7. Convert ONNX graph to an ORT file - name: 7. Convert ONNX graph to an ORT file
run: | run: |
PYTHONPATH=. python scripts/application_packager/convert_to_ort.py PYTHONPATH=. python scripts/application_packager/convert_to_ort.py

View File

@ -24,7 +24,7 @@ dataset:
snr_step: 3 snr_step: 3
# Number of iterations (signal recordings) per modulation and SNR combination # Number of iterations (signal recordings) per modulation and SNR combination
num_iterations: 10 num_iterations: 3
# Modulation scheme settings; keys must match the `modulation_types` list above # Modulation scheme settings; keys must match the `modulation_types` list above
# Each entry includes: # Each entry includes:
@ -50,7 +50,7 @@ dataset:
# Training and validation split ratios; must sum to 1 # Training and validation split ratios; must sum to 1
train_split: 0.8 train_split: 0.8
val_split: 0.2 val_split : 0.2
training: training:
# Number of training examples processed together before the model updates its weights # Number of training examples processed together before the model updates its weights

View File

View File

@ -41,11 +41,7 @@ class AppConfig:
class AppSettings: class AppSettings:
""" """Application settings, to be initialized from app.yaml configuration file."""
Application settings,
to be initialized from
app.yaml configuration file.
"""
def __init__(self, config_file: str): def __init__(self, config_file: str):
# Load the YAML configuration file # Load the YAML configuration file

View File

@ -2,9 +2,9 @@ import os
import numpy as np import numpy as np
import torch import torch
from scripts.model_builder.mobilenetv3 import RFClassifier, mobilenetv3
from helpers.app_settings import get_app_settings from helpers.app_settings import get_app_settings
from scripts.model_builder.mobilenetv3 import RFClassifier, mobilenetv3
def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None: def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
@ -21,7 +21,7 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
in_channels = 2 in_channels = 2
batch_size = 1 batch_size = 1
slice_length = int(dataset_cfg.recording_length / dataset_cfg.num_slices) slice_length = int(1024 / dataset_cfg.num_slices)
num_classes = len(dataset_cfg.modulation_types) num_classes = len(dataset_cfg.modulation_types)
model = RFClassifier( model = RFClassifier(
@ -42,7 +42,7 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
model.eval() model.eval()
# Generate random sample data # Generate random sample data
base, _ = os.path.splitext(os.path.basename(ckpt_path)) base, ext = os.path.splitext(os.path.basename(ckpt_path))
if fp16: if fp16:
output_path = os.path.join("onnx_files", f"{base}.onnx") output_path = os.path.join("onnx_files", f"{base}.onnx")
sample_input = torch.from_numpy( sample_input = torch.from_numpy(

View File

@ -29,7 +29,7 @@ def generate_modulated_signals(output_dir: str) -> None:
for modulation in settings.modulation_types: for modulation in settings.modulation_types:
for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step): for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step):
for _ in range(settings.num_iterations): for i in range(3):
recording_length = settings.recording_length recording_length = settings.recording_length
beta = ( beta = (
settings.beta settings.beta

View File

@ -49,6 +49,8 @@ def write_hdf5_file(records: List, output_path: str, dataset_name: str = "data")
int(md["sps"]), int(md["sps"]),
) )
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
with h5py.File(output_path, "w") as hf: with h5py.File(output_path, "w") as hf:
data_arr = np.stack([rec[0] for rec in records]) data_arr = np.stack([rec[0] for rec in records])
dset = hf.create_dataset(dataset_name, data=data_arr, compression="gzip") dset = hf.create_dataset(dataset_name, data=data_arr, compression="gzip")

View File

@ -90,7 +90,7 @@ def split_recording(
snippet_list = [] snippet_list = []
for data, md in recording_list: for data, md in recording_list:
_, N = data.shape C, N = data.shape
L = N // num_snippets L = N // num_snippets
for i in range(num_snippets): for i in range(num_snippets):
start = i * L start = i * L

View File

@ -1,4 +1,5 @@
import lightning as L import lightning as L
import numpy as np
import timm import timm
import torch import torch
from torch import nn from torch import nn

View File

@ -2,26 +2,26 @@ import os
import numpy as np import numpy as np
import torch import torch
from sklearn.metrics import classification_report
os.environ["NNPACK"] = "0"
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from mobilenetv3 import RFClassifier, mobilenetv3 from mobilenetv3 import RFClassifier, mobilenetv3
from modulation_dataset import ModulationH5Dataset from modulation_dataset import ModulationH5Dataset
from sklearn.metrics import classification_report
from helpers.app_settings import get_app_settings from helpers.app_settings import get_app_settings
os.environ["NNPACK"] = "0"
def load_validation_data(): def load_validation_data():
val_dataset = ModulationH5Dataset( val_dataset = ModulationH5Dataset(
"data/dataset/val.h5", label_name="modulation", data_key="validation_data" "data/dataset/val.h5", label_name="modulation", data_key="validation_data"
) )
x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L) X = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,) y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
class_names = list(val_dataset.label_encoder.classes_) class_names = list(val_dataset.label_encoder.classes_)
return x, y, class_names return X, y, class_names
def build_model_from_ckpt( def build_model_from_ckpt(
@ -75,7 +75,7 @@ def evaluate_checkpoint(ckpt_path: str):
classification_report(y_true, y_pred, target_names=class_names, zero_division=0) classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
) )
print_confusion_matrix( plot_confusion_matrix_with_counts(
y_true=np.array(y_true), y_true=np.array(y_true),
y_pred=np.array(y_pred), y_pred=np.array(y_pred),
classes=class_names, classes=class_names,
@ -84,7 +84,7 @@ def evaluate_checkpoint(ckpt_path: str):
) )
def print_confusion_matrix( def plot_confusion_matrix_with_counts(
y_true: np.ndarray, y_true: np.ndarray,
y_pred: np.ndarray, y_pred: np.ndarray,
classes: list[str], classes: list[str],
@ -94,7 +94,6 @@ def print_confusion_matrix(
""" """
Plot a confusion matrix showing both raw counts and (optionally) normalized values. Plot a confusion matrix showing both raw counts and (optionally) normalized values.
Args: Args:
y_true: true labels (integers 0..C-1) y_true: true labels (integers 0..C-1)
y_pred: predicted labels (same shape as y_true) y_pred: predicted labels (same shape as y_true)
@ -103,8 +102,8 @@ def print_confusion_matrix(
title: title for the plot title: title for the plot
""" """
# 1) build raw CM # 1) build raw CM
c = len(classes) C = len(classes)
cm = np.zeros((c, c), dtype=int) cm = np.zeros((C, C), dtype=int)
for t, p in zip(y_true, y_pred): for t, p in zip(y_true, y_pred):
cm[t, p] += 1 cm[t, p] += 1
@ -113,48 +112,33 @@ def print_confusion_matrix(
with np.errstate(divide="ignore", invalid="ignore"): with np.errstate(divide="ignore", invalid="ignore"):
cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None] cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None]
cm_norm = np.nan_to_num(cm_norm) cm_norm = np.nan_to_num(cm_norm)
print_confusion_matrix_helper(cm_norm, classes)
else: else:
print_confusion_matrix_helper(cm, classes) cm_norm = cm
# 3) plot
fig, ax = plt.subplots(figsize=(8, 8))
im = ax.imshow(cm_norm, interpolation="nearest")
ax.set_title(title)
ax.set_xlabel("Predicted label")
ax.set_ylabel("True label")
ax.set_xticks(np.arange(C))
ax.set_yticks(np.arange(C))
ax.set_xticklabels(classes, rotation=45, ha="right")
ax.set_yticklabels(classes)
import numpy as np # 4) annotate
for i in range(C):
for j in range(C):
count = cm[i, j]
val = cm_norm[i, j]
txt = f"{count}\n{val:.2f}"
ax.text(j, i, txt, ha="center", va="center")
fig.colorbar(im, ax=ax, label="Normalized value" if normalize else "Count")
def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2): plt.tight_layout()
""" plt.show()
Pretty prints a confusion matrix with x/y labels.
Parameters:
- matrix: square 2D numpy array
- labels: list of class labels (default: range(num_classes))
- normalize: whether to normalize rows to sum to 1
- digits: number of decimal places to show for normalized values
"""
matrix = np.array(matrix)
num_classes = matrix.shape[0]
labels = classes or list(range(num_classes))
# Header
print(" " * 9 + "Ground Truth →")
header = "Pred ↓ | " + " ".join([f"{str(label):>6}" for label in labels])
print(header)
print("-" * len(header))
# Rows
for i in range(num_classes):
row_vals = matrix[i]
if normalize:
row_sum = row_vals.sum()
row_vals = row_vals / row_sum if row_sum != 0 else row_vals
row_str = " ".join([f"{val:>6.{digits}f}" for val in row_vals])
else:
row_str = " ".join([f"{int(val):>6}" for val in row_vals])
print(f"{str(labels[i]):>7} | {row_str}")
if __name__ == "__main__": if __name__ == "__main__":
settings = get_app_settings() settings = get_app_settings()
evaluate_checkpoint( evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt"))
os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
)

View File

@ -1,22 +1,23 @@
import os import os
import sys import sys
os.environ["NNPACK"] = "0"
import lightning as L import lightning as L
import mobilenetv3 import mobilenetv3
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics import torchmetrics
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar from lightning.pytorch.callbacks import ModelCheckpoint
from modulation_dataset import ModulationH5Dataset from modulation_dataset import ModulationH5Dataset
from helpers.app_settings import get_app_settings from helpers.app_settings import get_app_settings
os.environ["NNPACK"] = "0"
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.abspath(os.path.join(script_dir, "..")) data_dir = os.path.abspath(os.path.join(script_dir, ".."))
project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path: if project_root not in sys.path:
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from lightning.pytorch.callbacks import TQDMProgressBar
class CustomProgressBar(TQDMProgressBar): class CustomProgressBar(TQDMProgressBar):
@ -58,6 +59,8 @@ def train_model():
print("X shape:", x.shape) print("X shape:", x.shape)
print("Y values:", y[:10]) print("Y values:", y[:10])
break break
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
num_classes = len(ds_train.label_encoder.classes_) num_classes = len(ds_train.label_encoder.classes_)
hparams = { hparams = {