import random from collections import defaultdict from typing import List, Tuple, Dict import numpy as np def split( dataset: List[Tuple[np.ndarray, Dict[str, any]]], train_frac: float, seed: int, label_key: str = "modulation" ) -> Tuple[List[Tuple[np.ndarray, Dict[str, any]]], List[Tuple[np.ndarray, Dict[str, any]]]]: """ Splits a dataset of modulated IQ signal recordings into training and validation subsets. Parameters: dataset (list): List of tuples where each tuple contains: - np.ndarray: 2xN real array (channels x samples) - dict: Metadata for the sample train_frac (float): Fraction of the dataset to use for training (default: 0.8) seed (int): Random seed for reproducibility (default: 42) label_key (str): Metadata key to group by during splitting (default: "modulation") Returns: tuple: Two lists of (np.ndarray, dict) pairs — (train_records, val_records) """ rec_buckets = defaultdict(list) for data, md in dataset: rec_buckets[md["recid"]].append((data, md)) rec_labels = {} # store labels for each recording for rec_id, group in rec_buckets.items(): label = group[0][1][label_key] if isinstance(label, bytes): # if the label is a byte string label = label.decode("utf-8") rec_labels[rec_id] = label label_rec_ids = defaultdict(list) # group rec_ids by label for rec_id, label in rec_labels.items(): label_rec_ids[label].append(rec_id) random.seed(seed) train_recs, val_recs = set(), set() for label, rec_ids in label_rec_ids.items(): random.shuffle(rec_ids) split_idx = int(len(rec_ids) * train_frac) train_recs.update( rec_ids[:split_idx] ) # pulls train_frac or rec_ids per label, guarantees all modulations are represented val_recs.update(rec_ids[split_idx:]) # add the assigned recordings to the train and val datasets train_dataset, val_dataset = [], [] for rec_id, group in rec_buckets.items(): if rec_id in train_recs: train_dataset.extend(group) elif rec_id in val_recs: val_dataset.extend(group) return train_dataset, val_dataset def split_recording( recording_list: List[Tuple[np.ndarray, Dict[str, any]]], num_snippets: int ) -> List[Tuple[np.ndarray, Dict[str, any]]]: """ Splits each full recording into a specified number of smaller snippets. Each recording is a tuple of: - data (np.ndarray): A 2xN real-valued array representing I/Q signal data. - metadata (dict): Metadata describing the recording (e.g., modulation, SNR, etc.) The split is typically done along the time axis (axis=1), dividing each (2, N) array into `num_snippets` contiguous chunks of shape (2, N // num_snippets). Parameters: recording_list (List[Tuple[np.ndarray, dict]]): List of (data, metadata) tuples to be split. num_snippets (int): Number of equal-length segments to divide each recording into. Returns: List[Tuple[np.ndarray, dict]]: A flat list containing all resulting (snippet, metadata) pairs. """ snippet_list = [] for data, md in recording_list: C, N = data.shape L = N // num_snippets for i in range(num_snippets): start = i * L end = (i + 1) * L snippet = data[:, start:end] # copy the metadata, adding a snippet index snippet_md = md.copy() snippet_md["snippet_idx"] = i snippet_list.append((snippet, snippet_md)) return snippet_list