diff --git a/.riahub/workflows/workflow.yaml b/.riahub/workflows/workflow.yaml index d9f7c27..6fd12f9 100644 --- a/.riahub/workflows/workflow.yaml +++ b/.riahub/workflows/workflow.yaml @@ -76,7 +76,7 @@ jobs: uses: actions/upload-artifact@v3 with: name: checkpoints - path: checkpoint_files/inference_recognition_model.ckpt + path: checkpoint_files/* - name: 4. Convert to ONNX file run: | diff --git a/scripts/training/train.py b/scripts/training/train.py index 2107b87..e6194ce 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -1,7 +1,7 @@ import sys, os os.environ["NNPACK"] = "0" import lightning as L -from lightning.pytorch.callbacks import ModelCheckpoint, Callback +from lightning.pytorch.callbacks import ModelCheckpoint import torch import torch.nn.functional as F import torchmetrics @@ -23,8 +23,6 @@ class CustomProgressBar(TQDMProgressBar): def train_model(): settings = get_app_settings() training_cfg = settings.training - dataset_cfg = settings.dataset - train_flag = True batch_size = training_cfg.batch_size epochs = training_cfg.epochs