import sys, os os.environ["NNPACK"] = "0" import lightning as L from lightning.pytorch.callbacks import ModelCheckpoint, Callback import torch import torch.nn.functional as F import torchmetrics from helpers.app_settings import get_app_settings from modulation_dataset import ModulationH5Dataset import mobilenetv3 script_dir = os.path.dirname(os.path.abspath(__file__)) data_dir = os.path.abspath(os.path.join(script_dir, "..")) project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) if project_root not in sys.path: sys.path.insert(0, project_root) from lightning.pytorch.callbacks import TQDMProgressBar class CustomProgressBar(TQDMProgressBar): def __init__(self): super().__init__(refresh_rate=128) # update every batch def train_model(): settings = get_app_settings() training_cfg = settings.training dataset_cfg = settings.dataset train_flag = True batch_size = training_cfg.batch_size epochs = training_cfg.epochs checkpoint_dir = training_cfg.checkpoint_dir checkpoint_filename = training_cfg.checkpoint_filename train_data = f"{dataset_cfg.output_dir}/train.h5" val_data = f"{dataset_cfg.output_dir}/val.h5" label = "modulation" torch.set_float32_matmul_precision("high") ds_train = ModulationH5Dataset(train_data, label, data_key="training_data") ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data") train_loader = torch.utils.data.DataLoader( dataset=ds_train, batch_size=batch_size, shuffle=True, num_workers=8, ) val_loader = torch.utils.data.DataLoader( dataset=ds_val, batch_size=2048, shuffle=False, num_workers=8, ) for x, y in train_loader: print("X shape:", x.shape) print("Y values:", y[:10]) break unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata])) num_classes = len(ds_train.label_encoder.classes_) hparams = { "drop_path_rate": 0.2, "drop_rate": 0.5, "learning_rate": float(training_cfg.learning_rate), "wd": 0.01, } class RFClassifier(L.LightningModule): def __init__(self, model): super().__init__() self.model = model self.accuracy = torchmetrics.Accuracy( task="multiclass", num_classes=num_classes ) def forward(self, x): return self.model(x) def configure_optimizers(self): optimizer = torch.optim.AdamW( self.parameters(), lr=hparams["learning_rate"], weight_decay=hparams["wd"], ) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=len(train_loader), ) return { "optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"}, } def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.log("train_loss", loss, on_epoch=True, prog_bar=False) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.accuracy(y_hat, y) self.log("val_loss", loss, prog_bar=False) self.log("val_acc", self.accuracy, prog_bar=False) model = RFClassifier( mobilenetv3.mobilenetv3( model_size="mobilenetv3_small_050", num_classes=num_classes, drop_rate=hparams["drop_rate"], drop_path_rate=hparams["drop_path_rate"], ) ) checkpoint_callback = ModelCheckpoint( dirpath=checkpoint_dir, filename=checkpoint_filename, save_top_k=1, verbose=True, # Disable checkpoint verbose output monitor="val_acc", mode="max", enable_version_counter=False, ) trainer = L.Trainer( max_epochs=epochs, callbacks=[checkpoint_callback, CustomProgressBar()], accelerator="gpu", devices=1, benchmark=True, precision="16-mixed", logger=False, enable_progress_bar=True, # Disable all progress bars enable_model_summary=False, # Disable model summary ) if train_flag: trainer.fit(model, train_loader, val_loader) if __name__ == "__main__": train_model()