modrec-workflow/scripts/dataset_building/produce_dataset.py

146 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os, h5py, numpy as np
from utils.io import from_npy
from split_dataset import split, split_recording
from helpers.app_settings import get_app_settings
meta_dtype = np.dtype(
[
("rec_id", "S256"),
("snippet_idx", np.int32),
("modulation", "S32"),
("snr", np.int32),
("beta", np.float32),
("sps", np.int32),
]
)
info_dtype = np.dtype(
[
("num_records", np.int32),
("dataset_name", "S64"), # up to 64byte UTF-8 strings
("creator", "S64"),
]
)
def write_hdf5_file(records, output_path, dataset_name="data"):
"""
Writes a list of records to an HDF5 file.
Parameters:
records (list): List of records to be written to the file
output_path (str): Path to the output HDF5 file
dataset_name (str): Name of the dataset in the HDF5 file (default: "data")
Returns:
str: Path to the created HDF5 file
"""
meta_arr = np.empty(len(records), dtype=meta_dtype)
for i, (_, md) in enumerate(records):
meta_arr[i] = (
md["rec_id"].encode("utf-8"),
md["snippet_idx"],
md["modulation"].encode("utf-8"),
int(md["snr"]),
float(md["beta"]),
int(md["sps"]),
)
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
sample = first_rec
shape, dtype = sample.shape, sample.dtype
with h5py.File(output_path, "w") as hf:
data_arr = np.stack([rec[0] for rec in records])
dset = hf.create_dataset(dataset_name, data=data_arr, compression="gzip")
mg = hf.create_group("metadata")
mg.create_dataset("metadata", data=meta_arr, compression="gzip")
print(dset.shape, f"snippets created in {dataset_name}")
info_arr = np.array(
[
(
len(records),
dataset_name.encode("utf-8"),
b"generate_dataset.py", # already bytes
)
],
dtype=info_dtype,
)
mg.create_dataset("dataset_info", data=info_arr)
return output_path
def complex_to_channel(data):
"""
Convert complex-valued IQ data of shape (1, N) to 2-channel real array of shape (2, N).
"""
assert np.iscomplexobj(data) # check if the data is in the form a+bi
real = np.real(data[0]) # (N,)
imag = np.imag(data[0]) # (N,)
stacked = np.stack([real, imag], axis=0) # shape (2, N)
return stacked.astype(np.float32)
def generate_datasets(cfg):
"""
Generates a dataset from a folder of .npy files and saves it to an HDF5 file
Parameters:
path_to_recordings (str): Path to the folder containing .npy files
output_path (str): Path to the output HDF5 file
dataset_name (str): Name of the dataset in the HDF5 file (default: "data")
Returns:
dset (h5py.Dataset): The created dataset object
"""
parent = os.path.dirname(cfg.output_dir)
if not parent:
os.makedirs(cfg.output_dir, exist_ok=True)
# we assume the recordings are in .npy format
files = os.listdir(cfg.input_dir)
if not files:
raise ValueError("No files found in the specified directory.")
records = []
for fname in files:
rec = from_npy(os.path.join(cfg.input_dir, fname))
data = rec.data # here data is a numpy array with the shape (1, N)
data = complex_to_channel(data) # convert to 2-channel real array
md = rec.metadata # pull metadata from the recording
md.setdefault("recid", len(records))
records.append((data, md))
# split each recording into <num_slices> snippets each
records = split_recording(records, cfg.num_slices)
train_records, val_records = split(records, cfg.train_split, cfg.seed)
train_path = os.path.join(cfg.output_dir, "train.h5")
val_path = os.path.join(cfg.output_dir, "val.h5")
write_hdf5_file(train_records, train_path, "training_data")
write_hdf5_file(val_records, val_path, "validation_data")
return train_path, val_path
def main():
settings = get_app_settings()
dataset_cfg = settings.dataset
train_path, val_path = generate_datasets(dataset_cfg)
print(f"✅ Train: {train_path}\n✅ Val: {val_path}")
if __name__ == "__main__":
main()