reformats and campaign additions

This commit is contained in:
ben 2026-03-11 10:27:18 -04:00
parent b1e3ebf74f
commit 019b0c6f4b
38 changed files with 4230 additions and 1059 deletions

2143
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -127,5 +127,8 @@ exclude = '''
)/ )/
''' '''
[tool.pytest.ini_options]
pythonpath = ["src"]
[tool.isort] [tool.isort]
profile = "black" profile = "black"

View File

@ -956,8 +956,10 @@ def get_result_sizes( # noqa: C901 # TODO: Simplify function
# Check that each class that will be augmented does not already suffice target_size # Check that each class that will be augmented does not already suffice target_size
for cls_name, target_size_value in zip(classes_to_augment, target_size): for cls_name, target_size_value in zip(classes_to_augment, target_size):
if class_sizes[cls_name] >= target_size_value: if class_sizes[cls_name] >= target_size_value:
raise ValueError(f"""target_size of {target_size_value} is already sufficed for current size of raise ValueError(
{class_sizes[cls_name]} for class: {cls_name}""") f"""target_size of {target_size_value} is already sufficed for current size of
{class_sizes[cls_name]} for class: {cls_name}"""
)
for index, class_name in enumerate(classes_to_augment): for index, class_name in enumerate(classes_to_augment):
result_sizes[class_name] = target_size[index] result_sizes[class_name] = target_size[index]

View File

@ -316,6 +316,8 @@ def to_sigmf(
meta_dict = sigMF_metafile.ordered_metadata() meta_dict = sigMF_metafile.ordered_metadata()
meta_dict["ria"] = metadata meta_dict["ria"] = metadata
if overwrite and os.path.isfile(meta_file_path):
os.remove(meta_file_path)
sigMF_metafile.tofile(meta_file_path) sigMF_metafile.tofile(meta_file_path)

View File

@ -0,0 +1,26 @@
"""Orchestration layer for automated RF capture campaigns."""
from .campaign import (
CampaignConfig,
CaptureStep,
QAConfig,
RecorderConfig,
TransmitterConfig,
)
from .executor import CampaignExecutor, CampaignResult, StepResult
from .labeler import label_recording
from .qa import QAResult, check_recording
__all__ = [
"CampaignConfig",
"CaptureStep",
"QAConfig",
"RecorderConfig",
"TransmitterConfig",
"CampaignExecutor",
"CampaignResult",
"StepResult",
"label_recording",
"QAResult",
"check_recording",
]

View File

@ -0,0 +1,446 @@
"""Campaign configuration schema and YAML parser for orchestrated RF captures."""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import yaml
# ---------------------------------------------------------------------------
# Parsing helpers
# ---------------------------------------------------------------------------
def parse_duration(value: str | float | int) -> float:
"""Parse a duration string to seconds.
Accepts:
"30s" 30.0
"1.5m" or "1.5min" 90.0
"2h" 7200.0
30 (numeric) 30.0
"""
if isinstance(value, (int, float)):
return float(value)
value = str(value).strip()
match = re.fullmatch(r"([\d.]+)\s*(s|sec|m|min|h|hr)?", value, re.IGNORECASE)
if not match:
raise ValueError(f"Cannot parse duration: '{value}'")
amount = float(match.group(1))
unit = (match.group(2) or "s").lower()
if unit in ("h", "hr"):
return amount * 3600
if unit in ("m", "min"):
return amount * 60
return amount
def parse_frequency(value: str | float | int) -> float:
"""Parse a frequency string to Hz.
Accepts:
"2.45GHz" 2_450_000_000.0
"40MHz" 40_000_000.0
"915e6" 915_000_000.0
2.45e9 (numeric) 2_450_000_000.0
"""
if isinstance(value, (int, float)):
return float(value)
value = str(value).strip()
# Try bare numeric first (handles scientific notation like "915e6")
try:
return float(value)
except ValueError:
pass
# Handle suffix notation: "2.45GHz", "40MHz", "40M", "433k"
match = re.fullmatch(r"([\d.]+)\s*(k|M|G)(?:\s*Hz?)?", value, re.IGNORECASE)
if match:
amount = float(match.group(1))
suffix = match.group(2).upper()
return amount * {"K": 1e3, "M": 1e6, "G": 1e9}[suffix]
raise ValueError(f"Cannot parse frequency: '{value}'")
def parse_gain(value: str | float | int) -> float | str:
"""Parse a gain string.
Accepts:
"40dB" or "40 dB" 40.0
"auto" "auto"
40 (numeric) 40.0
"""
if isinstance(value, (int, float)):
return float(value)
value = str(value).strip()
if value.lower() == "auto":
return "auto"
match = re.fullmatch(r"([\d.+\-]+)\s*dB?", value, re.IGNORECASE)
if not match:
raise ValueError(f"Cannot parse gain: '{value}'")
return float(match.group(1))
def parse_bandwidth_mhz(value: str | float | int | None) -> Optional[float]:
"""Parse a bandwidth string to MHz.
Accepts:
"20MHz" 20.0
"40MHz" 40.0
20 (numeric, assumed MHz) 20.0
None None
"""
if value is None:
return None
if isinstance(value, (int, float)):
return float(value)
value = str(value).strip()
match = re.fullmatch(r"([\d.]+)\s*MHz?", value, re.IGNORECASE)
if match:
return float(match.group(1))
match = re.fullmatch(r"([\d.]+)", value)
if match:
return float(match.group(1))
raise ValueError(f"Cannot parse bandwidth: '{value}'")
# ---------------------------------------------------------------------------
# Config dataclasses
# ---------------------------------------------------------------------------
@dataclass
class RecorderConfig:
"""SDR recorder configuration."""
device: str
center_freq: float # Hz
sample_rate: float # Hz
gain: float | str # dB float, or "auto"
bandwidth: Optional[float] = None # Hz, None = match sample_rate
@classmethod
def from_dict(cls, d: dict) -> "RecorderConfig":
gain = parse_gain(d.get("gain", "auto"))
bandwidth_raw = d.get("bandwidth") or d.get("bandwidth_hz")
bandwidth = parse_frequency(bandwidth_raw) if bandwidth_raw else None
return cls(
device=str(d["device"]),
center_freq=parse_frequency(d["center_freq"]),
sample_rate=parse_frequency(d["sample_rate"]),
gain=gain,
bandwidth=bandwidth,
)
@dataclass
class CaptureStep:
"""A single timed capture within a transmitter schedule."""
duration: float # seconds
label: str # used as filename component
# WiFi-specific
channel: Optional[int] = None
bandwidth_mhz: Optional[float] = None # MHz
traffic: Optional[str] = None
# Bluetooth-specific
connection_interval_ms: Optional[float] = None
# Power (dBm), optional
power_dbm: Optional[float] = None
@classmethod
def from_dict(cls, d: dict, auto_label: bool = True) -> "CaptureStep":
duration = parse_duration(d["duration"])
label = d.get("label", "")
if not label and auto_label:
parts = []
if d.get("channel"):
parts.append(f"ch{d['channel']:02d}")
if d.get("bandwidth"):
bw = parse_bandwidth_mhz(d["bandwidth"])
parts.append(f"{int(bw)}mhz")
if d.get("traffic"):
parts.append(str(d["traffic"]).replace(" ", "_"))
label = "_".join(parts) if parts else "capture"
return cls(
duration=duration,
label=label,
channel=d.get("channel"),
bandwidth_mhz=parse_bandwidth_mhz(d.get("bandwidth")),
traffic=d.get("traffic"),
connection_interval_ms=d.get("connection_interval_ms"),
power_dbm=float(d["power"].rstrip("dBm").strip()) if d.get("power") else None,
)
@dataclass
class TransmitterConfig:
"""Configuration for a single transmitter device in the campaign."""
id: str
type: str # "wifi", "bluetooth", "sdr", "external"
control_method: str # "external_script" | "sdr"
schedule: list[CaptureStep]
# For external_script control
script: Optional[str] = None # path to control script
device: Optional[str] = None # e.g. "/dev/wlan0"
@classmethod
def from_dict(cls, d: dict) -> "TransmitterConfig":
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
return cls(
id=str(d["id"]),
type=str(d["type"]),
control_method=str(d.get("control_method", "external_script")),
schedule=schedule,
script=d.get("script"),
device=d.get("device"),
)
@dataclass
class QAConfig:
"""Quality assurance thresholds."""
snr_threshold_db: float = 10.0
min_duration_s: float = 25.0
flag_for_review: bool = True
@classmethod
def from_dict(cls, d: dict) -> "QAConfig":
return cls(
snr_threshold_db=float(str(d.get("snr_threshold", "10")).rstrip("dB").strip()),
min_duration_s=parse_duration(d.get("min_duration", "25s")),
flag_for_review=bool(d.get("flag_for_review", True)),
)
@dataclass
class OutputConfig:
"""Where to save captured recordings."""
format: str = "sigmf"
path: str = "recordings"
device_id: Optional[str] = None # for device-profile campaigns
repo: Optional[str] = None
@classmethod
def from_dict(cls, d: dict) -> "OutputConfig":
return cls(
format=str(d.get("format", "sigmf")),
path=str(d.get("path", "recordings")),
device_id=d.get("device_id"),
repo=d.get("repo"),
)
@dataclass
class CampaignConfig:
"""Full campaign configuration parsed from YAML."""
name: str
recorder: RecorderConfig
transmitters: list[TransmitterConfig]
qa: QAConfig = field(default_factory=QAConfig)
output: OutputConfig = field(default_factory=OutputConfig)
mode: str = "controlled_testbed"
# ---------------------------------------------------------------------------
# Loaders
# ---------------------------------------------------------------------------
@classmethod
def from_dict(cls, raw: dict) -> "CampaignConfig":
"""Build a CampaignConfig from a parsed dictionary.
Accepts the same structure as the campaign YAML, already loaded into
a Python dict (e.g. from a JSON HTTP request body).
Raises:
ValueError: If required fields are missing or malformed.
KeyError: If ``recorder`` key is absent.
"""
campaign_meta = raw.get("campaign", {})
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
if not transmitters:
raise ValueError("Campaign config must define at least one transmitter")
return cls(
name=str(campaign_meta.get("name", "unnamed")),
mode=str(campaign_meta.get("mode", "controlled_testbed")),
recorder=RecorderConfig.from_dict(raw["recorder"]),
transmitters=transmitters,
qa=QAConfig.from_dict(raw.get("qa", {})),
output=OutputConfig.from_dict(raw.get("output", {})),
)
@classmethod
def from_yaml(cls, path: str | Path) -> "CampaignConfig":
"""Load a full campaign config YAML.
Expected format::
campaign:
name: "wifi_capture_001"
mode: "controlled_testbed"
transmitters:
- id: "laptop_wifi"
type: "wifi"
control_method: "external_script"
script: "./scripts/wifi_control.sh"
device: "/dev/wlan0"
schedule:
- channel: 6
bandwidth: "20MHz"
traffic: "iperf_udp"
duration: "30s"
recorder:
device: "usrp_b210"
center_freq: "2.45GHz"
sample_rate: "40MHz"
gain: "40dB"
qa:
snr_threshold: "10dB"
min_duration: "25s"
flag_for_review: true
output:
format: "sigmf"
path: "./recordings"
"""
path = Path(path)
try:
with open(path) as f:
raw = yaml.safe_load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Campaign config not found: {path}")
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {path}: {e}")
campaign_meta = raw.get("campaign", {})
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
if not transmitters:
raise ValueError("Campaign config must define at least one transmitter")
return cls(
name=str(campaign_meta.get("name", path.stem)),
mode=str(campaign_meta.get("mode", "controlled_testbed")),
recorder=RecorderConfig.from_dict(raw["recorder"]),
transmitters=transmitters,
qa=QAConfig.from_dict(raw.get("qa", {})),
output=OutputConfig.from_dict(raw.get("output", {})),
)
@classmethod
def from_device_profile(cls, path: str | Path) -> "CampaignConfig":
"""Build a campaign config from an App 1 device profile YAML.
Expected format::
device:
name: "iPhone_13_WiFi"
type: "wifi"
protocol: "wifi_24ghz"
capture:
channels: [1, 6, 11] # WiFi only
bandwidth: "20MHz" # WiFi only
traffic_patterns: ["idle", "ping", "iperf_udp"]
duration_per_config: "30s"
recorder:
device: "usrp_b210"
center_freq: "2.45GHz"
sample_rate: "40MHz"
gain: "auto"
output:
path: "./recordings"
device_id: "iphone13_wifi_001"
For WiFi devices, schedule is expanded as channels × traffic_patterns.
For Bluetooth devices (no channels), schedule is traffic_patterns only.
"""
path = Path(path)
try:
with open(path) as f:
raw = yaml.safe_load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Device profile not found: {path}")
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {path}: {e}")
device = raw.get("device", {})
capture = raw.get("capture", {})
device_type = str(device.get("type", "wifi")).lower()
device_name = str(device.get("name", path.stem))
duration = parse_duration(capture.get("duration_per_config", "30s"))
traffic_patterns = capture.get("traffic_patterns", ["idle"])
# Build capture schedule
schedule: list[CaptureStep] = []
if device_type in ("wifi", "wifi_24ghz", "wifi_5ghz"):
channels = capture.get("channels", [6])
bw_str = capture.get("bandwidth", "20MHz")
bw_mhz = parse_bandwidth_mhz(bw_str)
for ch in channels:
for traffic in traffic_patterns:
label = f"ch{ch:02d}_{int(bw_mhz)}mhz_{traffic}"
schedule.append(
CaptureStep(
duration=duration,
label=label,
channel=ch,
bandwidth_mhz=bw_mhz,
traffic=traffic,
)
)
else:
# Bluetooth / generic — no channels
for traffic in traffic_patterns:
schedule.append(
CaptureStep(
duration=duration,
label=traffic,
traffic=traffic,
)
)
device_id = raw.get("output", {}).get("device_id", device_name.lower().replace(" ", "_"))
transmitter = TransmitterConfig(
id=device_id,
type=device_type,
control_method=str(capture.get("control_method", "external_script")),
schedule=schedule,
script=capture.get("script"),
device=capture.get("device"),
)
return cls(
name=f"enroll_{device_id}",
mode="controlled_testbed",
recorder=RecorderConfig.from_dict(raw["recorder"]),
transmitters=[transmitter],
qa=QAConfig.from_dict(raw.get("qa", {})),
output=OutputConfig.from_dict(raw.get("output", {})),
)
def total_capture_time_s(self) -> float:
"""Sum of all step durations across all transmitters."""
return sum(step.duration for tx in self.transmitters for step in tx.schedule)
def total_steps(self) -> int:
"""Total number of capture steps across all transmitters."""
return sum(len(tx.schedule) for tx in self.transmitters)

View File

@ -0,0 +1,423 @@
"""Campaign executor: runs a capture campaign end-to-end."""
from __future__ import annotations
import json
import logging
import subprocess
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.io.recording import to_sigmf
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
from .labeler import build_output_filename, label_recording
from .qa import QAResult, check_recording
logger = logging.getLogger(__name__)
# Device name aliases: campaign YAML names → get_sdr_device() names
_DEVICE_ALIASES = {
"usrp_b210": "usrp",
"usrp_b200": "usrp",
"usrp": "usrp",
"plutosdr": "pluto",
"pluto": "pluto",
"hackrf": "hackrf",
"hackrf_one": "hackrf",
"bladerf": "bladerf",
"rtlsdr": "rtlsdr",
"rtl_sdr": "rtlsdr",
"thinkrf": "thinkrf",
}
@dataclass
class StepResult:
"""Outcome of a single capture step."""
transmitter_id: str
step_label: str
output_path: Optional[str]
qa: QAResult
capture_timestamp: float
error: Optional[str] = None
@property
def ok(self) -> bool:
return self.error is None and self.qa.passed
def to_dict(self) -> dict:
return {
"transmitter_id": self.transmitter_id,
"step_label": self.step_label,
"output_path": self.output_path,
"capture_timestamp": self.capture_timestamp,
"qa": self.qa.to_dict(),
"error": self.error,
}
@dataclass
class CampaignResult:
"""Aggregate outcome of a full campaign."""
campaign_name: str
steps: list[StepResult] = field(default_factory=list)
start_time: float = field(default_factory=time.time)
end_time: Optional[float] = None
@property
def total_steps(self) -> int:
return len(self.steps)
@property
def passed(self) -> int:
return sum(1 for s in self.steps if s.ok)
@property
def flagged(self) -> int:
return sum(1 for s in self.steps if not s.error and s.qa.flagged)
@property
def failed(self) -> int:
return sum(1 for s in self.steps if s.error or not s.qa.passed)
@property
def duration_s(self) -> float:
if self.end_time:
return self.end_time - self.start_time
return time.time() - self.start_time
def to_dict(self) -> dict:
return {
"campaign_name": self.campaign_name,
"total_steps": self.total_steps,
"passed": self.passed,
"flagged": self.flagged,
"failed": self.failed,
"duration_s": round(self.duration_s, 1),
"steps": [s.to_dict() for s in self.steps],
}
def write_report(self, path: str | Path) -> None:
"""Write a JSON QA report to disk."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
json.dump(self.to_dict(), f, indent=2)
logger.info(f"QA report written to {path}")
# ---------------------------------------------------------------------------
# External script interface
# ---------------------------------------------------------------------------
def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
"""Run an external control script and return stdout.
The script is called as::
<script> <arg1> <arg2> ...
A non-zero return code raises RuntimeError.
Args:
script: Path to executable script.
*args: Positional arguments forwarded to the script.
timeout: Maximum seconds to wait.
Returns:
Script stdout as a string.
"""
cmd = [script, *args]
logger.debug(f"Running script: {' '.join(cmd)}")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=timeout,
)
except subprocess.TimeoutExpired:
raise RuntimeError(f"Script timed out after {timeout}s: {script}")
except FileNotFoundError:
raise RuntimeError(f"Script not found: {script}")
if result.returncode != 0:
raise RuntimeError(f"Script exited {result.returncode}: {result.stderr.strip() or result.stdout.strip()}")
return result.stdout.strip()
# ---------------------------------------------------------------------------
# Campaign executor
# ---------------------------------------------------------------------------
class CampaignExecutor:
"""Executes a :class:`CampaignConfig` end-to-end.
Initialises the SDR recorder once, then for each (transmitter, step):
1. Configures the transmitter (via external script or SDR TX)
2. Records IQ samples
3. Labels the recording with device/config metadata
4. Runs QA checks
5. Saves the recording to disk
6. Stops/resets the transmitter
Args:
config: Parsed campaign configuration.
progress_cb: Optional callback ``(step_index, total_steps, step_result)``
called after each step completes. Useful for status reporting.
verbose: Enable debug logging.
"""
def __init__(
self,
config: CampaignConfig,
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
verbose: bool = False,
):
self.config = config
self.progress_cb = progress_cb
self._sdr = None
if verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def run(self) -> CampaignResult:
"""Execute the full campaign and return a :class:`CampaignResult`.
Initialises the SDR, runs all steps across all transmitters,
then closes the SDR. If SDR initialisation fails the exception
propagates immediately (nothing is captured).
"""
result = CampaignResult(campaign_name=self.config.name)
logger.info(
f"Starting campaign '{self.config.name}': "
f"{self.config.total_steps()} steps, "
f"~{self.config.total_capture_time_s():.0f}s capture time"
)
self._init_sdr()
try:
total = self.config.total_steps()
step_index = 0
for transmitter in self.config.transmitters:
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
for step in transmitter.schedule:
step_result = self._execute_step(transmitter, step)
result.steps.append(step_result)
step_index += 1
if self.progress_cb:
self.progress_cb(step_index, total, step_result)
if step_result.error:
logger.warning(f"Step '{step.label}' error: {step_result.error}")
elif step_result.qa.flagged:
logger.warning(f"Step '{step.label}' flagged for review: " + "; ".join(step_result.qa.issues))
else:
logger.info(
f"Step '{step.label}' OK "
f"(SNR {step_result.qa.snr_db:.1f} dB, "
f"{step_result.qa.duration_s:.1f}s)"
)
finally:
self._close_sdr()
result.end_time = time.time()
logger.info(
f"Campaign complete: {result.passed}/{result.total_steps} passed, "
f"{result.flagged} flagged, {result.failed} failed"
)
return result
# ------------------------------------------------------------------
# SDR management
# ------------------------------------------------------------------
def _init_sdr(self) -> None:
"""Initialise and configure the SDR recorder."""
from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
rec = self.config.recorder
device_name = _DEVICE_ALIASES.get(rec.device.lower(), rec.device.lower())
logger.info(f"Initialising SDR: {device_name} @ {rec.center_freq/1e6:.2f} MHz")
self._sdr = get_sdr_device(device_name)
gain = None if rec.gain == "auto" else float(rec.gain)
self._sdr.init_rx(
sample_rate=rec.sample_rate,
center_frequency=rec.center_freq,
gain=gain,
channel=0,
)
if rec.bandwidth and hasattr(self._sdr, "set_rx_bandwidth"):
self._sdr.set_rx_bandwidth(rec.bandwidth)
def _close_sdr(self) -> None:
if self._sdr is not None:
try:
self._sdr.close()
except Exception as e:
logger.warning(f"SDR close error: {e}")
self._sdr = None
def _record(self, duration_s: float) -> Recording:
"""Capture ``duration_s`` seconds of IQ samples."""
num_samples = int(duration_s * self.config.recorder.sample_rate)
return self._sdr.record(num_samples=num_samples)
# ------------------------------------------------------------------
# Step execution
# ------------------------------------------------------------------
def _execute_step(self, transmitter: TransmitterConfig, step: CaptureStep) -> StepResult:
"""Run a single capture step.
Returns:
StepResult with QA outcome and output path (or error string).
"""
capture_timestamp = time.time()
output_path: Optional[str] = None
try:
self._start_transmitter(transmitter, step)
recording = self._record(step.duration)
self._stop_transmitter(transmitter, step)
except Exception as e:
# Best-effort stop on error
try:
self._stop_transmitter(transmitter, step)
except Exception:
pass
return StepResult(
transmitter_id=transmitter.id,
step_label=step.label,
output_path=None,
qa=QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=[f"Capture error: {e}"]),
capture_timestamp=capture_timestamp,
error=str(e),
)
# Label recording
recording = label_recording(
recording=recording,
device_id=transmitter.id,
step=step,
capture_timestamp=capture_timestamp,
campaign_name=self.config.name,
)
# QA
qa_result = check_recording(recording, self.config.qa)
# Save
try:
output_path = self._save(recording, transmitter.id, step)
except Exception as e:
return StepResult(
transmitter_id=transmitter.id,
step_label=step.label,
output_path=None,
qa=qa_result,
capture_timestamp=capture_timestamp,
error=f"Save failed: {e}",
)
return StepResult(
transmitter_id=transmitter.id,
step_label=step.label,
output_path=output_path,
qa=qa_result,
capture_timestamp=capture_timestamp,
)
# ------------------------------------------------------------------
# Transmitter control (external script interface)
# ------------------------------------------------------------------
def _start_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
"""Configure the transmitter for this step.
For ``external_script`` control method the script is called as::
<script> configure <step_params_json>
where ``step_params_json`` is a JSON object with channel, bandwidth,
traffic, etc. The script is responsible for applying the configuration
and returning promptly (i.e. not blocking for the capture duration).
For SDR transmitters this is a no-op placeholder (TX not yet implemented).
"""
if transmitter.control_method == "external_script":
if not transmitter.script:
logger.debug(f"No script configured for {transmitter.id}, skipping configure")
return
params = self._step_params_json(transmitter, step)
_run_script(transmitter.script, "configure", params)
elif transmitter.control_method == "sdr":
logger.debug("SDR TX not yet implemented — skipping start")
else:
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
def _stop_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
"""Signal the transmitter to stop.
Calls ``<script> stop`` for external_script transmitters.
"""
if transmitter.control_method == "external_script":
if not transmitter.script:
return
try:
_run_script(transmitter.script, "stop")
except Exception as e:
logger.warning(f"Script stop failed for {transmitter.id}: {e}")
@staticmethod
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
"""Serialise step parameters to a JSON string for the control script."""
params: dict = {"device": transmitter.device or ""}
if step.channel is not None:
params["channel"] = step.channel
if step.bandwidth_mhz is not None:
params["bandwidth_mhz"] = step.bandwidth_mhz
if step.traffic is not None:
params["traffic"] = step.traffic
if step.power_dbm is not None:
params["power_dbm"] = step.power_dbm
return json.dumps(params)
# ------------------------------------------------------------------
# Output
# ------------------------------------------------------------------
def _save(self, recording: Recording, device_id: str, step: CaptureStep) -> str:
"""Save a recording to disk and return the data file path."""
out = self.config.output
rel_filename = build_output_filename(device_id, step)
out_dir = Path(out.path)
# build_output_filename returns "<device_id>/<label>"
# to_sigmf needs filename (base) and path (dir) separately
parts = Path(rel_filename)
subdir = out_dir / parts.parent
subdir.mkdir(parents=True, exist_ok=True)
base = parts.name
to_sigmf(recording, filename=base, path=str(subdir), overwrite=True)
return str(subdir / f"{base}.sigmf-data")

View File

@ -0,0 +1,77 @@
"""Timestamp-based labeling for captured recordings."""
from __future__ import annotations
from typing import Optional
from ria_toolkit_oss.datatypes.recording import Recording
from .campaign import CaptureStep
def label_recording(
recording: Recording,
device_id: str,
step: CaptureStep,
capture_timestamp: float,
campaign_name: Optional[str] = None,
) -> Recording:
"""Apply device identity and capture configuration labels to a recording's metadata.
Labels are stored in the ``ria:*`` namespace when the recording is saved
as SigMF, via the existing ``update_metadata`` mechanism.
Args:
recording: The recording to label.
device_id: Identifier for the transmitting device (e.g. "iphone13_wifi_001").
step: The capture step that was active during this recording.
capture_timestamp: Unix timestamp (float) of when capture started.
campaign_name: Optional campaign name for cross-recording reference.
Returns:
The same recording with updated metadata.
"""
recording.update_metadata("device_id", device_id)
recording.update_metadata("capture_timestamp", capture_timestamp)
recording.update_metadata("step_label", step.label)
recording.update_metadata("step_duration_s", step.duration)
if campaign_name:
recording.update_metadata("campaign", campaign_name)
# WiFi-specific labels
if step.channel is not None:
recording.update_metadata("wifi_channel", step.channel)
if step.bandwidth_mhz is not None:
recording.update_metadata("wifi_bandwidth_mhz", step.bandwidth_mhz)
# Bluetooth-specific labels
if step.connection_interval_ms is not None:
recording.update_metadata("bt_connection_interval_ms", step.connection_interval_ms)
# Traffic pattern (WiFi + BT)
if step.traffic is not None:
recording.update_metadata("traffic_pattern", step.traffic)
# TX power
if step.power_dbm is not None:
recording.update_metadata("tx_power_dbm", step.power_dbm)
return recording
def build_output_filename(device_id: str, step: CaptureStep) -> str:
"""Generate a deterministic filename for a labeled recording.
Format: ``<device_id>/<step_label>``
Args:
device_id: Device identifier string.
step: Capture step.
Returns:
Relative path string (no extension) to use as ``filename`` in ``to_sigmf()``.
"""
safe_id = device_id.replace("/", "_").replace(" ", "_")
safe_label = step.label.replace("/", "_").replace(" ", "_")
return f"{safe_id}/{safe_label}"

View File

@ -0,0 +1,109 @@
"""QA metrics for captured RF recordings."""
from __future__ import annotations
from dataclasses import dataclass, field
import numpy as np
from ria_toolkit_oss.datatypes.recording import Recording
from .campaign import QAConfig
@dataclass
class QAResult:
"""Result of QA checks on a single recording."""
passed: bool
flagged: bool # True if any metric is below threshold (but not hard-failed)
snr_db: float
duration_s: float
issues: list[str] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"passed": self.passed,
"flagged": self.flagged,
"snr_db": round(self.snr_db, 2),
"duration_s": round(self.duration_s, 3),
"issues": self.issues,
}
def estimate_snr_db(samples: np.ndarray, signal_fraction: float = 0.7) -> float:
"""Estimate SNR from IQ samples using PSD-based signal/noise separation.
Computes an FFT of the samples and assumes the top ``signal_fraction``
of power bins are signal and the remainder are noise. This is a
heuristic appropriate for a controlled testbed where a single dominant
signal is expected.
Args:
samples: 1-D complex array of IQ samples.
signal_fraction: Fraction of PSD bins to treat as signal (01).
Returns:
Estimated SNR in dB, or 0.0 if the noise floor is zero.
"""
n_fft = min(4096, len(samples))
window = np.hanning(n_fft)
psd = np.abs(np.fft.fft(samples[:n_fft] * window)) ** 2
psd_sorted = np.sort(psd)[::-1]
n_signal = max(1, int(n_fft * signal_fraction))
signal_power = psd_sorted[:n_signal].mean()
noise_power = psd_sorted[n_signal:].mean()
if noise_power <= 0.0:
return 0.0
return float(10.0 * np.log10(signal_power / noise_power))
def check_recording(recording: Recording, config: QAConfig) -> QAResult:
"""Run QA checks on a recording against the campaign QA config.
Checks performed:
- Duration: number of samples / sample_rate >= min_duration_s
- SNR: estimated SNR >= snr_threshold_db
Args:
recording: Recording to evaluate.
config: QA thresholds from the campaign config.
Returns:
QAResult with pass/flag status and per-metric details.
"""
issues: list[str] = []
flagged = False
# --- Duration check ---
sample_rate = recording.metadata.get("sample_rate", 1.0)
n_samples = recording.data.shape[-1]
duration_s = n_samples / sample_rate if sample_rate else 0.0
if duration_s < config.min_duration_s:
issues.append(f"Duration too short: {duration_s:.1f}s < {config.min_duration_s:.1f}s threshold")
flagged = True
# --- SNR check ---
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
snr_db = estimate_snr_db(samples)
if snr_db < config.snr_threshold_db:
issues.append(f"SNR below threshold: {snr_db:.1f} dB < {config.snr_threshold_db:.1f} dB")
flagged = True
# In flag_for_review mode: flag but don't hard-fail
if config.flag_for_review:
passed = True # always accept; human reviews flagged recordings
else:
passed = not flagged
return QAResult(
passed=passed,
flagged=flagged,
snr_db=snr_db,
duration_s=duration_s,
issues=issues,
)

View File

@ -474,8 +474,10 @@ class Blade(SDR):
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets \ raise SDRParameterError(
the gain relative to the maximum possible gain.") "When gain_mode = 'relative', gain must be < 0. This sets \
the gain relative to the maximum possible gain."
)
else: else:
abs_gain = rx_gain_max + gain abs_gain = rx_gain_max + gain
else: else:
@ -548,8 +550,10 @@ class Blade(SDR):
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\ raise SDRParameterError(
the gain relative to the maximum possible gain.") "When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain."
)
else: else:
abs_gain = tx_gain_max + gain abs_gain = tx_gain_max + gain
else: else:

View File

@ -172,8 +172,10 @@ class HackRF(SDR):
tx_gain_max = 47 tx_gain_max = 47
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This \ raise SDRParameterError(
sets the gain relative to the maximum possible gain.") "When gain_mode = 'relative', gain must be < 0. This \
sets the gain relative to the maximum possible gain."
)
else: else:
abs_gain = tx_gain_max + gain abs_gain = tx_gain_max + gain
else: else:

View File

@ -274,16 +274,20 @@ class Pluto(SDR):
data = [self._convert_tx_samples(samples), self._convert_tx_samples(samples)] data = [self._convert_tx_samples(samples), self._convert_tx_samples(samples)]
else: else:
if len(recording) > 2: if len(recording) > 2:
warnings.warn("More recordings were provided than channels in the Pluto. \ warnings.warn(
Only the first two recordings will be used") "More recordings were provided than channels in the Pluto. \
Only the first two recordings will be used"
)
sample0 = self._convert_tx_samples(recording.data[0]) sample0 = self._convert_tx_samples(recording.data[0])
sample1 = self._convert_tx_samples(recording.data[1]) sample1 = self._convert_tx_samples(recording.data[1])
data = [sample0, sample1] data = [sample0, sample1]
elif isinstance(recording, list): elif isinstance(recording, list):
if len(recording) > 2: if len(recording) > 2:
warnings.warn("More recordings were provided than channels in the Pluto. \ warnings.warn(
Only the first two recordings will be used") "More recordings were provided than channels in the Pluto. \
Only the first two recordings will be used"
)
if isinstance(recording[0], np.ndarray): if isinstance(recording[0], np.ndarray):
data = [self._convert_tx_samples(recording[0]), self._convert_tx_samples(recording[1])] data = [self._convert_tx_samples(recording[0]), self._convert_tx_samples(recording[1])]
@ -423,8 +427,10 @@ class Pluto(SDR):
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets \ raise SDRParameterError(
the gain relative to the maximum possible gain.") "When gain_mode = 'relative', gain must be < 0. This sets \
the gain relative to the maximum possible gain."
)
else: else:
abs_gain = rx_gain_max + gain abs_gain = rx_gain_max + gain
else: else:
@ -534,8 +540,10 @@ class Pluto(SDR):
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\ raise SDRParameterError(
the gain relative to the maximum possible gain.") "When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain."
)
else: else:
abs_gain = tx_gain_max + gain abs_gain = tx_gain_max + gain
else: else:

View File

@ -131,15 +131,19 @@ class RTLSDR(SDR):
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\ raise SDRParameterError(
the gain relative to the maximum possible gain.") "When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain."
)
target_gain = max_gain + gain target_gain = max_gain + gain
else: else:
target_gain = gain target_gain = gain
if target_gain < min_gain or target_gain > max_gain: if target_gain < min_gain or target_gain > max_gain:
print(f"Requested gain {target_gain} dB out of range;\ print(
clamping to valid span {min_gain}-{max_gain} dB.") f"Requested gain {target_gain} dB out of range;\
clamping to valid span {min_gain}-{max_gain} dB."
)
target_gain = min(max(target_gain, min_gain), max_gain) target_gain = min(max(target_gain, min_gain), max_gain)
target_gain = min(available_gains, key=lambda g: abs(g - target_gain)) target_gain = min(available_gains, key=lambda g: abs(g - target_gain))

View File

@ -392,8 +392,10 @@ class ThinkRF(SDR):
actual_sample_rate = self.BASE_SAMPLE_RATE / decimation actual_sample_rate = self.BASE_SAMPLE_RATE / decimation
if abs(actual_sample_rate - requested_sample_rate) > 1e3: # More than 1 kHz difference if abs(actual_sample_rate - requested_sample_rate) > 1e3: # More than 1 kHz difference
print(f"ThinkRF: Requested {requested_sample_rate/1e6:.2f} MS/s → \ print(
Using decimation={decimation} ({actual_sample_rate/1e6:.2f} MS/s)") f"ThinkRF: Requested {requested_sample_rate/1e6:.2f} MS/s → \
Using decimation={decimation} ({actual_sample_rate/1e6:.2f} MS/s)"
)
return decimation, actual_sample_rate return decimation, actual_sample_rate

View File

@ -148,8 +148,10 @@ class USRP(SDR):
gain_range = self.usrp.get_rx_gain_range() gain_range = self.usrp.get_rx_gain_range()
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\ raise SDRParameterError(
the gain relative to the maximum possible gain.") "When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain."
)
else: else:
# set gain relative to max # set gain relative to max
abs_gain = gain_range.stop() + gain abs_gain = gain_range.stop() + gain
@ -354,8 +356,10 @@ class USRP(SDR):
gain_range = self.usrp.get_tx_gain_range() gain_range = self.usrp.get_tx_gain_range()
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\ raise SDRParameterError(
the gain relative to the maximum possible gain.") "When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain."
)
else: else:
# set gain relative to max # set gain relative to max
abs_gain = gain_range.stop() + gain abs_gain = gain_range.stop() + gain

View File

@ -0,0 +1,5 @@
"""RT-OSS HTTP server for RIA Hub integration."""
from .app import create_app
__all__ = ["create_app"]

View File

@ -0,0 +1,48 @@
"""FastAPI application factory for the RT-OSS HTTP server."""
from fastapi import Depends, FastAPI
from .auth import require_api_key
from .routers import inference, orchestrator
def create_app(api_key: str = "") -> FastAPI:
"""Create and configure the RT-OSS FastAPI application.
Args:
api_key: Secret key required in the ``X-API-Key`` request header.
Pass an empty string to disable authentication (development only).
Returns:
Configured FastAPI application instance.
"""
app = FastAPI(
title="RIA Toolkit OSS Server",
version="0.1.0",
description=(
"HTTP API for RT-OSS campaign orchestration and RF zone inference. "
"All endpoints (except /health) require the X-API-Key header when "
"an API key is configured."
),
)
app.state.api_key = api_key
app.include_router(
orchestrator.router,
prefix="/orchestrator",
tags=["Orchestrator"],
dependencies=[Depends(require_api_key)],
)
app.include_router(
inference.router,
prefix="/inference",
tags=["Inference"],
dependencies=[Depends(require_api_key)],
)
@app.get("/health", tags=["Health"])
async def health():
"""Health check — always returns 200."""
return {"status": "ok"}
return app

View File

@ -0,0 +1,25 @@
"""API key authentication dependency."""
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import APIKeyHeader
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def require_api_key(
request: Request,
api_key: str | None = Depends(_api_key_header),
) -> None:
"""FastAPI dependency that enforces X-API-Key header authentication.
If no API key is configured on the server (empty string), all requests
are allowed this is intended for local development only.
"""
expected: str = request.app.state.api_key
if not expected:
return # dev mode: no key set, allow all
if api_key != expected:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing API key",
)

View File

@ -0,0 +1,77 @@
"""Pydantic request and response models for the RT-OSS HTTP server."""
from __future__ import annotations
from pydantic import BaseModel
# ---------------------------------------------------------------------------
# Orchestrator
# ---------------------------------------------------------------------------
class DeployRequest(BaseModel):
config: dict
class DeployResponse(BaseModel):
campaign_id: str
class CampaignStatusResponse(BaseModel):
campaign_id: str
status: str
config_name: str
progress: int
total_steps: int
started_at: float
ended_at: float | None = None
result: dict | None = None
error: str | None = None
class CancelResponse(BaseModel):
campaign_id: str
cancelled: bool
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
class SdrConfig(BaseModel):
device: str
center_freq: float
sample_rate: float
gain: float | str = "auto"
class LoadModelRequest(BaseModel):
model_path: str
label_map: dict[str, int] # class_name -> class_index
class LoadModelResponse(BaseModel):
loaded: bool
model_path: str
num_classes: int
class StartInferenceRequest(BaseModel):
sdr_config: SdrConfig
class StartInferenceResponse(BaseModel):
running: bool
class StopInferenceResponse(BaseModel):
stopped: bool
class InferenceStatusResponse(BaseModel):
timestamp: float
prediction: str
confidence: float
snr_db: float
zone: str | None = None

View File

@ -0,0 +1,183 @@
"""Inference routes: model loading, inference loop control, and status polling."""
from __future__ import annotations
import logging
import threading
import time
import numpy as np
from fastapi import APIRouter, HTTPException, status
from scipy.special import softmax
from ..models import (
InferenceStatusResponse,
LoadModelRequest,
LoadModelResponse,
StartInferenceRequest,
StartInferenceResponse,
StopInferenceResponse,
)
from ..state import InferenceState, get_inference, set_inference
router = APIRouter()
logger = logging.getLogger(__name__)
_INFERENCE_NUM_SAMPLES = 4096
def _load_onnx_session(model_path: str):
try:
import onnxruntime as ort
except ImportError:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="onnxruntime not installed. Install with: pip install ria-toolkit-oss[server]",
)
try:
return ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
except Exception as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Failed to load ONNX model: {e}")
def _preprocess_samples(samples: np.ndarray, expected_shape: tuple) -> np.ndarray:
"""Reshape complex IQ samples to float32 matching the model's expected input.
Supports ``(batch, 2*N)`` interleaved and ``(batch, 2, N)`` two-channel conventions.
"""
iq = samples.astype(np.complex64)
i_ch, q_ch = iq.real, iq.imag
if len(expected_shape) == 2:
n = expected_shape[1] // 2
interleaved = np.empty(expected_shape[1], dtype=np.float32)
interleaved[0::2] = i_ch[:n]
interleaved[1::2] = q_ch[:n]
return interleaved.reshape(1, -1)
elif len(expected_shape) == 3:
n = expected_shape[2]
return np.stack([i_ch[:n], q_ch[:n]], axis=0).astype(np.float32).reshape(1, 2, n)
else:
raise ValueError(f"Unsupported model input shape: {expected_shape}")
def _stop_current_inference(state: InferenceState, timeout: float = 5.0) -> None:
state.stop_event.set()
if state.thread and state.thread.is_alive():
state.thread.join(timeout=timeout)
if state.thread.is_alive():
logger.warning("Inference thread did not stop within %.1fs; SDR resources may not be released", timeout)
def _inference_loop(state: InferenceState, sdr) -> None:
from ria_toolkit_oss.orchestration.qa import estimate_snr_db
session = state.session
input_name = session.get_inputs()[0].name
expected_shape = tuple(
d if isinstance(d, int) and d > 0 else _INFERENCE_NUM_SAMPLES for d in session.get_inputs()[0].shape
)
try:
while not state.stop_event.is_set():
recording = sdr.record(num_samples=_INFERENCE_NUM_SAMPLES)
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
snr_db = estimate_snr_db(samples)
try:
model_input = _preprocess_samples(samples, expected_shape)
logits = session.run(None, {input_name: model_input})[0][0].astype(np.float32)
probs = softmax(logits)
pred_idx = int(np.argmax(probs))
prediction = state.index_to_label.get(pred_idx, str(pred_idx))
except Exception:
continue
state.set_latest(
{
"timestamp": time.time(),
"prediction": prediction,
"confidence": round(float(probs[pred_idx]), 4),
"snr_db": round(snr_db, 2),
"zone": None,
}
)
finally:
try:
sdr.close()
except Exception:
pass
state.running = False
@router.post("/load", response_model=LoadModelResponse)
async def load_model(request: LoadModelRequest):
"""Load an ONNX model. Stops any running inference first.
``label_map`` maps class names to integer indices (e.g. ``{"zone_a": 0}``).
"""
existing = get_inference()
if existing and existing.running:
_stop_current_inference(existing)
session = _load_onnx_session(request.model_path)
set_inference(
InferenceState(
model_path=request.model_path,
label_map=request.label_map,
index_to_label={v: k for k, v in request.label_map.items()},
session=session,
)
)
return LoadModelResponse(loaded=True, model_path=request.model_path, num_classes=len(request.label_map))
@router.post("/start", response_model=StartInferenceResponse)
async def start_inference(request: StartInferenceRequest):
"""Start continuous inference. Requires a model to be loaded first."""
state = get_inference()
if not state:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail="No model loaded. Call POST /inference/load first."
)
if state.running:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Inference is already running.")
try:
from ria_toolkit_oss.orchestration.executor import _DEVICE_ALIASES
from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
except ImportError as e:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"SDR import failed: {e}")
sdr_cfg = request.sdr_config
try:
sdr = get_sdr_device(_DEVICE_ALIASES.get(sdr_cfg.device.lower(), sdr_cfg.device.lower()))
gain = None if sdr_cfg.gain == "auto" else float(sdr_cfg.gain)
sdr.init_rx(sample_rate=sdr_cfg.sample_rate, center_frequency=sdr_cfg.center_freq, gain=gain, channel=0)
except Exception as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"SDR initialisation failed: {e}")
state.stop_event.clear()
state.running = True
state.thread = threading.Thread(target=_inference_loop, args=(state, sdr), daemon=True)
state.thread.start()
return StartInferenceResponse(running=True)
@router.post("/stop", response_model=StopInferenceResponse)
async def stop_inference():
"""Stop the running inference loop."""
state = get_inference()
if not state or not state.running:
return StopInferenceResponse(stopped=False)
_stop_current_inference(state)
return StopInferenceResponse(stopped=True)
@router.get("/status", response_model=InferenceStatusResponse | None)
async def inference_status():
"""Return the latest inference result, or null if none available yet."""
state = get_inference()
if not state:
return None
latest = state.get_latest()
return InferenceStatusResponse(**latest) if latest else None

View File

@ -0,0 +1,102 @@
"""Orchestrator routes: campaign deployment, status, and cancellation."""
from __future__ import annotations
import threading
import time
import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, status
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
from ..models import (
CampaignStatusResponse,
CancelResponse,
DeployRequest,
DeployResponse,
)
from ..state import (
CampaignCancelledError,
CampaignState,
get_campaign,
set_campaign,
update_campaign,
)
router = APIRouter()
def _make_progress_cb(campaign_id: str, cancel_event: threading.Event):
def cb(step_index: int, total_steps: int, step_result: Any) -> None:
update_campaign(campaign_id, progress=step_index)
if cancel_event.is_set():
raise CampaignCancelledError(f"Cancelled at step {step_index}/{total_steps}")
return cb
def _run_campaign_thread(campaign_id: str, cfg: CampaignConfig) -> None:
state = get_campaign(campaign_id)
try:
result = CampaignExecutor(
config=cfg,
progress_cb=_make_progress_cb(campaign_id, state.cancel_event),
).run()
update_campaign(
campaign_id, status="completed", progress=cfg.total_steps(), result=result.to_dict(), ended_at=time.time()
)
except CampaignCancelledError:
update_campaign(campaign_id, status="cancelled", ended_at=time.time())
except Exception as e:
update_campaign(campaign_id, status="failed", error=str(e), ended_at=time.time())
@router.post("/deploy", response_model=DeployResponse)
async def deploy(request: DeployRequest):
"""Deploy a campaign config and start execution. Returns a ``campaign_id`` for polling.
Cancellation takes effect at step boundaries, not mid-capture.
"""
try:
cfg = CampaignConfig.from_dict(request.config)
except (ValueError, KeyError) as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
campaign_id = str(uuid.uuid4())
cancel_event = threading.Event()
thread = threading.Thread(target=_run_campaign_thread, args=(campaign_id, cfg), daemon=True)
set_campaign(
CampaignState(
campaign_id=campaign_id,
status="running",
config_name=cfg.name,
cancel_event=cancel_event,
thread=thread,
total_steps=cfg.total_steps(),
)
)
thread.start()
return DeployResponse(campaign_id=campaign_id)
@router.get("/status/{campaign_id}", response_model=CampaignStatusResponse)
async def get_status(campaign_id: str):
"""Get the status and progress of a deployed campaign."""
state = get_campaign(campaign_id)
if not state:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
return CampaignStatusResponse(**state.to_dict())
@router.post("/cancel/{campaign_id}", response_model=CancelResponse)
async def cancel(campaign_id: str):
"""Request cancellation. Takes effect at the next step boundary."""
state = get_campaign(campaign_id)
if not state:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
if state.status != "running":
return CancelResponse(campaign_id=campaign_id, cancelled=False)
state.cancel_event.set()
return CancelResponse(campaign_id=campaign_id, cancelled=True)

View File

@ -0,0 +1,101 @@
"""In-memory state for running campaigns and inference sessions."""
from __future__ import annotations
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Optional
class CampaignCancelledError(Exception):
"""Raised by the progress callback when a cancel is requested."""
@dataclass
class CampaignState:
campaign_id: str
status: str # "running" | "completed" | "failed" | "cancelled"
config_name: str
cancel_event: threading.Event
thread: threading.Thread
total_steps: int = 0
progress: int = 0
result: Optional[dict] = None
error: Optional[str] = None
started_at: float = field(default_factory=time.time)
ended_at: Optional[float] = None
def to_dict(self) -> dict:
return {
"campaign_id": self.campaign_id,
"status": self.status,
"config_name": self.config_name,
"progress": self.progress,
"total_steps": self.total_steps,
"result": self.result,
"error": self.error,
"started_at": self.started_at,
"ended_at": self.ended_at,
}
@dataclass
class InferenceState:
model_path: str
label_map: dict[str, int] # class_name -> class_index
index_to_label: dict[int, str] # reverse: class_index -> class_name
session: Any # onnxruntime.InferenceSession
stop_event: threading.Event = field(default_factory=threading.Event)
thread: Optional[threading.Thread] = None
running: bool = False
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
_latest: Optional[dict] = field(default=None, repr=False)
def set_latest(self, result: dict) -> None:
with self._lock:
self._latest = result
def get_latest(self) -> Optional[dict]:
with self._lock:
return self._latest
# ---------------------------------------------------------------------------
# Module-level stores
# ---------------------------------------------------------------------------
_campaigns: dict[str, CampaignState] = {}
_campaigns_lock = threading.Lock()
_inference: Optional[InferenceState] = None
_inference_lock = threading.Lock()
def get_campaign(campaign_id: str) -> Optional[CampaignState]:
with _campaigns_lock:
return _campaigns.get(campaign_id)
def set_campaign(state: CampaignState) -> None:
with _campaigns_lock:
_campaigns[state.campaign_id] = state
def update_campaign(campaign_id: str, **kwargs) -> None:
with _campaigns_lock:
state = _campaigns.get(campaign_id)
if state:
for k, v in kwargs.items():
setattr(state, k, v)
def get_inference() -> Optional[InferenceState]:
with _inference_lock:
return _inference
def set_inference(state: Optional[InferenceState]) -> None:
global _inference
with _inference_lock:
_inference = state

View File

@ -37,8 +37,10 @@ class Add(RecordableBlock, ProcessBlock):
samples = block.get_samples(num_samples) samples = block.get_samples(num_samples)
if len(samples) != num_samples: if len(samples) != num_samples:
raise ValueError(f"Block {self.__class__.__name__} requested {num_samples} \ raise ValueError(
from block {block.__class__.__name__} but got {len(samples)}.") f"Block {self.__class__.__name__} requested {num_samples} \
from block {block.__class__.__name__} but got {len(samples)}."
)
return samples return samples

View File

@ -23,9 +23,11 @@ class ProcessBlock(Block, ABC):
) )
elif not all(isinstance(item, Block) for item in input): elif not all(isinstance(item, Block) for item in input):
raise ValueError(f"Invalid input to block '{self.__class__.__name__}'. \ raise ValueError(
f"Invalid input to block '{self.__class__.__name__}'. \
Expected a list of Block objects but got \ Expected a list of Block objects but got \
{'[' + ',' .join(f'{item.__class__.__name__}({repr(item)})' for item in input) + ']'}") {'[' + ',' .join(f'{item.__class__.__name__}({repr(item)})' for item in input) + ']'}"
)
elif len(input) != len(self.input_type): elif len(input) != len(self.input_type):
raise ValueError( raise ValueError(

View File

@ -20,8 +20,10 @@ class RecordableBlock(Block, Recordable):
:raises ValueError: If the number of samples is incorrect.""" :raises ValueError: If the number of samples is incorrect."""
samples = self.get_samples(num_samples) samples = self.get_samples(num_samples)
if len(samples) != num_samples: if len(samples) != num_samples:
raise ValueError(f"Error in block {self.__class__.__name__} record(). \ raise ValueError(
Requested {num_samples} samples but got {len(samples)}") f"Error in block {self.__class__.__name__} record(). \
Requested {num_samples} samples but got {len(samples)}"
)
metadata = self._get_metadata() metadata = self._get_metadata()
return Recording(data=samples, metadata=metadata) return Recording(data=samples, metadata=metadata)

View File

@ -39,7 +39,9 @@ class RecordingSource(SourceBlock, RecordableBlock):
:raises ValueError: If num_samples is greater than the recording length. :raises ValueError: If num_samples is greater than the recording length.
""" """
if num_samples - 1 >= self.recording.data.shape[1]: if num_samples - 1 >= self.recording.data.shape[1]:
raise ValueError(f"{num_samples} samples requested from recording source with \ raise ValueError(
{self.recording.data.shape[1]} samples available.") f"{num_samples} samples requested from recording source with \
{self.recording.data.shape[1]} samples available."
)
return self.recording.data[0, 0:num_samples] return self.recording.data[0, 0:num_samples]

View File

@ -610,8 +610,10 @@ def cut_out( # noqa: C901 # TODO: Simplify function
raise ValueError("signal must be CxN complex.") raise ValueError("signal must be CxN complex.")
if fill_type not in {"zeros", "ones", "low-snr", "avg-snr", "high-snr"}: if fill_type not in {"zeros", "ones", "low-snr", "avg-snr", "high-snr"}:
raise UserWarning("""fill_type must be "zeros", "ones", "low-snr", "avg-snr", or "high-snr", raise UserWarning(
"ones" has been selected by default""") """fill_type must be "zeros", "ones", "low-snr", "avg-snr", or "high-snr",
"ones" has been selected by default"""
)
if max_section_size < 1 or max_section_size >= n: if max_section_size < 1 or max_section_size >= n:
raise ValueError("max_section_size must be at least 1 and must be less than the length of signal.") raise ValueError("max_section_size must be at least 1 and must be less than the length of signal.")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 90 KiB

After

Width:  |  Height:  |  Size: 130 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

After

Width:  |  Height:  |  Size: 130 B

View File

@ -0,0 +1,258 @@
"""Campaign orchestration CLI commands.
Usage examples::
# Enroll a single device using a device profile YAML (App 1 workflow)
ria campaign enroll --config iphone13.yml
# Run a full custom campaign config
ria campaign run --config my_campaign.yml
# Validate a config file without running it
ria campaign validate --config my_campaign.yml
"""
import sys
import click
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
from ria_toolkit_oss.orchestration.executor import CampaignExecutor, StepResult
@click.group()
def campaign():
"""Orchestrate automated RF capture campaigns."""
# ---------------------------------------------------------------------------
# ria campaign validate
# ---------------------------------------------------------------------------
@campaign.command()
@click.option(
"--config",
"-c",
required=True,
type=click.Path(exists=True),
help="Campaign YAML config or device profile YAML.",
)
@click.option(
"--profile",
is_flag=True,
default=False,
help="Parse as a device profile (App 1 style) rather than a full campaign config.",
)
def validate(config, profile):
"""Validate a campaign or device profile YAML without running it.
\b
Examples:
ria campaign validate --config iphone13.yml --profile
ria campaign validate --config campaign.yml
"""
try:
if profile:
cfg = CampaignConfig.from_device_profile(config)
else:
cfg = CampaignConfig.from_yaml(config)
except (FileNotFoundError, ValueError, KeyError) as e:
raise click.ClickException(str(e))
click.echo(click.style("✓ Config valid", fg="green", bold=True))
click.echo(f" Campaign name : {cfg.name}")
click.echo(f" Mode : {cfg.mode}")
click.echo(f" Transmitters : {len(cfg.transmitters)}")
click.echo(f" Total steps : {cfg.total_steps()}")
click.echo(f" Capture time : {cfg.total_capture_time_s():.0f}s")
click.echo(f" Recorder : {cfg.recorder.device} @ {cfg.recorder.center_freq/1e6:.2f} MHz")
click.echo(f" Sample rate : {cfg.recorder.sample_rate/1e6:.1f} MS/s")
click.echo(f" Output path : {cfg.output.path}")
click.echo()
for tx in cfg.transmitters:
click.echo(f" Transmitter: {tx.id} ({tx.type}, {tx.control_method}, {len(tx.schedule)} steps)")
for step in tx.schedule:
extras = []
if step.channel is not None:
extras.append(f"ch={step.channel}")
if step.bandwidth_mhz is not None:
extras.append(f"{int(step.bandwidth_mhz)}MHz")
if step.traffic:
extras.append(step.traffic)
suffix = f" [{', '.join(extras)}]" if extras else ""
click.echo(f" [{step.duration:.0f}s] {step.label}{suffix}")
# ---------------------------------------------------------------------------
# ria campaign enroll
# ---------------------------------------------------------------------------
@campaign.command()
@click.option(
"--config",
"-c",
required=True,
type=click.Path(exists=True),
help="Device profile YAML (App 1 enrollment format).",
)
@click.option(
"--output",
"-o",
default=None,
help="Override output directory from config.",
)
@click.option(
"--report",
default="qa_report.json",
show_default=True,
help="Path for the JSON QA report.",
)
@click.option("--verbose", "-v", is_flag=True, help="Verbose output.")
@click.option("--dry-run", is_flag=True, help="Parse and validate config, then exit.")
def enroll(config, output, report, verbose, dry_run):
"""Enroll a single device by running its capture profile.
Parses a device profile YAML (App 1 format), generates a capture
campaign, and runs it. Outputs labelled SigMF recordings and a
JSON QA report.
\b
Examples:
ria campaign enroll --config iphone13.yml
ria campaign enroll --config airpods.yml --output ./my_recordings
ria campaign enroll --config iphone13.yml --dry-run
"""
try:
cfg = CampaignConfig.from_device_profile(config)
except (FileNotFoundError, ValueError, KeyError) as e:
raise click.ClickException(str(e))
if output:
cfg.output.path = output
_print_campaign_summary(cfg)
if dry_run:
click.echo(click.style("Dry run — exiting before capture.", fg="yellow"))
return
result = _run_campaign(cfg, verbose=verbose)
result.write_report(report)
_print_result_summary(result, report)
sys.exit(0 if result.failed == 0 else 1)
# ---------------------------------------------------------------------------
# ria campaign run
# ---------------------------------------------------------------------------
@campaign.command()
@click.option(
"--config",
"-c",
required=True,
type=click.Path(exists=True),
help="Full campaign YAML config.",
)
@click.option(
"--output",
"-o",
default=None,
help="Override output directory from config.",
)
@click.option(
"--report",
default="qa_report.json",
show_default=True,
help="Path for the JSON QA report.",
)
@click.option("--verbose", "-v", is_flag=True, help="Verbose output.")
@click.option("--dry-run", is_flag=True, help="Parse and validate config, then exit.")
def run(config, output, report, verbose, dry_run):
"""Run a full campaign from a campaign config YAML.
\b
Examples:
ria campaign run --config wifi_capture.yml
ria campaign run --config campaign.yml --output ./data --dry-run
"""
try:
cfg = CampaignConfig.from_yaml(config)
except (FileNotFoundError, ValueError, KeyError) as e:
raise click.ClickException(str(e))
if output:
cfg.output.path = output
_print_campaign_summary(cfg)
if dry_run:
click.echo(click.style("Dry run — exiting before capture.", fg="yellow"))
return
result = _run_campaign(cfg, verbose=verbose)
result.write_report(report)
_print_result_summary(result, report)
sys.exit(0 if result.failed == 0 else 1)
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def _print_campaign_summary(cfg: CampaignConfig) -> None:
click.echo()
click.echo(click.style(f"Campaign: {cfg.name}", bold=True))
click.echo(f" Transmitters : {len(cfg.transmitters)}")
click.echo(f" Total steps : {cfg.total_steps()}")
click.echo(f" Capture time : ~{cfg.total_capture_time_s():.0f}s")
click.echo(f" Output : {cfg.output.path}")
click.echo()
def _make_progress_cb(total: int):
"""Return a progress callback that prints step results to stderr."""
def cb(idx: int, _total: int, step: StepResult) -> None:
status = (
click.style("", fg="green")
if step.ok
else (click.style("", fg="yellow") if step.qa.flagged else click.style("", fg="red"))
)
snr_str = f"SNR {step.qa.snr_db:.1f} dB" if not step.error else f"ERROR: {step.error}"
click.echo(
f" [{idx:>3}/{_total}] {status} {step.transmitter_id}/{step.step_label}{snr_str}",
err=True,
)
return cb
def _run_campaign(cfg: CampaignConfig, verbose: bool = False):
executor = CampaignExecutor(
config=cfg,
progress_cb=_make_progress_cb(cfg.total_steps()),
verbose=verbose,
)
return executor.run()
def _print_result_summary(result, report_path: str) -> None:
click.echo()
click.echo(click.style("Campaign complete", bold=True))
click.echo(f" Steps : {result.total_steps}")
click.echo(f" Passed : {click.style(str(result.passed), fg='green')}")
if result.flagged:
click.echo(f" Flagged : {click.style(str(result.flagged), fg='yellow')} (review required)")
if result.failed:
click.echo(f" Failed : {click.style(str(result.failed), fg='red')}")
click.echo(f" Duration: {result.duration_s:.0f}s")
click.echo(f" Report : {report_path}")
click.echo()

View File

@ -3,6 +3,7 @@
This module contains all the CLI bindings for the ria package. This module contains all the CLI bindings for the ria package.
""" """
from .campaign import campaign
from .capture import capture from .capture import capture
from .combine import combine from .combine import combine
from .convert import convert from .convert import convert
@ -13,6 +14,7 @@ from .generate import generate
# from .generate import generate # from .generate import generate
from .init import init from .init import init
from .serve import serve
from .split import split from .split import split
from .transform import transform from .transform import transform
from .transmit import transmit from .transmit import transmit

View File

@ -0,0 +1,51 @@
"""``ria serve`` — start the RT-OSS HTTP server for RIA Hub integration."""
import click
@click.command()
@click.option("--host", default="0.0.0.0", show_default=True, help="Bind host.")
@click.option("--port", default=8080, show_default=True, type=int, help="Bind port.")
@click.option(
"--api-key",
envvar="RT_OSS_API_KEY",
default="",
help="Required X-API-Key value. Also reads RT_OSS_API_KEY. Empty = no auth (dev only).",
)
@click.option(
"--log-level",
default="info",
show_default=True,
type=click.Choice(["debug", "info", "warning", "error"], case_sensitive=False),
)
def serve(host: str, port: int, api_key: str, log_level: str):
"""Start the RT-OSS HTTP server.
\b
Endpoints:
POST /orchestrator/deploy
GET /orchestrator/status/{campaign_id}
POST /orchestrator/cancel/{campaign_id}
POST /inference/load
POST /inference/start
POST /inference/stop
GET /inference/status
GET /health
"""
try:
import uvicorn
from ria_toolkit_oss.server.app import create_app
except ImportError as e:
raise click.ClickException(
f"Server dependencies missing: {e}\nInstall with: pip install ria-toolkit-oss[server]"
)
if not api_key:
click.echo(
click.style("Warning: ", fg="yellow", bold=True) + "no API key set — all requests unauthenticated.",
err=True,
)
click.echo(f"Starting RT-OSS server on http://{host}:{port}")
uvicorn.run(create_app(api_key=api_key), host=host, port=port, log_level=log_level.lower())

View File

View File

@ -0,0 +1,489 @@
"""Tests for orchestration campaign schema and YAML parsing."""
import os
import tempfile
import pytest
import yaml
from ria_toolkit_oss.orchestration.campaign import (
CampaignConfig,
CaptureStep,
QAConfig,
RecorderConfig,
parse_bandwidth_mhz,
parse_duration,
parse_frequency,
parse_gain,
)
# ---------------------------------------------------------------------------
# parse_duration
# ---------------------------------------------------------------------------
class TestParseDuration:
def test_seconds_suffix(self):
assert parse_duration("30s") == 30.0
def test_seconds_suffix_long(self):
assert parse_duration("30sec") == 30.0
def test_minutes_suffix(self):
assert parse_duration("1.5m") == 90.0
def test_minutes_suffix_long(self):
assert parse_duration("2min") == 120.0
def test_hours_suffix(self):
assert parse_duration("2h") == 7200.0
def test_hours_suffix_long(self):
assert parse_duration("1hr") == 3600.0
def test_numeric_int(self):
assert parse_duration(45) == 45.0
def test_numeric_float(self):
assert parse_duration(1.5) == 1.5
def test_bare_number_string(self):
# No unit → treated as seconds
assert parse_duration("60") == 60.0
def test_invalid_raises(self):
with pytest.raises(ValueError):
parse_duration("two minutes")
# ---------------------------------------------------------------------------
# parse_frequency
# ---------------------------------------------------------------------------
class TestParseFrequency:
def test_ghz(self):
assert parse_frequency("2.45GHz") == pytest.approx(2.45e9)
def test_mhz(self):
assert parse_frequency("40MHz") == pytest.approx(40e6)
def test_khz(self):
assert parse_frequency("433k") == pytest.approx(433e3)
def test_scientific_notation_string(self):
assert parse_frequency("915e6") == pytest.approx(915e6)
def test_numeric_float(self):
assert parse_frequency(2.45e9) == pytest.approx(2.45e9)
def test_numeric_int(self):
assert parse_frequency(1000000) == pytest.approx(1e6)
def test_hz_suffix_optional(self):
# "40M" and "40MHz" should both work
assert parse_frequency("40M") == pytest.approx(40e6)
assert parse_frequency("40MHz") == pytest.approx(40e6)
def test_invalid_raises(self):
with pytest.raises(ValueError):
parse_frequency("two point four gigs")
# ---------------------------------------------------------------------------
# parse_gain
# ---------------------------------------------------------------------------
class TestParseGain:
def test_db_suffix(self):
assert parse_gain("40dB") == pytest.approx(40.0)
def test_db_suffix_lowercase(self):
assert parse_gain("32db") == pytest.approx(32.0)
def test_auto(self):
assert parse_gain("auto") == "auto"
def test_auto_case_insensitive(self):
assert parse_gain("AUTO") == "auto"
def test_numeric_int(self):
assert parse_gain(32) == pytest.approx(32.0)
def test_numeric_float(self):
assert parse_gain(32.5) == pytest.approx(32.5)
def test_invalid_raises(self):
with pytest.raises(ValueError):
parse_gain("high")
# ---------------------------------------------------------------------------
# parse_bandwidth_mhz
# ---------------------------------------------------------------------------
class TestParseBandwidthMhz:
def test_mhz_suffix(self):
assert parse_bandwidth_mhz("20MHz") == pytest.approx(20.0)
def test_numeric(self):
assert parse_bandwidth_mhz(40) == pytest.approx(40.0)
def test_none(self):
assert parse_bandwidth_mhz(None) is None
def test_invalid_raises(self):
with pytest.raises(ValueError):
parse_bandwidth_mhz("wide")
# ---------------------------------------------------------------------------
# CaptureStep.from_dict
# ---------------------------------------------------------------------------
class TestCaptureStep:
def test_wifi_step_auto_label(self):
d = {"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_udp", "duration": "30s"}
step = CaptureStep.from_dict(d)
assert step.duration == 30.0
assert step.channel == 6
assert step.bandwidth_mhz == 20.0
assert step.traffic == "iperf_udp"
assert step.label == "ch06_20mhz_iperf_udp"
def test_explicit_label(self):
d = {"channel": 1, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s", "label": "my_label"}
step = CaptureStep.from_dict(d)
assert step.label == "my_label"
def test_fallback_label(self):
# No channel/bandwidth/traffic → label falls back to "capture"
d = {"duration": "10s"}
step = CaptureStep.from_dict(d)
assert step.label == "capture"
def test_power_parsed(self):
d = {"channel": 6, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s", "power": "15dBm"}
step = CaptureStep.from_dict(d)
assert step.power_dbm == pytest.approx(15.0)
# ---------------------------------------------------------------------------
# RecorderConfig.from_dict
# ---------------------------------------------------------------------------
class TestRecorderConfig:
def test_basic(self):
d = {"device": "usrp_b210", "center_freq": "2.45GHz", "sample_rate": "40MHz", "gain": "40dB"}
rec = RecorderConfig.from_dict(d)
assert rec.device == "usrp_b210"
assert rec.center_freq == pytest.approx(2.45e9)
assert rec.sample_rate == pytest.approx(40e6)
assert rec.gain == pytest.approx(40.0)
assert rec.bandwidth is None
def test_auto_gain(self):
d = {"device": "pluto", "center_freq": "2.45GHz", "sample_rate": "20MHz", "gain": "auto"}
rec = RecorderConfig.from_dict(d)
assert rec.gain == "auto"
def test_bandwidth_set(self):
d = {"device": "pluto", "center_freq": "2.45GHz", "sample_rate": "20MHz", "gain": 32, "bandwidth": "20MHz"}
rec = RecorderConfig.from_dict(d)
assert rec.bandwidth == pytest.approx(20e6)
# ---------------------------------------------------------------------------
# QAConfig.from_dict
# ---------------------------------------------------------------------------
class TestQAConfig:
def test_defaults(self):
qa = QAConfig.from_dict({})
assert qa.snr_threshold_db == pytest.approx(10.0)
assert qa.min_duration_s == pytest.approx(25.0)
assert qa.flag_for_review is True
def test_custom_values(self):
d = {"snr_threshold": "15dB", "min_duration": "28s", "flag_for_review": False}
qa = QAConfig.from_dict(d)
assert qa.snr_threshold_db == pytest.approx(15.0)
assert qa.min_duration_s == pytest.approx(28.0)
assert qa.flag_for_review is False
# ---------------------------------------------------------------------------
# CampaignConfig.from_device_profile
# ---------------------------------------------------------------------------
def _write_device_profile(d: dict) -> str:
"""Write a dict as YAML to a temp file and return the path."""
f = tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False)
yaml.dump(d, f)
f.close()
return f.name
WIFI_PROFILE = {
"device": {"name": "iPhone_13_WiFi", "type": "wifi"},
"capture": {
"channels": [1, 6, 11],
"bandwidth": "20MHz",
"traffic_patterns": ["idle", "ping", "iperf_udp"],
"duration_per_config": "30s",
"script": "./scripts/wifi_control.sh",
},
"recorder": {
"device": "usrp_b210",
"center_freq": "2.45GHz",
"sample_rate": "40MHz",
"gain": "auto",
},
"output": {"path": "/tmp/test_recordings", "device_id": "iphone13_wifi_001"},
}
BT_PROFILE = {
"device": {"name": "AirPods_Pro", "type": "bluetooth"},
"capture": {
"traffic_patterns": ["idle", "audio_stream", "data_transfer"],
"duration_per_config": "30s",
},
"recorder": {
"device": "usrp_b210",
"center_freq": "2.45GHz",
"sample_rate": "40MHz",
"gain": "auto",
},
"output": {"path": "/tmp/test_recordings", "device_id": "airpods_pro_bt_001"},
}
class TestDeviceProfileParsing:
def test_wifi_schedule_count(self):
"""WiFi: 3 channels × 3 traffic = 9 steps."""
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
assert len(cfg.transmitters) == 1
assert len(cfg.transmitters[0].schedule) == 9
def test_wifi_campaign_name(self):
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
assert cfg.name == "enroll_iphone13_wifi_001"
def test_wifi_step_labels(self):
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
labels = [s.label for s in cfg.transmitters[0].schedule]
assert "ch01_20mhz_idle" in labels
assert "ch06_20mhz_ping" in labels
assert "ch11_20mhz_iperf_udp" in labels
def test_wifi_step_ordering(self):
"""Steps iterate channels first, then traffic."""
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
steps = cfg.transmitters[0].schedule
assert steps[0].channel == 1 and steps[0].traffic == "idle"
assert steps[1].channel == 1 and steps[1].traffic == "ping"
assert steps[3].channel == 6 and steps[3].traffic == "idle"
assert steps[8].channel == 11 and steps[8].traffic == "iperf_udp"
def test_wifi_step_duration(self):
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
for step in cfg.transmitters[0].schedule:
assert step.duration == pytest.approx(30.0)
def test_wifi_bandwidth(self):
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
for step in cfg.transmitters[0].schedule:
assert step.bandwidth_mhz == pytest.approx(20.0)
def test_wifi_recorder(self):
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
assert cfg.recorder.device == "usrp_b210"
assert cfg.recorder.center_freq == pytest.approx(2.45e9)
assert cfg.recorder.sample_rate == pytest.approx(40e6)
assert cfg.recorder.gain == "auto"
def test_wifi_total_capture_time(self):
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
assert cfg.total_capture_time_s() == pytest.approx(270.0) # 9 × 30s
def test_bt_schedule_count(self):
"""BT: no channels, 3 traffic patterns = 3 steps."""
path = _write_device_profile(BT_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
assert len(cfg.transmitters[0].schedule) == 3
def test_bt_no_channel(self):
path = _write_device_profile(BT_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
for step in cfg.transmitters[0].schedule:
assert step.channel is None
def test_bt_step_labels(self):
path = _write_device_profile(BT_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
labels = [s.label for s in cfg.transmitters[0].schedule]
assert labels == ["idle", "audio_stream", "data_transfer"]
def test_missing_file_raises(self):
with pytest.raises(FileNotFoundError):
CampaignConfig.from_device_profile("/nonexistent/path/profile.yml")
def test_invalid_yaml_raises(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
f.write(": bad: yaml: [\n")
path = f.name
try:
with pytest.raises(ValueError, match="Invalid YAML"):
CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
# ---------------------------------------------------------------------------
# CampaignConfig.from_yaml (full campaign format)
# ---------------------------------------------------------------------------
FULL_CAMPAIGN = {
"campaign": {"name": "wifi_capture_001", "mode": "controlled_testbed"},
"transmitters": [
{
"id": "laptop_wifi",
"type": "wifi",
"control_method": "external_script",
"script": "./scripts/wifi_control.sh",
"device": "/dev/wlan0",
"schedule": [
{"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_tcp", "duration": "30s"},
{"channel": 36, "bandwidth": "40MHz", "traffic": "ping_flood", "duration": "30s"},
],
}
],
"recorder": {
"device": "usrp_b210",
"center_freq": "2.45GHz",
"sample_rate": "20MHz",
"gain": "40dB",
},
"qa": {"snr_threshold": "10dB", "min_duration": "25s", "flag_for_review": True},
"output": {"format": "sigmf", "path": "./recordings"},
}
class TestFullCampaignParsing:
def test_name(self):
path = _write_device_profile(FULL_CAMPAIGN)
try:
cfg = CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
assert cfg.name == "wifi_capture_001"
def test_mode(self):
path = _write_device_profile(FULL_CAMPAIGN)
try:
cfg = CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
assert cfg.mode == "controlled_testbed"
def test_transmitter_id(self):
path = _write_device_profile(FULL_CAMPAIGN)
try:
cfg = CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
assert cfg.transmitters[0].id == "laptop_wifi"
assert cfg.transmitters[0].control_method == "external_script"
assert cfg.transmitters[0].script == "./scripts/wifi_control.sh"
def test_schedule_count(self):
path = _write_device_profile(FULL_CAMPAIGN)
try:
cfg = CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
assert len(cfg.transmitters[0].schedule) == 2
def test_qa_config(self):
path = _write_device_profile(FULL_CAMPAIGN)
try:
cfg = CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
assert cfg.qa.snr_threshold_db == pytest.approx(10.0)
assert cfg.qa.min_duration_s == pytest.approx(25.0)
assert cfg.qa.flag_for_review is True
def test_total_steps(self):
path = _write_device_profile(FULL_CAMPAIGN)
try:
cfg = CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
assert cfg.total_steps() == 2
def test_no_transmitters_raises(self):
bad = dict(FULL_CAMPAIGN)
bad["transmitters"] = []
path = _write_device_profile(bad)
try:
with pytest.raises(ValueError, match="at least one transmitter"):
CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
def test_missing_recorder_raises(self):
bad = {k: v for k, v in FULL_CAMPAIGN.items() if k != "recorder"}
path = _write_device_profile(bad)
try:
with pytest.raises((KeyError, ValueError)):
CampaignConfig.from_yaml(path)
finally:
os.unlink(path)

View File

@ -0,0 +1,274 @@
"""Tests for the `ria campaign` CLI commands."""
import os
import tempfile
import yaml
from click.testing import CliRunner
from ria_toolkit_oss_cli.cli import cli
def _write_yaml(d: dict, suffix=".yml") -> str:
f = tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False)
yaml.dump(d, f)
f.close()
return f.name
WIFI_PROFILE = {
"device": {"name": "iPhone_13_WiFi", "type": "wifi"},
"capture": {
"channels": [1, 6, 11],
"bandwidth": "20MHz",
"traffic_patterns": ["idle", "ping", "iperf_udp"],
"duration_per_config": "30s",
},
"recorder": {
"device": "usrp_b210",
"center_freq": "2.45GHz",
"sample_rate": "40MHz",
"gain": "auto",
},
"output": {"path": "/tmp/test_enroll", "device_id": "iphone13_wifi_001"},
}
FULL_CAMPAIGN = {
"campaign": {"name": "wifi_capture_001", "mode": "controlled_testbed"},
"transmitters": [
{
"id": "laptop_wifi",
"type": "wifi",
"control_method": "external_script",
"schedule": [
{"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_udp", "duration": "30s"},
{"channel": 11, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s"},
],
}
],
"recorder": {
"device": "usrp_b210",
"center_freq": "2.45GHz",
"sample_rate": "20MHz",
"gain": "40dB",
},
"qa": {"snr_threshold": "10dB", "min_duration": "25s", "flag_for_review": True},
"output": {"format": "sigmf", "path": "/tmp/test_campaign"},
}
# ---------------------------------------------------------------------------
# ria campaign --help
# ---------------------------------------------------------------------------
class TestCampaignHelp:
def test_campaign_help(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "--help"])
assert result.exit_code == 0
assert "campaign" in result.output.lower()
def test_subcommands_listed(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "--help"])
assert result.exit_code == 0
for sub in ("validate", "enroll", "run"):
assert sub in result.output
def test_validate_help(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--help"])
assert result.exit_code == 0
def test_enroll_help(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "enroll", "--help"])
assert result.exit_code == 0
def test_run_help(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "run", "--help"])
assert result.exit_code == 0
# ---------------------------------------------------------------------------
# ria campaign validate
# ---------------------------------------------------------------------------
class TestCampaignValidate:
def test_validate_device_profile(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
assert result.exit_code == 0
assert "" in result.output or "valid" in result.output.lower()
finally:
os.unlink(path)
def test_validate_shows_campaign_name(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
assert "enroll_iphone13_wifi_001" in result.output
finally:
os.unlink(path)
def test_validate_shows_step_count(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
assert "9" in result.output # 9 total steps
finally:
os.unlink(path)
def test_validate_shows_capture_time(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
assert "270" in result.output # 270s total
finally:
os.unlink(path)
def test_validate_full_campaign(self):
path = _write_yaml(FULL_CAMPAIGN)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path])
assert result.exit_code == 0
assert "wifi_capture_001" in result.output
finally:
os.unlink(path)
def test_validate_shows_steps(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
assert "ch01_20mhz_idle" in result.output
assert "ch06_20mhz_ping" in result.output
assert "ch11_20mhz_iperf_udp" in result.output
finally:
os.unlink(path)
def test_validate_missing_file(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", "/nonexistent/file.yml"])
assert result.exit_code != 0
def test_validate_bad_yaml(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
f.write(": broken yaml [\n")
path = f.name
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path])
assert result.exit_code != 0
finally:
os.unlink(path)
# ---------------------------------------------------------------------------
# ria campaign enroll --dry-run
# ---------------------------------------------------------------------------
class TestCampaignEnrollDryRun:
def test_dry_run_exits_cleanly(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "enroll", "--config", path, "--dry-run"])
assert result.exit_code == 0
finally:
os.unlink(path)
def test_dry_run_shows_campaign_info(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "enroll", "--config", path, "--dry-run"])
assert "enroll_iphone13_wifi_001" in result.output
assert "9" in result.output
finally:
os.unlink(path)
def test_dry_run_does_not_capture(self):
"""Dry run should not create any output files."""
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
with tempfile.TemporaryDirectory() as tmpdir:
runner.invoke(
cli,
["campaign", "enroll", "--config", path, "--output", tmpdir, "--dry-run"],
)
# No .sigmf-data files should have been created
sigmf_files = list(os.walk(tmpdir))
all_files = [f for _, _, files in sigmf_files for f in files]
assert not any(f.endswith(".sigmf-data") for f in all_files)
finally:
os.unlink(path)
def test_dry_run_output_override(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(
cli,
["campaign", "enroll", "--config", path, "--output", "/tmp/custom_out", "--dry-run"],
)
assert result.exit_code == 0
assert "Dry run" in result.output
finally:
os.unlink(path)
# ---------------------------------------------------------------------------
# ria campaign run --dry-run
# ---------------------------------------------------------------------------
class TestCampaignRunDryRun:
def test_dry_run_exits_cleanly(self):
path = _write_yaml(FULL_CAMPAIGN)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "run", "--config", path, "--dry-run"])
assert result.exit_code == 0
finally:
os.unlink(path)
def test_dry_run_shows_campaign_name(self):
path = _write_yaml(FULL_CAMPAIGN)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "run", "--config", path, "--dry-run"])
assert "wifi_capture_001" in result.output
finally:
os.unlink(path)
def test_dry_run_does_not_create_report(self):
path = _write_yaml(FULL_CAMPAIGN)
try:
runner = CliRunner()
with tempfile.TemporaryDirectory() as tmpdir:
report_path = os.path.join(tmpdir, "qa_report.json")
result = runner.invoke(
cli,
["campaign", "run", "--config", path, "--dry-run", "--report", report_path],
)
assert result.exit_code == 0
assert not os.path.exists(report_path)
finally:
os.unlink(path)
def test_missing_config_fails(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "run", "--config", "/nonexistent.yml"])
assert result.exit_code != 0

View File

@ -0,0 +1,145 @@
"""Tests for orchestration labeler."""
import time
import numpy as np
import pytest
from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.orchestration.campaign import CaptureStep
from ria_toolkit_oss.orchestration.labeler import build_output_filename, label_recording
def _simple_recording() -> Recording:
sr = 1e6
n = 1000
data = np.ones(n, dtype=np.complex64)
return Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9})
def _wifi_step() -> CaptureStep:
return CaptureStep(
duration=30.0,
label="ch06_20mhz_idle",
channel=6,
bandwidth_mhz=20.0,
traffic="idle",
)
def _bt_step() -> CaptureStep:
return CaptureStep(
duration=30.0,
label="audio_stream",
traffic="audio_stream",
connection_interval_ms=7.5,
)
# ---------------------------------------------------------------------------
# label_recording
# ---------------------------------------------------------------------------
class TestLabelRecording:
def test_device_id_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["device_id"] == "iphone13_001"
def test_capture_timestamp_set(self):
ts = time.time()
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), ts)
assert rec.metadata["capture_timestamp"] == pytest.approx(ts, abs=1.0)
def test_step_label_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["step_label"] == "ch06_20mhz_idle"
def test_step_duration_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["step_duration_s"] == pytest.approx(30.0)
def test_campaign_name_optional(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert "campaign" not in rec.metadata
def test_campaign_name_when_provided(self):
rec = label_recording(
_simple_recording(), "iphone13_001", _wifi_step(), time.time(), campaign_name="test_campaign"
)
assert rec.metadata["campaign"] == "test_campaign"
def test_wifi_channel_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["wifi_channel"] == 6
def test_wifi_bandwidth_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["wifi_bandwidth_mhz"] == pytest.approx(20.0)
def test_traffic_pattern_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["traffic_pattern"] == "idle"
def test_bt_connection_interval_set(self):
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
assert rec.metadata["bt_connection_interval_ms"] == pytest.approx(7.5)
def test_no_channel_key_for_bt(self):
"""BT steps with no channel should not add wifi_channel to metadata."""
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
assert "wifi_channel" not in rec.metadata
def test_no_bandwidth_key_for_bt(self):
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
assert "wifi_bandwidth_mhz" not in rec.metadata
def test_power_dbm_set(self):
step = CaptureStep(duration=30.0, label="test", traffic="idle", power_dbm=15.0)
rec = label_recording(_simple_recording(), "dev_001", step, time.time())
assert rec.metadata["tx_power_dbm"] == pytest.approx(15.0)
def test_no_power_key_when_unset(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert "tx_power_dbm" not in rec.metadata
def test_returns_same_recording(self):
"""label_recording should mutate and return the same Recording object."""
rec = _simple_recording()
result = label_recording(rec, "iphone13_001", _wifi_step(), time.time())
assert result is rec
# ---------------------------------------------------------------------------
# build_output_filename
# ---------------------------------------------------------------------------
class TestBuildOutputFilename:
def test_basic_wifi(self):
step = CaptureStep(duration=30.0, label="ch06_20mhz_idle")
fn = build_output_filename("iphone13_wifi_001", step)
assert fn == "iphone13_wifi_001/ch06_20mhz_idle"
def test_bt_step(self):
step = CaptureStep(duration=30.0, label="audio_stream")
fn = build_output_filename("airpods_pro_bt_001", step)
assert fn == "airpods_pro_bt_001/audio_stream"
def test_spaces_in_device_id_replaced(self):
step = CaptureStep(duration=30.0, label="idle")
fn = build_output_filename("my device", step)
assert " " not in fn
assert fn == "my_device/idle"
def test_slashes_in_label_replaced(self):
step = CaptureStep(duration=30.0, label="ch/6/idle")
fn = build_output_filename("dev_001", step)
assert "/" not in fn.split("/", 1)[1] # only the separator slash should remain
def test_path_structure(self):
"""Filename should be exactly '<device_id>/<label>' (one level of nesting)."""
step = CaptureStep(duration=30.0, label="idle")
fn = build_output_filename("dev_001", step)
parts = fn.split("/")
assert len(parts) == 2

View File

@ -0,0 +1,192 @@
"""Tests for orchestration QA metrics."""
import numpy as np
import pytest
from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.orchestration.campaign import QAConfig
from ria_toolkit_oss.orchestration.qa import QAResult, check_recording, estimate_snr_db
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_recording(n_samples: int, sample_rate: float, signal: np.ndarray) -> Recording:
return Recording(
signal.astype(np.complex64),
metadata={"sample_rate": sample_rate, "center_frequency": 2.45e9},
)
def _tone(n: int, sr: float, freq_hz: float = 100e3, amplitude: float = 0.5) -> np.ndarray:
t = np.arange(n) / sr
return (np.exp(1j * 2 * np.pi * freq_hz * t) * amplitude).astype(np.complex64)
def _noise(n: int, amplitude: float = 0.001) -> np.ndarray:
rng = np.random.default_rng(42)
return ((rng.standard_normal(n) + 1j * rng.standard_normal(n)) * amplitude).astype(np.complex64)
DEFAULT_QA = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
# ---------------------------------------------------------------------------
# estimate_snr_db
# ---------------------------------------------------------------------------
class TestEstimateSnrDb:
def test_high_snr_tone(self):
sr = 1e6
samples = _tone(int(sr * 1), sr)
snr = estimate_snr_db(samples)
assert snr > 20.0, f"Expected high SNR for clean tone, got {snr:.1f} dB"
def test_pure_noise_low_snr(self):
sr = 1e6
rng = np.random.default_rng(0)
samples = (rng.standard_normal(int(sr)) + 1j * rng.standard_normal(int(sr))).astype(np.complex64)
snr = estimate_snr_db(samples)
# Pure noise should yield a low (possibly negative) SNR
assert snr < 15.0, f"Expected low SNR for noise, got {snr:.1f} dB"
def test_snr_increases_with_amplitude(self):
sr = 1e6
n = int(sr)
rng = np.random.default_rng(1)
noise = (rng.standard_normal(n) + 1j * rng.standard_normal(n)).astype(np.complex64) * 0.01
t = np.arange(n) / sr
tone = np.exp(1j * 2 * np.pi * 100e3 * t).astype(np.complex64)
low_snr = estimate_snr_db(noise + tone * 0.1)
high_snr = estimate_snr_db(noise + tone * 1.0)
assert high_snr > low_snr
def test_short_input_still_works(self):
# Input shorter than n_fft=4096 should not raise
samples = _tone(512, 1e6)
snr = estimate_snr_db(samples)
assert np.isfinite(snr)
# ---------------------------------------------------------------------------
# check_recording — pass cases
# ---------------------------------------------------------------------------
class TestCheckRecordingPass:
def test_clean_tone_passes(self):
sr = 1e6
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
result = check_recording(rec, DEFAULT_QA)
assert result.passed is True
assert result.flagged is False
assert result.snr_db > 10.0
assert abs(result.duration_s - 30.0) < 0.1
def test_duration_exactly_at_threshold(self):
sr = 1e6
n = int(sr * 25) # exactly at min_duration_s
rec = _make_recording(n, sr, _tone(n, sr))
result = check_recording(rec, DEFAULT_QA)
assert result.flagged is False
def test_issues_empty_when_passing(self):
sr = 1e6
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
result = check_recording(rec, DEFAULT_QA)
assert result.issues == []
# ---------------------------------------------------------------------------
# check_recording — flag cases
# ---------------------------------------------------------------------------
class TestCheckRecordingFlag:
def test_short_recording_flagged(self):
sr = 1e6
n = int(sr * 10) # shorter than 25s min
rec = _make_recording(n, sr, _tone(n, sr))
result = check_recording(rec, DEFAULT_QA)
assert result.flagged is True
assert any("Duration" in issue for issue in result.issues)
def test_low_snr_flagged(self):
sr = 1e6
n = int(sr * 30)
rec = _make_recording(n, sr, _noise(n, amplitude=0.001))
result = check_recording(rec, DEFAULT_QA)
assert result.flagged is True
assert any("SNR" in issue for issue in result.issues)
def test_flag_for_review_still_passes(self):
"""With flag_for_review=True, flagged recordings are still marked passed."""
sr = 1e6
n = int(sr * 10) # short → will be flagged
rec = _make_recording(n, sr, _tone(n, sr))
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
result = check_recording(rec, qa)
assert result.flagged is True
assert result.passed is True # human review, not auto-reject
def test_flag_for_review_false_fails(self):
"""With flag_for_review=False, a flagged recording is also marked failed."""
sr = 1e6
n = int(sr * 10)
rec = _make_recording(n, sr, _tone(n, sr))
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=False)
result = check_recording(rec, qa)
assert result.flagged is True
assert result.passed is False
def test_multiple_issues_reported(self):
"""Both short duration AND low SNR should both appear in issues list."""
sr = 1e6
n = int(sr * 5) # very short
rec = _make_recording(n, sr, _noise(n, amplitude=0.0001))
result = check_recording(rec, DEFAULT_QA)
assert result.flagged is True
assert len(result.issues) >= 2
# ---------------------------------------------------------------------------
# check_recording — multichannel input
# ---------------------------------------------------------------------------
class TestCheckRecordingMultichannel:
def test_multichannel_recording(self):
"""2-channel recording should evaluate channel 0 without error."""
sr = 1e6
n = int(sr * 30)
ch0 = _tone(n, sr)
ch1 = _tone(n, sr, freq_hz=200e3)
data = np.stack([ch0, ch1]) # shape (2, N)
rec = Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9})
result = check_recording(rec, DEFAULT_QA)
assert result.passed is True
assert result.flagged is False
# ---------------------------------------------------------------------------
# QAResult.to_dict
# ---------------------------------------------------------------------------
class TestQAResultToDict:
def test_to_dict_keys(self):
r = QAResult(passed=True, flagged=False, snr_db=18.3, duration_s=30.0)
d = r.to_dict()
assert set(d.keys()) == {"passed", "flagged", "snr_db", "duration_s", "issues"}
def test_to_dict_values(self):
r = QAResult(passed=False, flagged=True, snr_db=7.5, duration_s=10.2, issues=["SNR below threshold"])
d = r.to_dict()
assert d["passed"] is False
assert d["flagged"] is True
assert d["snr_db"] == pytest.approx(7.5, abs=0.01)
assert d["duration_s"] == pytest.approx(10.2, abs=0.01)
assert d["issues"] == ["SNR below threshold"]