forked from qoherent/modrec-workflow
153 lines
4.3 KiB
Python
153 lines
4.3 KiB
Python
![]() |
import sys, os
|
||
|
|
||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
|
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
||
|
project_root = os.path.abspath(os.path.join(script_dir, "../.."))
|
||
|
|
||
|
if project_root not in sys.path:
|
||
|
sys.path.insert(0, project_root)
|
||
|
|
||
|
from helpers.app_settings import get_app_settings
|
||
|
|
||
|
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
||
|
if project_root not in sys.path:
|
||
|
sys.path.insert(0, project_root)
|
||
|
|
||
|
import lightning as L
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
import torchmetrics
|
||
|
|
||
|
from helpers.app_settings import get_app_settings
|
||
|
from modulation_dataset import ModulationH5Dataset
|
||
|
|
||
|
import mobilenetv3
|
||
|
|
||
|
|
||
|
|
||
|
def train_model():
|
||
|
settings = get_app_settings()
|
||
|
dataset = settings.dataset.modulation_types
|
||
|
|
||
|
train_flag = True
|
||
|
batch_size = 128
|
||
|
epochs = 1
|
||
|
|
||
|
checkpoint_filename = f'/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model'
|
||
|
|
||
|
train_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5'
|
||
|
val_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5'
|
||
|
|
||
|
dataset_name = 'Modulation Inference - Initial Model'
|
||
|
metadata_names = 'Modulation'
|
||
|
label = 'modulation'
|
||
|
|
||
|
torch.set_float32_matmul_precision('high')
|
||
|
|
||
|
ds_train = ModulationH5Dataset(train_data, label, data_key="training_data")
|
||
|
ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data")
|
||
|
|
||
|
train_loader = torch.utils.data.DataLoader(
|
||
|
dataset=ds_train,
|
||
|
batch_size=batch_size,
|
||
|
shuffle=True,
|
||
|
num_workers=8,
|
||
|
)
|
||
|
val_loader = torch.utils.data.DataLoader(
|
||
|
dataset=ds_val,
|
||
|
batch_size=2048,
|
||
|
shuffle=False,
|
||
|
num_workers=8,
|
||
|
)
|
||
|
|
||
|
for x, y in train_loader:
|
||
|
print("X shape:", x.shape)
|
||
|
print("Y values:", y[:10])
|
||
|
break
|
||
|
|
||
|
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
|
||
|
num_classes = len(ds_train.label_encoder.classes_)
|
||
|
|
||
|
hparams = {
|
||
|
'drop_path_rate': 0.2,
|
||
|
'drop_rate': 0.5,
|
||
|
'learning_rate': 3e-4,
|
||
|
'wd': 0.2
|
||
|
}
|
||
|
|
||
|
class RFClassifier(L.LightningModule):
|
||
|
def __init__(self, model):
|
||
|
super().__init__()
|
||
|
self.model = model
|
||
|
self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.model(x)
|
||
|
|
||
|
def configure_optimizers(self):
|
||
|
optimizer = torch.optim.AdamW(
|
||
|
self.parameters(),
|
||
|
lr=hparams['learning_rate'],
|
||
|
weight_decay=hparams['wd'],
|
||
|
)
|
||
|
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||
|
optimizer,
|
||
|
T_0=len(train_loader),
|
||
|
)
|
||
|
return {
|
||
|
'optimizer': optimizer,
|
||
|
'lr_scheduler': {
|
||
|
'scheduler': lr_scheduler,
|
||
|
'interval': 'step'
|
||
|
}
|
||
|
}
|
||
|
|
||
|
def training_step(self, batch, batch_idx):
|
||
|
x, y = batch
|
||
|
y_hat = self(x)
|
||
|
loss = F.cross_entropy(y_hat, y)
|
||
|
self.log('train_loss', loss, on_epoch=True, prog_bar=True)
|
||
|
return loss
|
||
|
|
||
|
def validation_step(self, batch, batch_idx):
|
||
|
x, y = batch
|
||
|
y_hat = self(x)
|
||
|
loss = F.cross_entropy(y_hat, y)
|
||
|
self.accuracy(y_hat, y)
|
||
|
self.log('val_loss', loss, prog_bar=True)
|
||
|
self.log('val_acc', self.accuracy, prog_bar=True)
|
||
|
|
||
|
model = RFClassifier(
|
||
|
mobilenetv3.mobilenetv3(
|
||
|
model_size='mobilenetv3_small_050',
|
||
|
num_classes=num_classes,
|
||
|
drop_rate=hparams['drop_rate'],
|
||
|
drop_path_rate=hparams['drop_path_rate']
|
||
|
)
|
||
|
)
|
||
|
|
||
|
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
|
||
|
filename=checkpoint_filename,
|
||
|
save_top_k=True,
|
||
|
verbose=True,
|
||
|
monitor='val_acc',
|
||
|
mode='max',
|
||
|
enable_version_counter=False,
|
||
|
)
|
||
|
|
||
|
trainer = L.Trainer(
|
||
|
max_epochs=epochs,
|
||
|
callbacks=[checkpoint_callback],
|
||
|
accelerator='gpu',
|
||
|
devices=1,
|
||
|
benchmark=True,
|
||
|
precision='bf16-mixed',
|
||
|
logger=False
|
||
|
)
|
||
|
|
||
|
if train_flag:
|
||
|
trainer.fit(model, train_loader, val_loader)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
train_model()
|