50 lines
1.1 KiB
Python
50 lines
1.1 KiB
Python
import random
|
|
from collections import defaultdict
|
|
|
|
|
|
def split(dataset, train_frac=0.8, seed=42):
|
|
"""
|
|
Splits a dataset into smaller datasets based on the specified lengths.
|
|
|
|
Parameters:
|
|
dataset (list): The dataset to be split.
|
|
lengths (list): A list of lengths for each split.
|
|
|
|
Returns:
|
|
list: A list of split datasets.
|
|
"""
|
|
N = len(dataset)
|
|
target = int(N * train_frac)
|
|
|
|
by_rec = defaultdict(list)
|
|
for i, (_, md) in enumerate(dataset):
|
|
by_rec[md['rec_id']].append(i)
|
|
|
|
|
|
rec_ids = list(by_rec.keys())
|
|
random.seed(seed)
|
|
random.shuffle(rec_ids)
|
|
|
|
|
|
train_set = set()
|
|
count = 0
|
|
for rec_id in rec_ids:
|
|
index = by_rec[rec_id]
|
|
if count + len(index) <= target:
|
|
train_set.update(index)
|
|
count += len(index)
|
|
|
|
|
|
|
|
validation_set = set(range(N)) - train_set
|
|
|
|
print(f"Train set :{len(train_set)}")
|
|
print(f"val set :{len(validation_set)}")
|
|
|
|
train_records = [dataset[i] for i in sorted(train_set)]
|
|
val_records = [dataset[i] for i in sorted(validation_set)]
|
|
|
|
return train_records, val_records
|
|
|
|
|