diff --git a/.riahub/workflows/workflow.yaml b/.riahub/workflows/workflow.yaml index 623c344..d9f7c27 100644 --- a/.riahub/workflows/workflow.yaml +++ b/.riahub/workflows/workflow.yaml @@ -59,7 +59,7 @@ jobs: - name: Upload Dataset Artifacts uses: actions/upload-artifact@v3 with: - name: ria-dataset + name: dataset path: data/dataset/** @@ -75,7 +75,7 @@ jobs: - name: Upload Checkpoints uses: actions/upload-artifact@v3 with: - name: ria-checkpoints + name: checkpoints path: checkpoint_files/inference_recognition_model.ckpt - name: 4. Convert to ONNX file diff --git a/onnx_scripts/convert_to_onnx.py b/onnx_scripts/convert_to_onnx.py index 4dfa4f5..0458b4f 100644 --- a/onnx_scripts/convert_to_onnx.py +++ b/onnx_scripts/convert_to_onnx.py @@ -3,8 +3,7 @@ import os import numpy as np import torch -from data.training.mobilenetv3 import mobilenetv3, RFClassifier -from onnx_files import ONNX_DIR +from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier from helpers.app_settings import get_app_settings @@ -48,12 +47,12 @@ def convert_to_onnx(ckpt_path, fp16=False): # Generate random sample data base, ext = os.path.splitext(os.path.basename(ckpt_path)) if fp16: - output_path = os.path.join(ONNX_DIR, f"{base}.onnx") + output_path = os.path.join("onnx_scripts", f"{base}.onnx") sample_input = torch.from_numpy( np.random.rand(batch_size, in_channels, slice_length).astype(np.float16) ) else: - output_path = os.path.join(ONNX_DIR, f"{base}.onnx") + output_path = os.path.join("onnx_scripts", f"{base}.onnx") sample_input = torch.rand(batch_size, in_channels, slice_length) torch.onnx.export( @@ -70,18 +69,17 @@ def convert_to_onnx(ckpt_path, fp16=False): if __name__ == "__main__": - from checkpoint_files import CHECKPOINTS_DIR settings = get_app_settings() - model_checkpoint = settings.training.checkpoint_filename + ".ckpt" + model_checkpoint = "inference_recognition_model.ckpt" print("Converting to ONNX...") convert_to_onnx( - ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False + ckpt_path=os.path.join("checkpoint_files", model_checkpoint), fp16=False ) - output_file = settings.inference.onnx_model_filename + ".onnx" + output_file = "convert_to_onnx.py" + ".onnx" - print("Conversion complete stored at: ", os.path.join(ONNX_DIR, output_file)) + print("Conversion complete stored at: ", os.path.join("onnx_scripts", output_file))