removed fp16, fp32 from the file names
All checks were successful
RIA Hub Workflow Demo / ria-demo (push) Successful in 41s

This commit is contained in:
Liyu Xiao 2025-05-22 15:08:31 -04:00
parent c6ad74b5df
commit 123cb82334

View File

@ -48,12 +48,12 @@ def convert_to_onnx(ckpt_path, fp16=False):
# Generate random sample data # Generate random sample data
base, ext = os.path.splitext(os.path.basename(ckpt_path)) base, ext = os.path.splitext(os.path.basename(ckpt_path))
if fp16: if fp16:
output_path = os.path.join(ONNX_DIR, f"{base}_fp16.onnx") output_path = os.path.join(ONNX_DIR, f"{base}.onnx")
sample_input = torch.from_numpy( sample_input = torch.from_numpy(
np.random.rand(batch_size, in_channels, slice_length).astype(np.float16) np.random.rand(batch_size, in_channels, slice_length).astype(np.float16)
) )
else: else:
output_path = os.path.join(ONNX_DIR, f"{base}_fp32.onnx") output_path = os.path.join(ONNX_DIR, f"{base}.onnx")
sample_input = torch.rand(batch_size, in_channels, slice_length) sample_input = torch.rand(batch_size, in_channels, slice_length)
torch.onnx.export( torch.onnx.export(