From 3557f854e8c59fefc64775069077f601d01e923e Mon Sep 17 00:00:00 2001 From: liyuxiao2 Date: Thu, 22 May 2025 14:12:10 -0400 Subject: [PATCH] completed converting the pytorch model -> ONNX --- checkpoint_files/__init__.py | 6 +++ convert_to_onnx.py | 84 ++++++++++++++++++++++++++++++++++++ onnx_files/__init__.py | 6 +++ 3 files changed, 96 insertions(+) create mode 100644 checkpoint_files/__init__.py create mode 100644 convert_to_onnx.py create mode 100644 onnx_files/__init__.py diff --git a/checkpoint_files/__init__.py b/checkpoint_files/__init__.py new file mode 100644 index 0000000..0123ec0 --- /dev/null +++ b/checkpoint_files/__init__.py @@ -0,0 +1,6 @@ +import os + +CHECKPOINTS_DIR = os.path.dirname(os.path.abspath(__file__)) + +if __name__ == "__main__": + print(CHECKPOINTS_DIR) \ No newline at end of file diff --git a/convert_to_onnx.py b/convert_to_onnx.py new file mode 100644 index 0000000..dc6b6b9 --- /dev/null +++ b/convert_to_onnx.py @@ -0,0 +1,84 @@ +import os + +import numpy as np +import torch + +from data.training.mobilenetv3 import mobilenetv3, RFClassifier +from onnx_files import ONNX_DIR +from helpers.app_settings import get_app_settings + + + +def convert_to_onnx(ckpt_path, fp16=False): + """ + Convert a PyTorch model to ONNX format. + + Parameters: + model (torch.nn.Module): The PyTorch model to convert. + input_shape (tuple): The shape of the input tensor. + output_path (str): The path to save the converted ONNX model. + """ + settings = get_app_settings() + + inference_cfg = settings.inference + dataset_cfg = settings.dataset + + + in_channels = 2 + batch_size = 1 + slice_length = int(1024/dataset_cfg.num_slices) + num_classes = inference_cfg.num_classes + + model = RFClassifier( + model=mobilenetv3( + model_size="mobilenetv3_small_050", + num_classes=num_classes, + in_chans=in_channels, + ) + ) + + checkpoint = torch.load( + ckpt_path, weights_only = True, map_location=torch.device("cpu") + ) + model.load_state_dict(checkpoint["state_dict"]) + + if fp16: + model.half() + + model.eval() + + + # Generate random sample data + base, ext = os.path.splitext(os.path.basename(ckpt_path)) + if fp16: + output_path = os.path.join(ONNX_DIR, f"{base}_fp16.onnx") + sample_input = torch.from_numpy( + np.random.rand(batch_size, in_channels, slice_length).astype(np.float16) + ) + else: + output_path = os.path.join(ONNX_DIR, f"{base}_fp32.onnx") + sample_input = torch.rand(batch_size, in_channels, slice_length) + + + torch.onnx.export( + model=model, + args=sample_input, + f=output_path, + export_params=True, + opset_version=20, # Last compatible with ORT v1.15.1 is 18, 21 not supported by torch export + do_constant_folding=True, + input_names=["input"], + output_names=["output"], + dynamo=False, # Requires onnxscript + ) + +if __name__ == "__main__": + from checkpoint_files import CHECKPOINTS_DIR + + model_checkpoint = "interference_recognition_model.ckpt" + + print("Converting to ONNX...") + + convert_to_onnx(ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False) + + print("Conversion complete stored at: ", os.path.join(ONNX_DIR, model_checkpoint)) \ No newline at end of file diff --git a/onnx_files/__init__.py b/onnx_files/__init__.py new file mode 100644 index 0000000..0fdddc6 --- /dev/null +++ b/onnx_files/__init__.py @@ -0,0 +1,6 @@ +import os + +ONNX_DIR = os.path.dirname(os.path.abspath(__file__)) + +if __name__ == "__main__": + print(ONNX_DIR)