forked from qoherent/modrec-workflow
updated so that the progress bar updates in place
This commit is contained in:
parent
1e7fdab60c
commit
2a4ecd175a
|
@ -13,8 +13,11 @@ data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
|||
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||
|
||||
|
||||
class CustomProgressBar(TQDMProgressBar):
|
||||
def __init__(self):
|
||||
super().__init__(refresh_rate=1) # update every batch
|
||||
|
||||
|
||||
def train_model():
|
||||
|
@ -132,7 +135,7 @@ def train_model():
|
|||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=epochs,
|
||||
callbacks=[checkpoint_callback],
|
||||
callbacks=[checkpoint_callback, CustomProgressBar()],
|
||||
accelerator="gpu",
|
||||
devices=1,
|
||||
benchmark=True,
|
||||
|
|
Loading…
Reference in New Issue
Block a user