diff --git a/.riahub/workflows/workflow.yaml b/.riahub/workflows/workflow.yaml index 6672b1b..a5cf268 100644 --- a/.riahub/workflows/workflow.yaml +++ b/.riahub/workflows/workflow.yaml @@ -36,6 +36,7 @@ jobs: run: | python -m pip install --upgrade pip pip install h5py numpy + pip install -e - name: 1. Build HDF5 Dataset run: | diff --git a/scripts/produce_dataset.py b/scripts/produce_dataset.py index 794a4c9..22f2205 100644 --- a/scripts/produce_dataset.py +++ b/scripts/produce_dataset.py @@ -21,6 +21,7 @@ info_dtype = np.dtype( ] ) + def write_hdf5_file(records, output_path, dataset_name="data"): """ Writes a list of records to an HDF5 file. @@ -74,6 +75,7 @@ def write_hdf5_file(records, output_path, dataset_name="data"): return output_path + def split_recording(recording_list): """ Splits a list of recordings into smaller chunks. @@ -135,15 +137,15 @@ def generate_datasets(path_to_recordings, output_path, dataset_name="data"): # split each recording into 8 snippets each records = split_recording(records) - + train_records, val_records = split(records, train_frac=0.8, seed=42) - + train_path = os.path.join(output_path, "train.h5") val_path = os.path.join(output_path, "val.h5") - + write_hdf5_file(train_records, train_path, "training_data") - write_hdf5_file(val_records, val_path, "validation_data") - + write_hdf5_file(val_records, val_path, "validation_data") + return train_path, val_path diff --git a/scripts/split_dataset.py b/scripts/split_dataset.py index 0a0c8c8..894512a 100644 --- a/scripts/split_dataset.py +++ b/scripts/split_dataset.py @@ -15,17 +15,15 @@ def split(dataset, train_frac=0.8, seed=42): """ N = len(dataset) target = int(N * train_frac) - + by_rec = defaultdict(list) for i, (_, md) in enumerate(dataset): - by_rec[md['rec_id']].append(i) - - - rec_ids = list(by_rec.keys()) + by_rec[md["rec_id"]].append(i) + + rec_ids = list(by_rec.keys()) random.seed(seed) random.shuffle(rec_ids) - - + train_set = set() count = 0 for rec_id in rec_ids: @@ -33,17 +31,13 @@ def split(dataset, train_frac=0.8, seed=42): if count + len(index) <= target: train_set.update(index) count += len(index) - - validation_set = set(range(N)) - train_set - + print(f"Train set :{len(train_set)}") print(f"val set :{len(validation_set)}") - + train_records = [dataset[i] for i in sorted(train_set)] - val_records = [dataset[i] for i in sorted(validation_set)] + val_records = [dataset[i] for i in sorted(validation_set)] return train_records, val_records - - diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..755c1c1 --- /dev/null +++ b/setup.py @@ -0,0 +1,10 @@ +from setuptools import setup, find_packages + +setup( + name="modrec_workflow", + version="0.1", + packages=find_packages(), # this will pick up `utils/` (so utils/__init__.py must exist) + install_requires=[ + # runtime dependencies go here (if any) + ], +)