diff --git a/.riahub/workflows/train.yaml b/.riahub/workflows/train.yaml index 5dd78ed..603ad9c 100644 --- a/.riahub/workflows/train.yaml +++ b/.riahub/workflows/train.yaml @@ -371,4 +371,4 @@ jobs: ${{ env.WAVESFM_OUTPUT_DIR }}/best.pth ${{ env.WAVESFM_OUTPUT_DIR }}/log.txt if-no-files-found: warn -# committed at 2026-05-28T05:41:59.835552+00:00 +# re-committed at 2026-05-28T05:48:03+00:00 with adapter script diff --git a/scripts/adapt_dataset.py b/scripts/adapt_dataset.py new file mode 100644 index 0000000..43d3aa6 --- /dev/null +++ b/scripts/adapt_dataset.py @@ -0,0 +1,127 @@ +"""Convert RIA Hub HDF5 datasets to WavesFM training format. + +Usage: + python scripts/adapt_dataset.py input.h5 output.h5 + +Input format (RIA Hub): + /data complex64 (N, C, T) or (N, T) + /metadata/metadata structured array with "label" column + file attrs: data_modality (optional) + +Output format (WavesFM): + /sample float32 (N, 2, C, T) -- split I/Q, z-score normalized + /label int64 (N,) -- class indices + file attrs: mu (JSON), std (JSON), labels (JSON str), class_weights (np array), sample_len (int) + +WavesFM reads labels via json.loads(attrs["labels"]) and class_weights via +torch.as_tensor(attrs["class_weights"], dtype=torch.float32). These two attributes +use different storage formats intentionally. +""" + +import argparse +import json +import logging +import sys + +import h5py +import numpy as np + +logger = logging.getLogger(__name__) + + +def adapt_iq(src, dst): + """Convert complex64 IQ to float32 split I/Q with z-score normalization. + + Note: loads the full dataset into memory. For typical RIA Hub datasets + (< 2 GB) this is fine; chunked I/O can be added if larger inputs arise. + """ + data = src["/data"][:] # complex64 (N, C, T) or (N, T) + if data.ndim not in (2, 3): + raise ValueError(f"Expected 2D (N, T) or 3D (N, C, T) data, got {data.ndim}D shape {data.shape}") + if data.shape[0] == 0: + raise ValueError("Dataset is empty (N=0)") + if data.ndim == 2: + data = data[:, np.newaxis, :] # (N, 1, T) + + # Split complex -> (N, 2, C, T) float32 + # Axis 1 = [real, imag], confirmed by Ahmed: "instead of T complex numbers, 2T real numbers" + iq = np.stack([data.real, data.imag], axis=1).astype(np.float32) + + # Per-channel z-score: mean/std over (N, T) for each of 2 I/Q components x C channels. + # Explicit float32 casts because numpy < 2.0 returns float64 from .mean()/.std() + # on float32 inputs. + mu = iq.mean(axis=(0, 3), keepdims=True).astype(np.float32) # (1, 2, C, 1) + std = iq.std(axis=(0, 3), keepdims=True).astype(np.float32) + std = np.where(std < 1e-6, np.float32(1.0), std) # floor preserving float32 + iq_normed = (iq - mu) / std + + dst.create_dataset("sample", data=iq_normed, compression="gzip", compression_opts=1, shuffle=True) + _write_labels(src, dst) + _write_stats(dst, mu, std, iq_normed.shape[-1]) + + +def _write_labels(src, dst): + """Extract labels from metadata, encode to int64 indices, write class_weights.""" + if "/metadata/metadata" not in src: + raise ValueError("Required dataset '/metadata/metadata' not found in input file") + meta = src["/metadata/metadata"] + if not hasattr(meta.dtype, "names") or meta.dtype.names is None or "label" not in meta.dtype.names: + raise ValueError("Required 'label' column not found in '/metadata/metadata'") + raw = meta["label"][:] + expected = src["/data"].shape[0] + if len(raw) != expected: + raise ValueError(f"Label count {len(raw)} does not match sample count {expected}") + + # Decode bytes/strings to Python strings + if raw.dtype.kind in ("S", "U", "O"): + decoded = [v.decode("utf-8", errors="replace") if isinstance(v, bytes) else str(v) for v in raw] + else: + decoded = [str(v) for v in raw] + + # Label encoding: sorted unique -> 0-based contiguous indices + unique_labels = sorted(set(decoded)) + label_map = {name: idx for idx, name in enumerate(unique_labels)} + indices = np.array([label_map[v] for v in decoded], dtype=np.int64) + + dst.create_dataset("label", data=indices) + # WavesFM reads: json.loads(h5.attrs["labels"]) + dst.attrs["labels"] = json.dumps(unique_labels) + + # Inverse-frequency class weights as numpy array (NOT JSON). + # WavesFM reads: torch.as_tensor(h5.attrs["class_weights"], dtype=torch.float32) + counts = np.bincount(indices, minlength=len(unique_labels)).astype(np.float64) + weights = np.where(counts > 0, 1.0 / counts, 0.0) + if weights.sum() > 0: + weights = weights / weights.sum() + dst.attrs["class_weights"] = weights.astype(np.float32) + + +def _write_stats(dst, mu, std, sample_len): + """Write normalization stats as file attributes (metadata, not consumed by WavesFM loader).""" + dst.attrs["mu"] = json.dumps(mu.flatten().tolist()) + dst.attrs["std"] = json.dumps(std.flatten().tolist()) + dst.attrs["sample_len"] = int(sample_len) + + +def adapt(src_path, dst_path): + """Main entry point: read RIA Hub HDF5, write WavesFM HDF5.""" + with h5py.File(src_path, "r") as src: + if "/data" not in src: + logger.error("Missing required '/data' dataset in source file: %s", src_path) + sys.exit(1) + if src["/data"].dtype.kind != "c": + logger.error("Expected complex64 IQ data, got dtype=%s", src["/data"].dtype) + sys.exit(1) + with h5py.File(dst_path, "w") as dst: + adapt_iq(src, dst) + + logger.info("Adapted %s -> %s", src_path, dst_path) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + parser = argparse.ArgumentParser(description="Convert RIA Hub HDF5 to WavesFM format") + parser.add_argument("input", help="Source HDF5 (RIA Hub format)") + parser.add_argument("output", help="Destination HDF5 (WavesFM format)") + args = parser.parse_args() + adapt(args.input, args.output)