diff --git a/data/training/train.py b/data/training/train.py index 6c94d07..6d76019 100644 --- a/data/training/train.py +++ b/data/training/train.py @@ -118,12 +118,13 @@ def train_model(): drop_path_rate=hparams["drop_path_rate"], ) ) + checkpoint_callback = ModelCheckpoint( dirpath=checkpoint_dir, filename=checkpoint_filename, save_top_k=1, - verbose=False, # Disable checkpoint verbose output + verbose=True, # Disable checkpoint verbose output monitor="val_acc", mode="max", enable_version_counter=False, @@ -132,10 +133,10 @@ def train_model(): trainer = L.Trainer( max_epochs=epochs, callbacks=[checkpoint_callback], - accelerator="cpu", + accelerator="gpu", devices=1, benchmark=True, - precision="bf16-mixed", + precision="16-mixed", logger=False, enable_progress_bar=False, # Disable all progress bars enable_model_summary=False, # Disable model summary