formatted code, added setup.py file to install packages

This commit is contained in:
Liyu Xiao 2025-05-16 11:31:37 -04:00
parent 66d4c47cc4
commit cdc293c7ce
4 changed files with 26 additions and 19 deletions

View File

@ -36,6 +36,7 @@ jobs:
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install h5py numpy pip install h5py numpy
pip install -e
- name: 1. Build HDF5 Dataset - name: 1. Build HDF5 Dataset
run: | run: |

View File

@ -21,6 +21,7 @@ info_dtype = np.dtype(
] ]
) )
def write_hdf5_file(records, output_path, dataset_name="data"): def write_hdf5_file(records, output_path, dataset_name="data"):
""" """
Writes a list of records to an HDF5 file. Writes a list of records to an HDF5 file.
@ -74,6 +75,7 @@ def write_hdf5_file(records, output_path, dataset_name="data"):
return output_path return output_path
def split_recording(recording_list): def split_recording(recording_list):
""" """
Splits a list of recordings into smaller chunks. Splits a list of recordings into smaller chunks.
@ -135,15 +137,15 @@ def generate_datasets(path_to_recordings, output_path, dataset_name="data"):
# split each recording into 8 snippets each # split each recording into 8 snippets each
records = split_recording(records) records = split_recording(records)
train_records, val_records = split(records, train_frac=0.8, seed=42) train_records, val_records = split(records, train_frac=0.8, seed=42)
train_path = os.path.join(output_path, "train.h5") train_path = os.path.join(output_path, "train.h5")
val_path = os.path.join(output_path, "val.h5") val_path = os.path.join(output_path, "val.h5")
write_hdf5_file(train_records, train_path, "training_data") write_hdf5_file(train_records, train_path, "training_data")
write_hdf5_file(val_records, val_path, "validation_data") write_hdf5_file(val_records, val_path, "validation_data")
return train_path, val_path return train_path, val_path

View File

@ -15,17 +15,15 @@ def split(dataset, train_frac=0.8, seed=42):
""" """
N = len(dataset) N = len(dataset)
target = int(N * train_frac) target = int(N * train_frac)
by_rec = defaultdict(list) by_rec = defaultdict(list)
for i, (_, md) in enumerate(dataset): for i, (_, md) in enumerate(dataset):
by_rec[md['rec_id']].append(i) by_rec[md["rec_id"]].append(i)
rec_ids = list(by_rec.keys())
rec_ids = list(by_rec.keys())
random.seed(seed) random.seed(seed)
random.shuffle(rec_ids) random.shuffle(rec_ids)
train_set = set() train_set = set()
count = 0 count = 0
for rec_id in rec_ids: for rec_id in rec_ids:
@ -33,17 +31,13 @@ def split(dataset, train_frac=0.8, seed=42):
if count + len(index) <= target: if count + len(index) <= target:
train_set.update(index) train_set.update(index)
count += len(index) count += len(index)
validation_set = set(range(N)) - train_set validation_set = set(range(N)) - train_set
print(f"Train set :{len(train_set)}") print(f"Train set :{len(train_set)}")
print(f"val set :{len(validation_set)}") print(f"val set :{len(validation_set)}")
train_records = [dataset[i] for i in sorted(train_set)] train_records = [dataset[i] for i in sorted(train_set)]
val_records = [dataset[i] for i in sorted(validation_set)] val_records = [dataset[i] for i in sorted(validation_set)]
return train_records, val_records return train_records, val_records

10
setup.py Normal file
View File

@ -0,0 +1,10 @@
from setuptools import setup, find_packages
setup(
name="modrec_workflow",
version="0.1",
packages=find_packages(), # this will pick up `utils/` (so utils/__init__.py must exist)
install_requires=[
# runtime dependencies go here (if any)
],
)