import random from collections import defaultdict def split(dataset, train_frac=0.8, seed=42): """ Splits a dataset into smaller datasets based on the specified lengths. Parameters: dataset (list): The dataset to be split. lengths (list): A list of lengths for each split. Returns: list: A list of split datasets. """ N = len(dataset) target = int(N * train_frac) by_rec = defaultdict(list) for i, (_, md) in enumerate(dataset): by_rec[md['rec_id']].append(i) rec_ids = list(by_rec.keys()) random.seed(seed) random.shuffle(rec_ids) train_set = set() count = 0 for rec_id in rec_ids: index = by_rec[rec_id] if count + len(index) <= target: train_set.update(index) count += len(index) validation_set = set(range(N)) - train_set print(f"Train set :{len(train_set)}") print(f"val set :{len(validation_set)}") train_records = [dataset[i] for i in sorted(train_set)] val_records = [dataset[i] for i in sorted(validation_set)] return train_records, val_records