Add adapt_dataset.py + re-trigger training canary
Some checks are pending
WavesFM Fine-Tuning / WavesFM-Training (push) Waiting to run
Some checks are pending
WavesFM Fine-Tuning / WavesFM-Training (push) Waiting to run
Previous run (2577) was waiting; missing scripts/adapt_dataset.py. Adds the WAC-side adapter so the workflow's actions/checkout@v5 sparse-checkout finds the file at /scripts/adapt_dataset.py.
This commit is contained in:
parent
b8fd7d73d9
commit
501a8c5e1c
|
|
@ -371,4 +371,4 @@ jobs:
|
|||
${{ env.WAVESFM_OUTPUT_DIR }}/best.pth
|
||||
${{ env.WAVESFM_OUTPUT_DIR }}/log.txt
|
||||
if-no-files-found: warn
|
||||
# committed at 2026-05-28T05:41:59.835552+00:00
|
||||
# re-committed at 2026-05-28T05:48:03+00:00 with adapter script
|
||||
|
|
|
|||
127
scripts/adapt_dataset.py
Normal file
127
scripts/adapt_dataset.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Convert RIA Hub HDF5 datasets to WavesFM training format.
|
||||
|
||||
Usage:
|
||||
python scripts/adapt_dataset.py input.h5 output.h5
|
||||
|
||||
Input format (RIA Hub):
|
||||
/data complex64 (N, C, T) or (N, T)
|
||||
/metadata/metadata structured array with "label" column
|
||||
file attrs: data_modality (optional)
|
||||
|
||||
Output format (WavesFM):
|
||||
/sample float32 (N, 2, C, T) -- split I/Q, z-score normalized
|
||||
/label int64 (N,) -- class indices
|
||||
file attrs: mu (JSON), std (JSON), labels (JSON str), class_weights (np array), sample_len (int)
|
||||
|
||||
WavesFM reads labels via json.loads(attrs["labels"]) and class_weights via
|
||||
torch.as_tensor(attrs["class_weights"], dtype=torch.float32). These two attributes
|
||||
use different storage formats intentionally.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def adapt_iq(src, dst):
|
||||
"""Convert complex64 IQ to float32 split I/Q with z-score normalization.
|
||||
|
||||
Note: loads the full dataset into memory. For typical RIA Hub datasets
|
||||
(< 2 GB) this is fine; chunked I/O can be added if larger inputs arise.
|
||||
"""
|
||||
data = src["/data"][:] # complex64 (N, C, T) or (N, T)
|
||||
if data.ndim not in (2, 3):
|
||||
raise ValueError(f"Expected 2D (N, T) or 3D (N, C, T) data, got {data.ndim}D shape {data.shape}")
|
||||
if data.shape[0] == 0:
|
||||
raise ValueError("Dataset is empty (N=0)")
|
||||
if data.ndim == 2:
|
||||
data = data[:, np.newaxis, :] # (N, 1, T)
|
||||
|
||||
# Split complex -> (N, 2, C, T) float32
|
||||
# Axis 1 = [real, imag], confirmed by Ahmed: "instead of T complex numbers, 2T real numbers"
|
||||
iq = np.stack([data.real, data.imag], axis=1).astype(np.float32)
|
||||
|
||||
# Per-channel z-score: mean/std over (N, T) for each of 2 I/Q components x C channels.
|
||||
# Explicit float32 casts because numpy < 2.0 returns float64 from .mean()/.std()
|
||||
# on float32 inputs.
|
||||
mu = iq.mean(axis=(0, 3), keepdims=True).astype(np.float32) # (1, 2, C, 1)
|
||||
std = iq.std(axis=(0, 3), keepdims=True).astype(np.float32)
|
||||
std = np.where(std < 1e-6, np.float32(1.0), std) # floor preserving float32
|
||||
iq_normed = (iq - mu) / std
|
||||
|
||||
dst.create_dataset("sample", data=iq_normed, compression="gzip", compression_opts=1, shuffle=True)
|
||||
_write_labels(src, dst)
|
||||
_write_stats(dst, mu, std, iq_normed.shape[-1])
|
||||
|
||||
|
||||
def _write_labels(src, dst):
|
||||
"""Extract labels from metadata, encode to int64 indices, write class_weights."""
|
||||
if "/metadata/metadata" not in src:
|
||||
raise ValueError("Required dataset '/metadata/metadata' not found in input file")
|
||||
meta = src["/metadata/metadata"]
|
||||
if not hasattr(meta.dtype, "names") or meta.dtype.names is None or "label" not in meta.dtype.names:
|
||||
raise ValueError("Required 'label' column not found in '/metadata/metadata'")
|
||||
raw = meta["label"][:]
|
||||
expected = src["/data"].shape[0]
|
||||
if len(raw) != expected:
|
||||
raise ValueError(f"Label count {len(raw)} does not match sample count {expected}")
|
||||
|
||||
# Decode bytes/strings to Python strings
|
||||
if raw.dtype.kind in ("S", "U", "O"):
|
||||
decoded = [v.decode("utf-8", errors="replace") if isinstance(v, bytes) else str(v) for v in raw]
|
||||
else:
|
||||
decoded = [str(v) for v in raw]
|
||||
|
||||
# Label encoding: sorted unique -> 0-based contiguous indices
|
||||
unique_labels = sorted(set(decoded))
|
||||
label_map = {name: idx for idx, name in enumerate(unique_labels)}
|
||||
indices = np.array([label_map[v] for v in decoded], dtype=np.int64)
|
||||
|
||||
dst.create_dataset("label", data=indices)
|
||||
# WavesFM reads: json.loads(h5.attrs["labels"])
|
||||
dst.attrs["labels"] = json.dumps(unique_labels)
|
||||
|
||||
# Inverse-frequency class weights as numpy array (NOT JSON).
|
||||
# WavesFM reads: torch.as_tensor(h5.attrs["class_weights"], dtype=torch.float32)
|
||||
counts = np.bincount(indices, minlength=len(unique_labels)).astype(np.float64)
|
||||
weights = np.where(counts > 0, 1.0 / counts, 0.0)
|
||||
if weights.sum() > 0:
|
||||
weights = weights / weights.sum()
|
||||
dst.attrs["class_weights"] = weights.astype(np.float32)
|
||||
|
||||
|
||||
def _write_stats(dst, mu, std, sample_len):
|
||||
"""Write normalization stats as file attributes (metadata, not consumed by WavesFM loader)."""
|
||||
dst.attrs["mu"] = json.dumps(mu.flatten().tolist())
|
||||
dst.attrs["std"] = json.dumps(std.flatten().tolist())
|
||||
dst.attrs["sample_len"] = int(sample_len)
|
||||
|
||||
|
||||
def adapt(src_path, dst_path):
|
||||
"""Main entry point: read RIA Hub HDF5, write WavesFM HDF5."""
|
||||
with h5py.File(src_path, "r") as src:
|
||||
if "/data" not in src:
|
||||
logger.error("Missing required '/data' dataset in source file: %s", src_path)
|
||||
sys.exit(1)
|
||||
if src["/data"].dtype.kind != "c":
|
||||
logger.error("Expected complex64 IQ data, got dtype=%s", src["/data"].dtype)
|
||||
sys.exit(1)
|
||||
with h5py.File(dst_path, "w") as dst:
|
||||
adapt_iq(src, dst)
|
||||
|
||||
logger.info("Adapted %s -> %s", src_path, dst_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
parser = argparse.ArgumentParser(description="Convert RIA Hub HDF5 to WavesFM format")
|
||||
parser.add_argument("input", help="Source HDF5 (RIA Hub format)")
|
||||
parser.add_argument("output", help="Destination HDF5 (WavesFM format)")
|
||||
args = parser.parse_args()
|
||||
adapt(args.input, args.output)
|
||||
Loading…
Reference in New Issue
Block a user