diff --git a/data/training/train.py b/data/training/train.py index eaa718a..0646bb3 100644 --- a/data/training/train.py +++ b/data/training/train.py @@ -26,6 +26,33 @@ from modulation_dataset import ModulationH5Dataset import mobilenetv3 +from lightning.pytorch.callbacks import TQDMProgressBar + +from lightning.pytorch.callbacks import TQDMProgressBar + +class EpochOnlyProgressBar(TQDMProgressBar): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + # Disable batch progress updates (no updates on each batch) + pass + + def on_train_epoch_end(self, trainer, pl_module): + # Complete and close progress bar for the epoch + self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n) + self.main_progress_bar.close() + + epoch = trainer.current_epoch + 1 # epochs start at 0 internally + + # Fetch logged metrics safely + train_loss = trainer.callback_metrics.get("train_loss") + val_loss = trainer.callback_metrics.get("val_loss") + val_acc = trainer.callback_metrics.get("val_acc") + + print(f"Epoch {epoch} Summary:") + print(f" Train Loss: {train_loss:.4f}" if train_loss is not None else " Train Loss: N/A") + print(f" Val Loss: {val_loss:.4f}" if val_loss is not None else " Val Loss: N/A") + print(f" Val Acc: {val_acc:.4f}" if val_acc is not None else " Val Acc: N/A") + print("-" * 30) + def train_model(): settings = get_app_settings() @@ -142,15 +169,18 @@ def train_model(): mode="max", enable_version_counter=False, ) - + + progress_bar = EpochOnlyProgressBar(refresh_rate=1) + trainer = L.Trainer( max_epochs=epochs, - callbacks=[checkpoint_callback], + callbacks=[checkpoint_callback,progress_bar], accelerator="gpu", devices=1, benchmark=True, precision="bf16-mixed", logger=False, + enable_progress_bar=False, ) if train_flag: