modrec-workflow/scripts/produce_dataset.py

152 lines
4.3 KiB
Python
Raw Normal View History

2025-05-16 11:26:33 -04:00
import os, h5py, numpy as np
from utils.io import from_npy
from split_dataset import split
meta_dtype = np.dtype(
[
("rec_id", "S256"),
("snippet_idx", np.int32),
("modulation", "S32"),
("snr", np.int32),
("beta", np.float32),
("sps", np.int32),
]
)
info_dtype = np.dtype(
[
("num_records", np.int32),
("dataset_name", "S64"), # up to 64byte UTF-8 strings
("creator", "S64"),
]
)
def write_hdf5_file(records, output_path, dataset_name="data"):
"""
Writes a list of records to an HDF5 file.
Parameters:
records (list): List of records to be written to the file
output_path (str): Path to the output HDF5 file
dataset_name (str): Name of the dataset in the HDF5 file (default: "data")
Returns:
str: Path to the created HDF5 file
"""
meta_arr = np.empty(len(records), dtype=meta_dtype)
for i, (_, md) in enumerate(records):
meta_arr[i] = (
md["rec_id"].encode("utf-8"),
md["snippet_idx"],
md["modulation"].encode("utf-8"),
int(md["snr"]),
float(md["beta"]),
int(md["sps"]),
)
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
sample = first_rec
shape, dtype = sample.shape, sample.dtype
with h5py.File(output_path, "w") as hf:
dset = hf.create_dataset(
dataset_name, shape=(len(records),) + shape, dtype=dtype, compression="gzip"
)
for idx, (snip, md) in enumerate(records):
dset[idx, ...] = snip
mg = hf.create_group("metadata")
mg.create_dataset("metadata", data=meta_arr, compression="gzip")
print(dset.shape, f"snippets created in {dataset_name}")
info_arr = np.array(
[
(
len(records),
dataset_name.encode("utf-8"),
b"generate_dataset.py", # already bytes
)
],
dtype=info_dtype,
)
mg.create_dataset("dataset_info", data=info_arr)
return output_path
def split_recording(recording_list):
"""
Splits a list of recordings into smaller chunks.
Parameters:
recording_list (list): List of recordings to be split
Returns: yeah yeah
list: List of split recordings
"""
snippet_list = []
for data, md in recording_list:
C, N = data.shape
L = N // 8
rec_id = md["rec_id"]
for i in range(8):
start = i * L
end = (i + 1) * L
snippet = data[:, start:end]
# copy the metadata, adding a snippet index
snippet_md = md.copy()
snippet_md["snippet_idx"] = i
snippet_list.append((snippet, snippet_md))
return snippet_list
def generate_datasets(path_to_recordings, output_path, dataset_name="data"):
"""
Generates a dataset from a folder of .npy files and saves it to an HDF5 file
Parameters:
path_to_recordings (str): Path to the folder containing .npy files
output_path (str): Path to the output HDF5 file
dataset_name (str): Name of the dataset in the HDF5 file (default: "data")
Returns:
dset (h5py.Dataset): The created dataset object
"""
parent = os.path.dirname(output_path)
if parent:
os.makedirs(parent, exist_ok=True)
# we assume the recordings are in .npy format
files = os.listdir(path_to_recordings)
if not files:
raise ValueError("No files found in the specified directory.")
records = []
for fname in files:
rec = from_npy(os.path.join(path_to_recordings, fname))
data = rec.data
md = rec.metadata # pull metadata from the recordinh
md.setdefault("recid", len(records))
records.append((data, md))
# split each recording into 8 snippets each
records = split_recording(records)
train_records, val_records = split(records, train_frac=0.8, seed=42)
train_path = os.path.join(output_path, "train.h5")
val_path = os.path.join(output_path, "val.h5")
write_hdf5_file(train_records, train_path, "training_data")
write_hdf5_file(val_records, val_path, "validation_data")
return train_path, val_path
if __name__ == "__main__":
print(generate_datasets("recordings", "data"))