modrec-workflow/data/training/train.py
2025-05-22 14:11:18 -04:00

153 lines
4.3 KiB
Python

import sys, os
script_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
project_root = os.path.abspath(os.path.join(script_dir, "../.."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from helpers.app_settings import get_app_settings
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
import lightning as L
import torch
import torch.nn.functional as F
import torchmetrics
from helpers.app_settings import get_app_settings
from modulation_dataset import ModulationH5Dataset
import mobilenetv3
def train_model():
settings = get_app_settings()
dataset = settings.dataset.modulation_types
train_flag = True
batch_size = 128
epochs = 1
checkpoint_filename = f'/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model'
train_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5'
val_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5'
dataset_name = 'Modulation Inference - Initial Model'
metadata_names = 'Modulation'
label = 'modulation'
torch.set_float32_matmul_precision('high')
ds_train = ModulationH5Dataset(train_data, label, data_key="training_data")
ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data")
train_loader = torch.utils.data.DataLoader(
dataset=ds_train,
batch_size=batch_size,
shuffle=True,
num_workers=8,
)
val_loader = torch.utils.data.DataLoader(
dataset=ds_val,
batch_size=2048,
shuffle=False,
num_workers=8,
)
for x, y in train_loader:
print("X shape:", x.shape)
print("Y values:", y[:10])
break
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
num_classes = len(ds_train.label_encoder.classes_)
hparams = {
'drop_path_rate': 0.2,
'drop_rate': 0.5,
'learning_rate': 3e-4,
'wd': 0.2
}
class RFClassifier(L.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=hparams['learning_rate'],
weight_decay=hparams['wd'],
)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=len(train_loader),
)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': lr_scheduler,
'interval': 'step'
}
}
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.accuracy(y_hat, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', self.accuracy, prog_bar=True)
model = RFClassifier(
mobilenetv3.mobilenetv3(
model_size='mobilenetv3_small_050',
num_classes=num_classes,
drop_rate=hparams['drop_rate'],
drop_path_rate=hparams['drop_path_rate']
)
)
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
filename=checkpoint_filename,
save_top_k=True,
verbose=True,
monitor='val_acc',
mode='max',
enable_version_counter=False,
)
trainer = L.Trainer(
max_epochs=epochs,
callbacks=[checkpoint_callback],
accelerator='gpu',
devices=1,
benchmark=True,
precision='bf16-mixed',
logger=False
)
if train_flag:
trainer.fit(model, train_loader, val_loader)
if __name__ == '__main__':
train_model()