This commit is contained in:
parent
6742b0770b
commit
ce20909ebe
|
@ -59,7 +59,7 @@ jobs:
|
||||||
- name: Upload Dataset Artifacts
|
- name: Upload Dataset Artifacts
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: ria-dataset
|
name: dataset
|
||||||
path: data/dataset/**
|
path: data/dataset/**
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ jobs:
|
||||||
- name: Upload Checkpoints
|
- name: Upload Checkpoints
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: ria-checkpoints
|
name: checkpoints
|
||||||
path: checkpoint_files/inference_recognition_model.ckpt
|
path: checkpoint_files/inference_recognition_model.ckpt
|
||||||
|
|
||||||
- name: 4. Convert to ONNX file
|
- name: 4. Convert to ONNX file
|
||||||
|
|
|
@ -3,8 +3,7 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from data.training.mobilenetv3 import mobilenetv3, RFClassifier
|
from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier
|
||||||
from onnx_files import ONNX_DIR
|
|
||||||
from helpers.app_settings import get_app_settings
|
from helpers.app_settings import get_app_settings
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,12 +47,12 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
||||||
# Generate random sample data
|
# Generate random sample data
|
||||||
base, ext = os.path.splitext(os.path.basename(ckpt_path))
|
base, ext = os.path.splitext(os.path.basename(ckpt_path))
|
||||||
if fp16:
|
if fp16:
|
||||||
output_path = os.path.join(ONNX_DIR, f"{base}.onnx")
|
output_path = os.path.join("onnx_scripts", f"{base}.onnx")
|
||||||
sample_input = torch.from_numpy(
|
sample_input = torch.from_numpy(
|
||||||
np.random.rand(batch_size, in_channels, slice_length).astype(np.float16)
|
np.random.rand(batch_size, in_channels, slice_length).astype(np.float16)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_path = os.path.join(ONNX_DIR, f"{base}.onnx")
|
output_path = os.path.join("onnx_scripts", f"{base}.onnx")
|
||||||
sample_input = torch.rand(batch_size, in_channels, slice_length)
|
sample_input = torch.rand(batch_size, in_channels, slice_length)
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
|
@ -70,18 +69,17 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from checkpoint_files import CHECKPOINTS_DIR
|
|
||||||
|
|
||||||
settings = get_app_settings()
|
settings = get_app_settings()
|
||||||
|
|
||||||
model_checkpoint = settings.training.checkpoint_filename + ".ckpt"
|
model_checkpoint = "inference_recognition_model.ckpt"
|
||||||
|
|
||||||
print("Converting to ONNX...")
|
print("Converting to ONNX...")
|
||||||
|
|
||||||
convert_to_onnx(
|
convert_to_onnx(
|
||||||
ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False
|
ckpt_path=os.path.join("checkpoint_files", model_checkpoint), fp16=False
|
||||||
)
|
)
|
||||||
|
|
||||||
output_file = settings.inference.onnx_model_filename + ".onnx"
|
output_file = "convert_to_onnx.py" + ".onnx"
|
||||||
|
|
||||||
print("Conversion complete stored at: ", os.path.join(ONNX_DIR, output_file))
|
print("Conversion complete stored at: ", os.path.join("onnx_scripts", output_file))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user