import os import numpy as np import torch from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier from helpers.app_settings import get_app_settings def convert_to_onnx(ckpt_path, fp16=False): """ Convert a PyTorch model to ONNX format. Parameters: model (torch.nn.Module): The PyTorch model to convert. input_shape (tuple): The shape of the input tensor. output_path (str): The path to save the converted ONNX model. """ settings = get_app_settings() inference_cfg = settings.inference dataset_cfg = settings.dataset in_channels = 2 batch_size = 1 slice_length = int(1024 / dataset_cfg.num_slices) num_classes = inference_cfg.num_classes model = RFClassifier( model=mobilenetv3( model_size="mobilenetv3_small_050", num_classes=num_classes, in_chans=in_channels, ) ) checkpoint = torch.load( ckpt_path, weights_only=True, map_location=torch.device("cpu") ) model.load_state_dict(checkpoint["state_dict"]) if fp16: model.half() model.eval() # Generate random sample data base, ext = os.path.splitext(os.path.basename(ckpt_path)) if fp16: output_path = os.path.join("onnx_scripts", f"{base}.onnx") sample_input = torch.from_numpy( np.random.rand(batch_size, in_channels, slice_length).astype(np.float16) ) else: output_path = os.path.join("onnx_scripts", f"{base}.onnx") sample_input = torch.rand(batch_size, in_channels, slice_length) torch.onnx.export( model=model, args=sample_input, f=output_path, export_params=True, opset_version=20, # Last compatible with ORT v1.15.1 is 18, 21 not supported by torch export do_constant_folding=True, input_names=["input"], output_names=["output"], dynamo=False, # Requires onnxscript ) if __name__ == "__main__": settings = get_app_settings() model_checkpoint = "inference_recognition_model.ckpt" print("Converting to ONNX...") convert_to_onnx( ckpt_path=os.path.join("checkpoint_files", model_checkpoint), fp16=False ) output_file = "convert_to_onnx.py" + ".onnx" print("Conversion complete stored at: ", os.path.join("onnx_scripts", output_file))