Compare commits

..

No commits in common. "main" and "lorne-test" have entirely different histories.

4 changed files with 5 additions and 5 deletions

View File

@ -24,7 +24,7 @@ dataset:
snr_step: 3 snr_step: 3
# Number of iterations (signal recordings) per modulation and SNR combination # Number of iterations (signal recordings) per modulation and SNR combination
num_iterations: 10 num_iterations: 100
# Modulation scheme settings; keys must match the `modulation_types` list above # Modulation scheme settings; keys must match the `modulation_types` list above
# Each entry includes: # Each entry includes:
@ -57,7 +57,7 @@ training:
batch_size: 256 batch_size: 256
# Number of complete passes through the training dataset during training # Number of complete passes through the training dataset during training
epochs: 5 epochs: 30
# Learning rate: step size for weight updates after each batch # Learning rate: step size for weight updates after each batch
# Recommended range for fine-tuning: 1e-6 to 1e-4 # Recommended range for fine-tuning: 1e-6 to 1e-4

View File

View File

@ -21,7 +21,7 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
in_channels = 2 in_channels = 2
batch_size = 1 batch_size = 1
slice_length = int(dataset_cfg.recording_length / dataset_cfg.num_slices) slice_length = int(1024 / dataset_cfg.num_slices)
num_classes = len(dataset_cfg.modulation_types) num_classes = len(dataset_cfg.modulation_types)
model = RFClassifier( model = RFClassifier(
@ -42,7 +42,7 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
model.eval() model.eval()
# Generate random sample data # Generate random sample data
base, _ = os.path.splitext(os.path.basename(ckpt_path)) base, ext = os.path.splitext(os.path.basename(ckpt_path))
if fp16: if fp16:
output_path = os.path.join("onnx_files", f"{base}.onnx") output_path = os.path.join("onnx_files", f"{base}.onnx")
sample_input = torch.from_numpy( sample_input = torch.from_numpy(

View File

@ -90,7 +90,7 @@ def split_recording(
snippet_list = [] snippet_list = []
for data, md in recording_list: for data, md in recording_list:
_, N = data.shape C, N = data.shape
L = N // num_snippets L = N // num_snippets
for i in range(num_snippets): for i in range(num_snippets):
start = i * L start = i * L