reformats and campaign additions

This commit is contained in:
ben 2026-03-11 10:27:18 -04:00
parent b1e3ebf74f
commit 019b0c6f4b
38 changed files with 4230 additions and 1059 deletions

2143
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -127,5 +127,8 @@ exclude = '''
)/
'''
[tool.pytest.ini_options]
pythonpath = ["src"]
[tool.isort]
profile = "black"

View File

@ -956,8 +956,10 @@ def get_result_sizes( # noqa: C901 # TODO: Simplify function
# Check that each class that will be augmented does not already suffice target_size
for cls_name, target_size_value in zip(classes_to_augment, target_size):
if class_sizes[cls_name] >= target_size_value:
raise ValueError(f"""target_size of {target_size_value} is already sufficed for current size of
{class_sizes[cls_name]} for class: {cls_name}""")
raise ValueError(
f"""target_size of {target_size_value} is already sufficed for current size of
{class_sizes[cls_name]} for class: {cls_name}"""
)
for index, class_name in enumerate(classes_to_augment):
result_sizes[class_name] = target_size[index]

View File

@ -316,6 +316,8 @@ def to_sigmf(
meta_dict = sigMF_metafile.ordered_metadata()
meta_dict["ria"] = metadata
if overwrite and os.path.isfile(meta_file_path):
os.remove(meta_file_path)
sigMF_metafile.tofile(meta_file_path)

View File

@ -0,0 +1,26 @@
"""Orchestration layer for automated RF capture campaigns."""
from .campaign import (
CampaignConfig,
CaptureStep,
QAConfig,
RecorderConfig,
TransmitterConfig,
)
from .executor import CampaignExecutor, CampaignResult, StepResult
from .labeler import label_recording
from .qa import QAResult, check_recording
__all__ = [
"CampaignConfig",
"CaptureStep",
"QAConfig",
"RecorderConfig",
"TransmitterConfig",
"CampaignExecutor",
"CampaignResult",
"StepResult",
"label_recording",
"QAResult",
"check_recording",
]

View File

@ -0,0 +1,446 @@
"""Campaign configuration schema and YAML parser for orchestrated RF captures."""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import yaml
# ---------------------------------------------------------------------------
# Parsing helpers
# ---------------------------------------------------------------------------
def parse_duration(value: str | float | int) -> float:
"""Parse a duration string to seconds.
Accepts:
"30s" 30.0
"1.5m" or "1.5min" 90.0
"2h" 7200.0
30 (numeric) 30.0
"""
if isinstance(value, (int, float)):
return float(value)
value = str(value).strip()
match = re.fullmatch(r"([\d.]+)\s*(s|sec|m|min|h|hr)?", value, re.IGNORECASE)
if not match:
raise ValueError(f"Cannot parse duration: '{value}'")
amount = float(match.group(1))
unit = (match.group(2) or "s").lower()
if unit in ("h", "hr"):
return amount * 3600
if unit in ("m", "min"):
return amount * 60
return amount
def parse_frequency(value: str | float | int) -> float:
"""Parse a frequency string to Hz.
Accepts:
"2.45GHz" 2_450_000_000.0
"40MHz" 40_000_000.0
"915e6" 915_000_000.0
2.45e9 (numeric) 2_450_000_000.0
"""
if isinstance(value, (int, float)):
return float(value)
value = str(value).strip()
# Try bare numeric first (handles scientific notation like "915e6")
try:
return float(value)
except ValueError:
pass
# Handle suffix notation: "2.45GHz", "40MHz", "40M", "433k"
match = re.fullmatch(r"([\d.]+)\s*(k|M|G)(?:\s*Hz?)?", value, re.IGNORECASE)
if match:
amount = float(match.group(1))
suffix = match.group(2).upper()
return amount * {"K": 1e3, "M": 1e6, "G": 1e9}[suffix]
raise ValueError(f"Cannot parse frequency: '{value}'")
def parse_gain(value: str | float | int) -> float | str:
"""Parse a gain string.
Accepts:
"40dB" or "40 dB" 40.0
"auto" "auto"
40 (numeric) 40.0
"""
if isinstance(value, (int, float)):
return float(value)
value = str(value).strip()
if value.lower() == "auto":
return "auto"
match = re.fullmatch(r"([\d.+\-]+)\s*dB?", value, re.IGNORECASE)
if not match:
raise ValueError(f"Cannot parse gain: '{value}'")
return float(match.group(1))
def parse_bandwidth_mhz(value: str | float | int | None) -> Optional[float]:
"""Parse a bandwidth string to MHz.
Accepts:
"20MHz" 20.0
"40MHz" 40.0
20 (numeric, assumed MHz) 20.0
None None
"""
if value is None:
return None
if isinstance(value, (int, float)):
return float(value)
value = str(value).strip()
match = re.fullmatch(r"([\d.]+)\s*MHz?", value, re.IGNORECASE)
if match:
return float(match.group(1))
match = re.fullmatch(r"([\d.]+)", value)
if match:
return float(match.group(1))
raise ValueError(f"Cannot parse bandwidth: '{value}'")
# ---------------------------------------------------------------------------
# Config dataclasses
# ---------------------------------------------------------------------------
@dataclass
class RecorderConfig:
"""SDR recorder configuration."""
device: str
center_freq: float # Hz
sample_rate: float # Hz
gain: float | str # dB float, or "auto"
bandwidth: Optional[float] = None # Hz, None = match sample_rate
@classmethod
def from_dict(cls, d: dict) -> "RecorderConfig":
gain = parse_gain(d.get("gain", "auto"))
bandwidth_raw = d.get("bandwidth") or d.get("bandwidth_hz")
bandwidth = parse_frequency(bandwidth_raw) if bandwidth_raw else None
return cls(
device=str(d["device"]),
center_freq=parse_frequency(d["center_freq"]),
sample_rate=parse_frequency(d["sample_rate"]),
gain=gain,
bandwidth=bandwidth,
)
@dataclass
class CaptureStep:
"""A single timed capture within a transmitter schedule."""
duration: float # seconds
label: str # used as filename component
# WiFi-specific
channel: Optional[int] = None
bandwidth_mhz: Optional[float] = None # MHz
traffic: Optional[str] = None
# Bluetooth-specific
connection_interval_ms: Optional[float] = None
# Power (dBm), optional
power_dbm: Optional[float] = None
@classmethod
def from_dict(cls, d: dict, auto_label: bool = True) -> "CaptureStep":
duration = parse_duration(d["duration"])
label = d.get("label", "")
if not label and auto_label:
parts = []
if d.get("channel"):
parts.append(f"ch{d['channel']:02d}")
if d.get("bandwidth"):
bw = parse_bandwidth_mhz(d["bandwidth"])
parts.append(f"{int(bw)}mhz")
if d.get("traffic"):
parts.append(str(d["traffic"]).replace(" ", "_"))
label = "_".join(parts) if parts else "capture"
return cls(
duration=duration,
label=label,
channel=d.get("channel"),
bandwidth_mhz=parse_bandwidth_mhz(d.get("bandwidth")),
traffic=d.get("traffic"),
connection_interval_ms=d.get("connection_interval_ms"),
power_dbm=float(d["power"].rstrip("dBm").strip()) if d.get("power") else None,
)
@dataclass
class TransmitterConfig:
"""Configuration for a single transmitter device in the campaign."""
id: str
type: str # "wifi", "bluetooth", "sdr", "external"
control_method: str # "external_script" | "sdr"
schedule: list[CaptureStep]
# For external_script control
script: Optional[str] = None # path to control script
device: Optional[str] = None # e.g. "/dev/wlan0"
@classmethod
def from_dict(cls, d: dict) -> "TransmitterConfig":
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
return cls(
id=str(d["id"]),
type=str(d["type"]),
control_method=str(d.get("control_method", "external_script")),
schedule=schedule,
script=d.get("script"),
device=d.get("device"),
)
@dataclass
class QAConfig:
"""Quality assurance thresholds."""
snr_threshold_db: float = 10.0
min_duration_s: float = 25.0
flag_for_review: bool = True
@classmethod
def from_dict(cls, d: dict) -> "QAConfig":
return cls(
snr_threshold_db=float(str(d.get("snr_threshold", "10")).rstrip("dB").strip()),
min_duration_s=parse_duration(d.get("min_duration", "25s")),
flag_for_review=bool(d.get("flag_for_review", True)),
)
@dataclass
class OutputConfig:
"""Where to save captured recordings."""
format: str = "sigmf"
path: str = "recordings"
device_id: Optional[str] = None # for device-profile campaigns
repo: Optional[str] = None
@classmethod
def from_dict(cls, d: dict) -> "OutputConfig":
return cls(
format=str(d.get("format", "sigmf")),
path=str(d.get("path", "recordings")),
device_id=d.get("device_id"),
repo=d.get("repo"),
)
@dataclass
class CampaignConfig:
"""Full campaign configuration parsed from YAML."""
name: str
recorder: RecorderConfig
transmitters: list[TransmitterConfig]
qa: QAConfig = field(default_factory=QAConfig)
output: OutputConfig = field(default_factory=OutputConfig)
mode: str = "controlled_testbed"
# ---------------------------------------------------------------------------
# Loaders
# ---------------------------------------------------------------------------
@classmethod
def from_dict(cls, raw: dict) -> "CampaignConfig":
"""Build a CampaignConfig from a parsed dictionary.
Accepts the same structure as the campaign YAML, already loaded into
a Python dict (e.g. from a JSON HTTP request body).
Raises:
ValueError: If required fields are missing or malformed.
KeyError: If ``recorder`` key is absent.
"""
campaign_meta = raw.get("campaign", {})
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
if not transmitters:
raise ValueError("Campaign config must define at least one transmitter")
return cls(
name=str(campaign_meta.get("name", "unnamed")),
mode=str(campaign_meta.get("mode", "controlled_testbed")),
recorder=RecorderConfig.from_dict(raw["recorder"]),
transmitters=transmitters,
qa=QAConfig.from_dict(raw.get("qa", {})),
output=OutputConfig.from_dict(raw.get("output", {})),
)
@classmethod
def from_yaml(cls, path: str | Path) -> "CampaignConfig":
"""Load a full campaign config YAML.
Expected format::
campaign:
name: "wifi_capture_001"
mode: "controlled_testbed"
transmitters:
- id: "laptop_wifi"
type: "wifi"
control_method: "external_script"
script: "./scripts/wifi_control.sh"
device: "/dev/wlan0"
schedule:
- channel: 6
bandwidth: "20MHz"
traffic: "iperf_udp"
duration: "30s"
recorder:
device: "usrp_b210"
center_freq: "2.45GHz"
sample_rate: "40MHz"
gain: "40dB"
qa:
snr_threshold: "10dB"
min_duration: "25s"
flag_for_review: true
output:
format: "sigmf"
path: "./recordings"
"""
path = Path(path)
try:
with open(path) as f:
raw = yaml.safe_load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Campaign config not found: {path}")
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {path}: {e}")
campaign_meta = raw.get("campaign", {})
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
if not transmitters:
raise ValueError("Campaign config must define at least one transmitter")
return cls(
name=str(campaign_meta.get("name", path.stem)),
mode=str(campaign_meta.get("mode", "controlled_testbed")),
recorder=RecorderConfig.from_dict(raw["recorder"]),
transmitters=transmitters,
qa=QAConfig.from_dict(raw.get("qa", {})),
output=OutputConfig.from_dict(raw.get("output", {})),
)
@classmethod
def from_device_profile(cls, path: str | Path) -> "CampaignConfig":
"""Build a campaign config from an App 1 device profile YAML.
Expected format::
device:
name: "iPhone_13_WiFi"
type: "wifi"
protocol: "wifi_24ghz"
capture:
channels: [1, 6, 11] # WiFi only
bandwidth: "20MHz" # WiFi only
traffic_patterns: ["idle", "ping", "iperf_udp"]
duration_per_config: "30s"
recorder:
device: "usrp_b210"
center_freq: "2.45GHz"
sample_rate: "40MHz"
gain: "auto"
output:
path: "./recordings"
device_id: "iphone13_wifi_001"
For WiFi devices, schedule is expanded as channels × traffic_patterns.
For Bluetooth devices (no channels), schedule is traffic_patterns only.
"""
path = Path(path)
try:
with open(path) as f:
raw = yaml.safe_load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Device profile not found: {path}")
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {path}: {e}")
device = raw.get("device", {})
capture = raw.get("capture", {})
device_type = str(device.get("type", "wifi")).lower()
device_name = str(device.get("name", path.stem))
duration = parse_duration(capture.get("duration_per_config", "30s"))
traffic_patterns = capture.get("traffic_patterns", ["idle"])
# Build capture schedule
schedule: list[CaptureStep] = []
if device_type in ("wifi", "wifi_24ghz", "wifi_5ghz"):
channels = capture.get("channels", [6])
bw_str = capture.get("bandwidth", "20MHz")
bw_mhz = parse_bandwidth_mhz(bw_str)
for ch in channels:
for traffic in traffic_patterns:
label = f"ch{ch:02d}_{int(bw_mhz)}mhz_{traffic}"
schedule.append(
CaptureStep(
duration=duration,
label=label,
channel=ch,
bandwidth_mhz=bw_mhz,
traffic=traffic,
)
)
else:
# Bluetooth / generic — no channels
for traffic in traffic_patterns:
schedule.append(
CaptureStep(
duration=duration,
label=traffic,
traffic=traffic,
)
)
device_id = raw.get("output", {}).get("device_id", device_name.lower().replace(" ", "_"))
transmitter = TransmitterConfig(
id=device_id,
type=device_type,
control_method=str(capture.get("control_method", "external_script")),
schedule=schedule,
script=capture.get("script"),
device=capture.get("device"),
)
return cls(
name=f"enroll_{device_id}",
mode="controlled_testbed",
recorder=RecorderConfig.from_dict(raw["recorder"]),
transmitters=[transmitter],
qa=QAConfig.from_dict(raw.get("qa", {})),
output=OutputConfig.from_dict(raw.get("output", {})),
)
def total_capture_time_s(self) -> float:
"""Sum of all step durations across all transmitters."""
return sum(step.duration for tx in self.transmitters for step in tx.schedule)
def total_steps(self) -> int:
"""Total number of capture steps across all transmitters."""
return sum(len(tx.schedule) for tx in self.transmitters)

View File

@ -0,0 +1,423 @@
"""Campaign executor: runs a capture campaign end-to-end."""
from __future__ import annotations
import json
import logging
import subprocess
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.io.recording import to_sigmf
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
from .labeler import build_output_filename, label_recording
from .qa import QAResult, check_recording
logger = logging.getLogger(__name__)
# Device name aliases: campaign YAML names → get_sdr_device() names
_DEVICE_ALIASES = {
"usrp_b210": "usrp",
"usrp_b200": "usrp",
"usrp": "usrp",
"plutosdr": "pluto",
"pluto": "pluto",
"hackrf": "hackrf",
"hackrf_one": "hackrf",
"bladerf": "bladerf",
"rtlsdr": "rtlsdr",
"rtl_sdr": "rtlsdr",
"thinkrf": "thinkrf",
}
@dataclass
class StepResult:
"""Outcome of a single capture step."""
transmitter_id: str
step_label: str
output_path: Optional[str]
qa: QAResult
capture_timestamp: float
error: Optional[str] = None
@property
def ok(self) -> bool:
return self.error is None and self.qa.passed
def to_dict(self) -> dict:
return {
"transmitter_id": self.transmitter_id,
"step_label": self.step_label,
"output_path": self.output_path,
"capture_timestamp": self.capture_timestamp,
"qa": self.qa.to_dict(),
"error": self.error,
}
@dataclass
class CampaignResult:
"""Aggregate outcome of a full campaign."""
campaign_name: str
steps: list[StepResult] = field(default_factory=list)
start_time: float = field(default_factory=time.time)
end_time: Optional[float] = None
@property
def total_steps(self) -> int:
return len(self.steps)
@property
def passed(self) -> int:
return sum(1 for s in self.steps if s.ok)
@property
def flagged(self) -> int:
return sum(1 for s in self.steps if not s.error and s.qa.flagged)
@property
def failed(self) -> int:
return sum(1 for s in self.steps if s.error or not s.qa.passed)
@property
def duration_s(self) -> float:
if self.end_time:
return self.end_time - self.start_time
return time.time() - self.start_time
def to_dict(self) -> dict:
return {
"campaign_name": self.campaign_name,
"total_steps": self.total_steps,
"passed": self.passed,
"flagged": self.flagged,
"failed": self.failed,
"duration_s": round(self.duration_s, 1),
"steps": [s.to_dict() for s in self.steps],
}
def write_report(self, path: str | Path) -> None:
"""Write a JSON QA report to disk."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
json.dump(self.to_dict(), f, indent=2)
logger.info(f"QA report written to {path}")
# ---------------------------------------------------------------------------
# External script interface
# ---------------------------------------------------------------------------
def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
"""Run an external control script and return stdout.
The script is called as::
<script> <arg1> <arg2> ...
A non-zero return code raises RuntimeError.
Args:
script: Path to executable script.
*args: Positional arguments forwarded to the script.
timeout: Maximum seconds to wait.
Returns:
Script stdout as a string.
"""
cmd = [script, *args]
logger.debug(f"Running script: {' '.join(cmd)}")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=timeout,
)
except subprocess.TimeoutExpired:
raise RuntimeError(f"Script timed out after {timeout}s: {script}")
except FileNotFoundError:
raise RuntimeError(f"Script not found: {script}")
if result.returncode != 0:
raise RuntimeError(f"Script exited {result.returncode}: {result.stderr.strip() or result.stdout.strip()}")
return result.stdout.strip()
# ---------------------------------------------------------------------------
# Campaign executor
# ---------------------------------------------------------------------------
class CampaignExecutor:
"""Executes a :class:`CampaignConfig` end-to-end.
Initialises the SDR recorder once, then for each (transmitter, step):
1. Configures the transmitter (via external script or SDR TX)
2. Records IQ samples
3. Labels the recording with device/config metadata
4. Runs QA checks
5. Saves the recording to disk
6. Stops/resets the transmitter
Args:
config: Parsed campaign configuration.
progress_cb: Optional callback ``(step_index, total_steps, step_result)``
called after each step completes. Useful for status reporting.
verbose: Enable debug logging.
"""
def __init__(
self,
config: CampaignConfig,
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
verbose: bool = False,
):
self.config = config
self.progress_cb = progress_cb
self._sdr = None
if verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def run(self) -> CampaignResult:
"""Execute the full campaign and return a :class:`CampaignResult`.
Initialises the SDR, runs all steps across all transmitters,
then closes the SDR. If SDR initialisation fails the exception
propagates immediately (nothing is captured).
"""
result = CampaignResult(campaign_name=self.config.name)
logger.info(
f"Starting campaign '{self.config.name}': "
f"{self.config.total_steps()} steps, "
f"~{self.config.total_capture_time_s():.0f}s capture time"
)
self._init_sdr()
try:
total = self.config.total_steps()
step_index = 0
for transmitter in self.config.transmitters:
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
for step in transmitter.schedule:
step_result = self._execute_step(transmitter, step)
result.steps.append(step_result)
step_index += 1
if self.progress_cb:
self.progress_cb(step_index, total, step_result)
if step_result.error:
logger.warning(f"Step '{step.label}' error: {step_result.error}")
elif step_result.qa.flagged:
logger.warning(f"Step '{step.label}' flagged for review: " + "; ".join(step_result.qa.issues))
else:
logger.info(
f"Step '{step.label}' OK "
f"(SNR {step_result.qa.snr_db:.1f} dB, "
f"{step_result.qa.duration_s:.1f}s)"
)
finally:
self._close_sdr()
result.end_time = time.time()
logger.info(
f"Campaign complete: {result.passed}/{result.total_steps} passed, "
f"{result.flagged} flagged, {result.failed} failed"
)
return result
# ------------------------------------------------------------------
# SDR management
# ------------------------------------------------------------------
def _init_sdr(self) -> None:
"""Initialise and configure the SDR recorder."""
from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
rec = self.config.recorder
device_name = _DEVICE_ALIASES.get(rec.device.lower(), rec.device.lower())
logger.info(f"Initialising SDR: {device_name} @ {rec.center_freq/1e6:.2f} MHz")
self._sdr = get_sdr_device(device_name)
gain = None if rec.gain == "auto" else float(rec.gain)
self._sdr.init_rx(
sample_rate=rec.sample_rate,
center_frequency=rec.center_freq,
gain=gain,
channel=0,
)
if rec.bandwidth and hasattr(self._sdr, "set_rx_bandwidth"):
self._sdr.set_rx_bandwidth(rec.bandwidth)
def _close_sdr(self) -> None:
if self._sdr is not None:
try:
self._sdr.close()
except Exception as e:
logger.warning(f"SDR close error: {e}")
self._sdr = None
def _record(self, duration_s: float) -> Recording:
"""Capture ``duration_s`` seconds of IQ samples."""
num_samples = int(duration_s * self.config.recorder.sample_rate)
return self._sdr.record(num_samples=num_samples)
# ------------------------------------------------------------------
# Step execution
# ------------------------------------------------------------------
def _execute_step(self, transmitter: TransmitterConfig, step: CaptureStep) -> StepResult:
"""Run a single capture step.
Returns:
StepResult with QA outcome and output path (or error string).
"""
capture_timestamp = time.time()
output_path: Optional[str] = None
try:
self._start_transmitter(transmitter, step)
recording = self._record(step.duration)
self._stop_transmitter(transmitter, step)
except Exception as e:
# Best-effort stop on error
try:
self._stop_transmitter(transmitter, step)
except Exception:
pass
return StepResult(
transmitter_id=transmitter.id,
step_label=step.label,
output_path=None,
qa=QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=[f"Capture error: {e}"]),
capture_timestamp=capture_timestamp,
error=str(e),
)
# Label recording
recording = label_recording(
recording=recording,
device_id=transmitter.id,
step=step,
capture_timestamp=capture_timestamp,
campaign_name=self.config.name,
)
# QA
qa_result = check_recording(recording, self.config.qa)
# Save
try:
output_path = self._save(recording, transmitter.id, step)
except Exception as e:
return StepResult(
transmitter_id=transmitter.id,
step_label=step.label,
output_path=None,
qa=qa_result,
capture_timestamp=capture_timestamp,
error=f"Save failed: {e}",
)
return StepResult(
transmitter_id=transmitter.id,
step_label=step.label,
output_path=output_path,
qa=qa_result,
capture_timestamp=capture_timestamp,
)
# ------------------------------------------------------------------
# Transmitter control (external script interface)
# ------------------------------------------------------------------
def _start_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
"""Configure the transmitter for this step.
For ``external_script`` control method the script is called as::
<script> configure <step_params_json>
where ``step_params_json`` is a JSON object with channel, bandwidth,
traffic, etc. The script is responsible for applying the configuration
and returning promptly (i.e. not blocking for the capture duration).
For SDR transmitters this is a no-op placeholder (TX not yet implemented).
"""
if transmitter.control_method == "external_script":
if not transmitter.script:
logger.debug(f"No script configured for {transmitter.id}, skipping configure")
return
params = self._step_params_json(transmitter, step)
_run_script(transmitter.script, "configure", params)
elif transmitter.control_method == "sdr":
logger.debug("SDR TX not yet implemented — skipping start")
else:
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
def _stop_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
"""Signal the transmitter to stop.
Calls ``<script> stop`` for external_script transmitters.
"""
if transmitter.control_method == "external_script":
if not transmitter.script:
return
try:
_run_script(transmitter.script, "stop")
except Exception as e:
logger.warning(f"Script stop failed for {transmitter.id}: {e}")
@staticmethod
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
"""Serialise step parameters to a JSON string for the control script."""
params: dict = {"device": transmitter.device or ""}
if step.channel is not None:
params["channel"] = step.channel
if step.bandwidth_mhz is not None:
params["bandwidth_mhz"] = step.bandwidth_mhz
if step.traffic is not None:
params["traffic"] = step.traffic
if step.power_dbm is not None:
params["power_dbm"] = step.power_dbm
return json.dumps(params)
# ------------------------------------------------------------------
# Output
# ------------------------------------------------------------------
def _save(self, recording: Recording, device_id: str, step: CaptureStep) -> str:
"""Save a recording to disk and return the data file path."""
out = self.config.output
rel_filename = build_output_filename(device_id, step)
out_dir = Path(out.path)
# build_output_filename returns "<device_id>/<label>"
# to_sigmf needs filename (base) and path (dir) separately
parts = Path(rel_filename)
subdir = out_dir / parts.parent
subdir.mkdir(parents=True, exist_ok=True)
base = parts.name
to_sigmf(recording, filename=base, path=str(subdir), overwrite=True)
return str(subdir / f"{base}.sigmf-data")

View File

@ -0,0 +1,77 @@
"""Timestamp-based labeling for captured recordings."""
from __future__ import annotations
from typing import Optional
from ria_toolkit_oss.datatypes.recording import Recording
from .campaign import CaptureStep
def label_recording(
recording: Recording,
device_id: str,
step: CaptureStep,
capture_timestamp: float,
campaign_name: Optional[str] = None,
) -> Recording:
"""Apply device identity and capture configuration labels to a recording's metadata.
Labels are stored in the ``ria:*`` namespace when the recording is saved
as SigMF, via the existing ``update_metadata`` mechanism.
Args:
recording: The recording to label.
device_id: Identifier for the transmitting device (e.g. "iphone13_wifi_001").
step: The capture step that was active during this recording.
capture_timestamp: Unix timestamp (float) of when capture started.
campaign_name: Optional campaign name for cross-recording reference.
Returns:
The same recording with updated metadata.
"""
recording.update_metadata("device_id", device_id)
recording.update_metadata("capture_timestamp", capture_timestamp)
recording.update_metadata("step_label", step.label)
recording.update_metadata("step_duration_s", step.duration)
if campaign_name:
recording.update_metadata("campaign", campaign_name)
# WiFi-specific labels
if step.channel is not None:
recording.update_metadata("wifi_channel", step.channel)
if step.bandwidth_mhz is not None:
recording.update_metadata("wifi_bandwidth_mhz", step.bandwidth_mhz)
# Bluetooth-specific labels
if step.connection_interval_ms is not None:
recording.update_metadata("bt_connection_interval_ms", step.connection_interval_ms)
# Traffic pattern (WiFi + BT)
if step.traffic is not None:
recording.update_metadata("traffic_pattern", step.traffic)
# TX power
if step.power_dbm is not None:
recording.update_metadata("tx_power_dbm", step.power_dbm)
return recording
def build_output_filename(device_id: str, step: CaptureStep) -> str:
"""Generate a deterministic filename for a labeled recording.
Format: ``<device_id>/<step_label>``
Args:
device_id: Device identifier string.
step: Capture step.
Returns:
Relative path string (no extension) to use as ``filename`` in ``to_sigmf()``.
"""
safe_id = device_id.replace("/", "_").replace(" ", "_")
safe_label = step.label.replace("/", "_").replace(" ", "_")
return f"{safe_id}/{safe_label}"

View File

@ -0,0 +1,109 @@
"""QA metrics for captured RF recordings."""
from __future__ import annotations
from dataclasses import dataclass, field
import numpy as np
from ria_toolkit_oss.datatypes.recording import Recording
from .campaign import QAConfig
@dataclass
class QAResult:
"""Result of QA checks on a single recording."""
passed: bool
flagged: bool # True if any metric is below threshold (but not hard-failed)
snr_db: float
duration_s: float
issues: list[str] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"passed": self.passed,
"flagged": self.flagged,
"snr_db": round(self.snr_db, 2),
"duration_s": round(self.duration_s, 3),
"issues": self.issues,
}
def estimate_snr_db(samples: np.ndarray, signal_fraction: float = 0.7) -> float:
"""Estimate SNR from IQ samples using PSD-based signal/noise separation.
Computes an FFT of the samples and assumes the top ``signal_fraction``
of power bins are signal and the remainder are noise. This is a
heuristic appropriate for a controlled testbed where a single dominant
signal is expected.
Args:
samples: 1-D complex array of IQ samples.
signal_fraction: Fraction of PSD bins to treat as signal (01).
Returns:
Estimated SNR in dB, or 0.0 if the noise floor is zero.
"""
n_fft = min(4096, len(samples))
window = np.hanning(n_fft)
psd = np.abs(np.fft.fft(samples[:n_fft] * window)) ** 2
psd_sorted = np.sort(psd)[::-1]
n_signal = max(1, int(n_fft * signal_fraction))
signal_power = psd_sorted[:n_signal].mean()
noise_power = psd_sorted[n_signal:].mean()
if noise_power <= 0.0:
return 0.0
return float(10.0 * np.log10(signal_power / noise_power))
def check_recording(recording: Recording, config: QAConfig) -> QAResult:
"""Run QA checks on a recording against the campaign QA config.
Checks performed:
- Duration: number of samples / sample_rate >= min_duration_s
- SNR: estimated SNR >= snr_threshold_db
Args:
recording: Recording to evaluate.
config: QA thresholds from the campaign config.
Returns:
QAResult with pass/flag status and per-metric details.
"""
issues: list[str] = []
flagged = False
# --- Duration check ---
sample_rate = recording.metadata.get("sample_rate", 1.0)
n_samples = recording.data.shape[-1]
duration_s = n_samples / sample_rate if sample_rate else 0.0
if duration_s < config.min_duration_s:
issues.append(f"Duration too short: {duration_s:.1f}s < {config.min_duration_s:.1f}s threshold")
flagged = True
# --- SNR check ---
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
snr_db = estimate_snr_db(samples)
if snr_db < config.snr_threshold_db:
issues.append(f"SNR below threshold: {snr_db:.1f} dB < {config.snr_threshold_db:.1f} dB")
flagged = True
# In flag_for_review mode: flag but don't hard-fail
if config.flag_for_review:
passed = True # always accept; human reviews flagged recordings
else:
passed = not flagged
return QAResult(
passed=passed,
flagged=flagged,
snr_db=snr_db,
duration_s=duration_s,
issues=issues,
)

View File

@ -474,8 +474,10 @@ class Blade(SDR):
if gain_mode == "relative":
if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets \
the gain relative to the maximum possible gain.")
raise SDRParameterError(
"When gain_mode = 'relative', gain must be < 0. This sets \
the gain relative to the maximum possible gain."
)
else:
abs_gain = rx_gain_max + gain
else:
@ -548,8 +550,10 @@ class Blade(SDR):
if gain_mode == "relative":
if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain.")
raise SDRParameterError(
"When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain."
)
else:
abs_gain = tx_gain_max + gain
else:

View File

@ -172,8 +172,10 @@ class HackRF(SDR):
tx_gain_max = 47
if gain_mode == "relative":
if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This \
sets the gain relative to the maximum possible gain.")
raise SDRParameterError(
"When gain_mode = 'relative', gain must be < 0. This \
sets the gain relative to the maximum possible gain."
)
else:
abs_gain = tx_gain_max + gain
else:

View File

@ -274,16 +274,20 @@ class Pluto(SDR):
data = [self._convert_tx_samples(samples), self._convert_tx_samples(samples)]
else:
if len(recording) > 2:
warnings.warn("More recordings were provided than channels in the Pluto. \
Only the first two recordings will be used")
warnings.warn(
"More recordings were provided than channels in the Pluto. \
Only the first two recordings will be used"
)
sample0 = self._convert_tx_samples(recording.data[0])
sample1 = self._convert_tx_samples(recording.data[1])
data = [sample0, sample1]
elif isinstance(recording, list):
if len(recording) > 2:
warnings.warn("More recordings were provided than channels in the Pluto. \
Only the first two recordings will be used")
warnings.warn(
"More recordings were provided than channels in the Pluto. \
Only the first two recordings will be used"
)
if isinstance(recording[0], np.ndarray):
data = [self._convert_tx_samples(recording[0]), self._convert_tx_samples(recording[1])]
@ -423,8 +427,10 @@ class Pluto(SDR):
if gain_mode == "relative":
if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets \
the gain relative to the maximum possible gain.")
raise SDRParameterError(
"When gain_mode = 'relative', gain must be < 0. This sets \
the gain relative to the maximum possible gain."
)
else:
abs_gain = rx_gain_max + gain
else:
@ -534,8 +540,10 @@ class Pluto(SDR):
if gain_mode == "relative":
if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain.")
raise SDRParameterError(
"When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain."
)
else:
abs_gain = tx_gain_max + gain
else:

View File

@ -131,15 +131,19 @@ class RTLSDR(SDR):
if gain_mode == "relative":
if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain.")
raise SDRParameterError(
"When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain."
)
target_gain = max_gain + gain
else:
target_gain = gain
if target_gain < min_gain or target_gain > max_gain:
print(f"Requested gain {target_gain} dB out of range;\
clamping to valid span {min_gain}-{max_gain} dB.")
print(
f"Requested gain {target_gain} dB out of range;\
clamping to valid span {min_gain}-{max_gain} dB."
)
target_gain = min(max(target_gain, min_gain), max_gain)
target_gain = min(available_gains, key=lambda g: abs(g - target_gain))

View File

@ -392,8 +392,10 @@ class ThinkRF(SDR):
actual_sample_rate = self.BASE_SAMPLE_RATE / decimation
if abs(actual_sample_rate - requested_sample_rate) > 1e3: # More than 1 kHz difference
print(f"ThinkRF: Requested {requested_sample_rate/1e6:.2f} MS/s → \
Using decimation={decimation} ({actual_sample_rate/1e6:.2f} MS/s)")
print(
f"ThinkRF: Requested {requested_sample_rate/1e6:.2f} MS/s → \
Using decimation={decimation} ({actual_sample_rate/1e6:.2f} MS/s)"
)
return decimation, actual_sample_rate

View File

@ -148,8 +148,10 @@ class USRP(SDR):
gain_range = self.usrp.get_rx_gain_range()
if gain_mode == "relative":
if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain.")
raise SDRParameterError(
"When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain."
)
else:
# set gain relative to max
abs_gain = gain_range.stop() + gain
@ -354,8 +356,10 @@ class USRP(SDR):
gain_range = self.usrp.get_tx_gain_range()
if gain_mode == "relative":
if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain.")
raise SDRParameterError(
"When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain."
)
else:
# set gain relative to max
abs_gain = gain_range.stop() + gain

View File

@ -0,0 +1,5 @@
"""RT-OSS HTTP server for RIA Hub integration."""
from .app import create_app
__all__ = ["create_app"]

View File

@ -0,0 +1,48 @@
"""FastAPI application factory for the RT-OSS HTTP server."""
from fastapi import Depends, FastAPI
from .auth import require_api_key
from .routers import inference, orchestrator
def create_app(api_key: str = "") -> FastAPI:
"""Create and configure the RT-OSS FastAPI application.
Args:
api_key: Secret key required in the ``X-API-Key`` request header.
Pass an empty string to disable authentication (development only).
Returns:
Configured FastAPI application instance.
"""
app = FastAPI(
title="RIA Toolkit OSS Server",
version="0.1.0",
description=(
"HTTP API for RT-OSS campaign orchestration and RF zone inference. "
"All endpoints (except /health) require the X-API-Key header when "
"an API key is configured."
),
)
app.state.api_key = api_key
app.include_router(
orchestrator.router,
prefix="/orchestrator",
tags=["Orchestrator"],
dependencies=[Depends(require_api_key)],
)
app.include_router(
inference.router,
prefix="/inference",
tags=["Inference"],
dependencies=[Depends(require_api_key)],
)
@app.get("/health", tags=["Health"])
async def health():
"""Health check — always returns 200."""
return {"status": "ok"}
return app

View File

@ -0,0 +1,25 @@
"""API key authentication dependency."""
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import APIKeyHeader
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def require_api_key(
request: Request,
api_key: str | None = Depends(_api_key_header),
) -> None:
"""FastAPI dependency that enforces X-API-Key header authentication.
If no API key is configured on the server (empty string), all requests
are allowed this is intended for local development only.
"""
expected: str = request.app.state.api_key
if not expected:
return # dev mode: no key set, allow all
if api_key != expected:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing API key",
)

View File

@ -0,0 +1,77 @@
"""Pydantic request and response models for the RT-OSS HTTP server."""
from __future__ import annotations
from pydantic import BaseModel
# ---------------------------------------------------------------------------
# Orchestrator
# ---------------------------------------------------------------------------
class DeployRequest(BaseModel):
config: dict
class DeployResponse(BaseModel):
campaign_id: str
class CampaignStatusResponse(BaseModel):
campaign_id: str
status: str
config_name: str
progress: int
total_steps: int
started_at: float
ended_at: float | None = None
result: dict | None = None
error: str | None = None
class CancelResponse(BaseModel):
campaign_id: str
cancelled: bool
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
class SdrConfig(BaseModel):
device: str
center_freq: float
sample_rate: float
gain: float | str = "auto"
class LoadModelRequest(BaseModel):
model_path: str
label_map: dict[str, int] # class_name -> class_index
class LoadModelResponse(BaseModel):
loaded: bool
model_path: str
num_classes: int
class StartInferenceRequest(BaseModel):
sdr_config: SdrConfig
class StartInferenceResponse(BaseModel):
running: bool
class StopInferenceResponse(BaseModel):
stopped: bool
class InferenceStatusResponse(BaseModel):
timestamp: float
prediction: str
confidence: float
snr_db: float
zone: str | None = None

View File

@ -0,0 +1,183 @@
"""Inference routes: model loading, inference loop control, and status polling."""
from __future__ import annotations
import logging
import threading
import time
import numpy as np
from fastapi import APIRouter, HTTPException, status
from scipy.special import softmax
from ..models import (
InferenceStatusResponse,
LoadModelRequest,
LoadModelResponse,
StartInferenceRequest,
StartInferenceResponse,
StopInferenceResponse,
)
from ..state import InferenceState, get_inference, set_inference
router = APIRouter()
logger = logging.getLogger(__name__)
_INFERENCE_NUM_SAMPLES = 4096
def _load_onnx_session(model_path: str):
try:
import onnxruntime as ort
except ImportError:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="onnxruntime not installed. Install with: pip install ria-toolkit-oss[server]",
)
try:
return ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
except Exception as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Failed to load ONNX model: {e}")
def _preprocess_samples(samples: np.ndarray, expected_shape: tuple) -> np.ndarray:
"""Reshape complex IQ samples to float32 matching the model's expected input.
Supports ``(batch, 2*N)`` interleaved and ``(batch, 2, N)`` two-channel conventions.
"""
iq = samples.astype(np.complex64)
i_ch, q_ch = iq.real, iq.imag
if len(expected_shape) == 2:
n = expected_shape[1] // 2
interleaved = np.empty(expected_shape[1], dtype=np.float32)
interleaved[0::2] = i_ch[:n]
interleaved[1::2] = q_ch[:n]
return interleaved.reshape(1, -1)
elif len(expected_shape) == 3:
n = expected_shape[2]
return np.stack([i_ch[:n], q_ch[:n]], axis=0).astype(np.float32).reshape(1, 2, n)
else:
raise ValueError(f"Unsupported model input shape: {expected_shape}")
def _stop_current_inference(state: InferenceState, timeout: float = 5.0) -> None:
state.stop_event.set()
if state.thread and state.thread.is_alive():
state.thread.join(timeout=timeout)
if state.thread.is_alive():
logger.warning("Inference thread did not stop within %.1fs; SDR resources may not be released", timeout)
def _inference_loop(state: InferenceState, sdr) -> None:
from ria_toolkit_oss.orchestration.qa import estimate_snr_db
session = state.session
input_name = session.get_inputs()[0].name
expected_shape = tuple(
d if isinstance(d, int) and d > 0 else _INFERENCE_NUM_SAMPLES for d in session.get_inputs()[0].shape
)
try:
while not state.stop_event.is_set():
recording = sdr.record(num_samples=_INFERENCE_NUM_SAMPLES)
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
snr_db = estimate_snr_db(samples)
try:
model_input = _preprocess_samples(samples, expected_shape)
logits = session.run(None, {input_name: model_input})[0][0].astype(np.float32)
probs = softmax(logits)
pred_idx = int(np.argmax(probs))
prediction = state.index_to_label.get(pred_idx, str(pred_idx))
except Exception:
continue
state.set_latest(
{
"timestamp": time.time(),
"prediction": prediction,
"confidence": round(float(probs[pred_idx]), 4),
"snr_db": round(snr_db, 2),
"zone": None,
}
)
finally:
try:
sdr.close()
except Exception:
pass
state.running = False
@router.post("/load", response_model=LoadModelResponse)
async def load_model(request: LoadModelRequest):
"""Load an ONNX model. Stops any running inference first.
``label_map`` maps class names to integer indices (e.g. ``{"zone_a": 0}``).
"""
existing = get_inference()
if existing and existing.running:
_stop_current_inference(existing)
session = _load_onnx_session(request.model_path)
set_inference(
InferenceState(
model_path=request.model_path,
label_map=request.label_map,
index_to_label={v: k for k, v in request.label_map.items()},
session=session,
)
)
return LoadModelResponse(loaded=True, model_path=request.model_path, num_classes=len(request.label_map))
@router.post("/start", response_model=StartInferenceResponse)
async def start_inference(request: StartInferenceRequest):
"""Start continuous inference. Requires a model to be loaded first."""
state = get_inference()
if not state:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail="No model loaded. Call POST /inference/load first."
)
if state.running:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Inference is already running.")
try:
from ria_toolkit_oss.orchestration.executor import _DEVICE_ALIASES
from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
except ImportError as e:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"SDR import failed: {e}")
sdr_cfg = request.sdr_config
try:
sdr = get_sdr_device(_DEVICE_ALIASES.get(sdr_cfg.device.lower(), sdr_cfg.device.lower()))
gain = None if sdr_cfg.gain == "auto" else float(sdr_cfg.gain)
sdr.init_rx(sample_rate=sdr_cfg.sample_rate, center_frequency=sdr_cfg.center_freq, gain=gain, channel=0)
except Exception as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"SDR initialisation failed: {e}")
state.stop_event.clear()
state.running = True
state.thread = threading.Thread(target=_inference_loop, args=(state, sdr), daemon=True)
state.thread.start()
return StartInferenceResponse(running=True)
@router.post("/stop", response_model=StopInferenceResponse)
async def stop_inference():
"""Stop the running inference loop."""
state = get_inference()
if not state or not state.running:
return StopInferenceResponse(stopped=False)
_stop_current_inference(state)
return StopInferenceResponse(stopped=True)
@router.get("/status", response_model=InferenceStatusResponse | None)
async def inference_status():
"""Return the latest inference result, or null if none available yet."""
state = get_inference()
if not state:
return None
latest = state.get_latest()
return InferenceStatusResponse(**latest) if latest else None

View File

@ -0,0 +1,102 @@
"""Orchestrator routes: campaign deployment, status, and cancellation."""
from __future__ import annotations
import threading
import time
import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, status
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
from ..models import (
CampaignStatusResponse,
CancelResponse,
DeployRequest,
DeployResponse,
)
from ..state import (
CampaignCancelledError,
CampaignState,
get_campaign,
set_campaign,
update_campaign,
)
router = APIRouter()
def _make_progress_cb(campaign_id: str, cancel_event: threading.Event):
def cb(step_index: int, total_steps: int, step_result: Any) -> None:
update_campaign(campaign_id, progress=step_index)
if cancel_event.is_set():
raise CampaignCancelledError(f"Cancelled at step {step_index}/{total_steps}")
return cb
def _run_campaign_thread(campaign_id: str, cfg: CampaignConfig) -> None:
state = get_campaign(campaign_id)
try:
result = CampaignExecutor(
config=cfg,
progress_cb=_make_progress_cb(campaign_id, state.cancel_event),
).run()
update_campaign(
campaign_id, status="completed", progress=cfg.total_steps(), result=result.to_dict(), ended_at=time.time()
)
except CampaignCancelledError:
update_campaign(campaign_id, status="cancelled", ended_at=time.time())
except Exception as e:
update_campaign(campaign_id, status="failed", error=str(e), ended_at=time.time())
@router.post("/deploy", response_model=DeployResponse)
async def deploy(request: DeployRequest):
"""Deploy a campaign config and start execution. Returns a ``campaign_id`` for polling.
Cancellation takes effect at step boundaries, not mid-capture.
"""
try:
cfg = CampaignConfig.from_dict(request.config)
except (ValueError, KeyError) as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
campaign_id = str(uuid.uuid4())
cancel_event = threading.Event()
thread = threading.Thread(target=_run_campaign_thread, args=(campaign_id, cfg), daemon=True)
set_campaign(
CampaignState(
campaign_id=campaign_id,
status="running",
config_name=cfg.name,
cancel_event=cancel_event,
thread=thread,
total_steps=cfg.total_steps(),
)
)
thread.start()
return DeployResponse(campaign_id=campaign_id)
@router.get("/status/{campaign_id}", response_model=CampaignStatusResponse)
async def get_status(campaign_id: str):
"""Get the status and progress of a deployed campaign."""
state = get_campaign(campaign_id)
if not state:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
return CampaignStatusResponse(**state.to_dict())
@router.post("/cancel/{campaign_id}", response_model=CancelResponse)
async def cancel(campaign_id: str):
"""Request cancellation. Takes effect at the next step boundary."""
state = get_campaign(campaign_id)
if not state:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
if state.status != "running":
return CancelResponse(campaign_id=campaign_id, cancelled=False)
state.cancel_event.set()
return CancelResponse(campaign_id=campaign_id, cancelled=True)

View File

@ -0,0 +1,101 @@
"""In-memory state for running campaigns and inference sessions."""
from __future__ import annotations
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Optional
class CampaignCancelledError(Exception):
"""Raised by the progress callback when a cancel is requested."""
@dataclass
class CampaignState:
campaign_id: str
status: str # "running" | "completed" | "failed" | "cancelled"
config_name: str
cancel_event: threading.Event
thread: threading.Thread
total_steps: int = 0
progress: int = 0
result: Optional[dict] = None
error: Optional[str] = None
started_at: float = field(default_factory=time.time)
ended_at: Optional[float] = None
def to_dict(self) -> dict:
return {
"campaign_id": self.campaign_id,
"status": self.status,
"config_name": self.config_name,
"progress": self.progress,
"total_steps": self.total_steps,
"result": self.result,
"error": self.error,
"started_at": self.started_at,
"ended_at": self.ended_at,
}
@dataclass
class InferenceState:
model_path: str
label_map: dict[str, int] # class_name -> class_index
index_to_label: dict[int, str] # reverse: class_index -> class_name
session: Any # onnxruntime.InferenceSession
stop_event: threading.Event = field(default_factory=threading.Event)
thread: Optional[threading.Thread] = None
running: bool = False
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
_latest: Optional[dict] = field(default=None, repr=False)
def set_latest(self, result: dict) -> None:
with self._lock:
self._latest = result
def get_latest(self) -> Optional[dict]:
with self._lock:
return self._latest
# ---------------------------------------------------------------------------
# Module-level stores
# ---------------------------------------------------------------------------
_campaigns: dict[str, CampaignState] = {}
_campaigns_lock = threading.Lock()
_inference: Optional[InferenceState] = None
_inference_lock = threading.Lock()
def get_campaign(campaign_id: str) -> Optional[CampaignState]:
with _campaigns_lock:
return _campaigns.get(campaign_id)
def set_campaign(state: CampaignState) -> None:
with _campaigns_lock:
_campaigns[state.campaign_id] = state
def update_campaign(campaign_id: str, **kwargs) -> None:
with _campaigns_lock:
state = _campaigns.get(campaign_id)
if state:
for k, v in kwargs.items():
setattr(state, k, v)
def get_inference() -> Optional[InferenceState]:
with _inference_lock:
return _inference
def set_inference(state: Optional[InferenceState]) -> None:
global _inference
with _inference_lock:
_inference = state

View File

@ -37,8 +37,10 @@ class Add(RecordableBlock, ProcessBlock):
samples = block.get_samples(num_samples)
if len(samples) != num_samples:
raise ValueError(f"Block {self.__class__.__name__} requested {num_samples} \
from block {block.__class__.__name__} but got {len(samples)}.")
raise ValueError(
f"Block {self.__class__.__name__} requested {num_samples} \
from block {block.__class__.__name__} but got {len(samples)}."
)
return samples

View File

@ -23,9 +23,11 @@ class ProcessBlock(Block, ABC):
)
elif not all(isinstance(item, Block) for item in input):
raise ValueError(f"Invalid input to block '{self.__class__.__name__}'. \
raise ValueError(
f"Invalid input to block '{self.__class__.__name__}'. \
Expected a list of Block objects but got \
{'[' + ',' .join(f'{item.__class__.__name__}({repr(item)})' for item in input) + ']'}")
{'[' + ',' .join(f'{item.__class__.__name__}({repr(item)})' for item in input) + ']'}"
)
elif len(input) != len(self.input_type):
raise ValueError(

View File

@ -20,8 +20,10 @@ class RecordableBlock(Block, Recordable):
:raises ValueError: If the number of samples is incorrect."""
samples = self.get_samples(num_samples)
if len(samples) != num_samples:
raise ValueError(f"Error in block {self.__class__.__name__} record(). \
Requested {num_samples} samples but got {len(samples)}")
raise ValueError(
f"Error in block {self.__class__.__name__} record(). \
Requested {num_samples} samples but got {len(samples)}"
)
metadata = self._get_metadata()
return Recording(data=samples, metadata=metadata)

View File

@ -39,7 +39,9 @@ class RecordingSource(SourceBlock, RecordableBlock):
:raises ValueError: If num_samples is greater than the recording length.
"""
if num_samples - 1 >= self.recording.data.shape[1]:
raise ValueError(f"{num_samples} samples requested from recording source with \
{self.recording.data.shape[1]} samples available.")
raise ValueError(
f"{num_samples} samples requested from recording source with \
{self.recording.data.shape[1]} samples available."
)
return self.recording.data[0, 0:num_samples]

View File

@ -610,8 +610,10 @@ def cut_out( # noqa: C901 # TODO: Simplify function
raise ValueError("signal must be CxN complex.")
if fill_type not in {"zeros", "ones", "low-snr", "avg-snr", "high-snr"}:
raise UserWarning("""fill_type must be "zeros", "ones", "low-snr", "avg-snr", or "high-snr",
"ones" has been selected by default""")
raise UserWarning(
"""fill_type must be "zeros", "ones", "low-snr", "avg-snr", or "high-snr",
"ones" has been selected by default"""
)
if max_section_size < 1 or max_section_size >= n:
raise ValueError("max_section_size must be at least 1 and must be less than the length of signal.")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 90 KiB

After

Width:  |  Height:  |  Size: 130 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

After

Width:  |  Height:  |  Size: 130 B

View File

@ -0,0 +1,258 @@
"""Campaign orchestration CLI commands.
Usage examples::
# Enroll a single device using a device profile YAML (App 1 workflow)
ria campaign enroll --config iphone13.yml
# Run a full custom campaign config
ria campaign run --config my_campaign.yml
# Validate a config file without running it
ria campaign validate --config my_campaign.yml
"""
import sys
import click
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
from ria_toolkit_oss.orchestration.executor import CampaignExecutor, StepResult
@click.group()
def campaign():
"""Orchestrate automated RF capture campaigns."""
# ---------------------------------------------------------------------------
# ria campaign validate
# ---------------------------------------------------------------------------
@campaign.command()
@click.option(
"--config",
"-c",
required=True,
type=click.Path(exists=True),
help="Campaign YAML config or device profile YAML.",
)
@click.option(
"--profile",
is_flag=True,
default=False,
help="Parse as a device profile (App 1 style) rather than a full campaign config.",
)
def validate(config, profile):
"""Validate a campaign or device profile YAML without running it.
\b
Examples:
ria campaign validate --config iphone13.yml --profile
ria campaign validate --config campaign.yml
"""
try:
if profile:
cfg = CampaignConfig.from_device_profile(config)
else:
cfg = CampaignConfig.from_yaml(config)
except (FileNotFoundError, ValueError, KeyError) as e:
raise click.ClickException(str(e))
click.echo(click.style("✓ Config valid", fg="green", bold=True))
click.echo(f" Campaign name : {cfg.name}")
click.echo(f" Mode : {cfg.mode}")
click.echo(f" Transmitters : {len(cfg.transmitters)}")
click.echo(f" Total steps : {cfg.total_steps()}")
click.echo(f" Capture time : {cfg.total_capture_time_s():.0f}s")
click.echo(f" Recorder : {cfg.recorder.device} @ {cfg.recorder.center_freq/1e6:.2f} MHz")
click.echo(f" Sample rate : {cfg.recorder.sample_rate/1e6:.1f} MS/s")
click.echo(f" Output path : {cfg.output.path}")
click.echo()
for tx in cfg.transmitters:
click.echo(f" Transmitter: {tx.id} ({tx.type}, {tx.control_method}, {len(tx.schedule)} steps)")
for step in tx.schedule:
extras = []
if step.channel is not None:
extras.append(f"ch={step.channel}")
if step.bandwidth_mhz is not None:
extras.append(f"{int(step.bandwidth_mhz)}MHz")
if step.traffic:
extras.append(step.traffic)
suffix = f" [{', '.join(extras)}]" if extras else ""
click.echo(f" [{step.duration:.0f}s] {step.label}{suffix}")
# ---------------------------------------------------------------------------
# ria campaign enroll
# ---------------------------------------------------------------------------
@campaign.command()
@click.option(
"--config",
"-c",
required=True,
type=click.Path(exists=True),
help="Device profile YAML (App 1 enrollment format).",
)
@click.option(
"--output",
"-o",
default=None,
help="Override output directory from config.",
)
@click.option(
"--report",
default="qa_report.json",
show_default=True,
help="Path for the JSON QA report.",
)
@click.option("--verbose", "-v", is_flag=True, help="Verbose output.")
@click.option("--dry-run", is_flag=True, help="Parse and validate config, then exit.")
def enroll(config, output, report, verbose, dry_run):
"""Enroll a single device by running its capture profile.
Parses a device profile YAML (App 1 format), generates a capture
campaign, and runs it. Outputs labelled SigMF recordings and a
JSON QA report.
\b
Examples:
ria campaign enroll --config iphone13.yml
ria campaign enroll --config airpods.yml --output ./my_recordings
ria campaign enroll --config iphone13.yml --dry-run
"""
try:
cfg = CampaignConfig.from_device_profile(config)
except (FileNotFoundError, ValueError, KeyError) as e:
raise click.ClickException(str(e))
if output:
cfg.output.path = output
_print_campaign_summary(cfg)
if dry_run:
click.echo(click.style("Dry run — exiting before capture.", fg="yellow"))
return
result = _run_campaign(cfg, verbose=verbose)
result.write_report(report)
_print_result_summary(result, report)
sys.exit(0 if result.failed == 0 else 1)
# ---------------------------------------------------------------------------
# ria campaign run
# ---------------------------------------------------------------------------
@campaign.command()
@click.option(
"--config",
"-c",
required=True,
type=click.Path(exists=True),
help="Full campaign YAML config.",
)
@click.option(
"--output",
"-o",
default=None,
help="Override output directory from config.",
)
@click.option(
"--report",
default="qa_report.json",
show_default=True,
help="Path for the JSON QA report.",
)
@click.option("--verbose", "-v", is_flag=True, help="Verbose output.")
@click.option("--dry-run", is_flag=True, help="Parse and validate config, then exit.")
def run(config, output, report, verbose, dry_run):
"""Run a full campaign from a campaign config YAML.
\b
Examples:
ria campaign run --config wifi_capture.yml
ria campaign run --config campaign.yml --output ./data --dry-run
"""
try:
cfg = CampaignConfig.from_yaml(config)
except (FileNotFoundError, ValueError, KeyError) as e:
raise click.ClickException(str(e))
if output:
cfg.output.path = output
_print_campaign_summary(cfg)
if dry_run:
click.echo(click.style("Dry run — exiting before capture.", fg="yellow"))
return
result = _run_campaign(cfg, verbose=verbose)
result.write_report(report)
_print_result_summary(result, report)
sys.exit(0 if result.failed == 0 else 1)
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def _print_campaign_summary(cfg: CampaignConfig) -> None:
click.echo()
click.echo(click.style(f"Campaign: {cfg.name}", bold=True))
click.echo(f" Transmitters : {len(cfg.transmitters)}")
click.echo(f" Total steps : {cfg.total_steps()}")
click.echo(f" Capture time : ~{cfg.total_capture_time_s():.0f}s")
click.echo(f" Output : {cfg.output.path}")
click.echo()
def _make_progress_cb(total: int):
"""Return a progress callback that prints step results to stderr."""
def cb(idx: int, _total: int, step: StepResult) -> None:
status = (
click.style("", fg="green")
if step.ok
else (click.style("", fg="yellow") if step.qa.flagged else click.style("", fg="red"))
)
snr_str = f"SNR {step.qa.snr_db:.1f} dB" if not step.error else f"ERROR: {step.error}"
click.echo(
f" [{idx:>3}/{_total}] {status} {step.transmitter_id}/{step.step_label}{snr_str}",
err=True,
)
return cb
def _run_campaign(cfg: CampaignConfig, verbose: bool = False):
executor = CampaignExecutor(
config=cfg,
progress_cb=_make_progress_cb(cfg.total_steps()),
verbose=verbose,
)
return executor.run()
def _print_result_summary(result, report_path: str) -> None:
click.echo()
click.echo(click.style("Campaign complete", bold=True))
click.echo(f" Steps : {result.total_steps}")
click.echo(f" Passed : {click.style(str(result.passed), fg='green')}")
if result.flagged:
click.echo(f" Flagged : {click.style(str(result.flagged), fg='yellow')} (review required)")
if result.failed:
click.echo(f" Failed : {click.style(str(result.failed), fg='red')}")
click.echo(f" Duration: {result.duration_s:.0f}s")
click.echo(f" Report : {report_path}")
click.echo()

View File

@ -3,6 +3,7 @@
This module contains all the CLI bindings for the ria package.
"""
from .campaign import campaign
from .capture import capture
from .combine import combine
from .convert import convert
@ -13,6 +14,7 @@ from .generate import generate
# from .generate import generate
from .init import init
from .serve import serve
from .split import split
from .transform import transform
from .transmit import transmit

View File

@ -0,0 +1,51 @@
"""``ria serve`` — start the RT-OSS HTTP server for RIA Hub integration."""
import click
@click.command()
@click.option("--host", default="0.0.0.0", show_default=True, help="Bind host.")
@click.option("--port", default=8080, show_default=True, type=int, help="Bind port.")
@click.option(
"--api-key",
envvar="RT_OSS_API_KEY",
default="",
help="Required X-API-Key value. Also reads RT_OSS_API_KEY. Empty = no auth (dev only).",
)
@click.option(
"--log-level",
default="info",
show_default=True,
type=click.Choice(["debug", "info", "warning", "error"], case_sensitive=False),
)
def serve(host: str, port: int, api_key: str, log_level: str):
"""Start the RT-OSS HTTP server.
\b
Endpoints:
POST /orchestrator/deploy
GET /orchestrator/status/{campaign_id}
POST /orchestrator/cancel/{campaign_id}
POST /inference/load
POST /inference/start
POST /inference/stop
GET /inference/status
GET /health
"""
try:
import uvicorn
from ria_toolkit_oss.server.app import create_app
except ImportError as e:
raise click.ClickException(
f"Server dependencies missing: {e}\nInstall with: pip install ria-toolkit-oss[server]"
)
if not api_key:
click.echo(
click.style("Warning: ", fg="yellow", bold=True) + "no API key set — all requests unauthenticated.",
err=True,
)
click.echo(f"Starting RT-OSS server on http://{host}:{port}")
uvicorn.run(create_app(api_key=api_key), host=host, port=port, log_level=log_level.lower())

View File

View File

@ -0,0 +1,489 @@
"""Tests for orchestration campaign schema and YAML parsing."""
import os
import tempfile
import pytest
import yaml
from ria_toolkit_oss.orchestration.campaign import (
CampaignConfig,
CaptureStep,
QAConfig,
RecorderConfig,
parse_bandwidth_mhz,
parse_duration,
parse_frequency,
parse_gain,
)
# ---------------------------------------------------------------------------
# parse_duration
# ---------------------------------------------------------------------------
class TestParseDuration:
def test_seconds_suffix(self):
assert parse_duration("30s") == 30.0
def test_seconds_suffix_long(self):
assert parse_duration("30sec") == 30.0
def test_minutes_suffix(self):
assert parse_duration("1.5m") == 90.0
def test_minutes_suffix_long(self):
assert parse_duration("2min") == 120.0
def test_hours_suffix(self):
assert parse_duration("2h") == 7200.0
def test_hours_suffix_long(self):
assert parse_duration("1hr") == 3600.0
def test_numeric_int(self):
assert parse_duration(45) == 45.0
def test_numeric_float(self):
assert parse_duration(1.5) == 1.5
def test_bare_number_string(self):
# No unit → treated as seconds
assert parse_duration("60") == 60.0
def test_invalid_raises(self):
with pytest.raises(ValueError):
parse_duration("two minutes")
# ---------------------------------------------------------------------------
# parse_frequency
# ---------------------------------------------------------------------------
class TestParseFrequency:
def test_ghz(self):
assert parse_frequency("2.45GHz") == pytest.approx(2.45e9)
def test_mhz(self):
assert parse_frequency("40MHz") == pytest.approx(40e6)
def test_khz(self):
assert parse_frequency("433k") == pytest.approx(433e3)
def test_scientific_notation_string(self):
assert parse_frequency("915e6") == pytest.approx(915e6)
def test_numeric_float(self):
assert parse_frequency(2.45e9) == pytest.approx(2.45e9)
def test_numeric_int(self):
assert parse_frequency(1000000) == pytest.approx(1e6)
def test_hz_suffix_optional(self):
# "40M" and "40MHz" should both work
assert parse_frequency("40M") == pytest.approx(40e6)
assert parse_frequency("40MHz") == pytest.approx(40e6)
def test_invalid_raises(self):
with pytest.raises(ValueError):
parse_frequency("two point four gigs")
# ---------------------------------------------------------------------------
# parse_gain
# ---------------------------------------------------------------------------
class TestParseGain:
def test_db_suffix(self):
assert parse_gain("40dB") == pytest.approx(40.0)
def test_db_suffix_lowercase(self):
assert parse_gain("32db") == pytest.approx(32.0)
def test_auto(self):
assert parse_gain("auto") == "auto"
def test_auto_case_insensitive(self):
assert parse_gain("AUTO") == "auto"
def test_numeric_int(self):
assert parse_gain(32) == pytest.approx(32.0)
def test_numeric_float(self):
assert parse_gain(32.5) == pytest.approx(32.5)
def test_invalid_raises(self):
with pytest.raises(ValueError):
parse_gain("high")
# ---------------------------------------------------------------------------
# parse_bandwidth_mhz
# ---------------------------------------------------------------------------
class TestParseBandwidthMhz:
def test_mhz_suffix(self):
assert parse_bandwidth_mhz("20MHz") == pytest.approx(20.0)
def test_numeric(self):
assert parse_bandwidth_mhz(40) == pytest.approx(40.0)
def test_none(self):
assert parse_bandwidth_mhz(None) is None
def test_invalid_raises(self):
with pytest.raises(ValueError):
parse_bandwidth_mhz("wide")
# ---------------------------------------------------------------------------
# CaptureStep.from_dict
# ---------------------------------------------------------------------------
class TestCaptureStep:
def test_wifi_step_auto_label(self):
d = {"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_udp", "duration": "30s"}
step = CaptureStep.from_dict(d)
assert step.duration == 30.0
assert step.channel == 6
assert step.bandwidth_mhz == 20.0
assert step.traffic == "iperf_udp"
assert step.label == "ch06_20mhz_iperf_udp"
def test_explicit_label(self):
d = {"channel": 1, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s", "label": "my_label"}
step = CaptureStep.from_dict(d)
assert step.label == "my_label"
def test_fallback_label(self):
# No channel/bandwidth/traffic → label falls back to "capture"
d = {"duration": "10s"}
step = CaptureStep.from_dict(d)
assert step.label == "capture"
def test_power_parsed(self):
d = {"channel": 6, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s", "power": "15dBm"}
step = CaptureStep.from_dict(d)
assert step.power_dbm == pytest.approx(15.0)
# ---------------------------------------------------------------------------
# RecorderConfig.from_dict
# ---------------------------------------------------------------------------
class TestRecorderConfig:
def test_basic(self):
d = {"device": "usrp_b210", "center_freq": "2.45GHz", "sample_rate": "40MHz", "gain": "40dB"}
rec = RecorderConfig.from_dict(d)
assert rec.device == "usrp_b210"
assert rec.center_freq == pytest.approx(2.45e9)
assert rec.sample_rate == pytest.approx(40e6)
assert rec.gain == pytest.approx(40.0)
assert rec.bandwidth is None
def test_auto_gain(self):
d = {"device": "pluto", "center_freq": "2.45GHz", "sample_rate": "20MHz", "gain": "auto"}
rec = RecorderConfig.from_dict(d)
assert rec.gain == "auto"
def test_bandwidth_set(self):
d = {"device": "pluto", "center_freq": "2.45GHz", "sample_rate": "20MHz", "gain": 32, "bandwidth": "20MHz"}
rec = RecorderConfig.from_dict(d)
assert rec.bandwidth == pytest.approx(20e6)
# ---------------------------------------------------------------------------
# QAConfig.from_dict
# ---------------------------------------------------------------------------
class TestQAConfig:
def test_defaults(self):
qa = QAConfig.from_dict({})
assert qa.snr_threshold_db == pytest.approx(10.0)
assert qa.min_duration_s == pytest.approx(25.0)
assert qa.flag_for_review is True
def test_custom_values(self):
d = {"snr_threshold": "15dB", "min_duration": "28s", "flag_for_review": False}
qa = QAConfig.from_dict(d)
assert qa.snr_threshold_db == pytest.approx(15.0)
assert qa.min_duration_s == pytest.approx(28.0)
assert qa.flag_for_review is False
# ---------------------------------------------------------------------------
# CampaignConfig.from_device_profile
# ---------------------------------------------------------------------------
def _write_device_profile(d: dict) -> str:
"""Write a dict as YAML to a temp file and return the path."""
f = tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False)
yaml.dump(d, f)
f.close()
return f.name
WIFI_PROFILE = {
"device": {"name": "iPhone_13_WiFi", "type": "wifi"},
"capture": {
"channels": [1, 6, 11],
"bandwidth": "20MHz",
"traffic_patterns": ["idle", "ping", "iperf_udp"],
"duration_per_config": "30s",
"script": "./scripts/wifi_control.sh",
},
"recorder": {
"device": "usrp_b210",
"center_freq": "2.45GHz",
"sample_rate": "40MHz",
"gain": "auto",
},
"output": {"path": "/tmp/test_recordings", "device_id": "iphone13_wifi_001"},
}
BT_PROFILE = {
"device": {"name": "AirPods_Pro", "type": "bluetooth"},
"capture": {
"traffic_patterns": ["idle", "audio_stream", "data_transfer"],
"duration_per_config": "30s",
},
"recorder": {
"device": "usrp_b210",
"center_freq": "2.45GHz",
"sample_rate": "40MHz",
"gain": "auto",
},
"output": {"path": "/tmp/test_recordings", "device_id": "airpods_pro_bt_001"},
}
class TestDeviceProfileParsing:
def test_wifi_schedule_count(self):
"""WiFi: 3 channels × 3 traffic = 9 steps."""
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
assert len(cfg.transmitters) == 1
assert len(cfg.transmitters[0].schedule) == 9
def test_wifi_campaign_name(self):
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
assert cfg.name == "enroll_iphone13_wifi_001"
def test_wifi_step_labels(self):
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
labels = [s.label for s in cfg.transmitters[0].schedule]
assert "ch01_20mhz_idle" in labels
assert "ch06_20mhz_ping" in labels
assert "ch11_20mhz_iperf_udp" in labels
def test_wifi_step_ordering(self):
"""Steps iterate channels first, then traffic."""
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
steps = cfg.transmitters[0].schedule
assert steps[0].channel == 1 and steps[0].traffic == "idle"
assert steps[1].channel == 1 and steps[1].traffic == "ping"
assert steps[3].channel == 6 and steps[3].traffic == "idle"
assert steps[8].channel == 11 and steps[8].traffic == "iperf_udp"
def test_wifi_step_duration(self):
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
for step in cfg.transmitters[0].schedule:
assert step.duration == pytest.approx(30.0)
def test_wifi_bandwidth(self):
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
for step in cfg.transmitters[0].schedule:
assert step.bandwidth_mhz == pytest.approx(20.0)
def test_wifi_recorder(self):
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
assert cfg.recorder.device == "usrp_b210"
assert cfg.recorder.center_freq == pytest.approx(2.45e9)
assert cfg.recorder.sample_rate == pytest.approx(40e6)
assert cfg.recorder.gain == "auto"
def test_wifi_total_capture_time(self):
path = _write_device_profile(WIFI_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
assert cfg.total_capture_time_s() == pytest.approx(270.0) # 9 × 30s
def test_bt_schedule_count(self):
"""BT: no channels, 3 traffic patterns = 3 steps."""
path = _write_device_profile(BT_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
assert len(cfg.transmitters[0].schedule) == 3
def test_bt_no_channel(self):
path = _write_device_profile(BT_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
for step in cfg.transmitters[0].schedule:
assert step.channel is None
def test_bt_step_labels(self):
path = _write_device_profile(BT_PROFILE)
try:
cfg = CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
labels = [s.label for s in cfg.transmitters[0].schedule]
assert labels == ["idle", "audio_stream", "data_transfer"]
def test_missing_file_raises(self):
with pytest.raises(FileNotFoundError):
CampaignConfig.from_device_profile("/nonexistent/path/profile.yml")
def test_invalid_yaml_raises(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
f.write(": bad: yaml: [\n")
path = f.name
try:
with pytest.raises(ValueError, match="Invalid YAML"):
CampaignConfig.from_device_profile(path)
finally:
os.unlink(path)
# ---------------------------------------------------------------------------
# CampaignConfig.from_yaml (full campaign format)
# ---------------------------------------------------------------------------
FULL_CAMPAIGN = {
"campaign": {"name": "wifi_capture_001", "mode": "controlled_testbed"},
"transmitters": [
{
"id": "laptop_wifi",
"type": "wifi",
"control_method": "external_script",
"script": "./scripts/wifi_control.sh",
"device": "/dev/wlan0",
"schedule": [
{"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_tcp", "duration": "30s"},
{"channel": 36, "bandwidth": "40MHz", "traffic": "ping_flood", "duration": "30s"},
],
}
],
"recorder": {
"device": "usrp_b210",
"center_freq": "2.45GHz",
"sample_rate": "20MHz",
"gain": "40dB",
},
"qa": {"snr_threshold": "10dB", "min_duration": "25s", "flag_for_review": True},
"output": {"format": "sigmf", "path": "./recordings"},
}
class TestFullCampaignParsing:
def test_name(self):
path = _write_device_profile(FULL_CAMPAIGN)
try:
cfg = CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
assert cfg.name == "wifi_capture_001"
def test_mode(self):
path = _write_device_profile(FULL_CAMPAIGN)
try:
cfg = CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
assert cfg.mode == "controlled_testbed"
def test_transmitter_id(self):
path = _write_device_profile(FULL_CAMPAIGN)
try:
cfg = CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
assert cfg.transmitters[0].id == "laptop_wifi"
assert cfg.transmitters[0].control_method == "external_script"
assert cfg.transmitters[0].script == "./scripts/wifi_control.sh"
def test_schedule_count(self):
path = _write_device_profile(FULL_CAMPAIGN)
try:
cfg = CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
assert len(cfg.transmitters[0].schedule) == 2
def test_qa_config(self):
path = _write_device_profile(FULL_CAMPAIGN)
try:
cfg = CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
assert cfg.qa.snr_threshold_db == pytest.approx(10.0)
assert cfg.qa.min_duration_s == pytest.approx(25.0)
assert cfg.qa.flag_for_review is True
def test_total_steps(self):
path = _write_device_profile(FULL_CAMPAIGN)
try:
cfg = CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
assert cfg.total_steps() == 2
def test_no_transmitters_raises(self):
bad = dict(FULL_CAMPAIGN)
bad["transmitters"] = []
path = _write_device_profile(bad)
try:
with pytest.raises(ValueError, match="at least one transmitter"):
CampaignConfig.from_yaml(path)
finally:
os.unlink(path)
def test_missing_recorder_raises(self):
bad = {k: v for k, v in FULL_CAMPAIGN.items() if k != "recorder"}
path = _write_device_profile(bad)
try:
with pytest.raises((KeyError, ValueError)):
CampaignConfig.from_yaml(path)
finally:
os.unlink(path)

View File

@ -0,0 +1,274 @@
"""Tests for the `ria campaign` CLI commands."""
import os
import tempfile
import yaml
from click.testing import CliRunner
from ria_toolkit_oss_cli.cli import cli
def _write_yaml(d: dict, suffix=".yml") -> str:
f = tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False)
yaml.dump(d, f)
f.close()
return f.name
WIFI_PROFILE = {
"device": {"name": "iPhone_13_WiFi", "type": "wifi"},
"capture": {
"channels": [1, 6, 11],
"bandwidth": "20MHz",
"traffic_patterns": ["idle", "ping", "iperf_udp"],
"duration_per_config": "30s",
},
"recorder": {
"device": "usrp_b210",
"center_freq": "2.45GHz",
"sample_rate": "40MHz",
"gain": "auto",
},
"output": {"path": "/tmp/test_enroll", "device_id": "iphone13_wifi_001"},
}
FULL_CAMPAIGN = {
"campaign": {"name": "wifi_capture_001", "mode": "controlled_testbed"},
"transmitters": [
{
"id": "laptop_wifi",
"type": "wifi",
"control_method": "external_script",
"schedule": [
{"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_udp", "duration": "30s"},
{"channel": 11, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s"},
],
}
],
"recorder": {
"device": "usrp_b210",
"center_freq": "2.45GHz",
"sample_rate": "20MHz",
"gain": "40dB",
},
"qa": {"snr_threshold": "10dB", "min_duration": "25s", "flag_for_review": True},
"output": {"format": "sigmf", "path": "/tmp/test_campaign"},
}
# ---------------------------------------------------------------------------
# ria campaign --help
# ---------------------------------------------------------------------------
class TestCampaignHelp:
def test_campaign_help(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "--help"])
assert result.exit_code == 0
assert "campaign" in result.output.lower()
def test_subcommands_listed(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "--help"])
assert result.exit_code == 0
for sub in ("validate", "enroll", "run"):
assert sub in result.output
def test_validate_help(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--help"])
assert result.exit_code == 0
def test_enroll_help(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "enroll", "--help"])
assert result.exit_code == 0
def test_run_help(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "run", "--help"])
assert result.exit_code == 0
# ---------------------------------------------------------------------------
# ria campaign validate
# ---------------------------------------------------------------------------
class TestCampaignValidate:
def test_validate_device_profile(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
assert result.exit_code == 0
assert "" in result.output or "valid" in result.output.lower()
finally:
os.unlink(path)
def test_validate_shows_campaign_name(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
assert "enroll_iphone13_wifi_001" in result.output
finally:
os.unlink(path)
def test_validate_shows_step_count(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
assert "9" in result.output # 9 total steps
finally:
os.unlink(path)
def test_validate_shows_capture_time(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
assert "270" in result.output # 270s total
finally:
os.unlink(path)
def test_validate_full_campaign(self):
path = _write_yaml(FULL_CAMPAIGN)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path])
assert result.exit_code == 0
assert "wifi_capture_001" in result.output
finally:
os.unlink(path)
def test_validate_shows_steps(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
assert "ch01_20mhz_idle" in result.output
assert "ch06_20mhz_ping" in result.output
assert "ch11_20mhz_iperf_udp" in result.output
finally:
os.unlink(path)
def test_validate_missing_file(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", "/nonexistent/file.yml"])
assert result.exit_code != 0
def test_validate_bad_yaml(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
f.write(": broken yaml [\n")
path = f.name
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "validate", "--config", path])
assert result.exit_code != 0
finally:
os.unlink(path)
# ---------------------------------------------------------------------------
# ria campaign enroll --dry-run
# ---------------------------------------------------------------------------
class TestCampaignEnrollDryRun:
def test_dry_run_exits_cleanly(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "enroll", "--config", path, "--dry-run"])
assert result.exit_code == 0
finally:
os.unlink(path)
def test_dry_run_shows_campaign_info(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "enroll", "--config", path, "--dry-run"])
assert "enroll_iphone13_wifi_001" in result.output
assert "9" in result.output
finally:
os.unlink(path)
def test_dry_run_does_not_capture(self):
"""Dry run should not create any output files."""
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
with tempfile.TemporaryDirectory() as tmpdir:
runner.invoke(
cli,
["campaign", "enroll", "--config", path, "--output", tmpdir, "--dry-run"],
)
# No .sigmf-data files should have been created
sigmf_files = list(os.walk(tmpdir))
all_files = [f for _, _, files in sigmf_files for f in files]
assert not any(f.endswith(".sigmf-data") for f in all_files)
finally:
os.unlink(path)
def test_dry_run_output_override(self):
path = _write_yaml(WIFI_PROFILE)
try:
runner = CliRunner()
result = runner.invoke(
cli,
["campaign", "enroll", "--config", path, "--output", "/tmp/custom_out", "--dry-run"],
)
assert result.exit_code == 0
assert "Dry run" in result.output
finally:
os.unlink(path)
# ---------------------------------------------------------------------------
# ria campaign run --dry-run
# ---------------------------------------------------------------------------
class TestCampaignRunDryRun:
def test_dry_run_exits_cleanly(self):
path = _write_yaml(FULL_CAMPAIGN)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "run", "--config", path, "--dry-run"])
assert result.exit_code == 0
finally:
os.unlink(path)
def test_dry_run_shows_campaign_name(self):
path = _write_yaml(FULL_CAMPAIGN)
try:
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "run", "--config", path, "--dry-run"])
assert "wifi_capture_001" in result.output
finally:
os.unlink(path)
def test_dry_run_does_not_create_report(self):
path = _write_yaml(FULL_CAMPAIGN)
try:
runner = CliRunner()
with tempfile.TemporaryDirectory() as tmpdir:
report_path = os.path.join(tmpdir, "qa_report.json")
result = runner.invoke(
cli,
["campaign", "run", "--config", path, "--dry-run", "--report", report_path],
)
assert result.exit_code == 0
assert not os.path.exists(report_path)
finally:
os.unlink(path)
def test_missing_config_fails(self):
runner = CliRunner()
result = runner.invoke(cli, ["campaign", "run", "--config", "/nonexistent.yml"])
assert result.exit_code != 0

View File

@ -0,0 +1,145 @@
"""Tests for orchestration labeler."""
import time
import numpy as np
import pytest
from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.orchestration.campaign import CaptureStep
from ria_toolkit_oss.orchestration.labeler import build_output_filename, label_recording
def _simple_recording() -> Recording:
sr = 1e6
n = 1000
data = np.ones(n, dtype=np.complex64)
return Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9})
def _wifi_step() -> CaptureStep:
return CaptureStep(
duration=30.0,
label="ch06_20mhz_idle",
channel=6,
bandwidth_mhz=20.0,
traffic="idle",
)
def _bt_step() -> CaptureStep:
return CaptureStep(
duration=30.0,
label="audio_stream",
traffic="audio_stream",
connection_interval_ms=7.5,
)
# ---------------------------------------------------------------------------
# label_recording
# ---------------------------------------------------------------------------
class TestLabelRecording:
def test_device_id_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["device_id"] == "iphone13_001"
def test_capture_timestamp_set(self):
ts = time.time()
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), ts)
assert rec.metadata["capture_timestamp"] == pytest.approx(ts, abs=1.0)
def test_step_label_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["step_label"] == "ch06_20mhz_idle"
def test_step_duration_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["step_duration_s"] == pytest.approx(30.0)
def test_campaign_name_optional(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert "campaign" not in rec.metadata
def test_campaign_name_when_provided(self):
rec = label_recording(
_simple_recording(), "iphone13_001", _wifi_step(), time.time(), campaign_name="test_campaign"
)
assert rec.metadata["campaign"] == "test_campaign"
def test_wifi_channel_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["wifi_channel"] == 6
def test_wifi_bandwidth_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["wifi_bandwidth_mhz"] == pytest.approx(20.0)
def test_traffic_pattern_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["traffic_pattern"] == "idle"
def test_bt_connection_interval_set(self):
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
assert rec.metadata["bt_connection_interval_ms"] == pytest.approx(7.5)
def test_no_channel_key_for_bt(self):
"""BT steps with no channel should not add wifi_channel to metadata."""
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
assert "wifi_channel" not in rec.metadata
def test_no_bandwidth_key_for_bt(self):
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
assert "wifi_bandwidth_mhz" not in rec.metadata
def test_power_dbm_set(self):
step = CaptureStep(duration=30.0, label="test", traffic="idle", power_dbm=15.0)
rec = label_recording(_simple_recording(), "dev_001", step, time.time())
assert rec.metadata["tx_power_dbm"] == pytest.approx(15.0)
def test_no_power_key_when_unset(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert "tx_power_dbm" not in rec.metadata
def test_returns_same_recording(self):
"""label_recording should mutate and return the same Recording object."""
rec = _simple_recording()
result = label_recording(rec, "iphone13_001", _wifi_step(), time.time())
assert result is rec
# ---------------------------------------------------------------------------
# build_output_filename
# ---------------------------------------------------------------------------
class TestBuildOutputFilename:
def test_basic_wifi(self):
step = CaptureStep(duration=30.0, label="ch06_20mhz_idle")
fn = build_output_filename("iphone13_wifi_001", step)
assert fn == "iphone13_wifi_001/ch06_20mhz_idle"
def test_bt_step(self):
step = CaptureStep(duration=30.0, label="audio_stream")
fn = build_output_filename("airpods_pro_bt_001", step)
assert fn == "airpods_pro_bt_001/audio_stream"
def test_spaces_in_device_id_replaced(self):
step = CaptureStep(duration=30.0, label="idle")
fn = build_output_filename("my device", step)
assert " " not in fn
assert fn == "my_device/idle"
def test_slashes_in_label_replaced(self):
step = CaptureStep(duration=30.0, label="ch/6/idle")
fn = build_output_filename("dev_001", step)
assert "/" not in fn.split("/", 1)[1] # only the separator slash should remain
def test_path_structure(self):
"""Filename should be exactly '<device_id>/<label>' (one level of nesting)."""
step = CaptureStep(duration=30.0, label="idle")
fn = build_output_filename("dev_001", step)
parts = fn.split("/")
assert len(parts) == 2

View File

@ -0,0 +1,192 @@
"""Tests for orchestration QA metrics."""
import numpy as np
import pytest
from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.orchestration.campaign import QAConfig
from ria_toolkit_oss.orchestration.qa import QAResult, check_recording, estimate_snr_db
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_recording(n_samples: int, sample_rate: float, signal: np.ndarray) -> Recording:
return Recording(
signal.astype(np.complex64),
metadata={"sample_rate": sample_rate, "center_frequency": 2.45e9},
)
def _tone(n: int, sr: float, freq_hz: float = 100e3, amplitude: float = 0.5) -> np.ndarray:
t = np.arange(n) / sr
return (np.exp(1j * 2 * np.pi * freq_hz * t) * amplitude).astype(np.complex64)
def _noise(n: int, amplitude: float = 0.001) -> np.ndarray:
rng = np.random.default_rng(42)
return ((rng.standard_normal(n) + 1j * rng.standard_normal(n)) * amplitude).astype(np.complex64)
DEFAULT_QA = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
# ---------------------------------------------------------------------------
# estimate_snr_db
# ---------------------------------------------------------------------------
class TestEstimateSnrDb:
def test_high_snr_tone(self):
sr = 1e6
samples = _tone(int(sr * 1), sr)
snr = estimate_snr_db(samples)
assert snr > 20.0, f"Expected high SNR for clean tone, got {snr:.1f} dB"
def test_pure_noise_low_snr(self):
sr = 1e6
rng = np.random.default_rng(0)
samples = (rng.standard_normal(int(sr)) + 1j * rng.standard_normal(int(sr))).astype(np.complex64)
snr = estimate_snr_db(samples)
# Pure noise should yield a low (possibly negative) SNR
assert snr < 15.0, f"Expected low SNR for noise, got {snr:.1f} dB"
def test_snr_increases_with_amplitude(self):
sr = 1e6
n = int(sr)
rng = np.random.default_rng(1)
noise = (rng.standard_normal(n) + 1j * rng.standard_normal(n)).astype(np.complex64) * 0.01
t = np.arange(n) / sr
tone = np.exp(1j * 2 * np.pi * 100e3 * t).astype(np.complex64)
low_snr = estimate_snr_db(noise + tone * 0.1)
high_snr = estimate_snr_db(noise + tone * 1.0)
assert high_snr > low_snr
def test_short_input_still_works(self):
# Input shorter than n_fft=4096 should not raise
samples = _tone(512, 1e6)
snr = estimate_snr_db(samples)
assert np.isfinite(snr)
# ---------------------------------------------------------------------------
# check_recording — pass cases
# ---------------------------------------------------------------------------
class TestCheckRecordingPass:
def test_clean_tone_passes(self):
sr = 1e6
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
result = check_recording(rec, DEFAULT_QA)
assert result.passed is True
assert result.flagged is False
assert result.snr_db > 10.0
assert abs(result.duration_s - 30.0) < 0.1
def test_duration_exactly_at_threshold(self):
sr = 1e6
n = int(sr * 25) # exactly at min_duration_s
rec = _make_recording(n, sr, _tone(n, sr))
result = check_recording(rec, DEFAULT_QA)
assert result.flagged is False
def test_issues_empty_when_passing(self):
sr = 1e6
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
result = check_recording(rec, DEFAULT_QA)
assert result.issues == []
# ---------------------------------------------------------------------------
# check_recording — flag cases
# ---------------------------------------------------------------------------
class TestCheckRecordingFlag:
def test_short_recording_flagged(self):
sr = 1e6
n = int(sr * 10) # shorter than 25s min
rec = _make_recording(n, sr, _tone(n, sr))
result = check_recording(rec, DEFAULT_QA)
assert result.flagged is True
assert any("Duration" in issue for issue in result.issues)
def test_low_snr_flagged(self):
sr = 1e6
n = int(sr * 30)
rec = _make_recording(n, sr, _noise(n, amplitude=0.001))
result = check_recording(rec, DEFAULT_QA)
assert result.flagged is True
assert any("SNR" in issue for issue in result.issues)
def test_flag_for_review_still_passes(self):
"""With flag_for_review=True, flagged recordings are still marked passed."""
sr = 1e6
n = int(sr * 10) # short → will be flagged
rec = _make_recording(n, sr, _tone(n, sr))
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
result = check_recording(rec, qa)
assert result.flagged is True
assert result.passed is True # human review, not auto-reject
def test_flag_for_review_false_fails(self):
"""With flag_for_review=False, a flagged recording is also marked failed."""
sr = 1e6
n = int(sr * 10)
rec = _make_recording(n, sr, _tone(n, sr))
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=False)
result = check_recording(rec, qa)
assert result.flagged is True
assert result.passed is False
def test_multiple_issues_reported(self):
"""Both short duration AND low SNR should both appear in issues list."""
sr = 1e6
n = int(sr * 5) # very short
rec = _make_recording(n, sr, _noise(n, amplitude=0.0001))
result = check_recording(rec, DEFAULT_QA)
assert result.flagged is True
assert len(result.issues) >= 2
# ---------------------------------------------------------------------------
# check_recording — multichannel input
# ---------------------------------------------------------------------------
class TestCheckRecordingMultichannel:
def test_multichannel_recording(self):
"""2-channel recording should evaluate channel 0 without error."""
sr = 1e6
n = int(sr * 30)
ch0 = _tone(n, sr)
ch1 = _tone(n, sr, freq_hz=200e3)
data = np.stack([ch0, ch1]) # shape (2, N)
rec = Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9})
result = check_recording(rec, DEFAULT_QA)
assert result.passed is True
assert result.flagged is False
# ---------------------------------------------------------------------------
# QAResult.to_dict
# ---------------------------------------------------------------------------
class TestQAResultToDict:
def test_to_dict_keys(self):
r = QAResult(passed=True, flagged=False, snr_db=18.3, duration_s=30.0)
d = r.to_dict()
assert set(d.keys()) == {"passed", "flagged", "snr_db", "duration_s", "issues"}
def test_to_dict_values(self):
r = QAResult(passed=False, flagged=True, snr_db=7.5, duration_s=10.2, issues=["SNR below threshold"])
d = r.to_dict()
assert d["passed"] is False
assert d["flagged"] is True
assert d["snr_db"] == pytest.approx(7.5, abs=0.01)
assert d["duration_s"] == pytest.approx(10.2, abs=0.01)
assert d["issues"] == ["SNR below threshold"]