diff --git a/convert_to_onnx.py b/convert_to_onnx.py index 4616290..ab21541 100644 --- a/convert_to_onnx.py +++ b/convert_to_onnx.py @@ -48,12 +48,12 @@ def convert_to_onnx(ckpt_path, fp16=False): # Generate random sample data base, ext = os.path.splitext(os.path.basename(ckpt_path)) if fp16: - output_path = os.path.join(ONNX_DIR, f"{base}_fp16.onnx") + output_path = os.path.join(ONNX_DIR, f"{base}.onnx") sample_input = torch.from_numpy( np.random.rand(batch_size, in_channels, slice_length).astype(np.float16) ) else: - output_path = os.path.join(ONNX_DIR, f"{base}_fp32.onnx") + output_path = os.path.join(ONNX_DIR, f"{base}.onnx") sample_input = torch.rand(batch_size, in_channels, slice_length) torch.onnx.export(