2025-05-22 14:12:10 -04:00
|
|
|
import os
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2025-07-08 10:50:41 -04:00
|
|
|
from scripts.model_builder.mobilenetv3 import RFClassifier, mobilenetv3
|
2025-05-22 14:12:10 -04:00
|
|
|
|
|
|
|
from helpers.app_settings import get_app_settings
|
|
|
|
|
|
|
|
|
2025-06-18 13:44:29 -04:00
|
|
|
def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
2025-05-22 14:12:10 -04:00
|
|
|
"""
|
|
|
|
Convert a PyTorch model to ONNX format.
|
|
|
|
|
|
|
|
Parameters:
|
2025-07-08 10:50:41 -04:00
|
|
|
ckpt_path (str): The path to save the converted ONNX model.
|
|
|
|
fp16 (bool): 16 float point precision
|
2025-05-22 14:12:10 -04:00
|
|
|
"""
|
|
|
|
settings = get_app_settings()
|
2025-06-17 14:16:16 -04:00
|
|
|
|
2025-05-22 14:12:10 -04:00
|
|
|
dataset_cfg = settings.dataset
|
2025-05-22 14:12:36 -04:00
|
|
|
|
2025-05-22 14:12:10 -04:00
|
|
|
in_channels = 2
|
|
|
|
batch_size = 1
|
2025-05-22 14:12:36 -04:00
|
|
|
slice_length = int(1024 / dataset_cfg.num_slices)
|
2025-06-16 13:43:59 -04:00
|
|
|
num_classes = len(dataset_cfg.modulation_types)
|
2025-05-22 14:12:36 -04:00
|
|
|
|
2025-05-22 14:12:10 -04:00
|
|
|
model = RFClassifier(
|
|
|
|
model=mobilenetv3(
|
|
|
|
model_size="mobilenetv3_small_050",
|
|
|
|
num_classes=num_classes,
|
|
|
|
in_chans=in_channels,
|
|
|
|
)
|
|
|
|
)
|
2025-06-18 13:44:29 -04:00
|
|
|
|
2025-06-17 14:16:16 -04:00
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
2025-06-18 13:44:29 -04:00
|
|
|
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
2025-05-22 14:12:10 -04:00
|
|
|
model.load_state_dict(checkpoint["state_dict"])
|
2025-05-22 14:12:36 -04:00
|
|
|
|
2025-05-22 14:12:10 -04:00
|
|
|
if fp16:
|
|
|
|
model.half()
|
2025-05-22 14:12:36 -04:00
|
|
|
|
2025-05-22 14:12:10 -04:00
|
|
|
model.eval()
|
2025-05-22 14:12:36 -04:00
|
|
|
|
2025-05-22 14:12:10 -04:00
|
|
|
# Generate random sample data
|
|
|
|
base, ext = os.path.splitext(os.path.basename(ckpt_path))
|
|
|
|
if fp16:
|
2025-06-16 10:17:17 -04:00
|
|
|
output_path = os.path.join("onnx_files", f"{base}.onnx")
|
2025-05-22 14:12:10 -04:00
|
|
|
sample_input = torch.from_numpy(
|
|
|
|
np.random.rand(batch_size, in_channels, slice_length).astype(np.float16)
|
|
|
|
)
|
|
|
|
else:
|
2025-06-16 10:17:17 -04:00
|
|
|
output_path = os.path.join("onnx_files", f"{base}.onnx")
|
2025-06-16 10:13:55 -04:00
|
|
|
print(output_path)
|
2025-05-22 14:12:10 -04:00
|
|
|
sample_input = torch.rand(batch_size, in_channels, slice_length)
|
2025-05-22 14:12:36 -04:00
|
|
|
|
2025-05-22 14:12:10 -04:00
|
|
|
torch.onnx.export(
|
|
|
|
model=model,
|
|
|
|
args=sample_input,
|
|
|
|
f=output_path,
|
|
|
|
export_params=True,
|
|
|
|
opset_version=20, # Last compatible with ORT v1.15.1 is 18, 21 not supported by torch export
|
|
|
|
do_constant_folding=True,
|
|
|
|
input_names=["input"],
|
|
|
|
output_names=["output"],
|
|
|
|
dynamo=False, # Requires onnxscript
|
|
|
|
)
|
2025-05-22 14:12:36 -04:00
|
|
|
|
|
|
|
|
2025-05-22 14:12:10 -04:00
|
|
|
if __name__ == "__main__":
|
|
|
|
|
2025-06-16 09:29:08 -04:00
|
|
|
model_checkpoint = "inference_recognition_model.ckpt"
|
2025-05-22 14:12:36 -04:00
|
|
|
|
2025-05-22 14:12:10 -04:00
|
|
|
print("Converting to ONNX...")
|
|
|
|
|
2025-05-22 14:12:36 -04:00
|
|
|
convert_to_onnx(
|
2025-06-16 09:29:08 -04:00
|
|
|
ckpt_path=os.path.join("checkpoint_files", model_checkpoint), fp16=False
|
2025-05-22 14:12:36 -04:00
|
|
|
)
|
2025-05-26 09:43:42 -04:00
|
|
|
|
2025-06-16 10:00:19 -04:00
|
|
|
output_file = "convert_to_onnx" + ".onnx"
|
2025-05-26 09:43:42 -04:00
|
|
|
|
2025-06-16 10:17:17 -04:00
|
|
|
print("Conversion complete stored at: ", os.path.join("onnx_files", output_file))
|