forked from qoherent/modrec-workflow
formatted code, added setup.py file to install packages
This commit is contained in:
parent
66d4c47cc4
commit
cdc293c7ce
|
@ -36,6 +36,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install h5py numpy
|
pip install h5py numpy
|
||||||
|
pip install -e
|
||||||
|
|
||||||
- name: 1. Build HDF5 Dataset
|
- name: 1. Build HDF5 Dataset
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -21,6 +21,7 @@ info_dtype = np.dtype(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def write_hdf5_file(records, output_path, dataset_name="data"):
|
def write_hdf5_file(records, output_path, dataset_name="data"):
|
||||||
"""
|
"""
|
||||||
Writes a list of records to an HDF5 file.
|
Writes a list of records to an HDF5 file.
|
||||||
|
@ -74,6 +75,7 @@ def write_hdf5_file(records, output_path, dataset_name="data"):
|
||||||
|
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
def split_recording(recording_list):
|
def split_recording(recording_list):
|
||||||
"""
|
"""
|
||||||
Splits a list of recordings into smaller chunks.
|
Splits a list of recordings into smaller chunks.
|
||||||
|
@ -135,15 +137,15 @@ def generate_datasets(path_to_recordings, output_path, dataset_name="data"):
|
||||||
|
|
||||||
# split each recording into 8 snippets each
|
# split each recording into 8 snippets each
|
||||||
records = split_recording(records)
|
records = split_recording(records)
|
||||||
|
|
||||||
train_records, val_records = split(records, train_frac=0.8, seed=42)
|
train_records, val_records = split(records, train_frac=0.8, seed=42)
|
||||||
|
|
||||||
train_path = os.path.join(output_path, "train.h5")
|
train_path = os.path.join(output_path, "train.h5")
|
||||||
val_path = os.path.join(output_path, "val.h5")
|
val_path = os.path.join(output_path, "val.h5")
|
||||||
|
|
||||||
write_hdf5_file(train_records, train_path, "training_data")
|
write_hdf5_file(train_records, train_path, "training_data")
|
||||||
write_hdf5_file(val_records, val_path, "validation_data")
|
write_hdf5_file(val_records, val_path, "validation_data")
|
||||||
|
|
||||||
return train_path, val_path
|
return train_path, val_path
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,17 +15,15 @@ def split(dataset, train_frac=0.8, seed=42):
|
||||||
"""
|
"""
|
||||||
N = len(dataset)
|
N = len(dataset)
|
||||||
target = int(N * train_frac)
|
target = int(N * train_frac)
|
||||||
|
|
||||||
by_rec = defaultdict(list)
|
by_rec = defaultdict(list)
|
||||||
for i, (_, md) in enumerate(dataset):
|
for i, (_, md) in enumerate(dataset):
|
||||||
by_rec[md['rec_id']].append(i)
|
by_rec[md["rec_id"]].append(i)
|
||||||
|
|
||||||
|
rec_ids = list(by_rec.keys())
|
||||||
rec_ids = list(by_rec.keys())
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
random.shuffle(rec_ids)
|
random.shuffle(rec_ids)
|
||||||
|
|
||||||
|
|
||||||
train_set = set()
|
train_set = set()
|
||||||
count = 0
|
count = 0
|
||||||
for rec_id in rec_ids:
|
for rec_id in rec_ids:
|
||||||
|
@ -33,17 +31,13 @@ def split(dataset, train_frac=0.8, seed=42):
|
||||||
if count + len(index) <= target:
|
if count + len(index) <= target:
|
||||||
train_set.update(index)
|
train_set.update(index)
|
||||||
count += len(index)
|
count += len(index)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
validation_set = set(range(N)) - train_set
|
validation_set = set(range(N)) - train_set
|
||||||
|
|
||||||
print(f"Train set :{len(train_set)}")
|
print(f"Train set :{len(train_set)}")
|
||||||
print(f"val set :{len(validation_set)}")
|
print(f"val set :{len(validation_set)}")
|
||||||
|
|
||||||
train_records = [dataset[i] for i in sorted(train_set)]
|
train_records = [dataset[i] for i in sorted(train_set)]
|
||||||
val_records = [dataset[i] for i in sorted(validation_set)]
|
val_records = [dataset[i] for i in sorted(validation_set)]
|
||||||
|
|
||||||
return train_records, val_records
|
return train_records, val_records
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user