forked from qoherent/modrec-workflow
fixed training errors inside trhe script
This commit is contained in:
parent
3caf1d21f2
commit
065129b4ed
|
@ -17,6 +17,8 @@ if project_root not in sys.path:
|
||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
|
@ -28,21 +30,22 @@ import mobilenetv3
|
||||||
|
|
||||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||||
|
|
||||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
|
||||||
|
|
||||||
class EpochOnlyProgressBar(TQDMProgressBar):
|
class EpochOnlyProgressBar(TQDMProgressBar):
|
||||||
|
def on_train_start(self, trainer, pl_module):
|
||||||
|
super().on_train_start(trainer, pl_module)
|
||||||
|
|
||||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||||
# Disable batch progress updates (no updates on each batch)
|
# Disable batch progress updates
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
# Complete and close progress bar for the epoch
|
# Update and close progress bar at epoch end
|
||||||
self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n)
|
if self.main_progress_bar is not None:
|
||||||
self.main_progress_bar.close()
|
self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n)
|
||||||
|
self.main_progress_bar.close()
|
||||||
|
|
||||||
epoch = trainer.current_epoch + 1 # epochs start at 0 internally
|
epoch = trainer.current_epoch + 1 # zero-based to one-based
|
||||||
|
|
||||||
# Fetch logged metrics safely
|
|
||||||
train_loss = trainer.callback_metrics.get("train_loss")
|
train_loss = trainer.callback_metrics.get("train_loss")
|
||||||
val_loss = trainer.callback_metrics.get("val_loss")
|
val_loss = trainer.callback_metrics.get("val_loss")
|
||||||
val_acc = trainer.callback_metrics.get("val_acc")
|
val_acc = trainer.callback_metrics.get("val_acc")
|
||||||
|
@ -160,7 +163,7 @@ def train_model():
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
dirpath=checkpoint_dir,
|
dirpath=checkpoint_dir,
|
||||||
filename=checkpoint_filename,
|
filename=checkpoint_filename,
|
||||||
save_top_k=True,
|
save_top_k=True,
|
||||||
|
@ -180,7 +183,7 @@ def train_model():
|
||||||
benchmark=True,
|
benchmark=True,
|
||||||
precision="bf16-mixed",
|
precision="bf16-mixed",
|
||||||
logger=False,
|
logger=False,
|
||||||
enable_progress_bar=False,
|
enable_progress_bar=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if train_flag:
|
if train_flag:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user