forked from qoherent/modrec-workflow
fixed config issue
This commit is contained in:
parent
85d6afd976
commit
c2a71605f8
|
@ -46,7 +46,6 @@ jobs:
|
|||
|
||||
- name: 2. Train Model
|
||||
run: |
|
||||
mkdir -p data/dataset
|
||||
PYTHONPATH=. python data/training/train.py
|
||||
echo "training model"
|
||||
|
||||
|
|
|
@ -15,7 +15,8 @@ training:
|
|||
batch_size: 64
|
||||
epochs: 50
|
||||
learning_rate: 0.001
|
||||
checkpoint_path: checkpoints/inference_recognition_model.ckpt
|
||||
checkpoint_dir: checkpoints
|
||||
checkpoint_filename: inference_recognition_model
|
||||
use_gpu: true
|
||||
|
||||
inference:
|
||||
|
|
|
@ -33,7 +33,8 @@ def train_model():
|
|||
batch_size = 128
|
||||
epochs = 1
|
||||
|
||||
checkpoint_filename = f"{training_cfg.checkpoint_path}"
|
||||
checkpoint_dir = training_cfg.checkpoint_dir
|
||||
checkpoint_filename = training_cfg.checkpoint_filename
|
||||
|
||||
train_data = (
|
||||
f"{dataset_cfg.output_dir}/train.h5"
|
||||
|
@ -130,6 +131,7 @@ def train_model():
|
|||
)
|
||||
|
||||
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
|
||||
dirpath=checkpoint_dir,
|
||||
filename=checkpoint_filename,
|
||||
save_top_k=True,
|
||||
verbose=True,
|
||||
|
|
|
@ -26,7 +26,8 @@ class TrainingConfig:
|
|||
batch_size: int
|
||||
epochs: int
|
||||
learning_rate: float
|
||||
checkpoint_path: str
|
||||
checkpoint_dir: str
|
||||
checkpoint_filename: str
|
||||
use_gpu: bool
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user