forked from qoherent/modrec-workflow
fixed progress bar for training
This commit is contained in:
parent
c31a5eaf79
commit
3caf1d21f2
|
@ -26,6 +26,33 @@ from modulation_dataset import ModulationH5Dataset
|
|||
|
||||
import mobilenetv3
|
||||
|
||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||
|
||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||
|
||||
class EpochOnlyProgressBar(TQDMProgressBar):
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
# Disable batch progress updates (no updates on each batch)
|
||||
pass
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
# Complete and close progress bar for the epoch
|
||||
self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n)
|
||||
self.main_progress_bar.close()
|
||||
|
||||
epoch = trainer.current_epoch + 1 # epochs start at 0 internally
|
||||
|
||||
# Fetch logged metrics safely
|
||||
train_loss = trainer.callback_metrics.get("train_loss")
|
||||
val_loss = trainer.callback_metrics.get("val_loss")
|
||||
val_acc = trainer.callback_metrics.get("val_acc")
|
||||
|
||||
print(f"Epoch {epoch} Summary:")
|
||||
print(f" Train Loss: {train_loss:.4f}" if train_loss is not None else " Train Loss: N/A")
|
||||
print(f" Val Loss: {val_loss:.4f}" if val_loss is not None else " Val Loss: N/A")
|
||||
print(f" Val Acc: {val_acc:.4f}" if val_acc is not None else " Val Acc: N/A")
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def train_model():
|
||||
settings = get_app_settings()
|
||||
|
@ -142,15 +169,18 @@ def train_model():
|
|||
mode="max",
|
||||
enable_version_counter=False,
|
||||
)
|
||||
|
||||
|
||||
progress_bar = EpochOnlyProgressBar(refresh_rate=1)
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=epochs,
|
||||
callbacks=[checkpoint_callback],
|
||||
callbacks=[checkpoint_callback,progress_bar],
|
||||
accelerator="gpu",
|
||||
devices=1,
|
||||
benchmark=True,
|
||||
precision="bf16-mixed",
|
||||
logger=False,
|
||||
enable_progress_bar=False,
|
||||
)
|
||||
|
||||
if train_flag:
|
||||
|
|
Loading…
Reference in New Issue
Block a user