forked from qoherent/modrec-workflow
updated so that the progress bar updates in place
This commit is contained in:
parent
1e7fdab60c
commit
2a4ecd175a
|
@ -13,8 +13,11 @@ data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
||||||
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
||||||
if project_root not in sys.path:
|
if project_root not in sys.path:
|
||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
|
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||||
|
|
||||||
|
class CustomProgressBar(TQDMProgressBar):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(refresh_rate=1) # update every batch
|
||||||
|
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
|
@ -132,7 +135,7 @@ def train_model():
|
||||||
|
|
||||||
trainer = L.Trainer(
|
trainer = L.Trainer(
|
||||||
max_epochs=epochs,
|
max_epochs=epochs,
|
||||||
callbacks=[checkpoint_callback],
|
callbacks=[checkpoint_callback, CustomProgressBar()],
|
||||||
accelerator="gpu",
|
accelerator="gpu",
|
||||||
devices=1,
|
devices=1,
|
||||||
benchmark=True,
|
benchmark=True,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user