2025-05-21 15:52:16 -04:00
|
|
|
import random
|
|
|
|
from collections import defaultdict
|
2025-06-17 14:16:16 -04:00
|
|
|
from typing import List, Tuple, Dict
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
def split(
|
|
|
|
dataset: List[Tuple[np.ndarray, Dict[str, any]]],
|
|
|
|
train_frac: float,
|
|
|
|
seed: int,
|
|
|
|
label_key: str = "modulation"
|
|
|
|
) -> Tuple[List[Tuple[np.ndarray, Dict[str, any]]], List[Tuple[np.ndarray, Dict[str, any]]]]:
|
2025-05-21 15:52:16 -04:00
|
|
|
"""
|
2025-06-17 14:16:16 -04:00
|
|
|
Splits a dataset of modulated IQ signal recordings into training and validation subsets.
|
2025-05-21 15:52:16 -04:00
|
|
|
|
|
|
|
Parameters:
|
2025-06-17 14:16:16 -04:00
|
|
|
dataset (list): List of tuples where each tuple contains:
|
|
|
|
- np.ndarray: 2xN real array (channels x samples)
|
|
|
|
- dict: Metadata for the sample
|
|
|
|
train_frac (float): Fraction of the dataset to use for training (default: 0.8)
|
|
|
|
seed (int): Random seed for reproducibility (default: 42)
|
|
|
|
label_key (str): Metadata key to group by during splitting (default: "modulation")
|
2025-05-21 15:52:16 -04:00
|
|
|
|
|
|
|
Returns:
|
2025-06-17 14:16:16 -04:00
|
|
|
tuple: Two lists of (np.ndarray, dict) pairs — (train_records, val_records)
|
2025-05-21 15:52:16 -04:00
|
|
|
"""
|
|
|
|
rec_buckets = defaultdict(list)
|
|
|
|
for data, md in dataset:
|
|
|
|
rec_buckets[md["recid"]].append((data, md))
|
2025-05-22 14:12:36 -04:00
|
|
|
|
|
|
|
rec_labels = {} # store labels for each recording
|
2025-05-21 15:52:16 -04:00
|
|
|
for rec_id, group in rec_buckets.items():
|
|
|
|
label = group[0][1][label_key]
|
2025-05-22 14:12:36 -04:00
|
|
|
if isinstance(label, bytes): # if the label is a byte string
|
2025-05-21 15:52:16 -04:00
|
|
|
label = label.decode("utf-8")
|
|
|
|
rec_labels[rec_id] = label
|
2025-05-22 14:12:36 -04:00
|
|
|
|
|
|
|
label_rec_ids = defaultdict(list) # group rec_ids by label
|
2025-05-21 15:52:16 -04:00
|
|
|
for rec_id, label in rec_labels.items():
|
|
|
|
label_rec_ids[label].append(rec_id)
|
2025-05-22 14:12:36 -04:00
|
|
|
|
2025-05-21 15:52:16 -04:00
|
|
|
random.seed(seed)
|
|
|
|
train_recs, val_recs = set(), set()
|
2025-05-22 14:12:36 -04:00
|
|
|
|
2025-05-21 15:52:16 -04:00
|
|
|
for label, rec_ids in label_rec_ids.items():
|
|
|
|
random.shuffle(rec_ids)
|
|
|
|
split_idx = int(len(rec_ids) * train_frac)
|
2025-05-22 14:12:36 -04:00
|
|
|
train_recs.update(
|
|
|
|
rec_ids[:split_idx]
|
|
|
|
) # pulls train_frac or rec_ids per label, guarantees all modulations are represented
|
2025-05-21 15:52:16 -04:00
|
|
|
val_recs.update(rec_ids[split_idx:])
|
2025-05-22 14:12:36 -04:00
|
|
|
|
2025-05-21 15:52:16 -04:00
|
|
|
# add the assigned recordings to the train and val datasets
|
|
|
|
train_dataset, val_dataset = [], []
|
|
|
|
for rec_id, group in rec_buckets.items():
|
|
|
|
if rec_id in train_recs:
|
|
|
|
train_dataset.extend(group)
|
|
|
|
elif rec_id in val_recs:
|
|
|
|
val_dataset.extend(group)
|
2025-05-22 14:12:36 -04:00
|
|
|
|
2025-05-21 15:52:16 -04:00
|
|
|
return train_dataset, val_dataset
|
|
|
|
|
2025-05-22 14:12:36 -04:00
|
|
|
|
2025-06-17 14:16:16 -04:00
|
|
|
def split_recording(
|
|
|
|
recording_list: List[Tuple[np.ndarray, Dict[str, any]]],
|
|
|
|
num_snippets: int
|
|
|
|
) -> List[Tuple[np.ndarray, Dict[str, any]]]:
|
2025-05-21 15:52:16 -04:00
|
|
|
"""
|
2025-06-17 14:16:16 -04:00
|
|
|
Splits each full recording into a specified number of smaller snippets.
|
|
|
|
|
|
|
|
Each recording is a tuple of:
|
|
|
|
- data (np.ndarray): A 2xN real-valued array representing I/Q signal data.
|
|
|
|
- metadata (dict): Metadata describing the recording (e.g., modulation, SNR, etc.)
|
|
|
|
|
|
|
|
The split is typically done along the time axis (axis=1), dividing each (2, N)
|
|
|
|
array into `num_snippets` contiguous chunks of shape (2, N // num_snippets).
|
2025-05-21 15:52:16 -04:00
|
|
|
|
|
|
|
Parameters:
|
2025-06-17 14:16:16 -04:00
|
|
|
recording_list (List[Tuple[np.ndarray, dict]]):
|
|
|
|
List of (data, metadata) tuples to be split.
|
|
|
|
num_snippets (int):
|
|
|
|
Number of equal-length segments to divide each recording into.
|
2025-05-21 15:52:16 -04:00
|
|
|
|
2025-06-17 14:16:16 -04:00
|
|
|
Returns:
|
|
|
|
List[Tuple[np.ndarray, dict]]:
|
|
|
|
A flat list containing all resulting (snippet, metadata) pairs.
|
2025-05-21 15:52:16 -04:00
|
|
|
"""
|
|
|
|
snippet_list = []
|
|
|
|
|
|
|
|
for data, md in recording_list:
|
|
|
|
C, N = data.shape
|
|
|
|
L = N // num_snippets
|
|
|
|
for i in range(num_snippets):
|
|
|
|
start = i * L
|
|
|
|
end = (i + 1) * L
|
|
|
|
snippet = data[:, start:end]
|
2025-05-22 14:12:36 -04:00
|
|
|
|
2025-05-21 15:52:16 -04:00
|
|
|
# copy the metadata, adding a snippet index
|
|
|
|
snippet_md = md.copy()
|
|
|
|
snippet_md["snippet_idx"] = i
|
|
|
|
snippet_list.append((snippet, snippet_md))
|
|
|
|
return snippet_list
|