"""Convert RIA Hub HDF5 datasets to WavesFM training format. Usage: python scripts/adapt_dataset.py input.h5 output.h5 Input format (RIA Hub): /data complex64 (N, C, T) or (N, T) /metadata/metadata structured array with "label" column file attrs: data_modality (optional) Output format (WavesFM): /sample float32 (N, 2, C, T) -- split I/Q, z-score normalized /label int64 (N,) -- class indices file attrs: mu (JSON), std (JSON), labels (JSON str), class_weights (np array), sample_len (int) WavesFM reads labels via json.loads(attrs["labels"]) and class_weights via torch.as_tensor(attrs["class_weights"], dtype=torch.float32). These two attributes use different storage formats intentionally. """ import argparse import json import logging import sys import h5py import numpy as np logger = logging.getLogger(__name__) def adapt_iq(src, dst): """Convert complex64 IQ to float32 split I/Q with z-score normalization. Note: loads the full dataset into memory. For typical RIA Hub datasets (< 2 GB) this is fine; chunked I/O can be added if larger inputs arise. """ data = src["/data"][:] # complex64 (N, C, T) or (N, T) if data.ndim not in (2, 3): raise ValueError(f"Expected 2D (N, T) or 3D (N, C, T) data, got {data.ndim}D shape {data.shape}") if data.shape[0] == 0: raise ValueError("Dataset is empty (N=0)") if data.ndim == 2: data = data[:, np.newaxis, :] # (N, 1, T) # Split complex -> (N, 2, C, T) float32 # Axis 1 = [real, imag], confirmed by Ahmed: "instead of T complex numbers, 2T real numbers" iq = np.stack([data.real, data.imag], axis=1).astype(np.float32) # Per-channel z-score: mean/std over (N, T) for each of 2 I/Q components x C channels. # Explicit float32 casts because numpy < 2.0 returns float64 from .mean()/.std() # on float32 inputs. mu = iq.mean(axis=(0, 3), keepdims=True).astype(np.float32) # (1, 2, C, 1) std = iq.std(axis=(0, 3), keepdims=True).astype(np.float32) std = np.where(std < 1e-6, np.float32(1.0), std) # floor preserving float32 iq_normed = (iq - mu) / std dst.create_dataset("sample", data=iq_normed, compression="gzip", compression_opts=1, shuffle=True) _write_labels(src, dst) _write_stats(dst, mu, std, iq_normed.shape[-1]) def _write_labels(src, dst): """Extract labels from metadata, encode to int64 indices, write class_weights.""" if "/metadata/metadata" not in src: raise ValueError("Required dataset '/metadata/metadata' not found in input file") meta = src["/metadata/metadata"] if not hasattr(meta.dtype, "names") or meta.dtype.names is None or "label" not in meta.dtype.names: raise ValueError("Required 'label' column not found in '/metadata/metadata'") raw = meta["label"][:] expected = src["/data"].shape[0] if len(raw) != expected: raise ValueError(f"Label count {len(raw)} does not match sample count {expected}") # Decode bytes/strings to Python strings if raw.dtype.kind in ("S", "U", "O"): decoded = [v.decode("utf-8", errors="replace") if isinstance(v, bytes) else str(v) for v in raw] else: decoded = [str(v) for v in raw] # Label encoding: sorted unique -> 0-based contiguous indices unique_labels = sorted(set(decoded)) label_map = {name: idx for idx, name in enumerate(unique_labels)} indices = np.array([label_map[v] for v in decoded], dtype=np.int64) dst.create_dataset("label", data=indices) # WavesFM reads: json.loads(h5.attrs["labels"]) dst.attrs["labels"] = json.dumps(unique_labels) # Inverse-frequency class weights as numpy array (NOT JSON). # WavesFM reads: torch.as_tensor(h5.attrs["class_weights"], dtype=torch.float32) counts = np.bincount(indices, minlength=len(unique_labels)).astype(np.float64) weights = np.where(counts > 0, 1.0 / counts, 0.0) if weights.sum() > 0: weights = weights / weights.sum() dst.attrs["class_weights"] = weights.astype(np.float32) def _write_stats(dst, mu, std, sample_len): """Write normalization stats as file attributes (metadata, not consumed by WavesFM loader).""" dst.attrs["mu"] = json.dumps(mu.flatten().tolist()) dst.attrs["std"] = json.dumps(std.flatten().tolist()) dst.attrs["sample_len"] = int(sample_len) def adapt(src_path, dst_path): """Main entry point: read RIA Hub HDF5, write WavesFM HDF5.""" with h5py.File(src_path, "r") as src: if "/data" not in src: logger.error("Missing required '/data' dataset in source file: %s", src_path) sys.exit(1) if src["/data"].dtype.kind != "c": logger.error("Expected complex64 IQ data, got dtype=%s", src["/data"].dtype) sys.exit(1) with h5py.File(dst_path, "w") as dst: adapt_iq(src, dst) logger.info("Adapted %s -> %s", src_path, dst_path) if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") parser = argparse.ArgumentParser(description="Convert RIA Hub HDF5 to WavesFM format") parser.add_argument("input", help="Source HDF5 (RIA Hub format)") parser.add_argument("output", help="Destination HDF5 (WavesFM format)") args = parser.parse_args() adapt(args.input, args.output)