160 lines
4.5 KiB
Python
160 lines
4.5 KiB
Python
import os, h5py, numpy as np
|
||
from utils.io import from_npy
|
||
from split_dataset import split
|
||
from helpers.app_settings import get_app_settings
|
||
|
||
meta_dtype = np.dtype(
|
||
[
|
||
("rec_id", "S256"),
|
||
("snippet_idx", np.int32),
|
||
("modulation", "S32"),
|
||
("snr", np.int32),
|
||
("beta", np.float32),
|
||
("sps", np.int32),
|
||
]
|
||
)
|
||
|
||
info_dtype = np.dtype(
|
||
[
|
||
("num_records", np.int32),
|
||
("dataset_name", "S64"), # up to 64‐byte UTF-8 strings
|
||
("creator", "S64"),
|
||
]
|
||
)
|
||
|
||
|
||
def write_hdf5_file(records, output_path, dataset_name="data"):
|
||
"""
|
||
Writes a list of records to an HDF5 file.
|
||
Parameters:
|
||
records (list): List of records to be written to the file
|
||
output_path (str): Path to the output HDF5 file
|
||
dataset_name (str): Name of the dataset in the HDF5 file (default: "data")
|
||
Returns:
|
||
str: Path to the created HDF5 file
|
||
"""
|
||
meta_arr = np.empty(len(records), dtype=meta_dtype)
|
||
for i, (_, md) in enumerate(records):
|
||
meta_arr[i] = (
|
||
md["rec_id"].encode("utf-8"),
|
||
md["snippet_idx"],
|
||
md["modulation"].encode("utf-8"),
|
||
int(md["snr"]),
|
||
float(md["beta"]),
|
||
int(md["sps"]),
|
||
)
|
||
|
||
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
|
||
sample = first_rec
|
||
shape, dtype = sample.shape, sample.dtype
|
||
|
||
with h5py.File(output_path, "w") as hf:
|
||
dset = hf.create_dataset(
|
||
dataset_name, shape=(len(records),) + shape, dtype=dtype, compression="gzip"
|
||
)
|
||
|
||
for idx, (snip, md) in enumerate(records):
|
||
dset[idx, ...] = snip
|
||
|
||
mg = hf.create_group("metadata")
|
||
mg.create_dataset("metadata", data=meta_arr, compression="gzip")
|
||
|
||
print(dset.shape, f"snippets created in {dataset_name}")
|
||
|
||
info_arr = np.array(
|
||
[
|
||
(
|
||
len(records),
|
||
dataset_name.encode("utf-8"),
|
||
b"generate_dataset.py", # already bytes
|
||
)
|
||
],
|
||
dtype=info_dtype,
|
||
)
|
||
|
||
mg.create_dataset("dataset_info", data=info_arr)
|
||
|
||
return output_path
|
||
|
||
|
||
def split_recording(recording_list, num_snippets):
|
||
"""
|
||
Splits a list of recordings into smaller chunks.
|
||
|
||
Parameters:
|
||
recording_list (list): List of recordings to be split
|
||
|
||
Returns: yeah yeah
|
||
list: List of split recordings
|
||
"""
|
||
snippet_list = []
|
||
|
||
for data, md in recording_list:
|
||
C, N = data.shape
|
||
L = N // num_snippets
|
||
for i in range(num_snippets):
|
||
start = i * L
|
||
end = (i + 1) * L
|
||
snippet = data[:, start:end]
|
||
# copy the metadata, adding a snippet index
|
||
snippet_md = md.copy()
|
||
snippet_md["snippet_idx"] = i
|
||
snippet_list.append((snippet, snippet_md))
|
||
return snippet_list
|
||
|
||
|
||
def generate_datasets(cfg):
|
||
"""
|
||
Generates a dataset from a folder of .npy files and saves it to an HDF5 file
|
||
|
||
Parameters:
|
||
path_to_recordings (str): Path to the folder containing .npy files
|
||
output_path (str): Path to the output HDF5 file
|
||
dataset_name (str): Name of the dataset in the HDF5 file (default: "data")
|
||
|
||
Returns:
|
||
dset (h5py.Dataset): The created dataset object
|
||
"""
|
||
|
||
parent = os.path.dirname(cfg.output_dir)
|
||
if parent:
|
||
os.makedirs(parent, exist_ok=True)
|
||
|
||
# we assume the recordings are in .npy format
|
||
files = os.listdir(cfg.input_dir)
|
||
if not files:
|
||
raise ValueError("No files found in the specified directory.")
|
||
|
||
records = []
|
||
for fname in files:
|
||
rec = from_npy(os.path.join(cfg.input_dir, fname))
|
||
|
||
data = rec.data
|
||
|
||
md = rec.metadata # pull metadata from the recordinh
|
||
md.setdefault("recid", len(records))
|
||
records.append((data, md))
|
||
|
||
# split each recording into 8 snippets each
|
||
records = split_recording(records, cfg.num_slices)
|
||
|
||
train_records, val_records = split(records, cfg.train_split, cfg.seed)
|
||
|
||
train_path = os.path.join(cfg.output_dir, "train.h5")
|
||
val_path = os.path.join(cfg.output_dir, "val.h5")
|
||
|
||
write_hdf5_file(train_records, train_path, "training_data")
|
||
write_hdf5_file(val_records, val_path, "validation_data")
|
||
|
||
return train_path, val_path
|
||
|
||
def main():
|
||
settings = get_app_settings()
|
||
dataset_cfg = settings.dataset
|
||
train_path, val_path = generate_datasets(dataset_cfg)
|
||
print(f"✅ Train: {train_path}\n✅ Val: {val_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|