forked from qoherent/modrec-workflow
80 lines
2.4 KiB
Python
80 lines
2.4 KiB
Python
![]() |
import random
|
||
|
from collections import defaultdict
|
||
|
|
||
|
|
||
|
def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"):
|
||
|
"""
|
||
|
Splits a dataset into smaller datasets based on the specified lengths.
|
||
|
|
||
|
Parameters:
|
||
|
dataset (list): The dataset to be split.
|
||
|
lengths (list): A list of lengths for each split.
|
||
|
|
||
|
Returns:
|
||
|
list: A list of split datasets.
|
||
|
"""
|
||
|
rec_buckets = defaultdict(list)
|
||
|
for data, md in dataset:
|
||
|
rec_buckets[md["recid"]].append((data, md))
|
||
|
|
||
|
|
||
|
rec_labels = {} #store labels for each recording
|
||
|
for rec_id, group in rec_buckets.items():
|
||
|
label = group[0][1][label_key]
|
||
|
if isinstance(label, bytes): #if the label is a byte string
|
||
|
label = label.decode("utf-8")
|
||
|
rec_labels[rec_id] = label
|
||
|
|
||
|
label_rec_ids = defaultdict(list) #group rec_ids by label
|
||
|
for rec_id, label in rec_labels.items():
|
||
|
label_rec_ids[label].append(rec_id)
|
||
|
|
||
|
random.seed(seed)
|
||
|
train_recs, val_recs = set(), set()
|
||
|
|
||
|
for label, rec_ids in label_rec_ids.items():
|
||
|
random.shuffle(rec_ids)
|
||
|
split_idx = int(len(rec_ids) * train_frac)
|
||
|
train_recs.update(rec_ids[:split_idx]) #pulls train_frac or rec_ids per label, guarantees all modulations are represented
|
||
|
val_recs.update(rec_ids[split_idx:])
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
# add the assigned recordings to the train and val datasets
|
||
|
train_dataset, val_dataset = [], []
|
||
|
for rec_id, group in rec_buckets.items():
|
||
|
if rec_id in train_recs:
|
||
|
train_dataset.extend(group)
|
||
|
elif rec_id in val_recs:
|
||
|
val_dataset.extend(group)
|
||
|
|
||
|
|
||
|
return train_dataset, val_dataset
|
||
|
|
||
|
def split_recording(recording_list, num_snippets):
|
||
|
"""
|
||
|
Splits a list of recordings into smaller chunks.
|
||
|
|
||
|
Parameters:
|
||
|
recording_list (list): List of recordings to be split
|
||
|
|
||
|
Returns: yeah yeah
|
||
|
list: List of split recordings
|
||
|
"""
|
||
|
snippet_list = []
|
||
|
|
||
|
for data, md in recording_list:
|
||
|
C, N = data.shape
|
||
|
L = N // num_snippets
|
||
|
for i in range(num_snippets):
|
||
|
start = i * L
|
||
|
end = (i + 1) * L
|
||
|
snippet = data[:, start:end]
|
||
|
|
||
|
# copy the metadata, adding a snippet index
|
||
|
snippet_md = md.copy()
|
||
|
snippet_md["snippet_idx"] = i
|
||
|
snippet_list.append((snippet, snippet_md))
|
||
|
return snippet_list
|