forked from qoherent/modrec-workflow
removed fp16, fp32 from the file names
This commit is contained in:
parent
c6ad74b5df
commit
123cb82334
|
@ -48,12 +48,12 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
|||
# Generate random sample data
|
||||
base, ext = os.path.splitext(os.path.basename(ckpt_path))
|
||||
if fp16:
|
||||
output_path = os.path.join(ONNX_DIR, f"{base}_fp16.onnx")
|
||||
output_path = os.path.join(ONNX_DIR, f"{base}.onnx")
|
||||
sample_input = torch.from_numpy(
|
||||
np.random.rand(batch_size, in_channels, slice_length).astype(np.float16)
|
||||
)
|
||||
else:
|
||||
output_path = os.path.join(ONNX_DIR, f"{base}_fp32.onnx")
|
||||
output_path = os.path.join(ONNX_DIR, f"{base}.onnx")
|
||||
sample_input = torch.rand(batch_size, in_channels, slice_length)
|
||||
|
||||
torch.onnx.export(
|
||||
|
|
Loading…
Reference in New Issue
Block a user