diff --git a/data/training/train.py b/data/training/train.py index 1d15968..ad7015b 100644 --- a/data/training/train.py +++ b/data/training/train.py @@ -13,8 +13,11 @@ data_dir = os.path.abspath(os.path.join(script_dir, "..")) project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) if project_root not in sys.path: sys.path.insert(0, project_root) +from lightning.pytorch.callbacks import TQDMProgressBar - +class CustomProgressBar(TQDMProgressBar): + def __init__(self): + super().__init__(refresh_rate=1) # update every batch def train_model(): @@ -132,7 +135,7 @@ def train_model(): trainer = L.Trainer( max_epochs=epochs, - callbacks=[checkpoint_callback], + callbacks=[checkpoint_callback, CustomProgressBar()], accelerator="gpu", devices=1, benchmark=True,