fixed progress bar for training
Some checks failed
RIA Hub Workflow Demo / ria-demo (push) Failing after 45s
Some checks failed
RIA Hub Workflow Demo / ria-demo (push) Failing after 45s
This commit is contained in:
parent
c31a5eaf79
commit
3caf1d21f2
|
@ -26,6 +26,33 @@ from modulation_dataset import ModulationH5Dataset
|
||||||
|
|
||||||
import mobilenetv3
|
import mobilenetv3
|
||||||
|
|
||||||
|
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||||
|
|
||||||
|
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||||
|
|
||||||
|
class EpochOnlyProgressBar(TQDMProgressBar):
|
||||||
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||||
|
# Disable batch progress updates (no updates on each batch)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
|
# Complete and close progress bar for the epoch
|
||||||
|
self.main_progress_bar.update(self.main_progress_bar.total - self.main_progress_bar.n)
|
||||||
|
self.main_progress_bar.close()
|
||||||
|
|
||||||
|
epoch = trainer.current_epoch + 1 # epochs start at 0 internally
|
||||||
|
|
||||||
|
# Fetch logged metrics safely
|
||||||
|
train_loss = trainer.callback_metrics.get("train_loss")
|
||||||
|
val_loss = trainer.callback_metrics.get("val_loss")
|
||||||
|
val_acc = trainer.callback_metrics.get("val_acc")
|
||||||
|
|
||||||
|
print(f"Epoch {epoch} Summary:")
|
||||||
|
print(f" Train Loss: {train_loss:.4f}" if train_loss is not None else " Train Loss: N/A")
|
||||||
|
print(f" Val Loss: {val_loss:.4f}" if val_loss is not None else " Val Loss: N/A")
|
||||||
|
print(f" Val Acc: {val_acc:.4f}" if val_acc is not None else " Val Acc: N/A")
|
||||||
|
print("-" * 30)
|
||||||
|
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
settings = get_app_settings()
|
settings = get_app_settings()
|
||||||
|
@ -143,14 +170,17 @@ def train_model():
|
||||||
enable_version_counter=False,
|
enable_version_counter=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
progress_bar = EpochOnlyProgressBar(refresh_rate=1)
|
||||||
|
|
||||||
trainer = L.Trainer(
|
trainer = L.Trainer(
|
||||||
max_epochs=epochs,
|
max_epochs=epochs,
|
||||||
callbacks=[checkpoint_callback],
|
callbacks=[checkpoint_callback,progress_bar],
|
||||||
accelerator="gpu",
|
accelerator="gpu",
|
||||||
devices=1,
|
devices=1,
|
||||||
benchmark=True,
|
benchmark=True,
|
||||||
precision="bf16-mixed",
|
precision="bf16-mixed",
|
||||||
logger=False,
|
logger=False,
|
||||||
|
enable_progress_bar=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if train_flag:
|
if train_flag:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user