modrec-workflow/scripts/application_packager/convert_to_onnx.py
Michael Luciuk 53d0552fd4
Some checks failed
Modulation Recognition Demo / ria-demo (pull_request) Failing after 22s
Removing some unused code and some shadowing
2025-07-07 12:19:34 -04:00

82 lines
2.3 KiB
Python

import os
import numpy as np
import torch
from scripts.training.mobilenetv3 import RFClassifier, mobilenetv3
from helpers.app_settings import get_app_settings
def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
"""
Convert a PyTorch model to ONNX format.
Parameters:
ckpt_path (str): The path to save the converted ONNX model.
fp16 (bool): 16 float point precision
"""
settings = get_app_settings()
dataset_cfg = settings.dataset
in_channels = 2
batch_size = 1
slice_length = int(1024 / dataset_cfg.num_slices)
num_classes = len(dataset_cfg.modulation_types)
model = RFClassifier(
model=mobilenetv3(
model_size="mobilenetv3_small_050",
num_classes=num_classes,
in_chans=in_channels,
)
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
model.load_state_dict(checkpoint["state_dict"])
if fp16:
model.half()
model.eval()
# Generate random sample data
base, ext = os.path.splitext(os.path.basename(ckpt_path))
if fp16:
output_path = os.path.join("onnx_files", f"{base}.onnx")
sample_input = torch.from_numpy(
np.random.rand(batch_size, in_channels, slice_length).astype(np.float16)
)
else:
output_path = os.path.join("onnx_files", f"{base}.onnx")
print(output_path)
sample_input = torch.rand(batch_size, in_channels, slice_length)
torch.onnx.export(
model=model,
args=sample_input,
f=output_path,
export_params=True,
opset_version=20, # Last compatible with ORT v1.15.1 is 18, 21 not supported by torch export
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamo=False, # Requires onnxscript
)
if __name__ == "__main__":
model_checkpoint = "inference_recognition_model.ckpt"
print("Converting to ONNX...")
convert_to_onnx(
ckpt_path=os.path.join("checkpoint_files", model_checkpoint), fp16=False
)
output_file = "convert_to_onnx" + ".onnx"
print("Conversion complete stored at: ", os.path.join("onnx_files", output_file))