modrec-workflow/data/training/train.py

154 lines
4.6 KiB
Python

import sys, os
os.environ["NNPACK"] = "0"
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
import torch
import torch.nn.functional as F
import torchmetrics
from helpers.app_settings import get_app_settings
from modulation_dataset import ModulationH5Dataset
import mobilenetv3
script_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from lightning.pytorch.callbacks import TQDMProgressBar
class CustomProgressBar(TQDMProgressBar):
def __init__(self):
super().__init__(refresh_rate=128) # update every batch
def train_model():
settings = get_app_settings()
training_cfg = settings.training
dataset_cfg = settings.dataset
train_flag = True
batch_size = training_cfg.batch_size
epochs = training_cfg.epochs
checkpoint_dir = training_cfg.checkpoint_dir
checkpoint_filename = training_cfg.checkpoint_filename
train_data = f"{dataset_cfg.output_dir}/train.h5"
val_data = f"{dataset_cfg.output_dir}/val.h5"
dataset_name = "Modulation Inference - Initial Model"
metadata_names = "Modulation"
label = "modulation"
torch.set_float32_matmul_precision("high")
ds_train = ModulationH5Dataset(train_data, label, data_key="training_data")
ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data")
train_loader = torch.utils.data.DataLoader(
dataset=ds_train,
batch_size=batch_size,
shuffle=True,
num_workers=8,
)
val_loader = torch.utils.data.DataLoader(
dataset=ds_val,
batch_size=2048,
shuffle=False,
num_workers=8,
)
for x, y in train_loader:
print("X shape:", x.shape)
print("Y values:", y[:10])
break
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
num_classes = len(ds_train.label_encoder.classes_)
hparams = {
"drop_path_rate": 0.2,
"drop_rate": 0.5,
"learning_rate": float(training_cfg.learning_rate),
"wd": 0.01,
}
class RFClassifier(L.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
self.accuracy = torchmetrics.Accuracy(
task="multiclass", num_classes=num_classes
)
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=hparams["learning_rate"],
weight_decay=hparams["wd"],
)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=len(train_loader),
)
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
}
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("train_loss", loss, on_epoch=True, prog_bar=False)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.accuracy(y_hat, y)
self.log("val_loss", loss, prog_bar=False)
self.log("val_acc", self.accuracy, prog_bar=False)
model = RFClassifier(
mobilenetv3.mobilenetv3(
model_size="mobilenetv3_small_050",
num_classes=num_classes,
drop_rate=hparams["drop_rate"],
drop_path_rate=hparams["drop_path_rate"],
)
)
checkpoint_callback = ModelCheckpoint(
dirpath=checkpoint_dir,
filename=checkpoint_filename,
save_top_k=1,
verbose=True, # Disable checkpoint verbose output
monitor="val_acc",
mode="max",
enable_version_counter=False,
)
trainer = L.Trainer(
max_epochs=epochs,
callbacks=[checkpoint_callback, CustomProgressBar()],
accelerator="gpu",
devices=1,
benchmark=True,
precision="16-mixed",
logger=False,
enable_progress_bar=True, # Disable all progress bars
enable_model_summary=False, # Disable model summary
)
if train_flag:
trainer.fit(model, train_loader, val_loader)
if __name__ == "__main__":
train_model()