diff --git a/checkpoint_files/__init__.py b/checkpoint_files/__init__.py index 0123ec0..17de6b4 100644 --- a/checkpoint_files/__init__.py +++ b/checkpoint_files/__init__.py @@ -3,4 +3,4 @@ import os CHECKPOINTS_DIR = os.path.dirname(os.path.abspath(__file__)) if __name__ == "__main__": - print(CHECKPOINTS_DIR) \ No newline at end of file + print(CHECKPOINTS_DIR) diff --git a/convert_to_onnx.py b/convert_to_onnx.py index dc6b6b9..97d13c4 100644 --- a/convert_to_onnx.py +++ b/convert_to_onnx.py @@ -8,7 +8,6 @@ from onnx_files import ONNX_DIR from helpers.app_settings import get_app_settings - def convert_to_onnx(ckpt_path, fp16=False): """ Convert a PyTorch model to ONNX format. @@ -19,16 +18,15 @@ def convert_to_onnx(ckpt_path, fp16=False): output_path (str): The path to save the converted ONNX model. """ settings = get_app_settings() - + inference_cfg = settings.inference dataset_cfg = settings.dataset - - + in_channels = 2 batch_size = 1 - slice_length = int(1024/dataset_cfg.num_slices) + slice_length = int(1024 / dataset_cfg.num_slices) num_classes = inference_cfg.num_classes - + model = RFClassifier( model=mobilenetv3( model_size="mobilenetv3_small_050", @@ -36,18 +34,17 @@ def convert_to_onnx(ckpt_path, fp16=False): in_chans=in_channels, ) ) - + checkpoint = torch.load( - ckpt_path, weights_only = True, map_location=torch.device("cpu") + ckpt_path, weights_only=True, map_location=torch.device("cpu") ) model.load_state_dict(checkpoint["state_dict"]) - + if fp16: model.half() - + model.eval() - - + # Generate random sample data base, ext = os.path.splitext(os.path.basename(ckpt_path)) if fp16: @@ -58,8 +55,7 @@ def convert_to_onnx(ckpt_path, fp16=False): else: output_path = os.path.join(ONNX_DIR, f"{base}_fp32.onnx") sample_input = torch.rand(batch_size, in_channels, slice_length) - - + torch.onnx.export( model=model, args=sample_input, @@ -71,14 +67,17 @@ def convert_to_onnx(ckpt_path, fp16=False): output_names=["output"], dynamo=False, # Requires onnxscript ) - + + if __name__ == "__main__": from checkpoint_files import CHECKPOINTS_DIR model_checkpoint = "interference_recognition_model.ckpt" - + print("Converting to ONNX...") - convert_to_onnx(ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False) - - print("Conversion complete stored at: ", os.path.join(ONNX_DIR, model_checkpoint)) \ No newline at end of file + convert_to_onnx( + ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False + ) + + print("Conversion complete stored at: ", os.path.join(ONNX_DIR, model_checkpoint)) diff --git a/data/scripts/produce_dataset.py b/data/scripts/produce_dataset.py index e64daef..17ca878 100644 --- a/data/scripts/produce_dataset.py +++ b/data/scripts/produce_dataset.py @@ -23,7 +23,6 @@ info_dtype = np.dtype( ) - def write_hdf5_file(records, output_path, dataset_name="data"): """ Writes a list of records to an HDF5 file. @@ -77,11 +76,12 @@ def write_hdf5_file(records, output_path, dataset_name="data"): return output_path + def complex_to_channel(data): """ Convert complex-valued IQ data of shape (1, N) to 2-channel real array of shape (2, N). """ - assert np.iscomplexobj(data) #check if the data is in the form a+bi + assert np.iscomplexobj(data) # check if the data is in the form a+bi real = np.real(data[0]) # (N,) imag = np.imag(data[0]) # (N,) stacked = np.stack([real, imag], axis=0) # shape (2, N) @@ -114,21 +114,17 @@ def generate_datasets(cfg): for fname in files: rec = from_npy(os.path.join(cfg.input_dir, fname)) - data = rec.data #here data is a numpy array with the shape (1, N) - + data = rec.data # here data is a numpy array with the shape (1, N) + data = complex_to_channel(data) # convert to 2-channel real array - - + md = rec.metadata # pull metadata from the recording md.setdefault("recid", len(records)) records.append((data, md)) # split each recording into snippets each - - + records = split_recording(records, cfg.num_slices) - - train_records, val_records = split(records, cfg.train_split, cfg.seed) @@ -147,6 +143,6 @@ def main(): train_path, val_path = generate_datasets(dataset_cfg) print(f"āœ… Train: {train_path}\nāœ… Val: {val_path}") - + if __name__ == "__main__": main() diff --git a/data/scripts/split_dataset.py b/data/scripts/split_dataset.py index 061d7d5..a35cdff 100644 --- a/data/scripts/split_dataset.py +++ b/data/scripts/split_dataset.py @@ -2,7 +2,7 @@ import random from collections import defaultdict -def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"): +def split(dataset, train_frac=0.8, seed=42, label_key="modulation"): """ Splits a dataset into smaller datasets based on the specified lengths. @@ -16,31 +16,29 @@ def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"): rec_buckets = defaultdict(list) for data, md in dataset: rec_buckets[md["recid"]].append((data, md)) - - - rec_labels = {} #store labels for each recording + + rec_labels = {} # store labels for each recording for rec_id, group in rec_buckets.items(): label = group[0][1][label_key] - if isinstance(label, bytes): #if the label is a byte string + if isinstance(label, bytes): # if the label is a byte string label = label.decode("utf-8") rec_labels[rec_id] = label - - label_rec_ids = defaultdict(list) #group rec_ids by label + + label_rec_ids = defaultdict(list) # group rec_ids by label for rec_id, label in rec_labels.items(): label_rec_ids[label].append(rec_id) - + random.seed(seed) train_recs, val_recs = set(), set() - + for label, rec_ids in label_rec_ids.items(): random.shuffle(rec_ids) split_idx = int(len(rec_ids) * train_frac) - train_recs.update(rec_ids[:split_idx]) #pulls train_frac or rec_ids per label, guarantees all modulations are represented + train_recs.update( + rec_ids[:split_idx] + ) # pulls train_frac or rec_ids per label, guarantees all modulations are represented val_recs.update(rec_ids[split_idx:]) - - - - + # add the assigned recordings to the train and val datasets train_dataset, val_dataset = [], [] for rec_id, group in rec_buckets.items(): @@ -48,10 +46,10 @@ def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"): train_dataset.extend(group) elif rec_id in val_recs: val_dataset.extend(group) - - + return train_dataset, val_dataset + def split_recording(recording_list, num_snippets): """ Splits a list of recordings into smaller chunks. @@ -71,7 +69,7 @@ def split_recording(recording_list, num_snippets): start = i * L end = (i + 1) * L snippet = data[:, start:end] - + # copy the metadata, adding a snippet index snippet_md = md.copy() snippet_md["snippet_idx"] = i diff --git a/data/training/mobilenetv3.py b/data/training/mobilenetv3.py index 754296a..83556bd 100644 --- a/data/training/mobilenetv3.py +++ b/data/training/mobilenetv3.py @@ -197,4 +197,4 @@ class RFClassifier(L.LightningModule): self.model = model def forward(self, x): - return self.model(x) \ No newline at end of file + return self.model(x) diff --git a/data/training/train.py b/data/training/train.py index 3b63a81..2dd2284 100644 --- a/data/training/train.py +++ b/data/training/train.py @@ -12,7 +12,7 @@ from helpers.app_settings import get_app_settings project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) if project_root not in sys.path: sys.path.insert(0, project_root) - + import lightning as L import torch import torch.nn.functional as F @@ -24,7 +24,6 @@ from modulation_dataset import ModulationH5Dataset import mobilenetv3 - def train_model(): settings = get_app_settings() dataset = settings.dataset.modulation_types @@ -33,16 +32,20 @@ def train_model(): batch_size = 128 epochs = 1 - checkpoint_filename = f'/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model' + checkpoint_filename = f"/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model" - train_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5' - val_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5' + train_data = ( + "/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5" + ) + val_data = ( + "/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5" + ) - dataset_name = 'Modulation Inference - Initial Model' - metadata_names = 'Modulation' - label = 'modulation' + dataset_name = "Modulation Inference - Initial Model" + metadata_names = "Modulation" + label = "modulation" - torch.set_float32_matmul_precision('high') + torch.set_float32_matmul_precision("high") ds_train = ModulationH5Dataset(train_data, label, data_key="training_data") ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data") @@ -69,17 +72,19 @@ def train_model(): num_classes = len(ds_train.label_encoder.classes_) hparams = { - 'drop_path_rate': 0.2, - 'drop_rate': 0.5, - 'learning_rate': 3e-4, - 'wd': 0.2 + "drop_path_rate": 0.2, + "drop_rate": 0.5, + "learning_rate": 3e-4, + "wd": 0.2, } class RFClassifier(L.LightningModule): def __init__(self, model): super().__init__() self.model = model - self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes) + self.accuracy = torchmetrics.Accuracy( + task="multiclass", num_classes=num_classes + ) def forward(self, x): return self.model(x) @@ -87,26 +92,23 @@ def train_model(): def configure_optimizers(self): optimizer = torch.optim.AdamW( self.parameters(), - lr=hparams['learning_rate'], - weight_decay=hparams['wd'], + lr=hparams["learning_rate"], + weight_decay=hparams["wd"], ) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=len(train_loader), ) return { - 'optimizer': optimizer, - 'lr_scheduler': { - 'scheduler': lr_scheduler, - 'interval': 'step' - } + "optimizer": optimizer, + "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"}, } def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) - self.log('train_loss', loss, on_epoch=True, prog_bar=True) + self.log("train_loss", loss, on_epoch=True, prog_bar=True) return loss def validation_step(self, batch, batch_idx): @@ -114,15 +116,15 @@ def train_model(): y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.accuracy(y_hat, y) - self.log('val_loss', loss, prog_bar=True) - self.log('val_acc', self.accuracy, prog_bar=True) + self.log("val_loss", loss, prog_bar=True) + self.log("val_acc", self.accuracy, prog_bar=True) model = RFClassifier( mobilenetv3.mobilenetv3( - model_size='mobilenetv3_small_050', + model_size="mobilenetv3_small_050", num_classes=num_classes, - drop_rate=hparams['drop_rate'], - drop_path_rate=hparams['drop_path_rate'] + drop_rate=hparams["drop_rate"], + drop_path_rate=hparams["drop_path_rate"], ) ) @@ -130,24 +132,24 @@ def train_model(): filename=checkpoint_filename, save_top_k=True, verbose=True, - monitor='val_acc', - mode='max', + monitor="val_acc", + mode="max", enable_version_counter=False, ) trainer = L.Trainer( max_epochs=epochs, callbacks=[checkpoint_callback], - accelerator='gpu', + accelerator="gpu", devices=1, benchmark=True, - precision='bf16-mixed', - logger=False + precision="bf16-mixed", + logger=False, ) if train_flag: trainer.fit(model, train_loader, val_loader) -if __name__ == '__main__': - train_model() \ No newline at end of file +if __name__ == "__main__": + train_model() diff --git a/helpers/app_settings.py b/helpers/app_settings.py index 2b2c861..4b593b0 100644 --- a/helpers/app_settings.py +++ b/helpers/app_settings.py @@ -4,10 +4,12 @@ from functools import lru_cache import yaml + @dataclass class GeneralConfig: - run_mode: str - + run_mode: str + + @dataclass class DataSetConfig: input_dir: str @@ -17,7 +19,8 @@ class DataSetConfig: modulation_types: list val_split: float output_dir: str - + + @dataclass class TrainingConfig: batch_size: int @@ -26,16 +29,19 @@ class TrainingConfig: checkpoint_path: str use_gpu: bool + @dataclass class InferenceConfig: model_path: str num_classes: int output_path: str + @dataclass class AppConfig: build_dir: str - + + class AppSettings: """Application settings, to be initialized from app.yaml configuration file.""" @@ -51,6 +57,7 @@ class AppSettings: self.inference = InferenceConfig(**config_data["inference"]) self.app = AppConfig(**config_data["app"]) + @lru_cache def get_app_settings() -> AppSettings: """Return application configuration settings.""" @@ -60,4 +67,4 @@ def get_app_settings() -> AppSettings: if __name__ == "__main__": - s = get_app_settings() \ No newline at end of file + s = get_app_settings()