modrec-workflow/scripts/dataset_building/produce_dataset.py
2025-06-18 13:44:29 -04:00

167 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os, h5py, numpy as np
from typing import List
from utils.io import from_npy
from split_dataset import split, split_recording
from helpers.app_settings import DataSetConfig, get_app_settings
meta_dtype = np.dtype(
[
("rec_id", "S256"),
("snippet_idx", np.int32),
("modulation", "S32"),
("snr", np.int32),
("beta", np.float32),
("sps", np.int32),
]
)
info_dtype = np.dtype(
[
("num_records", np.int32),
("dataset_name", "S64"), # up to 64byte UTF-8 strings
("creator", "S64"),
]
)
def write_hdf5_file(records: List, output_path: str, dataset_name: str = "data") -> str:
"""
Writes a list of records to an HDF5 file.
Parameters:
records (list): List of records to be written to the file
output_path (str): Path to the output HDF5 file
dataset_name (str): Name of the dataset in the HDF5 file (default: "data")
Returns:
str: Path to the created HDF5 file
"""
meta_arr = np.empty(len(records), dtype=meta_dtype)
for i, (_, md) in enumerate(records):
meta_arr[i] = (
md["rec_id"].encode("utf-8"),
md["snippet_idx"],
md["modulation"].encode("utf-8"),
int(md["snr"]),
float(md["beta"]),
int(md["sps"]),
)
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
sample = first_rec
shape, dtype = sample.shape, sample.dtype
with h5py.File(output_path, "w") as hf:
data_arr = np.stack([rec[0] for rec in records])
dset = hf.create_dataset(dataset_name, data=data_arr, compression="gzip")
mg = hf.create_group("metadata")
mg.create_dataset("metadata", data=meta_arr, compression="gzip")
print(dset.shape, f"snippets created in {dataset_name}")
info_arr = np.array(
[
(
len(records),
dataset_name.encode("utf-8"),
b"generate_dataset.py", # already bytes
)
],
dtype=info_dtype,
)
mg.create_dataset("dataset_info", data=info_arr)
return output_path
def complex_to_channel(data: np.ndarray) -> np.ndarray:
"""
Converts complex-valued IQ data of shape (1, N) to a 2-channel real array of shape (2, N).
Parameters:
data (np.ndarray): Complex-valued array of shape (1, N)
Returns:
np.ndarray: Real-valued array of shape (2, N) with separate real and imaginary channels
"""
assert np.iscomplexobj(data) # check if the data is in the form a+bi
real = np.real(data[0]) # (N,)
imag = np.imag(data[0]) # (N,)
stacked = np.stack([real, imag], axis=0) # shape (2, N)
return stacked.astype(np.float32)
def generate_datasets(cfg: DataSetConfig) -> tuple:
"""
Generates a dataset from a folder of .npy files and saves it to an HDF5 file
Parameters:
cfg (DataSetConfig): Dataset configuration loaded from app.yaml
Returns:
dset (h5py.Dataset): The created dataset object
"""
parent = os.path.dirname("data/dataset")
if not parent:
os.makedirs("data/dataset", exist_ok=True)
# we assume the recordings are in .npy format
files = os.listdir("data/recordings")
if not files:
raise ValueError("No files found in the specified directory.")
records = []
for fname in files:
rec = from_npy(os.path.join("data/recordings", fname))
data = rec.data # here data is a numpy array with the shape (1, N)
data = complex_to_channel(data) # convert to 2-channel real array
md = rec.metadata # pull metadata from the recording
md.setdefault("recid", len(records))
records.append((data, md))
# split each recording into <num_slices> snippets each
records = split_recording(records, cfg.num_slices)
train_records, val_records = split(records, cfg.train_split, cfg.seed)
train_path = os.path.join("data/dataset", "train.h5")
val_path = os.path.join("data/dataset", "val.h5")
write_hdf5_file(train_records, train_path, "training_data")
write_hdf5_file(val_records, val_path, "validation_data")
return train_path, val_path
def main():
settings = get_app_settings()
dataset_cfg = settings.dataset
print("📦 Generating training and validation datasets...")
print(f" ➤ Slicing each recording into {dataset_cfg.num_slices} snippets")
print(
f" ➤ Train/Val split: {int(dataset_cfg.train_split * 100)}% / {int((1 - dataset_cfg.train_split) * 100)}%"
)
print(f" ➤ Output directory: data/dataset\n")
train_path, val_path = generate_datasets(dataset_cfg)
# Count number of samples in each file
with h5py.File(train_path, "r") as f:
num_train = f["training_data"].shape[0]
with h5py.File(val_path, "r") as f:
num_val = f["validation_data"].shape[0]
print("✅ Dataset generation complete!")
print(f" 🔹 Training samples saved to: {train_path} ({num_train} samples)")
print(f" 🔸 Validation samples saved to: {val_path} ({num_val} samples)")
if __name__ == "__main__":
main()