diff --git a/data/training/train.py b/data/training/train.py index 0646bb3..e496494 100644 --- a/data/training/train.py +++ b/data/training/train.py @@ -17,6 +17,8 @@ if project_root not in sys.path: sys.path.insert(0, project_root) import lightning as L +from lightning.pytorch.callbacks import ModelCheckpoint + import torch import torch.nn.functional as F import torchmetrics @@ -28,21 +30,22 @@ import mobilenetv3 from lightning.pytorch.callbacks import TQDMProgressBar -from lightning.pytorch.callbacks import TQDMProgressBar - class EpochOnlyProgressBar(TQDMProgressBar): + def on_train_start(self, trainer, pl_module): + super().on_train_start(trainer, pl_module) + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - # Disable batch progress updates (no updates on each batch) + # Disable batch progress updates pass def on_train_epoch_end(self, trainer, pl_module): - # Complete and close progress bar for the epoch - self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n) - self.main_progress_bar.close() + # Update and close progress bar at epoch end + if self.main_progress_bar is not None: + self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n) + self.main_progress_bar.close() - epoch = trainer.current_epoch + 1 # epochs start at 0 internally + epoch = trainer.current_epoch + 1 # zero-based to one-based - # Fetch logged metrics safely train_loss = trainer.callback_metrics.get("train_loss") val_loss = trainer.callback_metrics.get("val_loss") val_acc = trainer.callback_metrics.get("val_acc") @@ -160,7 +163,7 @@ def train_model(): ) ) - checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint( + checkpoint_callback = ModelCheckpoint( dirpath=checkpoint_dir, filename=checkpoint_filename, save_top_k=True, @@ -180,7 +183,7 @@ def train_model(): benchmark=True, precision="bf16-mixed", logger=False, - enable_progress_bar=False, + enable_progress_bar=True, ) if train_flag: