fixed training errors inside trhe script
Some checks failed
RIA Hub Workflow Demo / ria-demo (push) Failing after 52s
Some checks failed
RIA Hub Workflow Demo / ria-demo (push) Failing after 52s
This commit is contained in:
parent
3caf1d21f2
commit
065129b4ed
|
@ -17,6 +17,8 @@ if project_root not in sys.path:
|
|||
sys.path.insert(0, project_root)
|
||||
|
||||
import lightning as L
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchmetrics
|
||||
|
@ -28,21 +30,22 @@ import mobilenetv3
|
|||
|
||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||
|
||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||
|
||||
class EpochOnlyProgressBar(TQDMProgressBar):
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
super().on_train_start(trainer, pl_module)
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
# Disable batch progress updates (no updates on each batch)
|
||||
# Disable batch progress updates
|
||||
pass
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
# Complete and close progress bar for the epoch
|
||||
self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n)
|
||||
self.main_progress_bar.close()
|
||||
# Update and close progress bar at epoch end
|
||||
if self.main_progress_bar is not None:
|
||||
self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n)
|
||||
self.main_progress_bar.close()
|
||||
|
||||
epoch = trainer.current_epoch + 1 # epochs start at 0 internally
|
||||
epoch = trainer.current_epoch + 1 # zero-based to one-based
|
||||
|
||||
# Fetch logged metrics safely
|
||||
train_loss = trainer.callback_metrics.get("train_loss")
|
||||
val_loss = trainer.callback_metrics.get("val_loss")
|
||||
val_acc = trainer.callback_metrics.get("val_acc")
|
||||
|
@ -160,7 +163,7 @@ def train_model():
|
|||
)
|
||||
)
|
||||
|
||||
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=checkpoint_dir,
|
||||
filename=checkpoint_filename,
|
||||
save_top_k=True,
|
||||
|
@ -180,7 +183,7 @@ def train_model():
|
|||
benchmark=True,
|
||||
precision="bf16-mixed",
|
||||
logger=False,
|
||||
enable_progress_bar=False,
|
||||
enable_progress_bar=True,
|
||||
)
|
||||
|
||||
if train_flag:
|
||||
|
|
Loading…
Reference in New Issue
Block a user