liyu-dev #3
|
@ -2,11 +2,9 @@ name: Modulation Recognition Demo
|
|||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
[main]
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches:
|
||||
[main]
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
ria-demo:
|
||||
|
@ -46,17 +44,19 @@ jobs:
|
|||
fi
|
||||
pip install -r requirements.txt
|
||||
|
||||
|
||||
- name: 1. Generate Recordings
|
||||
run: |
|
||||
mkdir -p data/recordings
|
||||
PYTHONPATH=. python scripts/dataset_manager/data_gen.py --output-dir data/recordings
|
||||
|
||||
- name: 📦 Compress Recordings
|
||||
run: tar -czf recordings.tar.gz -C data/recordings .
|
||||
|
||||
- name: ⬆️ Upload recordings
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: recordings
|
||||
path: data/recordings/**
|
||||
path: recordings.tar.gz
|
||||
|
||||
- name: 2. Build HDF5 Dataset
|
||||
run: |
|
||||
|
@ -113,7 +113,7 @@ jobs:
|
|||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: profile-data
|
||||
path: '**/onnxruntime_profile_*.json'
|
||||
path: "**/onnxruntime_profile_*.json"
|
||||
|
||||
- name: 7. Convert ONNX graph to an ORT file
|
||||
run: |
|
||||
|
|
|
@ -24,7 +24,7 @@ dataset:
|
|||
snr_step: 3
|
||||
|
||||
# Number of iterations (signal recordings) per modulation and SNR combination
|
||||
num_iterations: 3
|
||||
num_iterations: 100
|
||||
|
||||
# Modulation scheme settings; keys must match the `modulation_types` list above
|
||||
# Each entry includes:
|
||||
|
@ -57,7 +57,7 @@ training:
|
|||
batch_size: 256
|
||||
|
||||
# Number of complete passes through the training dataset during training
|
||||
epochs: 5
|
||||
epochs: 30
|
||||
|
||||
# Learning rate: step size for weight updates after each batch
|
||||
# Recommended range for fine-tuning: 1e-6 to 1e-4
|
||||
|
|
|
@ -41,7 +41,11 @@ class AppConfig:
|
|||
|
||||
|
||||
class AppSettings:
|
||||
"""Application settings, to be initialized from app.yaml configuration file."""
|
||||
"""
|
||||
Application settings,
|
||||
to be initialized from
|
||||
app.yaml configuration file.
|
||||
"""
|
||||
|
||||
def __init__(self, config_file: str):
|
||||
# Load the YAML configuration file
|
||||
|
|
|
@ -2,9 +2,9 @@ import os
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scripts.model_builder.mobilenetv3 import RFClassifier, mobilenetv3
|
||||
|
||||
from helpers.app_settings import get_app_settings
|
||||
from scripts.model_builder.mobilenetv3 import RFClassifier, mobilenetv3
|
||||
|
||||
|
||||
def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
||||
|
|
|
@ -29,7 +29,7 @@ def generate_modulated_signals(output_dir: str) -> None:
|
|||
|
||||
for modulation in settings.modulation_types:
|
||||
for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step):
|
||||
for i in range(3):
|
||||
for _ in range(settings.num_iterations):
|
||||
recording_length = settings.recording_length
|
||||
beta = (
|
||||
settings.beta
|
||||
|
|
|
@ -49,8 +49,6 @@ def write_hdf5_file(records: List, output_path: str, dataset_name: str = "data")
|
|||
int(md["sps"]),
|
||||
)
|
||||
|
||||
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
|
||||
|
||||
with h5py.File(output_path, "w") as hf:
|
||||
data_arr = np.stack([rec[0] for rec in records])
|
||||
dset = hf.create_dataset(dataset_name, data=data_arr, compression="gzip")
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import lightning as L
|
||||
import numpy as np
|
||||
import timm
|
||||
import torch
|
||||
from torch import nn
|
||||
|
|
|
@ -2,27 +2,25 @@ import os
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn.metrics import classification_report
|
||||
|
||||
os.environ["NNPACK"] = "0"
|
||||
from matplotlib import pyplot as plt
|
||||
from mobilenetv3 import RFClassifier, mobilenetv3
|
||||
from modulation_dataset import ModulationH5Dataset
|
||||
from sklearn.metrics import classification_report
|
||||
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
os.environ["NNPACK"] = "0"
|
||||
|
||||
|
||||
def load_validation_data():
|
||||
val_dataset = ModulationH5Dataset(
|
||||
"data/dataset/val.h5", label_name="modulation", data_key="validation_data"
|
||||
)
|
||||
|
||||
|
||||
x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
|
||||
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
|
||||
class_names = list(val_dataset.label_encoder.classes_)
|
||||
|
||||
|
||||
return x, y, class_names
|
||||
|
||||
|
||||
|
@ -46,27 +44,22 @@ def build_model_from_ckpt(
|
|||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
def evaluate_checkpoint(ckpt_path: str):
|
||||
"""
|
||||
Loads the model from checkpoint and evaluates it on a validation set.
|
||||
Prints classification metrics and plots a confusion matrix.
|
||||
"""
|
||||
|
||||
|
||||
# Load validation data
|
||||
X_val, y_true, class_names = load_validation_data()
|
||||
num_classes = len(class_names)
|
||||
in_channels = X_val.shape[1]
|
||||
|
||||
|
||||
# Load model
|
||||
model = build_model_from_ckpt(
|
||||
ckpt_path, in_channels=in_channels, num_classes=num_classes
|
||||
)
|
||||
|
||||
|
||||
# Inference
|
||||
y_pred = []
|
||||
with torch.no_grad():
|
||||
|
@ -76,15 +69,12 @@ def evaluate_checkpoint(ckpt_path: str):
|
|||
pred = torch.argmax(logits, dim=1).item()
|
||||
y_pred.append(pred)
|
||||
|
||||
|
||||
# Print classification report
|
||||
print("\nClassification Report:")
|
||||
print(
|
||||
classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
|
||||
)
|
||||
|
||||
|
||||
|
||||
print_confusion_matrix(
|
||||
y_true=np.array(y_true),
|
||||
y_pred=np.array(y_pred),
|
||||
|
@ -94,8 +84,6 @@ def evaluate_checkpoint(ckpt_path: str):
|
|||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def print_confusion_matrix(
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
|
@ -120,7 +108,6 @@ def print_confusion_matrix(
|
|||
for t, p in zip(y_true, y_pred):
|
||||
cm[t, p] += 1
|
||||
|
||||
|
||||
# 2) normalize if requested
|
||||
if normalize:
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
|
@ -131,10 +118,9 @@ def print_confusion_matrix(
|
|||
print_confusion_matrix_helper(cm, classes)
|
||||
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2):
|
||||
"""
|
||||
Pretty prints a confusion matrix with x/y labels.
|
||||
|
@ -169,6 +155,6 @@ def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=
|
|||
|
||||
if __name__ == "__main__":
|
||||
settings = get_app_settings()
|
||||
evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt"))
|
||||
|
||||
|
||||
evaluate_checkpoint(
|
||||
os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
|
||||
)
|
||||
|
|
|
@ -1,23 +1,22 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
os.environ["NNPACK"] = "0"
|
||||
import lightning as L
|
||||
import mobilenetv3
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchmetrics
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
|
||||
from modulation_dataset import ModulationH5Dataset
|
||||
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
os.environ["NNPACK"] = "0"
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
||||
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||
|
||||
|
||||
class CustomProgressBar(TQDMProgressBar):
|
||||
|
@ -59,8 +58,6 @@ def train_model():
|
|||
print("X shape:", x.shape)
|
||||
print("Y values:", y[:10])
|
||||
break
|
||||
|
||||
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
|
||||
num_classes = len(ds_train.label_encoder.classes_)
|
||||
|
||||
hparams = {
|
||||
|
|
Loading…
Reference in New Issue
Block a user