fixed progress bar for training
Some checks failed
RIA Hub Workflow Demo / ria-demo (push) Failing after 45s

This commit is contained in:
Liyu Xiao 2025-05-23 10:19:27 -04:00
parent c31a5eaf79
commit 3caf1d21f2

View File

@ -26,6 +26,33 @@ from modulation_dataset import ModulationH5Dataset
import mobilenetv3 import mobilenetv3
from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.callbacks import TQDMProgressBar
class EpochOnlyProgressBar(TQDMProgressBar):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# Disable batch progress updates (no updates on each batch)
pass
def on_train_epoch_end(self, trainer, pl_module):
# Complete and close progress bar for the epoch
self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n)
self.main_progress_bar.close()
epoch = trainer.current_epoch + 1 # epochs start at 0 internally
# Fetch logged metrics safely
train_loss = trainer.callback_metrics.get("train_loss")
val_loss = trainer.callback_metrics.get("val_loss")
val_acc = trainer.callback_metrics.get("val_acc")
print(f"Epoch {epoch} Summary:")
print(f" Train Loss: {train_loss:.4f}" if train_loss is not None else " Train Loss: N/A")
print(f" Val Loss: {val_loss:.4f}" if val_loss is not None else " Val Loss: N/A")
print(f" Val Acc: {val_acc:.4f}" if val_acc is not None else " Val Acc: N/A")
print("-" * 30)
def train_model(): def train_model():
settings = get_app_settings() settings = get_app_settings()
@ -143,14 +170,17 @@ def train_model():
enable_version_counter=False, enable_version_counter=False,
) )
progress_bar = EpochOnlyProgressBar(refresh_rate=1)
trainer = L.Trainer( trainer = L.Trainer(
max_epochs=epochs, max_epochs=epochs,
callbacks=[checkpoint_callback], callbacks=[checkpoint_callback,progress_bar],
accelerator="gpu", accelerator="gpu",
devices=1, devices=1,
benchmark=True, benchmark=True,
precision="bf16-mixed", precision="bf16-mixed",
logger=False, logger=False,
enable_progress_bar=False,
) )
if train_flag: if train_flag: