fixed training errors inside trhe script

This commit is contained in:
Liyu Xiao 2025-05-23 10:31:34 -04:00
parent 3caf1d21f2
commit 065129b4ed

View File

@ -17,6 +17,8 @@ if project_root not in sys.path:
sys.path.insert(0, project_root)
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
import torch
import torch.nn.functional as F
import torchmetrics
@ -28,21 +30,22 @@ import mobilenetv3
from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.callbacks import TQDMProgressBar
class EpochOnlyProgressBar(TQDMProgressBar):
def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# Disable batch progress updates (no updates on each batch)
# Disable batch progress updates
pass
def on_train_epoch_end(self, trainer, pl_module):
# Complete and close progress bar for the epoch
self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n)
self.main_progress_bar.close()
# Update and close progress bar at epoch end
if self.main_progress_bar is not None:
self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n)
self.main_progress_bar.close()
epoch = trainer.current_epoch + 1 # epochs start at 0 internally
epoch = trainer.current_epoch + 1 # zero-based to one-based
# Fetch logged metrics safely
train_loss = trainer.callback_metrics.get("train_loss")
val_loss = trainer.callback_metrics.get("val_loss")
val_acc = trainer.callback_metrics.get("val_acc")
@ -160,7 +163,7 @@ def train_model():
)
)
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
checkpoint_callback = ModelCheckpoint(
dirpath=checkpoint_dir,
filename=checkpoint_filename,
save_top_k=True,
@ -180,7 +183,7 @@ def train_model():
benchmark=True,
precision="bf16-mixed",
logger=False,
enable_progress_bar=False,
enable_progress_bar=True,
)
if train_flag: