fixed training model error, was a extra argument passed into the callback class

This commit is contained in:
Liyu Xiao 2025-05-23 10:37:56 -04:00
parent 065129b4ed
commit 5f0ac995c8

View File

@ -31,29 +31,24 @@ import mobilenetv3
from lightning.pytorch.callbacks import TQDMProgressBar from lightning.pytorch.callbacks import TQDMProgressBar
class EpochOnlyProgressBar(TQDMProgressBar): class EpochOnlyProgressBar(TQDMProgressBar):
def on_train_start(self, trainer, pl_module): def __init__(self):
super().on_train_start(trainer, pl_module) super().__init__()
self.enable = True
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# Disable batch progress updates def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# Skip batch-level updates
pass pass
def on_train_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module):
# Update and close progress bar at epoch end epoch = trainer.current_epoch + 1
if self.main_progress_bar is not None:
self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n)
self.main_progress_bar.close()
epoch = trainer.current_epoch + 1 # zero-based to one-based
train_loss = trainer.callback_metrics.get("train_loss") train_loss = trainer.callback_metrics.get("train_loss")
val_loss = trainer.callback_metrics.get("val_loss") val_loss = trainer.callback_metrics.get("val_loss")
val_acc = trainer.callback_metrics.get("val_acc") val_acc = trainer.callback_metrics.get("val_acc")
print(f"Epoch {epoch} Summary:") print(f"\nEpoch {epoch}:")
print(f" Train Loss: {train_loss:.4f}" if train_loss is not None else " Train Loss: N/A") if train_loss: print(f" Train Loss: {train_loss:.4f}")
print(f" Val Loss: {val_loss:.4f}" if val_loss is not None else " Val Loss: N/A") if val_loss: print(f" Val Loss: {val_loss:.4f}")
print(f" Val Acc: {val_acc:.4f}" if val_acc is not None else " Val Acc: N/A") if val_acc: print(f" Val Acc: {val_acc:.4f}")
print("-" * 30) print("-" * 30)
@ -173,7 +168,7 @@ def train_model():
enable_version_counter=False, enable_version_counter=False,
) )
progress_bar = EpochOnlyProgressBar(refresh_rate=1) progress_bar = EpochOnlyProgressBar()
trainer = L.Trainer( trainer = L.Trainer(
max_epochs=epochs, max_epochs=epochs,