forked from qoherent/modrec-workflow
fixed syntax
This commit is contained in:
parent
9f8a583857
commit
0ca66e886a
|
@ -21,7 +21,7 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
|||
|
||||
in_channels = 2
|
||||
batch_size = 1
|
||||
slice_length = int(1024 / dataset_cfg.num_slices)
|
||||
slice_length = int(dataset_cfg.recording_length / dataset_cfg.num_slices)
|
||||
num_classes = len(dataset_cfg.modulation_types)
|
||||
|
||||
model = RFClassifier(
|
||||
|
@ -42,7 +42,7 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
|||
model.eval()
|
||||
|
||||
# Generate random sample data
|
||||
base, ext = os.path.splitext(os.path.basename(ckpt_path))
|
||||
base, _ = os.path.splitext(os.path.basename(ckpt_path))
|
||||
if fp16:
|
||||
output_path = os.path.join("onnx_files", f"{base}.onnx")
|
||||
sample_input = torch.from_numpy(
|
||||
|
|
|
@ -90,7 +90,7 @@ def split_recording(
|
|||
snippet_list = []
|
||||
|
||||
for data, md in recording_list:
|
||||
C, N = data.shape
|
||||
_, N = data.shape
|
||||
L = N // num_snippets
|
||||
for i in range(num_snippets):
|
||||
start = i * L
|
||||
|
|
Loading…
Reference in New Issue
Block a user