fixed the training bar, so it only shows final epoch results.
All checks were successful
RIA Hub Workflow Demo / ria-demo (push) Successful in 37m46s

This commit is contained in:
Liyu Xiao 2025-05-23 10:47:27 -04:00
parent 5f0ac995c8
commit 7e255a704d

View File

@ -2,7 +2,6 @@ import sys, os
os.environ["NNPACK"] = "0" os.environ["NNPACK"] = "0"
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.abspath(os.path.join(script_dir, "..")) data_dir = os.path.abspath(os.path.join(script_dir, ".."))
project_root = os.path.abspath(os.path.join(script_dir, "../..")) project_root = os.path.abspath(os.path.join(script_dir, "../.."))
@ -17,7 +16,7 @@ if project_root not in sys.path:
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
import lightning as L import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.callbacks import ModelCheckpoint, Callback
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -28,27 +27,25 @@ from modulation_dataset import ModulationH5Dataset
import mobilenetv3 import mobilenetv3
from lightning.pytorch.callbacks import TQDMProgressBar class CleanProgressCallback(Callback):
"""Clean progress callback that only shows epoch summaries"""
class EpochOnlyProgressBar(TQDMProgressBar):
def __init__(self):
super().__init__()
self.enable = True
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# Skip batch-level updates
pass
def on_train_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch + 1 epoch = trainer.current_epoch + 1
# Get metrics
train_loss = trainer.callback_metrics.get("train_loss") train_loss = trainer.callback_metrics.get("train_loss")
val_loss = trainer.callback_metrics.get("val_loss") val_loss = trainer.callback_metrics.get("val_loss")
val_acc = trainer.callback_metrics.get("val_acc") val_acc = trainer.callback_metrics.get("val_acc")
print(f"\nEpoch {epoch}:") # Print clean output
if train_loss: print(f" Train Loss: {train_loss:.4f}") print(f"Epoch {epoch}:")
if val_loss: print(f" Val Loss: {val_loss:.4f}") if train_loss is not None:
if val_acc: print(f" Val Acc: {val_acc:.4f}") print(f" Train Loss: {train_loss:.4f}")
if val_loss is not None:
print(f" Val Loss: {val_loss:.4f}")
if val_acc is not None:
print(f" Val Acc: {val_acc:.4f}")
print("-" * 30) print("-" * 30)
@ -64,12 +61,8 @@ def train_model():
checkpoint_dir = training_cfg.checkpoint_dir checkpoint_dir = training_cfg.checkpoint_dir
checkpoint_filename = training_cfg.checkpoint_filename checkpoint_filename = training_cfg.checkpoint_filename
train_data = ( train_data = f"{dataset_cfg.output_dir}/train.h5"
f"{dataset_cfg.output_dir}/train.h5" val_data = f"{dataset_cfg.output_dir}/val.h5"
)
val_data = (
f"{dataset_cfg.output_dir}/val.h5"
)
dataset_name = "Modulation Inference - Initial Model" dataset_name = "Modulation Inference - Initial Model"
metadata_names = "Modulation" metadata_names = "Modulation"
@ -138,7 +131,7 @@ def train_model():
x, y = batch x, y = batch
y_hat = self(x) y_hat = self(x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
self.log("train_loss", loss, on_epoch=True, prog_bar=True) self.log("train_loss", loss, on_epoch=True, prog_bar=False)
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
@ -146,8 +139,8 @@ def train_model():
y_hat = self(x) y_hat = self(x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
self.accuracy(y_hat, y) self.accuracy(y_hat, y)
self.log("val_loss", loss, prog_bar=True) self.log("val_loss", loss, prog_bar=False)
self.log("val_acc", self.accuracy, prog_bar=True) self.log("val_acc", self.accuracy, prog_bar=False)
model = RFClassifier( model = RFClassifier(
mobilenetv3.mobilenetv3( mobilenetv3.mobilenetv3(
@ -161,24 +154,25 @@ def train_model():
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
dirpath=checkpoint_dir, dirpath=checkpoint_dir,
filename=checkpoint_filename, filename=checkpoint_filename,
save_top_k=True, save_top_k=1,
verbose=True, verbose=False, # Disable checkpoint verbose output
monitor="val_acc", monitor="val_acc",
mode="max", mode="max",
enable_version_counter=False, enable_version_counter=False,
) )
progress_bar = EpochOnlyProgressBar() clean_progress = CleanProgressCallback()
trainer = L.Trainer( trainer = L.Trainer(
max_epochs=epochs, max_epochs=epochs,
callbacks=[checkpoint_callback,progress_bar], callbacks=[checkpoint_callback, clean_progress],
accelerator="gpu", accelerator="gpu",
devices=1, devices=1,
benchmark=True, benchmark=True,
precision="bf16-mixed", precision="bf16-mixed",
logger=False, logger=False,
enable_progress_bar=True, enable_progress_bar=False, # Disable all progress bars
enable_model_summary=False, # Disable model summary
) )
if train_flag: if train_flag:
@ -186,4 +180,4 @@ def train_model():
if __name__ == "__main__": if __name__ == "__main__":
train_model() train_model()