This commit is contained in:
parent
95d56b90a2
commit
586acd123c
|
@ -47,12 +47,12 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
||||||
# Generate random sample data
|
# Generate random sample data
|
||||||
base, ext = os.path.splitext(os.path.basename(ckpt_path))
|
base, ext = os.path.splitext(os.path.basename(ckpt_path))
|
||||||
if fp16:
|
if fp16:
|
||||||
output_path = os.path.join("onnx_scripts", f"{base}.onnx")
|
output_path = os.path.join("onnx_files", f"{base}.onnx")
|
||||||
sample_input = torch.from_numpy(
|
sample_input = torch.from_numpy(
|
||||||
np.random.rand(batch_size, in_channels, slice_length).astype(np.float16)
|
np.random.rand(batch_size, in_channels, slice_length).astype(np.float16)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_path = os.path.join("onnx_scripts", f"{base}.onnx")
|
output_path = os.path.join("onnx_files", f"{base}.onnx")
|
||||||
print(output_path)
|
print(output_path)
|
||||||
sample_input = torch.rand(batch_size, in_channels, slice_length)
|
sample_input = torch.rand(batch_size, in_channels, slice_length)
|
||||||
|
|
||||||
|
@ -83,4 +83,4 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
output_file = "convert_to_onnx" + ".onnx"
|
output_file = "convert_to_onnx" + ".onnx"
|
||||||
|
|
||||||
print("Conversion complete stored at: ", os.path.join("onnx_scripts", output_file))
|
print("Conversion complete stored at: ", os.path.join("onnx_files", output_file))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user