Push Tracker
icc-28/scripts/adapt_dataset.py

128 lines
5.3 KiB
Python
Raw Permalink Normal View History

"""Convert RIA Hub HDF5 datasets to WavesFM training format.
Usage:
python scripts/adapt_dataset.py input.h5 output.h5
Input format (RIA Hub):
/data complex64 (N, C, T) or (N, T)
/metadata/metadata structured array with "label" column
file attrs: data_modality (optional)
Output format (WavesFM):
/sample float32 (N, 2, C, T) -- split I/Q, z-score normalized
/label int64 (N,) -- class indices
file attrs: mu (JSON), std (JSON), labels (JSON str), class_weights (np array), sample_len (int)
WavesFM reads labels via json.loads(attrs["labels"]) and class_weights via
torch.as_tensor(attrs["class_weights"], dtype=torch.float32). These two attributes
use different storage formats intentionally.
"""
import argparse
import json
import logging
import sys
import h5py
import numpy as np
logger = logging.getLogger(__name__)
def adapt_iq(src, dst):
"""Convert complex64 IQ to float32 split I/Q with z-score normalization.
Note: loads the full dataset into memory. For typical RIA Hub datasets
(< 2 GB) this is fine; chunked I/O can be added if larger inputs arise.
"""
data = src["/data"][:] # complex64 (N, C, T) or (N, T)
if data.ndim not in (2, 3):
raise ValueError(f"Expected 2D (N, T) or 3D (N, C, T) data, got {data.ndim}D shape {data.shape}")
if data.shape[0] == 0:
raise ValueError("Dataset is empty (N=0)")
if data.ndim == 2:
data = data[:, np.newaxis, :] # (N, 1, T)
# Split complex -> (N, 2, C, T) float32
# Axis 1 = [real, imag], confirmed by Ahmed: "instead of T complex numbers, 2T real numbers"
iq = np.stack([data.real, data.imag], axis=1).astype(np.float32)
# Per-channel z-score: mean/std over (N, T) for each of 2 I/Q components x C channels.
# Explicit float32 casts because numpy < 2.0 returns float64 from .mean()/.std()
# on float32 inputs.
mu = iq.mean(axis=(0, 3), keepdims=True).astype(np.float32) # (1, 2, C, 1)
std = iq.std(axis=(0, 3), keepdims=True).astype(np.float32)
std = np.where(std < 1e-6, np.float32(1.0), std) # floor preserving float32
iq_normed = (iq - mu) / std
dst.create_dataset("sample", data=iq_normed, compression="gzip", compression_opts=1, shuffle=True)
_write_labels(src, dst)
_write_stats(dst, mu, std, iq_normed.shape[-1])
def _write_labels(src, dst):
"""Extract labels from metadata, encode to int64 indices, write class_weights."""
if "/metadata/metadata" not in src:
raise ValueError("Required dataset '/metadata/metadata' not found in input file")
meta = src["/metadata/metadata"]
if not hasattr(meta.dtype, "names") or meta.dtype.names is None or "label" not in meta.dtype.names:
raise ValueError("Required 'label' column not found in '/metadata/metadata'")
raw = meta["label"][:]
expected = src["/data"].shape[0]
if len(raw) != expected:
raise ValueError(f"Label count {len(raw)} does not match sample count {expected}")
# Decode bytes/strings to Python strings
if raw.dtype.kind in ("S", "U", "O"):
decoded = [v.decode("utf-8", errors="replace") if isinstance(v, bytes) else str(v) for v in raw]
else:
decoded = [str(v) for v in raw]
# Label encoding: sorted unique -> 0-based contiguous indices
unique_labels = sorted(set(decoded))
label_map = {name: idx for idx, name in enumerate(unique_labels)}
indices = np.array([label_map[v] for v in decoded], dtype=np.int64)
dst.create_dataset("label", data=indices)
# WavesFM reads: json.loads(h5.attrs["labels"])
dst.attrs["labels"] = json.dumps(unique_labels)
# Inverse-frequency class weights as numpy array (NOT JSON).
# WavesFM reads: torch.as_tensor(h5.attrs["class_weights"], dtype=torch.float32)
counts = np.bincount(indices, minlength=len(unique_labels)).astype(np.float64)
weights = np.where(counts > 0, 1.0 / counts, 0.0)
if weights.sum() > 0:
weights = weights / weights.sum()
dst.attrs["class_weights"] = weights.astype(np.float32)
def _write_stats(dst, mu, std, sample_len):
"""Write normalization stats as file attributes (metadata, not consumed by WavesFM loader)."""
dst.attrs["mu"] = json.dumps(mu.flatten().tolist())
dst.attrs["std"] = json.dumps(std.flatten().tolist())
dst.attrs["sample_len"] = int(sample_len)
def adapt(src_path, dst_path):
"""Main entry point: read RIA Hub HDF5, write WavesFM HDF5."""
with h5py.File(src_path, "r") as src:
if "/data" not in src:
logger.error("Missing required '/data' dataset in source file: %s", src_path)
sys.exit(1)
if src["/data"].dtype.kind != "c":
logger.error("Expected complex64 IQ data, got dtype=%s", src["/data"].dtype)
sys.exit(1)
with h5py.File(dst_path, "w") as dst:
adapt_iq(src, dst)
logger.info("Adapted %s -> %s", src_path, dst_path)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
parser = argparse.ArgumentParser(description="Convert RIA Hub HDF5 to WavesFM format")
parser.add_argument("input", help="Source HDF5 (RIA Hub format)")
parser.add_argument("output", help="Destination HDF5 (WavesFM format)")
args = parser.parse_args()
adapt(args.input, args.output)