diff --git a/data/training/train.py b/data/training/train.py index e496494..ac790a8 100644 --- a/data/training/train.py +++ b/data/training/train.py @@ -31,29 +31,24 @@ import mobilenetv3 from lightning.pytorch.callbacks import TQDMProgressBar class EpochOnlyProgressBar(TQDMProgressBar): - def on_train_start(self, trainer, pl_module): - super().on_train_start(trainer, pl_module) - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - # Disable batch progress updates + def __init__(self): + super().__init__() + self.enable = True + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + # Skip batch-level updates pass - + def on_train_epoch_end(self, trainer, pl_module): - # Update and close progress bar at epoch end - if self.main_progress_bar is not None: - self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n) - self.main_progress_bar.close() - - epoch = trainer.current_epoch + 1 # zero-based to one-based - + epoch = trainer.current_epoch + 1 train_loss = trainer.callback_metrics.get("train_loss") - val_loss = trainer.callback_metrics.get("val_loss") + val_loss = trainer.callback_metrics.get("val_loss") val_acc = trainer.callback_metrics.get("val_acc") - - print(f"Epoch {epoch} Summary:") - print(f" Train Loss: {train_loss:.4f}" if train_loss is not None else " Train Loss: N/A") - print(f" Val Loss: {val_loss:.4f}" if val_loss is not None else " Val Loss: N/A") - print(f" Val Acc: {val_acc:.4f}" if val_acc is not None else " Val Acc: N/A") + + print(f"\nEpoch {epoch}:") + if train_loss: print(f" Train Loss: {train_loss:.4f}") + if val_loss: print(f" Val Loss: {val_loss:.4f}") + if val_acc: print(f" Val Acc: {val_acc:.4f}") print("-" * 30) @@ -173,7 +168,7 @@ def train_model(): enable_version_counter=False, ) - progress_bar = EpochOnlyProgressBar(refresh_rate=1) + progress_bar = EpochOnlyProgressBar() trainer = L.Trainer( max_epochs=epochs,