updated so that the progress bar updates in place

This commit is contained in:
Liyu Xiao 2025-05-26 14:57:34 -04:00
parent 1e7fdab60c
commit 2a4ecd175a

View File

@ -13,8 +13,11 @@ data_dir = os.path.abspath(os.path.join(script_dir, ".."))
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from lightning.pytorch.callbacks import TQDMProgressBar
class CustomProgressBar(TQDMProgressBar):
def __init__(self):
super().__init__(refresh_rate=1) # update every batch
def train_model():
@ -132,7 +135,7 @@ def train_model():
trainer = L.Trainer(
max_epochs=epochs,
callbacks=[checkpoint_callback],
callbacks=[checkpoint_callback, CustomProgressBar()],
accelerator="gpu",
devices=1,
benchmark=True,