removed fp16, fp32 from the file names

This commit is contained in:
Liyu Xiao 2025-05-22 15:08:31 -04:00
parent c6ad74b5df
commit 123cb82334

View File

@ -48,12 +48,12 @@ def convert_to_onnx(ckpt_path, fp16=False):
# Generate random sample data
base, ext = os.path.splitext(os.path.basename(ckpt_path))
if fp16:
output_path = os.path.join(ONNX_DIR, f"{base}_fp16.onnx")
output_path = os.path.join(ONNX_DIR, f"{base}.onnx")
sample_input = torch.from_numpy(
np.random.rand(batch_size, in_channels, slice_length).astype(np.float16)
)
else:
output_path = os.path.join(ONNX_DIR, f"{base}_fp32.onnx")
output_path = os.path.join(ONNX_DIR, f"{base}.onnx")
sample_input = torch.rand(batch_size, in_channels, slice_length)
torch.onnx.export(