forked from qoherent/modrec-workflow
completed converting the pytorch model -> ONNX
This commit is contained in:
parent
ba796961a3
commit
3557f854e8
6
checkpoint_files/__init__.py
Normal file
6
checkpoint_files/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
import os
|
||||
|
||||
CHECKPOINTS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(CHECKPOINTS_DIR)
|
84
convert_to_onnx.py
Normal file
84
convert_to_onnx.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from data.training.mobilenetv3 import mobilenetv3, RFClassifier
|
||||
from onnx_files import ONNX_DIR
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
|
||||
|
||||
def convert_to_onnx(ckpt_path, fp16=False):
|
||||
"""
|
||||
Convert a PyTorch model to ONNX format.
|
||||
|
||||
Parameters:
|
||||
model (torch.nn.Module): The PyTorch model to convert.
|
||||
input_shape (tuple): The shape of the input tensor.
|
||||
output_path (str): The path to save the converted ONNX model.
|
||||
"""
|
||||
settings = get_app_settings()
|
||||
|
||||
inference_cfg = settings.inference
|
||||
dataset_cfg = settings.dataset
|
||||
|
||||
|
||||
in_channels = 2
|
||||
batch_size = 1
|
||||
slice_length = int(1024/dataset_cfg.num_slices)
|
||||
num_classes = inference_cfg.num_classes
|
||||
|
||||
model = RFClassifier(
|
||||
model=mobilenetv3(
|
||||
model_size="mobilenetv3_small_050",
|
||||
num_classes=num_classes,
|
||||
in_chans=in_channels,
|
||||
)
|
||||
)
|
||||
|
||||
checkpoint = torch.load(
|
||||
ckpt_path, weights_only = True, map_location=torch.device("cpu")
|
||||
)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
|
||||
if fp16:
|
||||
model.half()
|
||||
|
||||
model.eval()
|
||||
|
||||
|
||||
# Generate random sample data
|
||||
base, ext = os.path.splitext(os.path.basename(ckpt_path))
|
||||
if fp16:
|
||||
output_path = os.path.join(ONNX_DIR, f"{base}_fp16.onnx")
|
||||
sample_input = torch.from_numpy(
|
||||
np.random.rand(batch_size, in_channels, slice_length).astype(np.float16)
|
||||
)
|
||||
else:
|
||||
output_path = os.path.join(ONNX_DIR, f"{base}_fp32.onnx")
|
||||
sample_input = torch.rand(batch_size, in_channels, slice_length)
|
||||
|
||||
|
||||
torch.onnx.export(
|
||||
model=model,
|
||||
args=sample_input,
|
||||
f=output_path,
|
||||
export_params=True,
|
||||
opset_version=20, # Last compatible with ORT v1.15.1 is 18, 21 not supported by torch export
|
||||
do_constant_folding=True,
|
||||
input_names=["input"],
|
||||
output_names=["output"],
|
||||
dynamo=False, # Requires onnxscript
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from checkpoint_files import CHECKPOINTS_DIR
|
||||
|
||||
model_checkpoint = "interference_recognition_model.ckpt"
|
||||
|
||||
print("Converting to ONNX...")
|
||||
|
||||
convert_to_onnx(ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False)
|
||||
|
||||
print("Conversion complete stored at: ", os.path.join(ONNX_DIR, model_checkpoint))
|
6
onnx_files/__init__.py
Normal file
6
onnx_files/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
import os
|
||||
|
||||
ONNX_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(ONNX_DIR)
|
Loading…
Reference in New Issue
Block a user