From 04e5f4db8a711b8d13fdc7a07b33b17777fe6516 Mon Sep 17 00:00:00 2001 From: Liyu Xiao Date: Thu, 21 Aug 2025 10:56:23 -0400 Subject: [PATCH] clean up code, adjusted default parameters --- conf/app.yaml | 5 +---- helpers/app_settings.py | 7 +++++-- scripts/dataset_manager/data_gen.py | 2 +- scripts/model_builder/mobilenetv3.py | 1 - scripts/model_builder/plot_data.py | 6 +++--- scripts/model_builder/train.py | 7 ++----- 6 files changed, 12 insertions(+), 16 deletions(-) diff --git a/conf/app.yaml b/conf/app.yaml index ff17ac9..b74b93d 100644 --- a/conf/app.yaml +++ b/conf/app.yaml @@ -5,9 +5,6 @@ dataset: # Number of samples per recording recording_length: 1024 - # Set this to scale the number of generated recordings - mult_factor: 5 - # List of signal modulation schemes to include in the dataset modulation_types: - bpsk @@ -27,7 +24,7 @@ dataset: snr_step: 3 # Number of iterations (signal recordings) per modulation and SNR combination - num_iterations: 3 + num_iterations: 100 # Modulation scheme settings; keys must match the `modulation_types` list above # Each entry includes: diff --git a/helpers/app_settings.py b/helpers/app_settings.py index 8eeaf2f..1939447 100644 --- a/helpers/app_settings.py +++ b/helpers/app_settings.py @@ -9,7 +9,6 @@ import yaml @dataclass class DataSetConfig: num_slices: int - mult_factor: int train_split: float seed: int modulation_types: list @@ -42,7 +41,11 @@ class AppConfig: class AppSettings: - """Application settings, to be initialized from app.yaml configuration file.""" + """ + Application settings, + to be initialized from + app.yaml configuration file. + """ def __init__(self, config_file: str): # Load the YAML configuration file diff --git a/scripts/dataset_manager/data_gen.py b/scripts/dataset_manager/data_gen.py index 1b36209..7efe1bc 100644 --- a/scripts/dataset_manager/data_gen.py +++ b/scripts/dataset_manager/data_gen.py @@ -29,7 +29,7 @@ def generate_modulated_signals(output_dir: str) -> None: for modulation in settings.modulation_types: for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step): - for _ in range(settings.mult_factor): + for _ in range(settings.num_iterations): recording_length = settings.recording_length beta = ( settings.beta diff --git a/scripts/model_builder/mobilenetv3.py b/scripts/model_builder/mobilenetv3.py index 9de7f27..41a3ab6 100644 --- a/scripts/model_builder/mobilenetv3.py +++ b/scripts/model_builder/mobilenetv3.py @@ -1,5 +1,4 @@ import lightning as L -import numpy as np import timm import torch from torch import nn diff --git a/scripts/model_builder/plot_data.py b/scripts/model_builder/plot_data.py index 0aafff7..25fe88e 100644 --- a/scripts/model_builder/plot_data.py +++ b/scripts/model_builder/plot_data.py @@ -2,15 +2,15 @@ import os import numpy as np import torch -from sklearn.metrics import classification_report - -os.environ["NNPACK"] = "0" from matplotlib import pyplot as plt from mobilenetv3 import RFClassifier, mobilenetv3 from modulation_dataset import ModulationH5Dataset +from sklearn.metrics import classification_report from helpers.app_settings import get_app_settings +os.environ["NNPACK"] = "0" + def load_validation_data(): val_dataset = ModulationH5Dataset( diff --git a/scripts/model_builder/train.py b/scripts/model_builder/train.py index 560a9d5..cbd9d9d 100644 --- a/scripts/model_builder/train.py +++ b/scripts/model_builder/train.py @@ -1,23 +1,22 @@ import os import sys -os.environ["NNPACK"] = "0" import lightning as L import mobilenetv3 import torch import torch.nn.functional as F import torchmetrics -from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar from modulation_dataset import ModulationH5Dataset from helpers.app_settings import get_app_settings +os.environ["NNPACK"] = "0" script_dir = os.path.dirname(os.path.abspath(__file__)) data_dir = os.path.abspath(os.path.join(script_dir, "..")) project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) if project_root not in sys.path: sys.path.insert(0, project_root) -from lightning.pytorch.callbacks import TQDMProgressBar class CustomProgressBar(TQDMProgressBar): @@ -59,8 +58,6 @@ def train_model(): print("X shape:", x.shape) print("Y values:", y[:10]) break - - unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata])) num_classes = len(ds_train.label_encoder.classes_) hparams = {