forked from qoherent/modrec-workflow
154 lines
4.6 KiB
Python
154 lines
4.6 KiB
Python
import sys, os
|
|
os.environ["NNPACK"] = "0"
|
|
import lightning as L
|
|
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torchmetrics
|
|
from helpers.app_settings import get_app_settings
|
|
from modulation_dataset import ModulationH5Dataset
|
|
import mobilenetv3
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
|
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
|
if project_root not in sys.path:
|
|
sys.path.insert(0, project_root)
|
|
from lightning.pytorch.callbacks import TQDMProgressBar
|
|
|
|
class CustomProgressBar(TQDMProgressBar):
|
|
def __init__(self):
|
|
super().__init__(refresh_rate=128) # update every batch
|
|
|
|
|
|
def train_model():
|
|
settings = get_app_settings()
|
|
training_cfg = settings.training
|
|
dataset_cfg = settings.dataset
|
|
|
|
train_flag = True
|
|
batch_size = training_cfg.batch_size
|
|
epochs = training_cfg.epochs
|
|
|
|
checkpoint_dir = training_cfg.checkpoint_dir
|
|
checkpoint_filename = training_cfg.checkpoint_filename
|
|
|
|
train_data = f"{dataset_cfg.output_dir}/train.h5"
|
|
val_data = f"{dataset_cfg.output_dir}/val.h5"
|
|
|
|
dataset_name = "Modulation Inference - Initial Model"
|
|
metadata_names = "Modulation"
|
|
label = "modulation"
|
|
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
ds_train = ModulationH5Dataset(train_data, label, data_key="training_data")
|
|
ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data")
|
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
dataset=ds_train,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=8,
|
|
)
|
|
val_loader = torch.utils.data.DataLoader(
|
|
dataset=ds_val,
|
|
batch_size=2048,
|
|
shuffle=False,
|
|
num_workers=8,
|
|
)
|
|
|
|
for x, y in train_loader:
|
|
print("X shape:", x.shape)
|
|
print("Y values:", y[:10])
|
|
break
|
|
|
|
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
|
|
num_classes = len(ds_train.label_encoder.classes_)
|
|
|
|
hparams = {
|
|
"drop_path_rate": 0.2,
|
|
"drop_rate": 0.5,
|
|
"learning_rate": float(training_cfg.learning_rate),
|
|
"wd": 0.01,
|
|
}
|
|
|
|
class RFClassifier(L.LightningModule):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.model = model
|
|
self.accuracy = torchmetrics.Accuracy(
|
|
task="multiclass", num_classes=num_classes
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = torch.optim.AdamW(
|
|
self.parameters(),
|
|
lr=hparams["learning_rate"],
|
|
weight_decay=hparams["wd"],
|
|
)
|
|
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
|
optimizer,
|
|
T_0=len(train_loader),
|
|
)
|
|
return {
|
|
"optimizer": optimizer,
|
|
"lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
|
|
}
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self(x)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
self.log("train_loss", loss, on_epoch=True, prog_bar=False)
|
|
return loss
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self(x)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
self.accuracy(y_hat, y)
|
|
self.log("val_loss", loss, prog_bar=False)
|
|
self.log("val_acc", self.accuracy, prog_bar=False)
|
|
|
|
model = RFClassifier(
|
|
mobilenetv3.mobilenetv3(
|
|
model_size="mobilenetv3_small_050",
|
|
num_classes=num_classes,
|
|
drop_rate=hparams["drop_rate"],
|
|
drop_path_rate=hparams["drop_path_rate"],
|
|
)
|
|
)
|
|
|
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
dirpath=checkpoint_dir,
|
|
filename=checkpoint_filename,
|
|
save_top_k=1,
|
|
verbose=True, # Disable checkpoint verbose output
|
|
monitor="val_acc",
|
|
mode="max",
|
|
enable_version_counter=False,
|
|
)
|
|
|
|
trainer = L.Trainer(
|
|
max_epochs=epochs,
|
|
callbacks=[checkpoint_callback, CustomProgressBar()],
|
|
accelerator="gpu",
|
|
devices=1,
|
|
benchmark=True,
|
|
precision="16-mixed",
|
|
logger=False,
|
|
enable_progress_bar=True, # Disable all progress bars
|
|
enable_model_summary=False, # Disable model summary
|
|
)
|
|
|
|
if train_flag:
|
|
trainer.fit(model, train_loader, val_loader)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train_model()
|