forked from qoherent/modrec-workflow
fixed training model error, was a extra argument passed into the callback class
This commit is contained in:
parent
065129b4ed
commit
5f0ac995c8
|
@ -31,29 +31,24 @@ import mobilenetv3
|
||||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||||
|
|
||||||
class EpochOnlyProgressBar(TQDMProgressBar):
|
class EpochOnlyProgressBar(TQDMProgressBar):
|
||||||
def on_train_start(self, trainer, pl_module):
|
def __init__(self):
|
||||||
super().on_train_start(trainer, pl_module)
|
super().__init__()
|
||||||
|
self.enable = True
|
||||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
|
||||||
# Disable batch progress updates
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||||
|
# Skip batch-level updates
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
# Update and close progress bar at epoch end
|
epoch = trainer.current_epoch + 1
|
||||||
if self.main_progress_bar is not None:
|
|
||||||
self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n)
|
|
||||||
self.main_progress_bar.close()
|
|
||||||
|
|
||||||
epoch = trainer.current_epoch + 1 # zero-based to one-based
|
|
||||||
|
|
||||||
train_loss = trainer.callback_metrics.get("train_loss")
|
train_loss = trainer.callback_metrics.get("train_loss")
|
||||||
val_loss = trainer.callback_metrics.get("val_loss")
|
val_loss = trainer.callback_metrics.get("val_loss")
|
||||||
val_acc = trainer.callback_metrics.get("val_acc")
|
val_acc = trainer.callback_metrics.get("val_acc")
|
||||||
|
|
||||||
print(f"Epoch {epoch} Summary:")
|
print(f"\nEpoch {epoch}:")
|
||||||
print(f" Train Loss: {train_loss:.4f}" if train_loss is not None else " Train Loss: N/A")
|
if train_loss: print(f" Train Loss: {train_loss:.4f}")
|
||||||
print(f" Val Loss: {val_loss:.4f}" if val_loss is not None else " Val Loss: N/A")
|
if val_loss: print(f" Val Loss: {val_loss:.4f}")
|
||||||
print(f" Val Acc: {val_acc:.4f}" if val_acc is not None else " Val Acc: N/A")
|
if val_acc: print(f" Val Acc: {val_acc:.4f}")
|
||||||
print("-" * 30)
|
print("-" * 30)
|
||||||
|
|
||||||
|
|
||||||
|
@ -173,7 +168,7 @@ def train_model():
|
||||||
enable_version_counter=False,
|
enable_version_counter=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
progress_bar = EpochOnlyProgressBar(refresh_rate=1)
|
progress_bar = EpochOnlyProgressBar()
|
||||||
|
|
||||||
trainer = L.Trainer(
|
trainer = L.Trainer(
|
||||||
max_epochs=epochs,
|
max_epochs=epochs,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user