forked from qoherent/modrec-workflow
Added Training bar
This commit is contained in:
parent
a092b92174
commit
92a0ed11e4
|
@ -15,26 +15,6 @@ if project_root not in sys.path:
|
||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
|
||||||
class CleanProgressCallback(Callback):
|
|
||||||
"""Clean progress callback that only shows epoch summaries"""
|
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
|
||||||
epoch = trainer.current_epoch + 1
|
|
||||||
|
|
||||||
# Get metrics
|
|
||||||
train_loss = trainer.callback_metrics.get("train_loss")
|
|
||||||
val_loss = trainer.callback_metrics.get("val_loss")
|
|
||||||
val_acc = trainer.callback_metrics.get("val_acc")
|
|
||||||
|
|
||||||
# Print clean output
|
|
||||||
print(f"Epoch {epoch}:")
|
|
||||||
if train_loss is not None:
|
|
||||||
print(f" Train Loss: {train_loss:.4f}")
|
|
||||||
if val_loss is not None:
|
|
||||||
print(f" Val Loss: {val_loss:.4f}")
|
|
||||||
if val_acc is not None:
|
|
||||||
print(f" Val Acc: {val_acc:.4f}")
|
|
||||||
print("-" * 30)
|
|
||||||
|
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
|
@ -149,12 +129,10 @@ def train_model():
|
||||||
enable_version_counter=False,
|
enable_version_counter=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
clean_progress = CleanProgressCallback()
|
|
||||||
|
|
||||||
trainer = L.Trainer(
|
trainer = L.Trainer(
|
||||||
max_epochs=epochs,
|
max_epochs=epochs,
|
||||||
callbacks=[checkpoint_callback, clean_progress],
|
callbacks=[checkpoint_callback],
|
||||||
accelerator="gpu",
|
accelerator="cpu",
|
||||||
devices=1,
|
devices=1,
|
||||||
benchmark=True,
|
benchmark=True,
|
||||||
precision="bf16-mixed",
|
precision="bf16-mixed",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user