modrec-workflow/data/training/train.py
Liyu Xiao 5f0ac995c8
Some checks failed
RIA Hub Workflow Demo / ria-demo (push) Has been cancelled
fixed training model error, was a extra argument passed into the callback class
2025-05-23 10:37:56 -04:00

190 lines
5.3 KiB
Python

import sys, os
os.environ["NNPACK"] = "0"
script_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
project_root = os.path.abspath(os.path.join(script_dir, "../.."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from helpers.app_settings import get_app_settings
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
import torch
import torch.nn.functional as F
import torchmetrics
from helpers.app_settings import get_app_settings
from modulation_dataset import ModulationH5Dataset
import mobilenetv3
from lightning.pytorch.callbacks import TQDMProgressBar
class EpochOnlyProgressBar(TQDMProgressBar):
def __init__(self):
super().__init__()
self.enable = True
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# Skip batch-level updates
pass
def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch + 1
train_loss = trainer.callback_metrics.get("train_loss")
val_loss = trainer.callback_metrics.get("val_loss")
val_acc = trainer.callback_metrics.get("val_acc")
print(f"\nEpoch {epoch}:")
if train_loss: print(f" Train Loss: {train_loss:.4f}")
if val_loss: print(f" Val Loss: {val_loss:.4f}")
if val_acc: print(f" Val Acc: {val_acc:.4f}")
print("-" * 30)
def train_model():
settings = get_app_settings()
training_cfg = settings.training
dataset_cfg = settings.dataset
train_flag = True
batch_size = 128
epochs = 50
checkpoint_dir = training_cfg.checkpoint_dir
checkpoint_filename = training_cfg.checkpoint_filename
train_data = (
f"{dataset_cfg.output_dir}/train.h5"
)
val_data = (
f"{dataset_cfg.output_dir}/val.h5"
)
dataset_name = "Modulation Inference - Initial Model"
metadata_names = "Modulation"
label = "modulation"
torch.set_float32_matmul_precision("high")
ds_train = ModulationH5Dataset(train_data, label, data_key="training_data")
ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data")
train_loader = torch.utils.data.DataLoader(
dataset=ds_train,
batch_size=batch_size,
shuffle=True,
num_workers=8,
)
val_loader = torch.utils.data.DataLoader(
dataset=ds_val,
batch_size=2048,
shuffle=False,
num_workers=8,
)
for x, y in train_loader:
print("X shape:", x.shape)
print("Y values:", y[:10])
break
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
num_classes = len(ds_train.label_encoder.classes_)
hparams = {
"drop_path_rate": 0.2,
"drop_rate": 0.5,
"learning_rate": 1e-4,
"wd": 0.01,
}
class RFClassifier(L.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
self.accuracy = torchmetrics.Accuracy(
task="multiclass", num_classes=num_classes
)
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=hparams["learning_rate"],
weight_decay=hparams["wd"],
)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=len(train_loader),
)
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
}
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.accuracy(y_hat, y)
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", self.accuracy, prog_bar=True)
model = RFClassifier(
mobilenetv3.mobilenetv3(
model_size="mobilenetv3_small_050",
num_classes=num_classes,
drop_rate=hparams["drop_rate"],
drop_path_rate=hparams["drop_path_rate"],
)
)
checkpoint_callback = ModelCheckpoint(
dirpath=checkpoint_dir,
filename=checkpoint_filename,
save_top_k=True,
verbose=True,
monitor="val_acc",
mode="max",
enable_version_counter=False,
)
progress_bar = EpochOnlyProgressBar()
trainer = L.Trainer(
max_epochs=epochs,
callbacks=[checkpoint_callback,progress_bar],
accelerator="gpu",
devices=1,
benchmark=True,
precision="bf16-mixed",
logger=False,
enable_progress_bar=True,
)
if train_flag:
trainer.fit(model, train_loader, val_loader)
if __name__ == "__main__":
train_model()