import sys, os script_dir = os.path.dirname(os.path.abspath(__file__)) data_dir = os.path.abspath(os.path.join(script_dir, "..")) project_root = os.path.abspath(os.path.join(script_dir, "../..")) if project_root not in sys.path: sys.path.insert(0, project_root) from helpers.app_settings import get_app_settings project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) if project_root not in sys.path: sys.path.insert(0, project_root) import lightning as L import torch import torch.nn.functional as F import torchmetrics from helpers.app_settings import get_app_settings from modulation_dataset import ModulationH5Dataset import mobilenetv3 def train_model(): settings = get_app_settings() dataset = settings.dataset.modulation_types train_flag = True batch_size = 128 epochs = 1 checkpoint_filename = f'/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model' train_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5' val_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5' dataset_name = 'Modulation Inference - Initial Model' metadata_names = 'Modulation' label = 'modulation' torch.set_float32_matmul_precision('high') ds_train = ModulationH5Dataset(train_data, label, data_key="training_data") ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data") train_loader = torch.utils.data.DataLoader( dataset=ds_train, batch_size=batch_size, shuffle=True, num_workers=8, ) val_loader = torch.utils.data.DataLoader( dataset=ds_val, batch_size=2048, shuffle=False, num_workers=8, ) for x, y in train_loader: print("X shape:", x.shape) print("Y values:", y[:10]) break unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata])) num_classes = len(ds_train.label_encoder.classes_) hparams = { 'drop_path_rate': 0.2, 'drop_rate': 0.5, 'learning_rate': 3e-4, 'wd': 0.2 } class RFClassifier(L.LightningModule): def __init__(self, model): super().__init__() self.model = model self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes) def forward(self, x): return self.model(x) def configure_optimizers(self): optimizer = torch.optim.AdamW( self.parameters(), lr=hparams['learning_rate'], weight_decay=hparams['wd'], ) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=len(train_loader), ) return { 'optimizer': optimizer, 'lr_scheduler': { 'scheduler': lr_scheduler, 'interval': 'step' } } def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.log('train_loss', loss, on_epoch=True, prog_bar=True) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.accuracy(y_hat, y) self.log('val_loss', loss, prog_bar=True) self.log('val_acc', self.accuracy, prog_bar=True) model = RFClassifier( mobilenetv3.mobilenetv3( model_size='mobilenetv3_small_050', num_classes=num_classes, drop_rate=hparams['drop_rate'], drop_path_rate=hparams['drop_path_rate'] ) ) checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint( filename=checkpoint_filename, save_top_k=True, verbose=True, monitor='val_acc', mode='max', enable_version_counter=False, ) trainer = L.Trainer( max_epochs=epochs, callbacks=[checkpoint_callback], accelerator='gpu', devices=1, benchmark=True, precision='bf16-mixed', logger=False ) if train_flag: trainer.fit(model, train_loader, val_loader) if __name__ == '__main__': train_model()