modrec-workflow/scripts/split_dataset.py

50 lines
1.1 KiB
Python

import random
from collections import defaultdict
def split(dataset, train_frac=0.8, seed=42):
"""
Splits a dataset into smaller datasets based on the specified lengths.
Parameters:
dataset (list): The dataset to be split.
lengths (list): A list of lengths for each split.
Returns:
list: A list of split datasets.
"""
N = len(dataset)
target = int(N * train_frac)
by_rec = defaultdict(list)
for i, (_, md) in enumerate(dataset):
by_rec[md['rec_id']].append(i)
rec_ids = list(by_rec.keys())
random.seed(seed)
random.shuffle(rec_ids)
train_set = set()
count = 0
for rec_id in rec_ids:
index = by_rec[rec_id]
if count + len(index) <= target:
train_set.update(index)
count += len(index)
validation_set = set(range(N)) - train_set
print(f"Train set :{len(train_set)}")
print(f"val set :{len(validation_set)}")
train_records = [dataset[i] for i in sorted(train_set)]
val_records = [dataset[i] for i in sorted(validation_set)]
return train_records, val_records