diff --git a/.riahub/workflows/workflow.yaml b/.riahub/workflows/workflow.yaml index 4c71765..8cb64a2 100644 --- a/.riahub/workflows/workflow.yaml +++ b/.riahub/workflows/workflow.yaml @@ -46,7 +46,6 @@ jobs: - name: 2. Train Model run: | - mkdir -p data/dataset PYTHONPATH=. python data/training/train.py echo "training model" diff --git a/conf/app.yaml b/conf/app.yaml index 9010e0c..3c5cb2c 100644 --- a/conf/app.yaml +++ b/conf/app.yaml @@ -15,7 +15,8 @@ training: batch_size: 64 epochs: 50 learning_rate: 0.001 - checkpoint_path: checkpoints/inference_recognition_model.ckpt + checkpoint_dir: checkpoints + checkpoint_filename: inference_recognition_model use_gpu: true inference: diff --git a/data/training/train.py b/data/training/train.py index 9184cd4..9943aba 100644 --- a/data/training/train.py +++ b/data/training/train.py @@ -33,7 +33,8 @@ def train_model(): batch_size = 128 epochs = 1 - checkpoint_filename = f"{training_cfg.checkpoint_path}" + checkpoint_dir = training_cfg.checkpoint_dir + checkpoint_filename = training_cfg.checkpoint_filename train_data = ( f"{dataset_cfg.output_dir}/train.h5" @@ -130,6 +131,7 @@ def train_model(): ) checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint( + dirpath=checkpoint_dir, filename=checkpoint_filename, save_top_k=True, verbose=True, diff --git a/helpers/app_settings.py b/helpers/app_settings.py index 4b593b0..ee52c47 100644 --- a/helpers/app_settings.py +++ b/helpers/app_settings.py @@ -26,7 +26,8 @@ class TrainingConfig: batch_size: int epochs: int learning_rate: float - checkpoint_path: str + checkpoint_dir: str + checkpoint_filename: str use_gpu: bool