reformats and campaign additions
This commit is contained in:
parent
b1e3ebf74f
commit
019b0c6f4b
2143
poetry.lock
generated
2143
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
|
|
@ -127,5 +127,8 @@ exclude = '''
|
||||||
)/
|
)/
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
pythonpath = ["src"]
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
profile = "black"
|
profile = "black"
|
||||||
|
|
|
||||||
|
|
@ -956,8 +956,10 @@ def get_result_sizes( # noqa: C901 # TODO: Simplify function
|
||||||
# Check that each class that will be augmented does not already suffice target_size
|
# Check that each class that will be augmented does not already suffice target_size
|
||||||
for cls_name, target_size_value in zip(classes_to_augment, target_size):
|
for cls_name, target_size_value in zip(classes_to_augment, target_size):
|
||||||
if class_sizes[cls_name] >= target_size_value:
|
if class_sizes[cls_name] >= target_size_value:
|
||||||
raise ValueError(f"""target_size of {target_size_value} is already sufficed for current size of
|
raise ValueError(
|
||||||
{class_sizes[cls_name]} for class: {cls_name}""")
|
f"""target_size of {target_size_value} is already sufficed for current size of
|
||||||
|
{class_sizes[cls_name]} for class: {cls_name}"""
|
||||||
|
)
|
||||||
|
|
||||||
for index, class_name in enumerate(classes_to_augment):
|
for index, class_name in enumerate(classes_to_augment):
|
||||||
result_sizes[class_name] = target_size[index]
|
result_sizes[class_name] = target_size[index]
|
||||||
|
|
|
||||||
|
|
@ -316,6 +316,8 @@ def to_sigmf(
|
||||||
meta_dict = sigMF_metafile.ordered_metadata()
|
meta_dict = sigMF_metafile.ordered_metadata()
|
||||||
meta_dict["ria"] = metadata
|
meta_dict["ria"] = metadata
|
||||||
|
|
||||||
|
if overwrite and os.path.isfile(meta_file_path):
|
||||||
|
os.remove(meta_file_path)
|
||||||
sigMF_metafile.tofile(meta_file_path)
|
sigMF_metafile.tofile(meta_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
26
src/ria_toolkit_oss/orchestration/__init__.py
Normal file
26
src/ria_toolkit_oss/orchestration/__init__.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""Orchestration layer for automated RF capture campaigns."""
|
||||||
|
|
||||||
|
from .campaign import (
|
||||||
|
CampaignConfig,
|
||||||
|
CaptureStep,
|
||||||
|
QAConfig,
|
||||||
|
RecorderConfig,
|
||||||
|
TransmitterConfig,
|
||||||
|
)
|
||||||
|
from .executor import CampaignExecutor, CampaignResult, StepResult
|
||||||
|
from .labeler import label_recording
|
||||||
|
from .qa import QAResult, check_recording
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CampaignConfig",
|
||||||
|
"CaptureStep",
|
||||||
|
"QAConfig",
|
||||||
|
"RecorderConfig",
|
||||||
|
"TransmitterConfig",
|
||||||
|
"CampaignExecutor",
|
||||||
|
"CampaignResult",
|
||||||
|
"StepResult",
|
||||||
|
"label_recording",
|
||||||
|
"QAResult",
|
||||||
|
"check_recording",
|
||||||
|
]
|
||||||
446
src/ria_toolkit_oss/orchestration/campaign.py
Normal file
446
src/ria_toolkit_oss/orchestration/campaign.py
Normal file
|
|
@ -0,0 +1,446 @@
|
||||||
|
"""Campaign configuration schema and YAML parser for orchestrated RF captures."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Parsing helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def parse_duration(value: str | float | int) -> float:
|
||||||
|
"""Parse a duration string to seconds.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
"30s" → 30.0
|
||||||
|
"1.5m" or "1.5min" → 90.0
|
||||||
|
"2h" → 7200.0
|
||||||
|
30 (numeric) → 30.0
|
||||||
|
"""
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return float(value)
|
||||||
|
value = str(value).strip()
|
||||||
|
match = re.fullmatch(r"([\d.]+)\s*(s|sec|m|min|h|hr)?", value, re.IGNORECASE)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"Cannot parse duration: '{value}'")
|
||||||
|
amount = float(match.group(1))
|
||||||
|
unit = (match.group(2) or "s").lower()
|
||||||
|
if unit in ("h", "hr"):
|
||||||
|
return amount * 3600
|
||||||
|
if unit in ("m", "min"):
|
||||||
|
return amount * 60
|
||||||
|
return amount
|
||||||
|
|
||||||
|
|
||||||
|
def parse_frequency(value: str | float | int) -> float:
|
||||||
|
"""Parse a frequency string to Hz.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
"2.45GHz" → 2_450_000_000.0
|
||||||
|
"40MHz" → 40_000_000.0
|
||||||
|
"915e6" → 915_000_000.0
|
||||||
|
2.45e9 (numeric) → 2_450_000_000.0
|
||||||
|
"""
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return float(value)
|
||||||
|
value = str(value).strip()
|
||||||
|
|
||||||
|
# Try bare numeric first (handles scientific notation like "915e6")
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Handle suffix notation: "2.45GHz", "40MHz", "40M", "433k"
|
||||||
|
match = re.fullmatch(r"([\d.]+)\s*(k|M|G)(?:\s*Hz?)?", value, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
amount = float(match.group(1))
|
||||||
|
suffix = match.group(2).upper()
|
||||||
|
return amount * {"K": 1e3, "M": 1e6, "G": 1e9}[suffix]
|
||||||
|
|
||||||
|
raise ValueError(f"Cannot parse frequency: '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_gain(value: str | float | int) -> float | str:
|
||||||
|
"""Parse a gain string.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
"40dB" or "40 dB" → 40.0
|
||||||
|
"auto" → "auto"
|
||||||
|
40 (numeric) → 40.0
|
||||||
|
"""
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return float(value)
|
||||||
|
value = str(value).strip()
|
||||||
|
if value.lower() == "auto":
|
||||||
|
return "auto"
|
||||||
|
match = re.fullmatch(r"([\d.+\-]+)\s*dB?", value, re.IGNORECASE)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"Cannot parse gain: '{value}'")
|
||||||
|
return float(match.group(1))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_bandwidth_mhz(value: str | float | int | None) -> Optional[float]:
|
||||||
|
"""Parse a bandwidth string to MHz.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
"20MHz" → 20.0
|
||||||
|
"40MHz" → 40.0
|
||||||
|
20 (numeric, assumed MHz) → 20.0
|
||||||
|
None → None
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return float(value)
|
||||||
|
value = str(value).strip()
|
||||||
|
match = re.fullmatch(r"([\d.]+)\s*MHz?", value, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
return float(match.group(1))
|
||||||
|
match = re.fullmatch(r"([\d.]+)", value)
|
||||||
|
if match:
|
||||||
|
return float(match.group(1))
|
||||||
|
raise ValueError(f"Cannot parse bandwidth: '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config dataclasses
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RecorderConfig:
|
||||||
|
"""SDR recorder configuration."""
|
||||||
|
|
||||||
|
device: str
|
||||||
|
center_freq: float # Hz
|
||||||
|
sample_rate: float # Hz
|
||||||
|
gain: float | str # dB float, or "auto"
|
||||||
|
bandwidth: Optional[float] = None # Hz, None = match sample_rate
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict) -> "RecorderConfig":
|
||||||
|
gain = parse_gain(d.get("gain", "auto"))
|
||||||
|
bandwidth_raw = d.get("bandwidth") or d.get("bandwidth_hz")
|
||||||
|
bandwidth = parse_frequency(bandwidth_raw) if bandwidth_raw else None
|
||||||
|
return cls(
|
||||||
|
device=str(d["device"]),
|
||||||
|
center_freq=parse_frequency(d["center_freq"]),
|
||||||
|
sample_rate=parse_frequency(d["sample_rate"]),
|
||||||
|
gain=gain,
|
||||||
|
bandwidth=bandwidth,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CaptureStep:
|
||||||
|
"""A single timed capture within a transmitter schedule."""
|
||||||
|
|
||||||
|
duration: float # seconds
|
||||||
|
label: str # used as filename component
|
||||||
|
|
||||||
|
# WiFi-specific
|
||||||
|
channel: Optional[int] = None
|
||||||
|
bandwidth_mhz: Optional[float] = None # MHz
|
||||||
|
traffic: Optional[str] = None
|
||||||
|
|
||||||
|
# Bluetooth-specific
|
||||||
|
connection_interval_ms: Optional[float] = None
|
||||||
|
|
||||||
|
# Power (dBm), optional
|
||||||
|
power_dbm: Optional[float] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict, auto_label: bool = True) -> "CaptureStep":
|
||||||
|
duration = parse_duration(d["duration"])
|
||||||
|
label = d.get("label", "")
|
||||||
|
if not label and auto_label:
|
||||||
|
parts = []
|
||||||
|
if d.get("channel"):
|
||||||
|
parts.append(f"ch{d['channel']:02d}")
|
||||||
|
if d.get("bandwidth"):
|
||||||
|
bw = parse_bandwidth_mhz(d["bandwidth"])
|
||||||
|
parts.append(f"{int(bw)}mhz")
|
||||||
|
if d.get("traffic"):
|
||||||
|
parts.append(str(d["traffic"]).replace(" ", "_"))
|
||||||
|
label = "_".join(parts) if parts else "capture"
|
||||||
|
return cls(
|
||||||
|
duration=duration,
|
||||||
|
label=label,
|
||||||
|
channel=d.get("channel"),
|
||||||
|
bandwidth_mhz=parse_bandwidth_mhz(d.get("bandwidth")),
|
||||||
|
traffic=d.get("traffic"),
|
||||||
|
connection_interval_ms=d.get("connection_interval_ms"),
|
||||||
|
power_dbm=float(d["power"].rstrip("dBm").strip()) if d.get("power") else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TransmitterConfig:
|
||||||
|
"""Configuration for a single transmitter device in the campaign."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
type: str # "wifi", "bluetooth", "sdr", "external"
|
||||||
|
control_method: str # "external_script" | "sdr"
|
||||||
|
schedule: list[CaptureStep]
|
||||||
|
|
||||||
|
# For external_script control
|
||||||
|
script: Optional[str] = None # path to control script
|
||||||
|
device: Optional[str] = None # e.g. "/dev/wlan0"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict) -> "TransmitterConfig":
|
||||||
|
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
|
||||||
|
return cls(
|
||||||
|
id=str(d["id"]),
|
||||||
|
type=str(d["type"]),
|
||||||
|
control_method=str(d.get("control_method", "external_script")),
|
||||||
|
schedule=schedule,
|
||||||
|
script=d.get("script"),
|
||||||
|
device=d.get("device"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QAConfig:
|
||||||
|
"""Quality assurance thresholds."""
|
||||||
|
|
||||||
|
snr_threshold_db: float = 10.0
|
||||||
|
min_duration_s: float = 25.0
|
||||||
|
flag_for_review: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict) -> "QAConfig":
|
||||||
|
return cls(
|
||||||
|
snr_threshold_db=float(str(d.get("snr_threshold", "10")).rstrip("dB").strip()),
|
||||||
|
min_duration_s=parse_duration(d.get("min_duration", "25s")),
|
||||||
|
flag_for_review=bool(d.get("flag_for_review", True)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OutputConfig:
|
||||||
|
"""Where to save captured recordings."""
|
||||||
|
|
||||||
|
format: str = "sigmf"
|
||||||
|
path: str = "recordings"
|
||||||
|
device_id: Optional[str] = None # for device-profile campaigns
|
||||||
|
repo: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict) -> "OutputConfig":
|
||||||
|
return cls(
|
||||||
|
format=str(d.get("format", "sigmf")),
|
||||||
|
path=str(d.get("path", "recordings")),
|
||||||
|
device_id=d.get("device_id"),
|
||||||
|
repo=d.get("repo"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CampaignConfig:
|
||||||
|
"""Full campaign configuration parsed from YAML."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
recorder: RecorderConfig
|
||||||
|
transmitters: list[TransmitterConfig]
|
||||||
|
qa: QAConfig = field(default_factory=QAConfig)
|
||||||
|
output: OutputConfig = field(default_factory=OutputConfig)
|
||||||
|
mode: str = "controlled_testbed"
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Loaders
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, raw: dict) -> "CampaignConfig":
|
||||||
|
"""Build a CampaignConfig from a parsed dictionary.
|
||||||
|
|
||||||
|
Accepts the same structure as the campaign YAML, already loaded into
|
||||||
|
a Python dict (e.g. from a JSON HTTP request body).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required fields are missing or malformed.
|
||||||
|
KeyError: If ``recorder`` key is absent.
|
||||||
|
"""
|
||||||
|
campaign_meta = raw.get("campaign", {})
|
||||||
|
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
|
||||||
|
if not transmitters:
|
||||||
|
raise ValueError("Campaign config must define at least one transmitter")
|
||||||
|
return cls(
|
||||||
|
name=str(campaign_meta.get("name", "unnamed")),
|
||||||
|
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||||
|
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||||
|
transmitters=transmitters,
|
||||||
|
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||||
|
output=OutputConfig.from_dict(raw.get("output", {})),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, path: str | Path) -> "CampaignConfig":
|
||||||
|
"""Load a full campaign config YAML.
|
||||||
|
|
||||||
|
Expected format::
|
||||||
|
|
||||||
|
campaign:
|
||||||
|
name: "wifi_capture_001"
|
||||||
|
mode: "controlled_testbed"
|
||||||
|
|
||||||
|
transmitters:
|
||||||
|
- id: "laptop_wifi"
|
||||||
|
type: "wifi"
|
||||||
|
control_method: "external_script"
|
||||||
|
script: "./scripts/wifi_control.sh"
|
||||||
|
device: "/dev/wlan0"
|
||||||
|
schedule:
|
||||||
|
- channel: 6
|
||||||
|
bandwidth: "20MHz"
|
||||||
|
traffic: "iperf_udp"
|
||||||
|
duration: "30s"
|
||||||
|
|
||||||
|
recorder:
|
||||||
|
device: "usrp_b210"
|
||||||
|
center_freq: "2.45GHz"
|
||||||
|
sample_rate: "40MHz"
|
||||||
|
gain: "40dB"
|
||||||
|
|
||||||
|
qa:
|
||||||
|
snr_threshold: "10dB"
|
||||||
|
min_duration: "25s"
|
||||||
|
flag_for_review: true
|
||||||
|
|
||||||
|
output:
|
||||||
|
format: "sigmf"
|
||||||
|
path: "./recordings"
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
try:
|
||||||
|
with open(path) as f:
|
||||||
|
raw = yaml.safe_load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise FileNotFoundError(f"Campaign config not found: {path}")
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
raise ValueError(f"Invalid YAML in {path}: {e}")
|
||||||
|
|
||||||
|
campaign_meta = raw.get("campaign", {})
|
||||||
|
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
|
||||||
|
if not transmitters:
|
||||||
|
raise ValueError("Campaign config must define at least one transmitter")
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=str(campaign_meta.get("name", path.stem)),
|
||||||
|
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||||
|
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||||
|
transmitters=transmitters,
|
||||||
|
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||||
|
output=OutputConfig.from_dict(raw.get("output", {})),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_device_profile(cls, path: str | Path) -> "CampaignConfig":
|
||||||
|
"""Build a campaign config from an App 1 device profile YAML.
|
||||||
|
|
||||||
|
Expected format::
|
||||||
|
|
||||||
|
device:
|
||||||
|
name: "iPhone_13_WiFi"
|
||||||
|
type: "wifi"
|
||||||
|
protocol: "wifi_24ghz"
|
||||||
|
|
||||||
|
capture:
|
||||||
|
channels: [1, 6, 11] # WiFi only
|
||||||
|
bandwidth: "20MHz" # WiFi only
|
||||||
|
traffic_patterns: ["idle", "ping", "iperf_udp"]
|
||||||
|
duration_per_config: "30s"
|
||||||
|
|
||||||
|
recorder:
|
||||||
|
device: "usrp_b210"
|
||||||
|
center_freq: "2.45GHz"
|
||||||
|
sample_rate: "40MHz"
|
||||||
|
gain: "auto"
|
||||||
|
|
||||||
|
output:
|
||||||
|
path: "./recordings"
|
||||||
|
device_id: "iphone13_wifi_001"
|
||||||
|
|
||||||
|
For WiFi devices, schedule is expanded as channels × traffic_patterns.
|
||||||
|
For Bluetooth devices (no channels), schedule is traffic_patterns only.
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
try:
|
||||||
|
with open(path) as f:
|
||||||
|
raw = yaml.safe_load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise FileNotFoundError(f"Device profile not found: {path}")
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
raise ValueError(f"Invalid YAML in {path}: {e}")
|
||||||
|
|
||||||
|
device = raw.get("device", {})
|
||||||
|
capture = raw.get("capture", {})
|
||||||
|
device_type = str(device.get("type", "wifi")).lower()
|
||||||
|
device_name = str(device.get("name", path.stem))
|
||||||
|
duration = parse_duration(capture.get("duration_per_config", "30s"))
|
||||||
|
traffic_patterns = capture.get("traffic_patterns", ["idle"])
|
||||||
|
|
||||||
|
# Build capture schedule
|
||||||
|
schedule: list[CaptureStep] = []
|
||||||
|
|
||||||
|
if device_type in ("wifi", "wifi_24ghz", "wifi_5ghz"):
|
||||||
|
channels = capture.get("channels", [6])
|
||||||
|
bw_str = capture.get("bandwidth", "20MHz")
|
||||||
|
bw_mhz = parse_bandwidth_mhz(bw_str)
|
||||||
|
for ch in channels:
|
||||||
|
for traffic in traffic_patterns:
|
||||||
|
label = f"ch{ch:02d}_{int(bw_mhz)}mhz_{traffic}"
|
||||||
|
schedule.append(
|
||||||
|
CaptureStep(
|
||||||
|
duration=duration,
|
||||||
|
label=label,
|
||||||
|
channel=ch,
|
||||||
|
bandwidth_mhz=bw_mhz,
|
||||||
|
traffic=traffic,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Bluetooth / generic — no channels
|
||||||
|
for traffic in traffic_patterns:
|
||||||
|
schedule.append(
|
||||||
|
CaptureStep(
|
||||||
|
duration=duration,
|
||||||
|
label=traffic,
|
||||||
|
traffic=traffic,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
device_id = raw.get("output", {}).get("device_id", device_name.lower().replace(" ", "_"))
|
||||||
|
transmitter = TransmitterConfig(
|
||||||
|
id=device_id,
|
||||||
|
type=device_type,
|
||||||
|
control_method=str(capture.get("control_method", "external_script")),
|
||||||
|
schedule=schedule,
|
||||||
|
script=capture.get("script"),
|
||||||
|
device=capture.get("device"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=f"enroll_{device_id}",
|
||||||
|
mode="controlled_testbed",
|
||||||
|
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||||
|
transmitters=[transmitter],
|
||||||
|
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||||
|
output=OutputConfig.from_dict(raw.get("output", {})),
|
||||||
|
)
|
||||||
|
|
||||||
|
def total_capture_time_s(self) -> float:
|
||||||
|
"""Sum of all step durations across all transmitters."""
|
||||||
|
return sum(step.duration for tx in self.transmitters for step in tx.schedule)
|
||||||
|
|
||||||
|
def total_steps(self) -> int:
|
||||||
|
"""Total number of capture steps across all transmitters."""
|
||||||
|
return sum(len(tx.schedule) for tx in self.transmitters)
|
||||||
423
src/ria_toolkit_oss/orchestration/executor.py
Normal file
423
src/ria_toolkit_oss/orchestration/executor.py
Normal file
|
|
@ -0,0 +1,423 @@
|
||||||
|
"""Campaign executor: runs a capture campaign end-to-end."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from ria_toolkit_oss.datatypes.recording import Recording
|
||||||
|
from ria_toolkit_oss.io.recording import to_sigmf
|
||||||
|
|
||||||
|
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
|
||||||
|
from .labeler import build_output_filename, label_recording
|
||||||
|
from .qa import QAResult, check_recording
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Device name aliases: campaign YAML names → get_sdr_device() names
|
||||||
|
_DEVICE_ALIASES = {
|
||||||
|
"usrp_b210": "usrp",
|
||||||
|
"usrp_b200": "usrp",
|
||||||
|
"usrp": "usrp",
|
||||||
|
"plutosdr": "pluto",
|
||||||
|
"pluto": "pluto",
|
||||||
|
"hackrf": "hackrf",
|
||||||
|
"hackrf_one": "hackrf",
|
||||||
|
"bladerf": "bladerf",
|
||||||
|
"rtlsdr": "rtlsdr",
|
||||||
|
"rtl_sdr": "rtlsdr",
|
||||||
|
"thinkrf": "thinkrf",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StepResult:
|
||||||
|
"""Outcome of a single capture step."""
|
||||||
|
|
||||||
|
transmitter_id: str
|
||||||
|
step_label: str
|
||||||
|
output_path: Optional[str]
|
||||||
|
qa: QAResult
|
||||||
|
capture_timestamp: float
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ok(self) -> bool:
|
||||||
|
return self.error is None and self.qa.passed
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"transmitter_id": self.transmitter_id,
|
||||||
|
"step_label": self.step_label,
|
||||||
|
"output_path": self.output_path,
|
||||||
|
"capture_timestamp": self.capture_timestamp,
|
||||||
|
"qa": self.qa.to_dict(),
|
||||||
|
"error": self.error,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CampaignResult:
|
||||||
|
"""Aggregate outcome of a full campaign."""
|
||||||
|
|
||||||
|
campaign_name: str
|
||||||
|
steps: list[StepResult] = field(default_factory=list)
|
||||||
|
start_time: float = field(default_factory=time.time)
|
||||||
|
end_time: Optional[float] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_steps(self) -> int:
|
||||||
|
return len(self.steps)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def passed(self) -> int:
|
||||||
|
return sum(1 for s in self.steps if s.ok)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def flagged(self) -> int:
|
||||||
|
return sum(1 for s in self.steps if not s.error and s.qa.flagged)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def failed(self) -> int:
|
||||||
|
return sum(1 for s in self.steps if s.error or not s.qa.passed)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration_s(self) -> float:
|
||||||
|
if self.end_time:
|
||||||
|
return self.end_time - self.start_time
|
||||||
|
return time.time() - self.start_time
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"campaign_name": self.campaign_name,
|
||||||
|
"total_steps": self.total_steps,
|
||||||
|
"passed": self.passed,
|
||||||
|
"flagged": self.flagged,
|
||||||
|
"failed": self.failed,
|
||||||
|
"duration_s": round(self.duration_s, 1),
|
||||||
|
"steps": [s.to_dict() for s in self.steps],
|
||||||
|
}
|
||||||
|
|
||||||
|
def write_report(self, path: str | Path) -> None:
|
||||||
|
"""Write a JSON QA report to disk."""
|
||||||
|
path = Path(path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(self.to_dict(), f, indent=2)
|
||||||
|
logger.info(f"QA report written to {path}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# External script interface
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
|
||||||
|
"""Run an external control script and return stdout.
|
||||||
|
|
||||||
|
The script is called as::
|
||||||
|
|
||||||
|
<script> <arg1> <arg2> ...
|
||||||
|
|
||||||
|
A non-zero return code raises RuntimeError.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
script: Path to executable script.
|
||||||
|
*args: Positional arguments forwarded to the script.
|
||||||
|
timeout: Maximum seconds to wait.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Script stdout as a string.
|
||||||
|
"""
|
||||||
|
cmd = [script, *args]
|
||||||
|
logger.debug(f"Running script: {' '.join(cmd)}")
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
raise RuntimeError(f"Script timed out after {timeout}s: {script}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise RuntimeError(f"Script not found: {script}")
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Script exited {result.returncode}: {result.stderr.strip() or result.stdout.strip()}")
|
||||||
|
return result.stdout.strip()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Campaign executor
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class CampaignExecutor:
|
||||||
|
"""Executes a :class:`CampaignConfig` end-to-end.
|
||||||
|
|
||||||
|
Initialises the SDR recorder once, then for each (transmitter, step):
|
||||||
|
1. Configures the transmitter (via external script or SDR TX)
|
||||||
|
2. Records IQ samples
|
||||||
|
3. Labels the recording with device/config metadata
|
||||||
|
4. Runs QA checks
|
||||||
|
5. Saves the recording to disk
|
||||||
|
6. Stops/resets the transmitter
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Parsed campaign configuration.
|
||||||
|
progress_cb: Optional callback ``(step_index, total_steps, step_result)``
|
||||||
|
called after each step completes. Useful for status reporting.
|
||||||
|
verbose: Enable debug logging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: CampaignConfig,
|
||||||
|
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.progress_cb = progress_cb
|
||||||
|
self._sdr = None
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public interface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def run(self) -> CampaignResult:
|
||||||
|
"""Execute the full campaign and return a :class:`CampaignResult`.
|
||||||
|
|
||||||
|
Initialises the SDR, runs all steps across all transmitters,
|
||||||
|
then closes the SDR. If SDR initialisation fails the exception
|
||||||
|
propagates immediately (nothing is captured).
|
||||||
|
"""
|
||||||
|
result = CampaignResult(campaign_name=self.config.name)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Starting campaign '{self.config.name}': "
|
||||||
|
f"{self.config.total_steps()} steps, "
|
||||||
|
f"~{self.config.total_capture_time_s():.0f}s capture time"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_sdr()
|
||||||
|
try:
|
||||||
|
total = self.config.total_steps()
|
||||||
|
step_index = 0
|
||||||
|
|
||||||
|
for transmitter in self.config.transmitters:
|
||||||
|
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
|
||||||
|
for step in transmitter.schedule:
|
||||||
|
step_result = self._execute_step(transmitter, step)
|
||||||
|
result.steps.append(step_result)
|
||||||
|
step_index += 1
|
||||||
|
|
||||||
|
if self.progress_cb:
|
||||||
|
self.progress_cb(step_index, total, step_result)
|
||||||
|
|
||||||
|
if step_result.error:
|
||||||
|
logger.warning(f"Step '{step.label}' error: {step_result.error}")
|
||||||
|
elif step_result.qa.flagged:
|
||||||
|
logger.warning(f"Step '{step.label}' flagged for review: " + "; ".join(step_result.qa.issues))
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Step '{step.label}' OK "
|
||||||
|
f"(SNR {step_result.qa.snr_db:.1f} dB, "
|
||||||
|
f"{step_result.qa.duration_s:.1f}s)"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._close_sdr()
|
||||||
|
|
||||||
|
result.end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"Campaign complete: {result.passed}/{result.total_steps} passed, "
|
||||||
|
f"{result.flagged} flagged, {result.failed} failed"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# SDR management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _init_sdr(self) -> None:
|
||||||
|
"""Initialise and configure the SDR recorder."""
|
||||||
|
from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
|
||||||
|
|
||||||
|
rec = self.config.recorder
|
||||||
|
device_name = _DEVICE_ALIASES.get(rec.device.lower(), rec.device.lower())
|
||||||
|
logger.info(f"Initialising SDR: {device_name} @ {rec.center_freq/1e6:.2f} MHz")
|
||||||
|
|
||||||
|
self._sdr = get_sdr_device(device_name)
|
||||||
|
gain = None if rec.gain == "auto" else float(rec.gain)
|
||||||
|
self._sdr.init_rx(
|
||||||
|
sample_rate=rec.sample_rate,
|
||||||
|
center_frequency=rec.center_freq,
|
||||||
|
gain=gain,
|
||||||
|
channel=0,
|
||||||
|
)
|
||||||
|
if rec.bandwidth and hasattr(self._sdr, "set_rx_bandwidth"):
|
||||||
|
self._sdr.set_rx_bandwidth(rec.bandwidth)
|
||||||
|
|
||||||
|
def _close_sdr(self) -> None:
|
||||||
|
if self._sdr is not None:
|
||||||
|
try:
|
||||||
|
self._sdr.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"SDR close error: {e}")
|
||||||
|
self._sdr = None
|
||||||
|
|
||||||
|
def _record(self, duration_s: float) -> Recording:
|
||||||
|
"""Capture ``duration_s`` seconds of IQ samples."""
|
||||||
|
num_samples = int(duration_s * self.config.recorder.sample_rate)
|
||||||
|
return self._sdr.record(num_samples=num_samples)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step execution
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _execute_step(self, transmitter: TransmitterConfig, step: CaptureStep) -> StepResult:
|
||||||
|
"""Run a single capture step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StepResult with QA outcome and output path (or error string).
|
||||||
|
"""
|
||||||
|
capture_timestamp = time.time()
|
||||||
|
output_path: Optional[str] = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._start_transmitter(transmitter, step)
|
||||||
|
recording = self._record(step.duration)
|
||||||
|
self._stop_transmitter(transmitter, step)
|
||||||
|
except Exception as e:
|
||||||
|
# Best-effort stop on error
|
||||||
|
try:
|
||||||
|
self._stop_transmitter(transmitter, step)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return StepResult(
|
||||||
|
transmitter_id=transmitter.id,
|
||||||
|
step_label=step.label,
|
||||||
|
output_path=None,
|
||||||
|
qa=QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=[f"Capture error: {e}"]),
|
||||||
|
capture_timestamp=capture_timestamp,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Label recording
|
||||||
|
recording = label_recording(
|
||||||
|
recording=recording,
|
||||||
|
device_id=transmitter.id,
|
||||||
|
step=step,
|
||||||
|
capture_timestamp=capture_timestamp,
|
||||||
|
campaign_name=self.config.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# QA
|
||||||
|
qa_result = check_recording(recording, self.config.qa)
|
||||||
|
|
||||||
|
# Save
|
||||||
|
try:
|
||||||
|
output_path = self._save(recording, transmitter.id, step)
|
||||||
|
except Exception as e:
|
||||||
|
return StepResult(
|
||||||
|
transmitter_id=transmitter.id,
|
||||||
|
step_label=step.label,
|
||||||
|
output_path=None,
|
||||||
|
qa=qa_result,
|
||||||
|
capture_timestamp=capture_timestamp,
|
||||||
|
error=f"Save failed: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return StepResult(
|
||||||
|
transmitter_id=transmitter.id,
|
||||||
|
step_label=step.label,
|
||||||
|
output_path=output_path,
|
||||||
|
qa=qa_result,
|
||||||
|
capture_timestamp=capture_timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Transmitter control (external script interface)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _start_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
|
||||||
|
"""Configure the transmitter for this step.
|
||||||
|
|
||||||
|
For ``external_script`` control method the script is called as::
|
||||||
|
|
||||||
|
<script> configure <step_params_json>
|
||||||
|
|
||||||
|
where ``step_params_json`` is a JSON object with channel, bandwidth,
|
||||||
|
traffic, etc. The script is responsible for applying the configuration
|
||||||
|
and returning promptly (i.e. not blocking for the capture duration).
|
||||||
|
|
||||||
|
For SDR transmitters this is a no-op placeholder (TX not yet implemented).
|
||||||
|
"""
|
||||||
|
if transmitter.control_method == "external_script":
|
||||||
|
if not transmitter.script:
|
||||||
|
logger.debug(f"No script configured for {transmitter.id}, skipping configure")
|
||||||
|
return
|
||||||
|
params = self._step_params_json(transmitter, step)
|
||||||
|
_run_script(transmitter.script, "configure", params)
|
||||||
|
|
||||||
|
elif transmitter.control_method == "sdr":
|
||||||
|
logger.debug("SDR TX not yet implemented — skipping start")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
|
||||||
|
|
||||||
|
def _stop_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
|
||||||
|
"""Signal the transmitter to stop.
|
||||||
|
|
||||||
|
Calls ``<script> stop`` for external_script transmitters.
|
||||||
|
"""
|
||||||
|
if transmitter.control_method == "external_script":
|
||||||
|
if not transmitter.script:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
_run_script(transmitter.script, "stop")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Script stop failed for {transmitter.id}: {e}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
|
||||||
|
"""Serialise step parameters to a JSON string for the control script."""
|
||||||
|
params: dict = {"device": transmitter.device or ""}
|
||||||
|
if step.channel is not None:
|
||||||
|
params["channel"] = step.channel
|
||||||
|
if step.bandwidth_mhz is not None:
|
||||||
|
params["bandwidth_mhz"] = step.bandwidth_mhz
|
||||||
|
if step.traffic is not None:
|
||||||
|
params["traffic"] = step.traffic
|
||||||
|
if step.power_dbm is not None:
|
||||||
|
params["power_dbm"] = step.power_dbm
|
||||||
|
return json.dumps(params)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Output
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _save(self, recording: Recording, device_id: str, step: CaptureStep) -> str:
|
||||||
|
"""Save a recording to disk and return the data file path."""
|
||||||
|
out = self.config.output
|
||||||
|
rel_filename = build_output_filename(device_id, step)
|
||||||
|
out_dir = Path(out.path)
|
||||||
|
# build_output_filename returns "<device_id>/<label>"
|
||||||
|
# to_sigmf needs filename (base) and path (dir) separately
|
||||||
|
parts = Path(rel_filename)
|
||||||
|
subdir = out_dir / parts.parent
|
||||||
|
subdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
base = parts.name
|
||||||
|
|
||||||
|
to_sigmf(recording, filename=base, path=str(subdir), overwrite=True)
|
||||||
|
return str(subdir / f"{base}.sigmf-data")
|
||||||
77
src/ria_toolkit_oss/orchestration/labeler.py
Normal file
77
src/ria_toolkit_oss/orchestration/labeler.py
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
"""Timestamp-based labeling for captured recordings."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ria_toolkit_oss.datatypes.recording import Recording
|
||||||
|
|
||||||
|
from .campaign import CaptureStep
|
||||||
|
|
||||||
|
|
||||||
|
def label_recording(
|
||||||
|
recording: Recording,
|
||||||
|
device_id: str,
|
||||||
|
step: CaptureStep,
|
||||||
|
capture_timestamp: float,
|
||||||
|
campaign_name: Optional[str] = None,
|
||||||
|
) -> Recording:
|
||||||
|
"""Apply device identity and capture configuration labels to a recording's metadata.
|
||||||
|
|
||||||
|
Labels are stored in the ``ria:*`` namespace when the recording is saved
|
||||||
|
as SigMF, via the existing ``update_metadata`` mechanism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recording: The recording to label.
|
||||||
|
device_id: Identifier for the transmitting device (e.g. "iphone13_wifi_001").
|
||||||
|
step: The capture step that was active during this recording.
|
||||||
|
capture_timestamp: Unix timestamp (float) of when capture started.
|
||||||
|
campaign_name: Optional campaign name for cross-recording reference.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same recording with updated metadata.
|
||||||
|
"""
|
||||||
|
recording.update_metadata("device_id", device_id)
|
||||||
|
recording.update_metadata("capture_timestamp", capture_timestamp)
|
||||||
|
recording.update_metadata("step_label", step.label)
|
||||||
|
recording.update_metadata("step_duration_s", step.duration)
|
||||||
|
|
||||||
|
if campaign_name:
|
||||||
|
recording.update_metadata("campaign", campaign_name)
|
||||||
|
|
||||||
|
# WiFi-specific labels
|
||||||
|
if step.channel is not None:
|
||||||
|
recording.update_metadata("wifi_channel", step.channel)
|
||||||
|
if step.bandwidth_mhz is not None:
|
||||||
|
recording.update_metadata("wifi_bandwidth_mhz", step.bandwidth_mhz)
|
||||||
|
|
||||||
|
# Bluetooth-specific labels
|
||||||
|
if step.connection_interval_ms is not None:
|
||||||
|
recording.update_metadata("bt_connection_interval_ms", step.connection_interval_ms)
|
||||||
|
|
||||||
|
# Traffic pattern (WiFi + BT)
|
||||||
|
if step.traffic is not None:
|
||||||
|
recording.update_metadata("traffic_pattern", step.traffic)
|
||||||
|
|
||||||
|
# TX power
|
||||||
|
if step.power_dbm is not None:
|
||||||
|
recording.update_metadata("tx_power_dbm", step.power_dbm)
|
||||||
|
|
||||||
|
return recording
|
||||||
|
|
||||||
|
|
||||||
|
def build_output_filename(device_id: str, step: CaptureStep) -> str:
|
||||||
|
"""Generate a deterministic filename for a labeled recording.
|
||||||
|
|
||||||
|
Format: ``<device_id>/<step_label>``
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_id: Device identifier string.
|
||||||
|
step: Capture step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Relative path string (no extension) to use as ``filename`` in ``to_sigmf()``.
|
||||||
|
"""
|
||||||
|
safe_id = device_id.replace("/", "_").replace(" ", "_")
|
||||||
|
safe_label = step.label.replace("/", "_").replace(" ", "_")
|
||||||
|
return f"{safe_id}/{safe_label}"
|
||||||
109
src/ria_toolkit_oss/orchestration/qa.py
Normal file
109
src/ria_toolkit_oss/orchestration/qa.py
Normal file
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""QA metrics for captured RF recordings."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ria_toolkit_oss.datatypes.recording import Recording
|
||||||
|
|
||||||
|
from .campaign import QAConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QAResult:
|
||||||
|
"""Result of QA checks on a single recording."""
|
||||||
|
|
||||||
|
passed: bool
|
||||||
|
flagged: bool # True if any metric is below threshold (but not hard-failed)
|
||||||
|
snr_db: float
|
||||||
|
duration_s: float
|
||||||
|
issues: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"passed": self.passed,
|
||||||
|
"flagged": self.flagged,
|
||||||
|
"snr_db": round(self.snr_db, 2),
|
||||||
|
"duration_s": round(self.duration_s, 3),
|
||||||
|
"issues": self.issues,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_snr_db(samples: np.ndarray, signal_fraction: float = 0.7) -> float:
|
||||||
|
"""Estimate SNR from IQ samples using PSD-based signal/noise separation.
|
||||||
|
|
||||||
|
Computes an FFT of the samples and assumes the top ``signal_fraction``
|
||||||
|
of power bins are signal and the remainder are noise. This is a
|
||||||
|
heuristic appropriate for a controlled testbed where a single dominant
|
||||||
|
signal is expected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
samples: 1-D complex array of IQ samples.
|
||||||
|
signal_fraction: Fraction of PSD bins to treat as signal (0–1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated SNR in dB, or 0.0 if the noise floor is zero.
|
||||||
|
"""
|
||||||
|
n_fft = min(4096, len(samples))
|
||||||
|
window = np.hanning(n_fft)
|
||||||
|
psd = np.abs(np.fft.fft(samples[:n_fft] * window)) ** 2
|
||||||
|
|
||||||
|
psd_sorted = np.sort(psd)[::-1]
|
||||||
|
n_signal = max(1, int(n_fft * signal_fraction))
|
||||||
|
signal_power = psd_sorted[:n_signal].mean()
|
||||||
|
noise_power = psd_sorted[n_signal:].mean()
|
||||||
|
|
||||||
|
if noise_power <= 0.0:
|
||||||
|
return 0.0
|
||||||
|
return float(10.0 * np.log10(signal_power / noise_power))
|
||||||
|
|
||||||
|
|
||||||
|
def check_recording(recording: Recording, config: QAConfig) -> QAResult:
|
||||||
|
"""Run QA checks on a recording against the campaign QA config.
|
||||||
|
|
||||||
|
Checks performed:
|
||||||
|
- Duration: number of samples / sample_rate >= min_duration_s
|
||||||
|
- SNR: estimated SNR >= snr_threshold_db
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recording: Recording to evaluate.
|
||||||
|
config: QA thresholds from the campaign config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QAResult with pass/flag status and per-metric details.
|
||||||
|
"""
|
||||||
|
issues: list[str] = []
|
||||||
|
flagged = False
|
||||||
|
|
||||||
|
# --- Duration check ---
|
||||||
|
sample_rate = recording.metadata.get("sample_rate", 1.0)
|
||||||
|
n_samples = recording.data.shape[-1]
|
||||||
|
duration_s = n_samples / sample_rate if sample_rate else 0.0
|
||||||
|
|
||||||
|
if duration_s < config.min_duration_s:
|
||||||
|
issues.append(f"Duration too short: {duration_s:.1f}s < {config.min_duration_s:.1f}s threshold")
|
||||||
|
flagged = True
|
||||||
|
|
||||||
|
# --- SNR check ---
|
||||||
|
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
|
||||||
|
snr_db = estimate_snr_db(samples)
|
||||||
|
|
||||||
|
if snr_db < config.snr_threshold_db:
|
||||||
|
issues.append(f"SNR below threshold: {snr_db:.1f} dB < {config.snr_threshold_db:.1f} dB")
|
||||||
|
flagged = True
|
||||||
|
|
||||||
|
# In flag_for_review mode: flag but don't hard-fail
|
||||||
|
if config.flag_for_review:
|
||||||
|
passed = True # always accept; human reviews flagged recordings
|
||||||
|
else:
|
||||||
|
passed = not flagged
|
||||||
|
|
||||||
|
return QAResult(
|
||||||
|
passed=passed,
|
||||||
|
flagged=flagged,
|
||||||
|
snr_db=snr_db,
|
||||||
|
duration_s=duration_s,
|
||||||
|
issues=issues,
|
||||||
|
)
|
||||||
|
|
@ -474,8 +474,10 @@ class Blade(SDR):
|
||||||
|
|
||||||
if gain_mode == "relative":
|
if gain_mode == "relative":
|
||||||
if gain > 0:
|
if gain > 0:
|
||||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets \
|
raise SDRParameterError(
|
||||||
the gain relative to the maximum possible gain.")
|
"When gain_mode = 'relative', gain must be < 0. This sets \
|
||||||
|
the gain relative to the maximum possible gain."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
abs_gain = rx_gain_max + gain
|
abs_gain = rx_gain_max + gain
|
||||||
else:
|
else:
|
||||||
|
|
@ -548,8 +550,10 @@ class Blade(SDR):
|
||||||
|
|
||||||
if gain_mode == "relative":
|
if gain_mode == "relative":
|
||||||
if gain > 0:
|
if gain > 0:
|
||||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
raise SDRParameterError(
|
||||||
the gain relative to the maximum possible gain.")
|
"When gain_mode = 'relative', gain must be < 0. This sets\
|
||||||
|
the gain relative to the maximum possible gain."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
abs_gain = tx_gain_max + gain
|
abs_gain = tx_gain_max + gain
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -172,8 +172,10 @@ class HackRF(SDR):
|
||||||
tx_gain_max = 47
|
tx_gain_max = 47
|
||||||
if gain_mode == "relative":
|
if gain_mode == "relative":
|
||||||
if gain > 0:
|
if gain > 0:
|
||||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This \
|
raise SDRParameterError(
|
||||||
sets the gain relative to the maximum possible gain.")
|
"When gain_mode = 'relative', gain must be < 0. This \
|
||||||
|
sets the gain relative to the maximum possible gain."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
abs_gain = tx_gain_max + gain
|
abs_gain = tx_gain_max + gain
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -274,16 +274,20 @@ class Pluto(SDR):
|
||||||
data = [self._convert_tx_samples(samples), self._convert_tx_samples(samples)]
|
data = [self._convert_tx_samples(samples), self._convert_tx_samples(samples)]
|
||||||
else:
|
else:
|
||||||
if len(recording) > 2:
|
if len(recording) > 2:
|
||||||
warnings.warn("More recordings were provided than channels in the Pluto. \
|
warnings.warn(
|
||||||
Only the first two recordings will be used")
|
"More recordings were provided than channels in the Pluto. \
|
||||||
|
Only the first two recordings will be used"
|
||||||
|
)
|
||||||
sample0 = self._convert_tx_samples(recording.data[0])
|
sample0 = self._convert_tx_samples(recording.data[0])
|
||||||
sample1 = self._convert_tx_samples(recording.data[1])
|
sample1 = self._convert_tx_samples(recording.data[1])
|
||||||
data = [sample0, sample1]
|
data = [sample0, sample1]
|
||||||
|
|
||||||
elif isinstance(recording, list):
|
elif isinstance(recording, list):
|
||||||
if len(recording) > 2:
|
if len(recording) > 2:
|
||||||
warnings.warn("More recordings were provided than channels in the Pluto. \
|
warnings.warn(
|
||||||
Only the first two recordings will be used")
|
"More recordings were provided than channels in the Pluto. \
|
||||||
|
Only the first two recordings will be used"
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(recording[0], np.ndarray):
|
if isinstance(recording[0], np.ndarray):
|
||||||
data = [self._convert_tx_samples(recording[0]), self._convert_tx_samples(recording[1])]
|
data = [self._convert_tx_samples(recording[0]), self._convert_tx_samples(recording[1])]
|
||||||
|
|
@ -423,8 +427,10 @@ class Pluto(SDR):
|
||||||
|
|
||||||
if gain_mode == "relative":
|
if gain_mode == "relative":
|
||||||
if gain > 0:
|
if gain > 0:
|
||||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets \
|
raise SDRParameterError(
|
||||||
the gain relative to the maximum possible gain.")
|
"When gain_mode = 'relative', gain must be < 0. This sets \
|
||||||
|
the gain relative to the maximum possible gain."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
abs_gain = rx_gain_max + gain
|
abs_gain = rx_gain_max + gain
|
||||||
else:
|
else:
|
||||||
|
|
@ -534,8 +540,10 @@ class Pluto(SDR):
|
||||||
|
|
||||||
if gain_mode == "relative":
|
if gain_mode == "relative":
|
||||||
if gain > 0:
|
if gain > 0:
|
||||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
raise SDRParameterError(
|
||||||
the gain relative to the maximum possible gain.")
|
"When gain_mode = 'relative', gain must be < 0. This sets\
|
||||||
|
the gain relative to the maximum possible gain."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
abs_gain = tx_gain_max + gain
|
abs_gain = tx_gain_max + gain
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -131,15 +131,19 @@ class RTLSDR(SDR):
|
||||||
|
|
||||||
if gain_mode == "relative":
|
if gain_mode == "relative":
|
||||||
if gain > 0:
|
if gain > 0:
|
||||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
raise SDRParameterError(
|
||||||
the gain relative to the maximum possible gain.")
|
"When gain_mode = 'relative', gain must be < 0. This sets\
|
||||||
|
the gain relative to the maximum possible gain."
|
||||||
|
)
|
||||||
target_gain = max_gain + gain
|
target_gain = max_gain + gain
|
||||||
else:
|
else:
|
||||||
target_gain = gain
|
target_gain = gain
|
||||||
|
|
||||||
if target_gain < min_gain or target_gain > max_gain:
|
if target_gain < min_gain or target_gain > max_gain:
|
||||||
print(f"Requested gain {target_gain} dB out of range;\
|
print(
|
||||||
clamping to valid span {min_gain}-{max_gain} dB.")
|
f"Requested gain {target_gain} dB out of range;\
|
||||||
|
clamping to valid span {min_gain}-{max_gain} dB."
|
||||||
|
)
|
||||||
target_gain = min(max(target_gain, min_gain), max_gain)
|
target_gain = min(max(target_gain, min_gain), max_gain)
|
||||||
|
|
||||||
target_gain = min(available_gains, key=lambda g: abs(g - target_gain))
|
target_gain = min(available_gains, key=lambda g: abs(g - target_gain))
|
||||||
|
|
|
||||||
|
|
@ -392,8 +392,10 @@ class ThinkRF(SDR):
|
||||||
actual_sample_rate = self.BASE_SAMPLE_RATE / decimation
|
actual_sample_rate = self.BASE_SAMPLE_RATE / decimation
|
||||||
|
|
||||||
if abs(actual_sample_rate - requested_sample_rate) > 1e3: # More than 1 kHz difference
|
if abs(actual_sample_rate - requested_sample_rate) > 1e3: # More than 1 kHz difference
|
||||||
print(f"ThinkRF: Requested {requested_sample_rate/1e6:.2f} MS/s → \
|
print(
|
||||||
Using decimation={decimation} ({actual_sample_rate/1e6:.2f} MS/s)")
|
f"ThinkRF: Requested {requested_sample_rate/1e6:.2f} MS/s → \
|
||||||
|
Using decimation={decimation} ({actual_sample_rate/1e6:.2f} MS/s)"
|
||||||
|
)
|
||||||
|
|
||||||
return decimation, actual_sample_rate
|
return decimation, actual_sample_rate
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -148,8 +148,10 @@ class USRP(SDR):
|
||||||
gain_range = self.usrp.get_rx_gain_range()
|
gain_range = self.usrp.get_rx_gain_range()
|
||||||
if gain_mode == "relative":
|
if gain_mode == "relative":
|
||||||
if gain > 0:
|
if gain > 0:
|
||||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
raise SDRParameterError(
|
||||||
the gain relative to the maximum possible gain.")
|
"When gain_mode = 'relative', gain must be < 0. This sets\
|
||||||
|
the gain relative to the maximum possible gain."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# set gain relative to max
|
# set gain relative to max
|
||||||
abs_gain = gain_range.stop() + gain
|
abs_gain = gain_range.stop() + gain
|
||||||
|
|
@ -354,8 +356,10 @@ class USRP(SDR):
|
||||||
gain_range = self.usrp.get_tx_gain_range()
|
gain_range = self.usrp.get_tx_gain_range()
|
||||||
if gain_mode == "relative":
|
if gain_mode == "relative":
|
||||||
if gain > 0:
|
if gain > 0:
|
||||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
raise SDRParameterError(
|
||||||
the gain relative to the maximum possible gain.")
|
"When gain_mode = 'relative', gain must be < 0. This sets\
|
||||||
|
the gain relative to the maximum possible gain."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# set gain relative to max
|
# set gain relative to max
|
||||||
abs_gain = gain_range.stop() + gain
|
abs_gain = gain_range.stop() + gain
|
||||||
|
|
|
||||||
5
src/ria_toolkit_oss/server/__init__.py
Normal file
5
src/ria_toolkit_oss/server/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""RT-OSS HTTP server for RIA Hub integration."""
|
||||||
|
|
||||||
|
from .app import create_app
|
||||||
|
|
||||||
|
__all__ = ["create_app"]
|
||||||
48
src/ria_toolkit_oss/server/app.py
Normal file
48
src/ria_toolkit_oss/server/app.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
"""FastAPI application factory for the RT-OSS HTTP server."""
|
||||||
|
|
||||||
|
from fastapi import Depends, FastAPI
|
||||||
|
|
||||||
|
from .auth import require_api_key
|
||||||
|
from .routers import inference, orchestrator
|
||||||
|
|
||||||
|
|
||||||
|
def create_app(api_key: str = "") -> FastAPI:
|
||||||
|
"""Create and configure the RT-OSS FastAPI application.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Secret key required in the ``X-API-Key`` request header.
|
||||||
|
Pass an empty string to disable authentication (development only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured FastAPI application instance.
|
||||||
|
"""
|
||||||
|
app = FastAPI(
|
||||||
|
title="RIA Toolkit OSS Server",
|
||||||
|
version="0.1.0",
|
||||||
|
description=(
|
||||||
|
"HTTP API for RT-OSS campaign orchestration and RF zone inference. "
|
||||||
|
"All endpoints (except /health) require the X-API-Key header when "
|
||||||
|
"an API key is configured."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
app.state.api_key = api_key
|
||||||
|
|
||||||
|
app.include_router(
|
||||||
|
orchestrator.router,
|
||||||
|
prefix="/orchestrator",
|
||||||
|
tags=["Orchestrator"],
|
||||||
|
dependencies=[Depends(require_api_key)],
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
inference.router,
|
||||||
|
prefix="/inference",
|
||||||
|
tags=["Inference"],
|
||||||
|
dependencies=[Depends(require_api_key)],
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/health", tags=["Health"])
|
||||||
|
async def health():
|
||||||
|
"""Health check — always returns 200."""
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
return app
|
||||||
25
src/ria_toolkit_oss/server/auth.py
Normal file
25
src/ria_toolkit_oss/server/auth.py
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
"""API key authentication dependency."""
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request, status
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
|
||||||
|
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def require_api_key(
|
||||||
|
request: Request,
|
||||||
|
api_key: str | None = Depends(_api_key_header),
|
||||||
|
) -> None:
|
||||||
|
"""FastAPI dependency that enforces X-API-Key header authentication.
|
||||||
|
|
||||||
|
If no API key is configured on the server (empty string), all requests
|
||||||
|
are allowed — this is intended for local development only.
|
||||||
|
"""
|
||||||
|
expected: str = request.app.state.api_key
|
||||||
|
if not expected:
|
||||||
|
return # dev mode: no key set, allow all
|
||||||
|
if api_key != expected:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid or missing API key",
|
||||||
|
)
|
||||||
77
src/ria_toolkit_oss/server/models.py
Normal file
77
src/ria_toolkit_oss/server/models.py
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
"""Pydantic request and response models for the RT-OSS HTTP server."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Orchestrator
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class DeployRequest(BaseModel):
|
||||||
|
config: dict
|
||||||
|
|
||||||
|
|
||||||
|
class DeployResponse(BaseModel):
|
||||||
|
campaign_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class CampaignStatusResponse(BaseModel):
|
||||||
|
campaign_id: str
|
||||||
|
status: str
|
||||||
|
config_name: str
|
||||||
|
progress: int
|
||||||
|
total_steps: int
|
||||||
|
started_at: float
|
||||||
|
ended_at: float | None = None
|
||||||
|
result: dict | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CancelResponse(BaseModel):
|
||||||
|
campaign_id: str
|
||||||
|
cancelled: bool
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Inference
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class SdrConfig(BaseModel):
|
||||||
|
device: str
|
||||||
|
center_freq: float
|
||||||
|
sample_rate: float
|
||||||
|
gain: float | str = "auto"
|
||||||
|
|
||||||
|
|
||||||
|
class LoadModelRequest(BaseModel):
|
||||||
|
model_path: str
|
||||||
|
label_map: dict[str, int] # class_name -> class_index
|
||||||
|
|
||||||
|
|
||||||
|
class LoadModelResponse(BaseModel):
|
||||||
|
loaded: bool
|
||||||
|
model_path: str
|
||||||
|
num_classes: int
|
||||||
|
|
||||||
|
|
||||||
|
class StartInferenceRequest(BaseModel):
|
||||||
|
sdr_config: SdrConfig
|
||||||
|
|
||||||
|
|
||||||
|
class StartInferenceResponse(BaseModel):
|
||||||
|
running: bool
|
||||||
|
|
||||||
|
|
||||||
|
class StopInferenceResponse(BaseModel):
|
||||||
|
stopped: bool
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceStatusResponse(BaseModel):
|
||||||
|
timestamp: float
|
||||||
|
prediction: str
|
||||||
|
confidence: float
|
||||||
|
snr_db: float
|
||||||
|
zone: str | None = None
|
||||||
0
src/ria_toolkit_oss/server/routers/__init__.py
Normal file
0
src/ria_toolkit_oss/server/routers/__init__.py
Normal file
183
src/ria_toolkit_oss/server/routers/inference.py
Normal file
183
src/ria_toolkit_oss/server/routers/inference.py
Normal file
|
|
@ -0,0 +1,183 @@
|
||||||
|
"""Inference routes: model loading, inference loop control, and status polling."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
from scipy.special import softmax
|
||||||
|
|
||||||
|
from ..models import (
|
||||||
|
InferenceStatusResponse,
|
||||||
|
LoadModelRequest,
|
||||||
|
LoadModelResponse,
|
||||||
|
StartInferenceRequest,
|
||||||
|
StartInferenceResponse,
|
||||||
|
StopInferenceResponse,
|
||||||
|
)
|
||||||
|
from ..state import InferenceState, get_inference, set_inference
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_INFERENCE_NUM_SAMPLES = 4096
|
||||||
|
|
||||||
|
|
||||||
|
def _load_onnx_session(model_path: str):
|
||||||
|
try:
|
||||||
|
import onnxruntime as ort
|
||||||
|
except ImportError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="onnxruntime not installed. Install with: pip install ria-toolkit-oss[server]",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Failed to load ONNX model: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_samples(samples: np.ndarray, expected_shape: tuple) -> np.ndarray:
|
||||||
|
"""Reshape complex IQ samples to float32 matching the model's expected input.
|
||||||
|
|
||||||
|
Supports ``(batch, 2*N)`` interleaved and ``(batch, 2, N)`` two-channel conventions.
|
||||||
|
"""
|
||||||
|
iq = samples.astype(np.complex64)
|
||||||
|
i_ch, q_ch = iq.real, iq.imag
|
||||||
|
|
||||||
|
if len(expected_shape) == 2:
|
||||||
|
n = expected_shape[1] // 2
|
||||||
|
interleaved = np.empty(expected_shape[1], dtype=np.float32)
|
||||||
|
interleaved[0::2] = i_ch[:n]
|
||||||
|
interleaved[1::2] = q_ch[:n]
|
||||||
|
return interleaved.reshape(1, -1)
|
||||||
|
elif len(expected_shape) == 3:
|
||||||
|
n = expected_shape[2]
|
||||||
|
return np.stack([i_ch[:n], q_ch[:n]], axis=0).astype(np.float32).reshape(1, 2, n)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model input shape: {expected_shape}")
|
||||||
|
|
||||||
|
|
||||||
|
def _stop_current_inference(state: InferenceState, timeout: float = 5.0) -> None:
|
||||||
|
state.stop_event.set()
|
||||||
|
if state.thread and state.thread.is_alive():
|
||||||
|
state.thread.join(timeout=timeout)
|
||||||
|
if state.thread.is_alive():
|
||||||
|
logger.warning("Inference thread did not stop within %.1fs; SDR resources may not be released", timeout)
|
||||||
|
|
||||||
|
|
||||||
|
def _inference_loop(state: InferenceState, sdr) -> None:
|
||||||
|
from ria_toolkit_oss.orchestration.qa import estimate_snr_db
|
||||||
|
|
||||||
|
session = state.session
|
||||||
|
input_name = session.get_inputs()[0].name
|
||||||
|
expected_shape = tuple(
|
||||||
|
d if isinstance(d, int) and d > 0 else _INFERENCE_NUM_SAMPLES for d in session.get_inputs()[0].shape
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not state.stop_event.is_set():
|
||||||
|
recording = sdr.record(num_samples=_INFERENCE_NUM_SAMPLES)
|
||||||
|
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
|
||||||
|
snr_db = estimate_snr_db(samples)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_input = _preprocess_samples(samples, expected_shape)
|
||||||
|
logits = session.run(None, {input_name: model_input})[0][0].astype(np.float32)
|
||||||
|
probs = softmax(logits)
|
||||||
|
pred_idx = int(np.argmax(probs))
|
||||||
|
prediction = state.index_to_label.get(pred_idx, str(pred_idx))
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
state.set_latest(
|
||||||
|
{
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"prediction": prediction,
|
||||||
|
"confidence": round(float(probs[pred_idx]), 4),
|
||||||
|
"snr_db": round(snr_db, 2),
|
||||||
|
"zone": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
sdr.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
state.running = False
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/load", response_model=LoadModelResponse)
|
||||||
|
async def load_model(request: LoadModelRequest):
|
||||||
|
"""Load an ONNX model. Stops any running inference first.
|
||||||
|
``label_map`` maps class names to integer indices (e.g. ``{"zone_a": 0}``).
|
||||||
|
"""
|
||||||
|
existing = get_inference()
|
||||||
|
if existing and existing.running:
|
||||||
|
_stop_current_inference(existing)
|
||||||
|
|
||||||
|
session = _load_onnx_session(request.model_path)
|
||||||
|
set_inference(
|
||||||
|
InferenceState(
|
||||||
|
model_path=request.model_path,
|
||||||
|
label_map=request.label_map,
|
||||||
|
index_to_label={v: k for k, v in request.label_map.items()},
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return LoadModelResponse(loaded=True, model_path=request.model_path, num_classes=len(request.label_map))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/start", response_model=StartInferenceResponse)
|
||||||
|
async def start_inference(request: StartInferenceRequest):
|
||||||
|
"""Start continuous inference. Requires a model to be loaded first."""
|
||||||
|
state = get_inference()
|
||||||
|
if not state:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT, detail="No model loaded. Call POST /inference/load first."
|
||||||
|
)
|
||||||
|
if state.running:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Inference is already running.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ria_toolkit_oss.orchestration.executor import _DEVICE_ALIASES
|
||||||
|
from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
|
||||||
|
except ImportError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"SDR import failed: {e}")
|
||||||
|
|
||||||
|
sdr_cfg = request.sdr_config
|
||||||
|
try:
|
||||||
|
sdr = get_sdr_device(_DEVICE_ALIASES.get(sdr_cfg.device.lower(), sdr_cfg.device.lower()))
|
||||||
|
gain = None if sdr_cfg.gain == "auto" else float(sdr_cfg.gain)
|
||||||
|
sdr.init_rx(sample_rate=sdr_cfg.sample_rate, center_frequency=sdr_cfg.center_freq, gain=gain, channel=0)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"SDR initialisation failed: {e}")
|
||||||
|
|
||||||
|
state.stop_event.clear()
|
||||||
|
state.running = True
|
||||||
|
state.thread = threading.Thread(target=_inference_loop, args=(state, sdr), daemon=True)
|
||||||
|
state.thread.start()
|
||||||
|
return StartInferenceResponse(running=True)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/stop", response_model=StopInferenceResponse)
|
||||||
|
async def stop_inference():
|
||||||
|
"""Stop the running inference loop."""
|
||||||
|
state = get_inference()
|
||||||
|
if not state or not state.running:
|
||||||
|
return StopInferenceResponse(stopped=False)
|
||||||
|
_stop_current_inference(state)
|
||||||
|
return StopInferenceResponse(stopped=True)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status", response_model=InferenceStatusResponse | None)
|
||||||
|
async def inference_status():
|
||||||
|
"""Return the latest inference result, or null if none available yet."""
|
||||||
|
state = get_inference()
|
||||||
|
if not state:
|
||||||
|
return None
|
||||||
|
latest = state.get_latest()
|
||||||
|
return InferenceStatusResponse(**latest) if latest else None
|
||||||
102
src/ria_toolkit_oss/server/routers/orchestrator.py
Normal file
102
src/ria_toolkit_oss/server/routers/orchestrator.py
Normal file
|
|
@ -0,0 +1,102 @@
|
||||||
|
"""Orchestrator routes: campaign deployment, status, and cancellation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
|
||||||
|
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
||||||
|
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
|
||||||
|
|
||||||
|
from ..models import (
|
||||||
|
CampaignStatusResponse,
|
||||||
|
CancelResponse,
|
||||||
|
DeployRequest,
|
||||||
|
DeployResponse,
|
||||||
|
)
|
||||||
|
from ..state import (
|
||||||
|
CampaignCancelledError,
|
||||||
|
CampaignState,
|
||||||
|
get_campaign,
|
||||||
|
set_campaign,
|
||||||
|
update_campaign,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_progress_cb(campaign_id: str, cancel_event: threading.Event):
|
||||||
|
def cb(step_index: int, total_steps: int, step_result: Any) -> None:
|
||||||
|
update_campaign(campaign_id, progress=step_index)
|
||||||
|
if cancel_event.is_set():
|
||||||
|
raise CampaignCancelledError(f"Cancelled at step {step_index}/{total_steps}")
|
||||||
|
|
||||||
|
return cb
|
||||||
|
|
||||||
|
|
||||||
|
def _run_campaign_thread(campaign_id: str, cfg: CampaignConfig) -> None:
|
||||||
|
state = get_campaign(campaign_id)
|
||||||
|
try:
|
||||||
|
result = CampaignExecutor(
|
||||||
|
config=cfg,
|
||||||
|
progress_cb=_make_progress_cb(campaign_id, state.cancel_event),
|
||||||
|
).run()
|
||||||
|
update_campaign(
|
||||||
|
campaign_id, status="completed", progress=cfg.total_steps(), result=result.to_dict(), ended_at=time.time()
|
||||||
|
)
|
||||||
|
except CampaignCancelledError:
|
||||||
|
update_campaign(campaign_id, status="cancelled", ended_at=time.time())
|
||||||
|
except Exception as e:
|
||||||
|
update_campaign(campaign_id, status="failed", error=str(e), ended_at=time.time())
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/deploy", response_model=DeployResponse)
|
||||||
|
async def deploy(request: DeployRequest):
|
||||||
|
"""Deploy a campaign config and start execution. Returns a ``campaign_id`` for polling.
|
||||||
|
Cancellation takes effect at step boundaries, not mid-capture.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_dict(request.config)
|
||||||
|
except (ValueError, KeyError) as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
|
||||||
|
|
||||||
|
campaign_id = str(uuid.uuid4())
|
||||||
|
cancel_event = threading.Event()
|
||||||
|
thread = threading.Thread(target=_run_campaign_thread, args=(campaign_id, cfg), daemon=True)
|
||||||
|
set_campaign(
|
||||||
|
CampaignState(
|
||||||
|
campaign_id=campaign_id,
|
||||||
|
status="running",
|
||||||
|
config_name=cfg.name,
|
||||||
|
cancel_event=cancel_event,
|
||||||
|
thread=thread,
|
||||||
|
total_steps=cfg.total_steps(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
return DeployResponse(campaign_id=campaign_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status/{campaign_id}", response_model=CampaignStatusResponse)
|
||||||
|
async def get_status(campaign_id: str):
|
||||||
|
"""Get the status and progress of a deployed campaign."""
|
||||||
|
state = get_campaign(campaign_id)
|
||||||
|
if not state:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
|
||||||
|
return CampaignStatusResponse(**state.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/cancel/{campaign_id}", response_model=CancelResponse)
|
||||||
|
async def cancel(campaign_id: str):
|
||||||
|
"""Request cancellation. Takes effect at the next step boundary."""
|
||||||
|
state = get_campaign(campaign_id)
|
||||||
|
if not state:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
|
||||||
|
if state.status != "running":
|
||||||
|
return CancelResponse(campaign_id=campaign_id, cancelled=False)
|
||||||
|
state.cancel_event.set()
|
||||||
|
return CancelResponse(campaign_id=campaign_id, cancelled=True)
|
||||||
101
src/ria_toolkit_oss/server/state.py
Normal file
101
src/ria_toolkit_oss/server/state.py
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
"""In-memory state for running campaigns and inference sessions."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class CampaignCancelledError(Exception):
|
||||||
|
"""Raised by the progress callback when a cancel is requested."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CampaignState:
|
||||||
|
campaign_id: str
|
||||||
|
status: str # "running" | "completed" | "failed" | "cancelled"
|
||||||
|
config_name: str
|
||||||
|
cancel_event: threading.Event
|
||||||
|
thread: threading.Thread
|
||||||
|
total_steps: int = 0
|
||||||
|
progress: int = 0
|
||||||
|
result: Optional[dict] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
started_at: float = field(default_factory=time.time)
|
||||||
|
ended_at: Optional[float] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"campaign_id": self.campaign_id,
|
||||||
|
"status": self.status,
|
||||||
|
"config_name": self.config_name,
|
||||||
|
"progress": self.progress,
|
||||||
|
"total_steps": self.total_steps,
|
||||||
|
"result": self.result,
|
||||||
|
"error": self.error,
|
||||||
|
"started_at": self.started_at,
|
||||||
|
"ended_at": self.ended_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferenceState:
|
||||||
|
model_path: str
|
||||||
|
label_map: dict[str, int] # class_name -> class_index
|
||||||
|
index_to_label: dict[int, str] # reverse: class_index -> class_name
|
||||||
|
session: Any # onnxruntime.InferenceSession
|
||||||
|
stop_event: threading.Event = field(default_factory=threading.Event)
|
||||||
|
thread: Optional[threading.Thread] = None
|
||||||
|
running: bool = False
|
||||||
|
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
||||||
|
_latest: Optional[dict] = field(default=None, repr=False)
|
||||||
|
|
||||||
|
def set_latest(self, result: dict) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._latest = result
|
||||||
|
|
||||||
|
def get_latest(self) -> Optional[dict]:
|
||||||
|
with self._lock:
|
||||||
|
return self._latest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level stores
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_campaigns: dict[str, CampaignState] = {}
|
||||||
|
_campaigns_lock = threading.Lock()
|
||||||
|
|
||||||
|
_inference: Optional[InferenceState] = None
|
||||||
|
_inference_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_campaign(campaign_id: str) -> Optional[CampaignState]:
|
||||||
|
with _campaigns_lock:
|
||||||
|
return _campaigns.get(campaign_id)
|
||||||
|
|
||||||
|
|
||||||
|
def set_campaign(state: CampaignState) -> None:
|
||||||
|
with _campaigns_lock:
|
||||||
|
_campaigns[state.campaign_id] = state
|
||||||
|
|
||||||
|
|
||||||
|
def update_campaign(campaign_id: str, **kwargs) -> None:
|
||||||
|
with _campaigns_lock:
|
||||||
|
state = _campaigns.get(campaign_id)
|
||||||
|
if state:
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(state, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference() -> Optional[InferenceState]:
|
||||||
|
with _inference_lock:
|
||||||
|
return _inference
|
||||||
|
|
||||||
|
|
||||||
|
def set_inference(state: Optional[InferenceState]) -> None:
|
||||||
|
global _inference
|
||||||
|
with _inference_lock:
|
||||||
|
_inference = state
|
||||||
|
|
@ -37,8 +37,10 @@ class Add(RecordableBlock, ProcessBlock):
|
||||||
|
|
||||||
samples = block.get_samples(num_samples)
|
samples = block.get_samples(num_samples)
|
||||||
if len(samples) != num_samples:
|
if len(samples) != num_samples:
|
||||||
raise ValueError(f"Block {self.__class__.__name__} requested {num_samples} \
|
raise ValueError(
|
||||||
from block {block.__class__.__name__} but got {len(samples)}.")
|
f"Block {self.__class__.__name__} requested {num_samples} \
|
||||||
|
from block {block.__class__.__name__} but got {len(samples)}."
|
||||||
|
)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,11 @@ class ProcessBlock(Block, ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
elif not all(isinstance(item, Block) for item in input):
|
elif not all(isinstance(item, Block) for item in input):
|
||||||
raise ValueError(f"Invalid input to block '{self.__class__.__name__}'. \
|
raise ValueError(
|
||||||
|
f"Invalid input to block '{self.__class__.__name__}'. \
|
||||||
Expected a list of Block objects but got \
|
Expected a list of Block objects but got \
|
||||||
{'[' + ',' .join(f'{item.__class__.__name__}({repr(item)})' for item in input) + ']'}")
|
{'[' + ',' .join(f'{item.__class__.__name__}({repr(item)})' for item in input) + ']'}"
|
||||||
|
)
|
||||||
|
|
||||||
elif len(input) != len(self.input_type):
|
elif len(input) != len(self.input_type):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,10 @@ class RecordableBlock(Block, Recordable):
|
||||||
:raises ValueError: If the number of samples is incorrect."""
|
:raises ValueError: If the number of samples is incorrect."""
|
||||||
samples = self.get_samples(num_samples)
|
samples = self.get_samples(num_samples)
|
||||||
if len(samples) != num_samples:
|
if len(samples) != num_samples:
|
||||||
raise ValueError(f"Error in block {self.__class__.__name__} record(). \
|
raise ValueError(
|
||||||
Requested {num_samples} samples but got {len(samples)}")
|
f"Error in block {self.__class__.__name__} record(). \
|
||||||
|
Requested {num_samples} samples but got {len(samples)}"
|
||||||
|
)
|
||||||
metadata = self._get_metadata()
|
metadata = self._get_metadata()
|
||||||
return Recording(data=samples, metadata=metadata)
|
return Recording(data=samples, metadata=metadata)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,9 @@ class RecordingSource(SourceBlock, RecordableBlock):
|
||||||
:raises ValueError: If num_samples is greater than the recording length.
|
:raises ValueError: If num_samples is greater than the recording length.
|
||||||
"""
|
"""
|
||||||
if num_samples - 1 >= self.recording.data.shape[1]:
|
if num_samples - 1 >= self.recording.data.shape[1]:
|
||||||
raise ValueError(f"{num_samples} samples requested from recording source with \
|
raise ValueError(
|
||||||
{self.recording.data.shape[1]} samples available.")
|
f"{num_samples} samples requested from recording source with \
|
||||||
|
{self.recording.data.shape[1]} samples available."
|
||||||
|
)
|
||||||
|
|
||||||
return self.recording.data[0, 0:num_samples]
|
return self.recording.data[0, 0:num_samples]
|
||||||
|
|
|
||||||
|
|
@ -610,8 +610,10 @@ def cut_out( # noqa: C901 # TODO: Simplify function
|
||||||
raise ValueError("signal must be CxN complex.")
|
raise ValueError("signal must be CxN complex.")
|
||||||
|
|
||||||
if fill_type not in {"zeros", "ones", "low-snr", "avg-snr", "high-snr"}:
|
if fill_type not in {"zeros", "ones", "low-snr", "avg-snr", "high-snr"}:
|
||||||
raise UserWarning("""fill_type must be "zeros", "ones", "low-snr", "avg-snr", or "high-snr",
|
raise UserWarning(
|
||||||
"ones" has been selected by default""")
|
"""fill_type must be "zeros", "ones", "low-snr", "avg-snr", or "high-snr",
|
||||||
|
"ones" has been selected by default"""
|
||||||
|
)
|
||||||
|
|
||||||
if max_section_size < 1 or max_section_size >= n:
|
if max_section_size < 1 or max_section_size >= n:
|
||||||
raise ValueError("max_section_size must be at least 1 and must be less than the length of signal.")
|
raise ValueError("max_section_size must be at least 1 and must be less than the length of signal.")
|
||||||
|
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 90 KiB After Width: | Height: | Size: 130 B |
Binary file not shown.
|
Before Width: | Height: | Size: 19 KiB After Width: | Height: | Size: 130 B |
258
src/ria_toolkit_oss_cli/ria_toolkit_oss/campaign.py
Normal file
258
src/ria_toolkit_oss_cli/ria_toolkit_oss/campaign.py
Normal file
|
|
@ -0,0 +1,258 @@
|
||||||
|
"""Campaign orchestration CLI commands.
|
||||||
|
|
||||||
|
Usage examples::
|
||||||
|
|
||||||
|
# Enroll a single device using a device profile YAML (App 1 workflow)
|
||||||
|
ria campaign enroll --config iphone13.yml
|
||||||
|
|
||||||
|
# Run a full custom campaign config
|
||||||
|
ria campaign run --config my_campaign.yml
|
||||||
|
|
||||||
|
# Validate a config file without running it
|
||||||
|
ria campaign validate --config my_campaign.yml
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
||||||
|
from ria_toolkit_oss.orchestration.executor import CampaignExecutor, StepResult
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def campaign():
|
||||||
|
"""Orchestrate automated RF capture campaigns."""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ria campaign validate
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@campaign.command()
|
||||||
|
@click.option(
|
||||||
|
"--config",
|
||||||
|
"-c",
|
||||||
|
required=True,
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help="Campaign YAML config or device profile YAML.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--profile",
|
||||||
|
is_flag=True,
|
||||||
|
default=False,
|
||||||
|
help="Parse as a device profile (App 1 style) rather than a full campaign config.",
|
||||||
|
)
|
||||||
|
def validate(config, profile):
|
||||||
|
"""Validate a campaign or device profile YAML without running it.
|
||||||
|
|
||||||
|
\b
|
||||||
|
Examples:
|
||||||
|
ria campaign validate --config iphone13.yml --profile
|
||||||
|
ria campaign validate --config campaign.yml
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if profile:
|
||||||
|
cfg = CampaignConfig.from_device_profile(config)
|
||||||
|
else:
|
||||||
|
cfg = CampaignConfig.from_yaml(config)
|
||||||
|
except (FileNotFoundError, ValueError, KeyError) as e:
|
||||||
|
raise click.ClickException(str(e))
|
||||||
|
|
||||||
|
click.echo(click.style("✓ Config valid", fg="green", bold=True))
|
||||||
|
click.echo(f" Campaign name : {cfg.name}")
|
||||||
|
click.echo(f" Mode : {cfg.mode}")
|
||||||
|
click.echo(f" Transmitters : {len(cfg.transmitters)}")
|
||||||
|
click.echo(f" Total steps : {cfg.total_steps()}")
|
||||||
|
click.echo(f" Capture time : {cfg.total_capture_time_s():.0f}s")
|
||||||
|
click.echo(f" Recorder : {cfg.recorder.device} @ {cfg.recorder.center_freq/1e6:.2f} MHz")
|
||||||
|
click.echo(f" Sample rate : {cfg.recorder.sample_rate/1e6:.1f} MS/s")
|
||||||
|
click.echo(f" Output path : {cfg.output.path}")
|
||||||
|
click.echo()
|
||||||
|
|
||||||
|
for tx in cfg.transmitters:
|
||||||
|
click.echo(f" Transmitter: {tx.id} ({tx.type}, {tx.control_method}, {len(tx.schedule)} steps)")
|
||||||
|
for step in tx.schedule:
|
||||||
|
extras = []
|
||||||
|
if step.channel is not None:
|
||||||
|
extras.append(f"ch={step.channel}")
|
||||||
|
if step.bandwidth_mhz is not None:
|
||||||
|
extras.append(f"{int(step.bandwidth_mhz)}MHz")
|
||||||
|
if step.traffic:
|
||||||
|
extras.append(step.traffic)
|
||||||
|
suffix = f" [{', '.join(extras)}]" if extras else ""
|
||||||
|
click.echo(f" [{step.duration:.0f}s] {step.label}{suffix}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ria campaign enroll
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@campaign.command()
|
||||||
|
@click.option(
|
||||||
|
"--config",
|
||||||
|
"-c",
|
||||||
|
required=True,
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help="Device profile YAML (App 1 enrollment format).",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--output",
|
||||||
|
"-o",
|
||||||
|
default=None,
|
||||||
|
help="Override output directory from config.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--report",
|
||||||
|
default="qa_report.json",
|
||||||
|
show_default=True,
|
||||||
|
help="Path for the JSON QA report.",
|
||||||
|
)
|
||||||
|
@click.option("--verbose", "-v", is_flag=True, help="Verbose output.")
|
||||||
|
@click.option("--dry-run", is_flag=True, help="Parse and validate config, then exit.")
|
||||||
|
def enroll(config, output, report, verbose, dry_run):
|
||||||
|
"""Enroll a single device by running its capture profile.
|
||||||
|
|
||||||
|
Parses a device profile YAML (App 1 format), generates a capture
|
||||||
|
campaign, and runs it. Outputs labelled SigMF recordings and a
|
||||||
|
JSON QA report.
|
||||||
|
|
||||||
|
\b
|
||||||
|
Examples:
|
||||||
|
ria campaign enroll --config iphone13.yml
|
||||||
|
ria campaign enroll --config airpods.yml --output ./my_recordings
|
||||||
|
ria campaign enroll --config iphone13.yml --dry-run
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_device_profile(config)
|
||||||
|
except (FileNotFoundError, ValueError, KeyError) as e:
|
||||||
|
raise click.ClickException(str(e))
|
||||||
|
|
||||||
|
if output:
|
||||||
|
cfg.output.path = output
|
||||||
|
|
||||||
|
_print_campaign_summary(cfg)
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
click.echo(click.style("Dry run — exiting before capture.", fg="yellow"))
|
||||||
|
return
|
||||||
|
|
||||||
|
result = _run_campaign(cfg, verbose=verbose)
|
||||||
|
result.write_report(report)
|
||||||
|
|
||||||
|
_print_result_summary(result, report)
|
||||||
|
sys.exit(0 if result.failed == 0 else 1)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ria campaign run
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@campaign.command()
|
||||||
|
@click.option(
|
||||||
|
"--config",
|
||||||
|
"-c",
|
||||||
|
required=True,
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help="Full campaign YAML config.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--output",
|
||||||
|
"-o",
|
||||||
|
default=None,
|
||||||
|
help="Override output directory from config.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--report",
|
||||||
|
default="qa_report.json",
|
||||||
|
show_default=True,
|
||||||
|
help="Path for the JSON QA report.",
|
||||||
|
)
|
||||||
|
@click.option("--verbose", "-v", is_flag=True, help="Verbose output.")
|
||||||
|
@click.option("--dry-run", is_flag=True, help="Parse and validate config, then exit.")
|
||||||
|
def run(config, output, report, verbose, dry_run):
|
||||||
|
"""Run a full campaign from a campaign config YAML.
|
||||||
|
|
||||||
|
\b
|
||||||
|
Examples:
|
||||||
|
ria campaign run --config wifi_capture.yml
|
||||||
|
ria campaign run --config campaign.yml --output ./data --dry-run
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_yaml(config)
|
||||||
|
except (FileNotFoundError, ValueError, KeyError) as e:
|
||||||
|
raise click.ClickException(str(e))
|
||||||
|
|
||||||
|
if output:
|
||||||
|
cfg.output.path = output
|
||||||
|
|
||||||
|
_print_campaign_summary(cfg)
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
click.echo(click.style("Dry run — exiting before capture.", fg="yellow"))
|
||||||
|
return
|
||||||
|
|
||||||
|
result = _run_campaign(cfg, verbose=verbose)
|
||||||
|
result.write_report(report)
|
||||||
|
|
||||||
|
_print_result_summary(result, report)
|
||||||
|
sys.exit(0 if result.failed == 0 else 1)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _print_campaign_summary(cfg: CampaignConfig) -> None:
|
||||||
|
click.echo()
|
||||||
|
click.echo(click.style(f"Campaign: {cfg.name}", bold=True))
|
||||||
|
click.echo(f" Transmitters : {len(cfg.transmitters)}")
|
||||||
|
click.echo(f" Total steps : {cfg.total_steps()}")
|
||||||
|
click.echo(f" Capture time : ~{cfg.total_capture_time_s():.0f}s")
|
||||||
|
click.echo(f" Output : {cfg.output.path}")
|
||||||
|
click.echo()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_progress_cb(total: int):
|
||||||
|
"""Return a progress callback that prints step results to stderr."""
|
||||||
|
|
||||||
|
def cb(idx: int, _total: int, step: StepResult) -> None:
|
||||||
|
status = (
|
||||||
|
click.style("✓", fg="green")
|
||||||
|
if step.ok
|
||||||
|
else (click.style("⚑", fg="yellow") if step.qa.flagged else click.style("✗", fg="red"))
|
||||||
|
)
|
||||||
|
snr_str = f"SNR {step.qa.snr_db:.1f} dB" if not step.error else f"ERROR: {step.error}"
|
||||||
|
click.echo(
|
||||||
|
f" [{idx:>3}/{_total}] {status} {step.transmitter_id}/{step.step_label} — {snr_str}",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cb
|
||||||
|
|
||||||
|
|
||||||
|
def _run_campaign(cfg: CampaignConfig, verbose: bool = False):
|
||||||
|
executor = CampaignExecutor(
|
||||||
|
config=cfg,
|
||||||
|
progress_cb=_make_progress_cb(cfg.total_steps()),
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
return executor.run()
|
||||||
|
|
||||||
|
|
||||||
|
def _print_result_summary(result, report_path: str) -> None:
|
||||||
|
click.echo()
|
||||||
|
click.echo(click.style("Campaign complete", bold=True))
|
||||||
|
click.echo(f" Steps : {result.total_steps}")
|
||||||
|
click.echo(f" Passed : {click.style(str(result.passed), fg='green')}")
|
||||||
|
if result.flagged:
|
||||||
|
click.echo(f" Flagged : {click.style(str(result.flagged), fg='yellow')} (review required)")
|
||||||
|
if result.failed:
|
||||||
|
click.echo(f" Failed : {click.style(str(result.failed), fg='red')}")
|
||||||
|
click.echo(f" Duration: {result.duration_s:.0f}s")
|
||||||
|
click.echo(f" Report : {report_path}")
|
||||||
|
click.echo()
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
This module contains all the CLI bindings for the ria package.
|
This module contains all the CLI bindings for the ria package.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .campaign import campaign
|
||||||
from .capture import capture
|
from .capture import capture
|
||||||
from .combine import combine
|
from .combine import combine
|
||||||
from .convert import convert
|
from .convert import convert
|
||||||
|
|
@ -13,6 +14,7 @@ from .generate import generate
|
||||||
|
|
||||||
# from .generate import generate
|
# from .generate import generate
|
||||||
from .init import init
|
from .init import init
|
||||||
|
from .serve import serve
|
||||||
from .split import split
|
from .split import split
|
||||||
from .transform import transform
|
from .transform import transform
|
||||||
from .transmit import transmit
|
from .transmit import transmit
|
||||||
|
|
|
||||||
51
src/ria_toolkit_oss_cli/ria_toolkit_oss/serve.py
Normal file
51
src/ria_toolkit_oss_cli/ria_toolkit_oss/serve.py
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
"""``ria serve`` — start the RT-OSS HTTP server for RIA Hub integration."""
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option("--host", default="0.0.0.0", show_default=True, help="Bind host.")
|
||||||
|
@click.option("--port", default=8080, show_default=True, type=int, help="Bind port.")
|
||||||
|
@click.option(
|
||||||
|
"--api-key",
|
||||||
|
envvar="RT_OSS_API_KEY",
|
||||||
|
default="",
|
||||||
|
help="Required X-API-Key value. Also reads RT_OSS_API_KEY. Empty = no auth (dev only).",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--log-level",
|
||||||
|
default="info",
|
||||||
|
show_default=True,
|
||||||
|
type=click.Choice(["debug", "info", "warning", "error"], case_sensitive=False),
|
||||||
|
)
|
||||||
|
def serve(host: str, port: int, api_key: str, log_level: str):
|
||||||
|
"""Start the RT-OSS HTTP server.
|
||||||
|
|
||||||
|
\b
|
||||||
|
Endpoints:
|
||||||
|
POST /orchestrator/deploy
|
||||||
|
GET /orchestrator/status/{campaign_id}
|
||||||
|
POST /orchestrator/cancel/{campaign_id}
|
||||||
|
POST /inference/load
|
||||||
|
POST /inference/start
|
||||||
|
POST /inference/stop
|
||||||
|
GET /inference/status
|
||||||
|
GET /health
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from ria_toolkit_oss.server.app import create_app
|
||||||
|
except ImportError as e:
|
||||||
|
raise click.ClickException(
|
||||||
|
f"Server dependencies missing: {e}\nInstall with: pip install ria-toolkit-oss[server]"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
click.echo(
|
||||||
|
click.style("Warning: ", fg="yellow", bold=True) + "no API key set — all requests unauthenticated.",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
click.echo(f"Starting RT-OSS server on http://{host}:{port}")
|
||||||
|
uvicorn.run(create_app(api_key=api_key), host=host, port=port, log_level=log_level.lower())
|
||||||
0
tests/orchestration/__init__.py
Normal file
0
tests/orchestration/__init__.py
Normal file
489
tests/orchestration/test_campaign.py
Normal file
489
tests/orchestration/test_campaign.py
Normal file
|
|
@ -0,0 +1,489 @@
|
||||||
|
"""Tests for orchestration campaign schema and YAML parsing."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from ria_toolkit_oss.orchestration.campaign import (
|
||||||
|
CampaignConfig,
|
||||||
|
CaptureStep,
|
||||||
|
QAConfig,
|
||||||
|
RecorderConfig,
|
||||||
|
parse_bandwidth_mhz,
|
||||||
|
parse_duration,
|
||||||
|
parse_frequency,
|
||||||
|
parse_gain,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# parse_duration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseDuration:
|
||||||
|
def test_seconds_suffix(self):
|
||||||
|
assert parse_duration("30s") == 30.0
|
||||||
|
|
||||||
|
def test_seconds_suffix_long(self):
|
||||||
|
assert parse_duration("30sec") == 30.0
|
||||||
|
|
||||||
|
def test_minutes_suffix(self):
|
||||||
|
assert parse_duration("1.5m") == 90.0
|
||||||
|
|
||||||
|
def test_minutes_suffix_long(self):
|
||||||
|
assert parse_duration("2min") == 120.0
|
||||||
|
|
||||||
|
def test_hours_suffix(self):
|
||||||
|
assert parse_duration("2h") == 7200.0
|
||||||
|
|
||||||
|
def test_hours_suffix_long(self):
|
||||||
|
assert parse_duration("1hr") == 3600.0
|
||||||
|
|
||||||
|
def test_numeric_int(self):
|
||||||
|
assert parse_duration(45) == 45.0
|
||||||
|
|
||||||
|
def test_numeric_float(self):
|
||||||
|
assert parse_duration(1.5) == 1.5
|
||||||
|
|
||||||
|
def test_bare_number_string(self):
|
||||||
|
# No unit → treated as seconds
|
||||||
|
assert parse_duration("60") == 60.0
|
||||||
|
|
||||||
|
def test_invalid_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
parse_duration("two minutes")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# parse_frequency
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseFrequency:
|
||||||
|
def test_ghz(self):
|
||||||
|
assert parse_frequency("2.45GHz") == pytest.approx(2.45e9)
|
||||||
|
|
||||||
|
def test_mhz(self):
|
||||||
|
assert parse_frequency("40MHz") == pytest.approx(40e6)
|
||||||
|
|
||||||
|
def test_khz(self):
|
||||||
|
assert parse_frequency("433k") == pytest.approx(433e3)
|
||||||
|
|
||||||
|
def test_scientific_notation_string(self):
|
||||||
|
assert parse_frequency("915e6") == pytest.approx(915e6)
|
||||||
|
|
||||||
|
def test_numeric_float(self):
|
||||||
|
assert parse_frequency(2.45e9) == pytest.approx(2.45e9)
|
||||||
|
|
||||||
|
def test_numeric_int(self):
|
||||||
|
assert parse_frequency(1000000) == pytest.approx(1e6)
|
||||||
|
|
||||||
|
def test_hz_suffix_optional(self):
|
||||||
|
# "40M" and "40MHz" should both work
|
||||||
|
assert parse_frequency("40M") == pytest.approx(40e6)
|
||||||
|
assert parse_frequency("40MHz") == pytest.approx(40e6)
|
||||||
|
|
||||||
|
def test_invalid_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
parse_frequency("two point four gigs")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# parse_gain
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseGain:
|
||||||
|
def test_db_suffix(self):
|
||||||
|
assert parse_gain("40dB") == pytest.approx(40.0)
|
||||||
|
|
||||||
|
def test_db_suffix_lowercase(self):
|
||||||
|
assert parse_gain("32db") == pytest.approx(32.0)
|
||||||
|
|
||||||
|
def test_auto(self):
|
||||||
|
assert parse_gain("auto") == "auto"
|
||||||
|
|
||||||
|
def test_auto_case_insensitive(self):
|
||||||
|
assert parse_gain("AUTO") == "auto"
|
||||||
|
|
||||||
|
def test_numeric_int(self):
|
||||||
|
assert parse_gain(32) == pytest.approx(32.0)
|
||||||
|
|
||||||
|
def test_numeric_float(self):
|
||||||
|
assert parse_gain(32.5) == pytest.approx(32.5)
|
||||||
|
|
||||||
|
def test_invalid_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
parse_gain("high")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# parse_bandwidth_mhz
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseBandwidthMhz:
|
||||||
|
def test_mhz_suffix(self):
|
||||||
|
assert parse_bandwidth_mhz("20MHz") == pytest.approx(20.0)
|
||||||
|
|
||||||
|
def test_numeric(self):
|
||||||
|
assert parse_bandwidth_mhz(40) == pytest.approx(40.0)
|
||||||
|
|
||||||
|
def test_none(self):
|
||||||
|
assert parse_bandwidth_mhz(None) is None
|
||||||
|
|
||||||
|
def test_invalid_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
parse_bandwidth_mhz("wide")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CaptureStep.from_dict
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCaptureStep:
|
||||||
|
def test_wifi_step_auto_label(self):
|
||||||
|
d = {"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_udp", "duration": "30s"}
|
||||||
|
step = CaptureStep.from_dict(d)
|
||||||
|
assert step.duration == 30.0
|
||||||
|
assert step.channel == 6
|
||||||
|
assert step.bandwidth_mhz == 20.0
|
||||||
|
assert step.traffic == "iperf_udp"
|
||||||
|
assert step.label == "ch06_20mhz_iperf_udp"
|
||||||
|
|
||||||
|
def test_explicit_label(self):
|
||||||
|
d = {"channel": 1, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s", "label": "my_label"}
|
||||||
|
step = CaptureStep.from_dict(d)
|
||||||
|
assert step.label == "my_label"
|
||||||
|
|
||||||
|
def test_fallback_label(self):
|
||||||
|
# No channel/bandwidth/traffic → label falls back to "capture"
|
||||||
|
d = {"duration": "10s"}
|
||||||
|
step = CaptureStep.from_dict(d)
|
||||||
|
assert step.label == "capture"
|
||||||
|
|
||||||
|
def test_power_parsed(self):
|
||||||
|
d = {"channel": 6, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s", "power": "15dBm"}
|
||||||
|
step = CaptureStep.from_dict(d)
|
||||||
|
assert step.power_dbm == pytest.approx(15.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RecorderConfig.from_dict
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRecorderConfig:
|
||||||
|
def test_basic(self):
|
||||||
|
d = {"device": "usrp_b210", "center_freq": "2.45GHz", "sample_rate": "40MHz", "gain": "40dB"}
|
||||||
|
rec = RecorderConfig.from_dict(d)
|
||||||
|
assert rec.device == "usrp_b210"
|
||||||
|
assert rec.center_freq == pytest.approx(2.45e9)
|
||||||
|
assert rec.sample_rate == pytest.approx(40e6)
|
||||||
|
assert rec.gain == pytest.approx(40.0)
|
||||||
|
assert rec.bandwidth is None
|
||||||
|
|
||||||
|
def test_auto_gain(self):
|
||||||
|
d = {"device": "pluto", "center_freq": "2.45GHz", "sample_rate": "20MHz", "gain": "auto"}
|
||||||
|
rec = RecorderConfig.from_dict(d)
|
||||||
|
assert rec.gain == "auto"
|
||||||
|
|
||||||
|
def test_bandwidth_set(self):
|
||||||
|
d = {"device": "pluto", "center_freq": "2.45GHz", "sample_rate": "20MHz", "gain": 32, "bandwidth": "20MHz"}
|
||||||
|
rec = RecorderConfig.from_dict(d)
|
||||||
|
assert rec.bandwidth == pytest.approx(20e6)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# QAConfig.from_dict
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestQAConfig:
|
||||||
|
def test_defaults(self):
|
||||||
|
qa = QAConfig.from_dict({})
|
||||||
|
assert qa.snr_threshold_db == pytest.approx(10.0)
|
||||||
|
assert qa.min_duration_s == pytest.approx(25.0)
|
||||||
|
assert qa.flag_for_review is True
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
d = {"snr_threshold": "15dB", "min_duration": "28s", "flag_for_review": False}
|
||||||
|
qa = QAConfig.from_dict(d)
|
||||||
|
assert qa.snr_threshold_db == pytest.approx(15.0)
|
||||||
|
assert qa.min_duration_s == pytest.approx(28.0)
|
||||||
|
assert qa.flag_for_review is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CampaignConfig.from_device_profile
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _write_device_profile(d: dict) -> str:
|
||||||
|
"""Write a dict as YAML to a temp file and return the path."""
|
||||||
|
f = tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False)
|
||||||
|
yaml.dump(d, f)
|
||||||
|
f.close()
|
||||||
|
return f.name
|
||||||
|
|
||||||
|
|
||||||
|
WIFI_PROFILE = {
|
||||||
|
"device": {"name": "iPhone_13_WiFi", "type": "wifi"},
|
||||||
|
"capture": {
|
||||||
|
"channels": [1, 6, 11],
|
||||||
|
"bandwidth": "20MHz",
|
||||||
|
"traffic_patterns": ["idle", "ping", "iperf_udp"],
|
||||||
|
"duration_per_config": "30s",
|
||||||
|
"script": "./scripts/wifi_control.sh",
|
||||||
|
},
|
||||||
|
"recorder": {
|
||||||
|
"device": "usrp_b210",
|
||||||
|
"center_freq": "2.45GHz",
|
||||||
|
"sample_rate": "40MHz",
|
||||||
|
"gain": "auto",
|
||||||
|
},
|
||||||
|
"output": {"path": "/tmp/test_recordings", "device_id": "iphone13_wifi_001"},
|
||||||
|
}
|
||||||
|
|
||||||
|
BT_PROFILE = {
|
||||||
|
"device": {"name": "AirPods_Pro", "type": "bluetooth"},
|
||||||
|
"capture": {
|
||||||
|
"traffic_patterns": ["idle", "audio_stream", "data_transfer"],
|
||||||
|
"duration_per_config": "30s",
|
||||||
|
},
|
||||||
|
"recorder": {
|
||||||
|
"device": "usrp_b210",
|
||||||
|
"center_freq": "2.45GHz",
|
||||||
|
"sample_rate": "40MHz",
|
||||||
|
"gain": "auto",
|
||||||
|
},
|
||||||
|
"output": {"path": "/tmp/test_recordings", "device_id": "airpods_pro_bt_001"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeviceProfileParsing:
|
||||||
|
def test_wifi_schedule_count(self):
|
||||||
|
"""WiFi: 3 channels × 3 traffic = 9 steps."""
|
||||||
|
path = _write_device_profile(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_device_profile(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
assert len(cfg.transmitters) == 1
|
||||||
|
assert len(cfg.transmitters[0].schedule) == 9
|
||||||
|
|
||||||
|
def test_wifi_campaign_name(self):
|
||||||
|
path = _write_device_profile(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_device_profile(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
assert cfg.name == "enroll_iphone13_wifi_001"
|
||||||
|
|
||||||
|
def test_wifi_step_labels(self):
|
||||||
|
path = _write_device_profile(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_device_profile(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
labels = [s.label for s in cfg.transmitters[0].schedule]
|
||||||
|
assert "ch01_20mhz_idle" in labels
|
||||||
|
assert "ch06_20mhz_ping" in labels
|
||||||
|
assert "ch11_20mhz_iperf_udp" in labels
|
||||||
|
|
||||||
|
def test_wifi_step_ordering(self):
|
||||||
|
"""Steps iterate channels first, then traffic."""
|
||||||
|
path = _write_device_profile(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_device_profile(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
steps = cfg.transmitters[0].schedule
|
||||||
|
assert steps[0].channel == 1 and steps[0].traffic == "idle"
|
||||||
|
assert steps[1].channel == 1 and steps[1].traffic == "ping"
|
||||||
|
assert steps[3].channel == 6 and steps[3].traffic == "idle"
|
||||||
|
assert steps[8].channel == 11 and steps[8].traffic == "iperf_udp"
|
||||||
|
|
||||||
|
def test_wifi_step_duration(self):
|
||||||
|
path = _write_device_profile(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_device_profile(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
for step in cfg.transmitters[0].schedule:
|
||||||
|
assert step.duration == pytest.approx(30.0)
|
||||||
|
|
||||||
|
def test_wifi_bandwidth(self):
|
||||||
|
path = _write_device_profile(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_device_profile(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
for step in cfg.transmitters[0].schedule:
|
||||||
|
assert step.bandwidth_mhz == pytest.approx(20.0)
|
||||||
|
|
||||||
|
def test_wifi_recorder(self):
|
||||||
|
path = _write_device_profile(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_device_profile(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
assert cfg.recorder.device == "usrp_b210"
|
||||||
|
assert cfg.recorder.center_freq == pytest.approx(2.45e9)
|
||||||
|
assert cfg.recorder.sample_rate == pytest.approx(40e6)
|
||||||
|
assert cfg.recorder.gain == "auto"
|
||||||
|
|
||||||
|
def test_wifi_total_capture_time(self):
|
||||||
|
path = _write_device_profile(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_device_profile(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
assert cfg.total_capture_time_s() == pytest.approx(270.0) # 9 × 30s
|
||||||
|
|
||||||
|
def test_bt_schedule_count(self):
|
||||||
|
"""BT: no channels, 3 traffic patterns = 3 steps."""
|
||||||
|
path = _write_device_profile(BT_PROFILE)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_device_profile(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
assert len(cfg.transmitters[0].schedule) == 3
|
||||||
|
|
||||||
|
def test_bt_no_channel(self):
|
||||||
|
path = _write_device_profile(BT_PROFILE)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_device_profile(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
for step in cfg.transmitters[0].schedule:
|
||||||
|
assert step.channel is None
|
||||||
|
|
||||||
|
def test_bt_step_labels(self):
|
||||||
|
path = _write_device_profile(BT_PROFILE)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_device_profile(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
labels = [s.label for s in cfg.transmitters[0].schedule]
|
||||||
|
assert labels == ["idle", "audio_stream", "data_transfer"]
|
||||||
|
|
||||||
|
def test_missing_file_raises(self):
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
CampaignConfig.from_device_profile("/nonexistent/path/profile.yml")
|
||||||
|
|
||||||
|
def test_invalid_yaml_raises(self):
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
|
||||||
|
f.write(": bad: yaml: [\n")
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
with pytest.raises(ValueError, match="Invalid YAML"):
|
||||||
|
CampaignConfig.from_device_profile(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CampaignConfig.from_yaml (full campaign format)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
FULL_CAMPAIGN = {
|
||||||
|
"campaign": {"name": "wifi_capture_001", "mode": "controlled_testbed"},
|
||||||
|
"transmitters": [
|
||||||
|
{
|
||||||
|
"id": "laptop_wifi",
|
||||||
|
"type": "wifi",
|
||||||
|
"control_method": "external_script",
|
||||||
|
"script": "./scripts/wifi_control.sh",
|
||||||
|
"device": "/dev/wlan0",
|
||||||
|
"schedule": [
|
||||||
|
{"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_tcp", "duration": "30s"},
|
||||||
|
{"channel": 36, "bandwidth": "40MHz", "traffic": "ping_flood", "duration": "30s"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"recorder": {
|
||||||
|
"device": "usrp_b210",
|
||||||
|
"center_freq": "2.45GHz",
|
||||||
|
"sample_rate": "20MHz",
|
||||||
|
"gain": "40dB",
|
||||||
|
},
|
||||||
|
"qa": {"snr_threshold": "10dB", "min_duration": "25s", "flag_for_review": True},
|
||||||
|
"output": {"format": "sigmf", "path": "./recordings"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFullCampaignParsing:
|
||||||
|
def test_name(self):
|
||||||
|
path = _write_device_profile(FULL_CAMPAIGN)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_yaml(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
assert cfg.name == "wifi_capture_001"
|
||||||
|
|
||||||
|
def test_mode(self):
|
||||||
|
path = _write_device_profile(FULL_CAMPAIGN)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_yaml(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
assert cfg.mode == "controlled_testbed"
|
||||||
|
|
||||||
|
def test_transmitter_id(self):
|
||||||
|
path = _write_device_profile(FULL_CAMPAIGN)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_yaml(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
assert cfg.transmitters[0].id == "laptop_wifi"
|
||||||
|
assert cfg.transmitters[0].control_method == "external_script"
|
||||||
|
assert cfg.transmitters[0].script == "./scripts/wifi_control.sh"
|
||||||
|
|
||||||
|
def test_schedule_count(self):
|
||||||
|
path = _write_device_profile(FULL_CAMPAIGN)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_yaml(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
assert len(cfg.transmitters[0].schedule) == 2
|
||||||
|
|
||||||
|
def test_qa_config(self):
|
||||||
|
path = _write_device_profile(FULL_CAMPAIGN)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_yaml(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
assert cfg.qa.snr_threshold_db == pytest.approx(10.0)
|
||||||
|
assert cfg.qa.min_duration_s == pytest.approx(25.0)
|
||||||
|
assert cfg.qa.flag_for_review is True
|
||||||
|
|
||||||
|
def test_total_steps(self):
|
||||||
|
path = _write_device_profile(FULL_CAMPAIGN)
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_yaml(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
assert cfg.total_steps() == 2
|
||||||
|
|
||||||
|
def test_no_transmitters_raises(self):
|
||||||
|
bad = dict(FULL_CAMPAIGN)
|
||||||
|
bad["transmitters"] = []
|
||||||
|
path = _write_device_profile(bad)
|
||||||
|
try:
|
||||||
|
with pytest.raises(ValueError, match="at least one transmitter"):
|
||||||
|
CampaignConfig.from_yaml(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_missing_recorder_raises(self):
|
||||||
|
bad = {k: v for k, v in FULL_CAMPAIGN.items() if k != "recorder"}
|
||||||
|
path = _write_device_profile(bad)
|
||||||
|
try:
|
||||||
|
with pytest.raises((KeyError, ValueError)):
|
||||||
|
CampaignConfig.from_yaml(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
274
tests/orchestration/test_campaign_cli.py
Normal file
274
tests/orchestration/test_campaign_cli.py
Normal file
|
|
@ -0,0 +1,274 @@
|
||||||
|
"""Tests for the `ria campaign` CLI commands."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
from ria_toolkit_oss_cli.cli import cli
|
||||||
|
|
||||||
|
|
||||||
|
def _write_yaml(d: dict, suffix=".yml") -> str:
|
||||||
|
f = tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False)
|
||||||
|
yaml.dump(d, f)
|
||||||
|
f.close()
|
||||||
|
return f.name
|
||||||
|
|
||||||
|
|
||||||
|
WIFI_PROFILE = {
|
||||||
|
"device": {"name": "iPhone_13_WiFi", "type": "wifi"},
|
||||||
|
"capture": {
|
||||||
|
"channels": [1, 6, 11],
|
||||||
|
"bandwidth": "20MHz",
|
||||||
|
"traffic_patterns": ["idle", "ping", "iperf_udp"],
|
||||||
|
"duration_per_config": "30s",
|
||||||
|
},
|
||||||
|
"recorder": {
|
||||||
|
"device": "usrp_b210",
|
||||||
|
"center_freq": "2.45GHz",
|
||||||
|
"sample_rate": "40MHz",
|
||||||
|
"gain": "auto",
|
||||||
|
},
|
||||||
|
"output": {"path": "/tmp/test_enroll", "device_id": "iphone13_wifi_001"},
|
||||||
|
}
|
||||||
|
|
||||||
|
FULL_CAMPAIGN = {
|
||||||
|
"campaign": {"name": "wifi_capture_001", "mode": "controlled_testbed"},
|
||||||
|
"transmitters": [
|
||||||
|
{
|
||||||
|
"id": "laptop_wifi",
|
||||||
|
"type": "wifi",
|
||||||
|
"control_method": "external_script",
|
||||||
|
"schedule": [
|
||||||
|
{"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_udp", "duration": "30s"},
|
||||||
|
{"channel": 11, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"recorder": {
|
||||||
|
"device": "usrp_b210",
|
||||||
|
"center_freq": "2.45GHz",
|
||||||
|
"sample_rate": "20MHz",
|
||||||
|
"gain": "40dB",
|
||||||
|
},
|
||||||
|
"qa": {"snr_threshold": "10dB", "min_duration": "25s", "flag_for_review": True},
|
||||||
|
"output": {"format": "sigmf", "path": "/tmp/test_campaign"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ria campaign --help
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCampaignHelp:
|
||||||
|
def test_campaign_help(self):
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "campaign" in result.output.lower()
|
||||||
|
|
||||||
|
def test_subcommands_listed(self):
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
for sub in ("validate", "enroll", "run"):
|
||||||
|
assert sub in result.output
|
||||||
|
|
||||||
|
def test_validate_help(self):
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "validate", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
def test_enroll_help(self):
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "enroll", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
def test_run_help(self):
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "run", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ria campaign validate
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCampaignValidate:
|
||||||
|
def test_validate_device_profile(self):
|
||||||
|
path = _write_yaml(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "✓" in result.output or "valid" in result.output.lower()
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_validate_shows_campaign_name(self):
|
||||||
|
path = _write_yaml(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||||
|
assert "enroll_iphone13_wifi_001" in result.output
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_validate_shows_step_count(self):
|
||||||
|
path = _write_yaml(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||||
|
assert "9" in result.output # 9 total steps
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_validate_shows_capture_time(self):
|
||||||
|
path = _write_yaml(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||||
|
assert "270" in result.output # 270s total
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_validate_full_campaign(self):
|
||||||
|
path = _write_yaml(FULL_CAMPAIGN)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "validate", "--config", path])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "wifi_capture_001" in result.output
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_validate_shows_steps(self):
|
||||||
|
path = _write_yaml(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||||
|
assert "ch01_20mhz_idle" in result.output
|
||||||
|
assert "ch06_20mhz_ping" in result.output
|
||||||
|
assert "ch11_20mhz_iperf_udp" in result.output
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_validate_missing_file(self):
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "validate", "--config", "/nonexistent/file.yml"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
def test_validate_bad_yaml(self):
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
|
||||||
|
f.write(": broken yaml [\n")
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "validate", "--config", path])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ria campaign enroll --dry-run
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCampaignEnrollDryRun:
|
||||||
|
def test_dry_run_exits_cleanly(self):
|
||||||
|
path = _write_yaml(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "enroll", "--config", path, "--dry-run"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_dry_run_shows_campaign_info(self):
|
||||||
|
path = _write_yaml(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "enroll", "--config", path, "--dry-run"])
|
||||||
|
assert "enroll_iphone13_wifi_001" in result.output
|
||||||
|
assert "9" in result.output
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_dry_run_does_not_capture(self):
|
||||||
|
"""Dry run should not create any output files."""
|
||||||
|
path = _write_yaml(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
runner.invoke(
|
||||||
|
cli,
|
||||||
|
["campaign", "enroll", "--config", path, "--output", tmpdir, "--dry-run"],
|
||||||
|
)
|
||||||
|
# No .sigmf-data files should have been created
|
||||||
|
sigmf_files = list(os.walk(tmpdir))
|
||||||
|
all_files = [f for _, _, files in sigmf_files for f in files]
|
||||||
|
assert not any(f.endswith(".sigmf-data") for f in all_files)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_dry_run_output_override(self):
|
||||||
|
path = _write_yaml(WIFI_PROFILE)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(
|
||||||
|
cli,
|
||||||
|
["campaign", "enroll", "--config", path, "--output", "/tmp/custom_out", "--dry-run"],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Dry run" in result.output
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ria campaign run --dry-run
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCampaignRunDryRun:
|
||||||
|
def test_dry_run_exits_cleanly(self):
|
||||||
|
path = _write_yaml(FULL_CAMPAIGN)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "run", "--config", path, "--dry-run"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_dry_run_shows_campaign_name(self):
|
||||||
|
path = _write_yaml(FULL_CAMPAIGN)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "run", "--config", path, "--dry-run"])
|
||||||
|
assert "wifi_capture_001" in result.output
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_dry_run_does_not_create_report(self):
|
||||||
|
path = _write_yaml(FULL_CAMPAIGN)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
report_path = os.path.join(tmpdir, "qa_report.json")
|
||||||
|
result = runner.invoke(
|
||||||
|
cli,
|
||||||
|
["campaign", "run", "--config", path, "--dry-run", "--report", report_path],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert not os.path.exists(report_path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_missing_config_fails(self):
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["campaign", "run", "--config", "/nonexistent.yml"])
|
||||||
|
assert result.exit_code != 0
|
||||||
145
tests/orchestration/test_labeler.py
Normal file
145
tests/orchestration/test_labeler.py
Normal file
|
|
@ -0,0 +1,145 @@
|
||||||
|
"""Tests for orchestration labeler."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ria_toolkit_oss.datatypes.recording import Recording
|
||||||
|
from ria_toolkit_oss.orchestration.campaign import CaptureStep
|
||||||
|
from ria_toolkit_oss.orchestration.labeler import build_output_filename, label_recording
|
||||||
|
|
||||||
|
|
||||||
|
def _simple_recording() -> Recording:
|
||||||
|
sr = 1e6
|
||||||
|
n = 1000
|
||||||
|
data = np.ones(n, dtype=np.complex64)
|
||||||
|
return Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9})
|
||||||
|
|
||||||
|
|
||||||
|
def _wifi_step() -> CaptureStep:
|
||||||
|
return CaptureStep(
|
||||||
|
duration=30.0,
|
||||||
|
label="ch06_20mhz_idle",
|
||||||
|
channel=6,
|
||||||
|
bandwidth_mhz=20.0,
|
||||||
|
traffic="idle",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _bt_step() -> CaptureStep:
|
||||||
|
return CaptureStep(
|
||||||
|
duration=30.0,
|
||||||
|
label="audio_stream",
|
||||||
|
traffic="audio_stream",
|
||||||
|
connection_interval_ms=7.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# label_recording
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLabelRecording:
|
||||||
|
def test_device_id_set(self):
|
||||||
|
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||||
|
assert rec.metadata["device_id"] == "iphone13_001"
|
||||||
|
|
||||||
|
def test_capture_timestamp_set(self):
|
||||||
|
ts = time.time()
|
||||||
|
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), ts)
|
||||||
|
assert rec.metadata["capture_timestamp"] == pytest.approx(ts, abs=1.0)
|
||||||
|
|
||||||
|
def test_step_label_set(self):
|
||||||
|
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||||
|
assert rec.metadata["step_label"] == "ch06_20mhz_idle"
|
||||||
|
|
||||||
|
def test_step_duration_set(self):
|
||||||
|
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||||
|
assert rec.metadata["step_duration_s"] == pytest.approx(30.0)
|
||||||
|
|
||||||
|
def test_campaign_name_optional(self):
|
||||||
|
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||||
|
assert "campaign" not in rec.metadata
|
||||||
|
|
||||||
|
def test_campaign_name_when_provided(self):
|
||||||
|
rec = label_recording(
|
||||||
|
_simple_recording(), "iphone13_001", _wifi_step(), time.time(), campaign_name="test_campaign"
|
||||||
|
)
|
||||||
|
assert rec.metadata["campaign"] == "test_campaign"
|
||||||
|
|
||||||
|
def test_wifi_channel_set(self):
|
||||||
|
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||||
|
assert rec.metadata["wifi_channel"] == 6
|
||||||
|
|
||||||
|
def test_wifi_bandwidth_set(self):
|
||||||
|
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||||
|
assert rec.metadata["wifi_bandwidth_mhz"] == pytest.approx(20.0)
|
||||||
|
|
||||||
|
def test_traffic_pattern_set(self):
|
||||||
|
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||||
|
assert rec.metadata["traffic_pattern"] == "idle"
|
||||||
|
|
||||||
|
def test_bt_connection_interval_set(self):
|
||||||
|
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
|
||||||
|
assert rec.metadata["bt_connection_interval_ms"] == pytest.approx(7.5)
|
||||||
|
|
||||||
|
def test_no_channel_key_for_bt(self):
|
||||||
|
"""BT steps with no channel should not add wifi_channel to metadata."""
|
||||||
|
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
|
||||||
|
assert "wifi_channel" not in rec.metadata
|
||||||
|
|
||||||
|
def test_no_bandwidth_key_for_bt(self):
|
||||||
|
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
|
||||||
|
assert "wifi_bandwidth_mhz" not in rec.metadata
|
||||||
|
|
||||||
|
def test_power_dbm_set(self):
|
||||||
|
step = CaptureStep(duration=30.0, label="test", traffic="idle", power_dbm=15.0)
|
||||||
|
rec = label_recording(_simple_recording(), "dev_001", step, time.time())
|
||||||
|
assert rec.metadata["tx_power_dbm"] == pytest.approx(15.0)
|
||||||
|
|
||||||
|
def test_no_power_key_when_unset(self):
|
||||||
|
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||||
|
assert "tx_power_dbm" not in rec.metadata
|
||||||
|
|
||||||
|
def test_returns_same_recording(self):
|
||||||
|
"""label_recording should mutate and return the same Recording object."""
|
||||||
|
rec = _simple_recording()
|
||||||
|
result = label_recording(rec, "iphone13_001", _wifi_step(), time.time())
|
||||||
|
assert result is rec
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_output_filename
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildOutputFilename:
|
||||||
|
def test_basic_wifi(self):
|
||||||
|
step = CaptureStep(duration=30.0, label="ch06_20mhz_idle")
|
||||||
|
fn = build_output_filename("iphone13_wifi_001", step)
|
||||||
|
assert fn == "iphone13_wifi_001/ch06_20mhz_idle"
|
||||||
|
|
||||||
|
def test_bt_step(self):
|
||||||
|
step = CaptureStep(duration=30.0, label="audio_stream")
|
||||||
|
fn = build_output_filename("airpods_pro_bt_001", step)
|
||||||
|
assert fn == "airpods_pro_bt_001/audio_stream"
|
||||||
|
|
||||||
|
def test_spaces_in_device_id_replaced(self):
|
||||||
|
step = CaptureStep(duration=30.0, label="idle")
|
||||||
|
fn = build_output_filename("my device", step)
|
||||||
|
assert " " not in fn
|
||||||
|
assert fn == "my_device/idle"
|
||||||
|
|
||||||
|
def test_slashes_in_label_replaced(self):
|
||||||
|
step = CaptureStep(duration=30.0, label="ch/6/idle")
|
||||||
|
fn = build_output_filename("dev_001", step)
|
||||||
|
assert "/" not in fn.split("/", 1)[1] # only the separator slash should remain
|
||||||
|
|
||||||
|
def test_path_structure(self):
|
||||||
|
"""Filename should be exactly '<device_id>/<label>' (one level of nesting)."""
|
||||||
|
step = CaptureStep(duration=30.0, label="idle")
|
||||||
|
fn = build_output_filename("dev_001", step)
|
||||||
|
parts = fn.split("/")
|
||||||
|
assert len(parts) == 2
|
||||||
192
tests/orchestration/test_qa.py
Normal file
192
tests/orchestration/test_qa.py
Normal file
|
|
@ -0,0 +1,192 @@
|
||||||
|
"""Tests for orchestration QA metrics."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ria_toolkit_oss.datatypes.recording import Recording
|
||||||
|
from ria_toolkit_oss.orchestration.campaign import QAConfig
|
||||||
|
from ria_toolkit_oss.orchestration.qa import QAResult, check_recording, estimate_snr_db
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_recording(n_samples: int, sample_rate: float, signal: np.ndarray) -> Recording:
|
||||||
|
return Recording(
|
||||||
|
signal.astype(np.complex64),
|
||||||
|
metadata={"sample_rate": sample_rate, "center_frequency": 2.45e9},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tone(n: int, sr: float, freq_hz: float = 100e3, amplitude: float = 0.5) -> np.ndarray:
|
||||||
|
t = np.arange(n) / sr
|
||||||
|
return (np.exp(1j * 2 * np.pi * freq_hz * t) * amplitude).astype(np.complex64)
|
||||||
|
|
||||||
|
|
||||||
|
def _noise(n: int, amplitude: float = 0.001) -> np.ndarray:
|
||||||
|
rng = np.random.default_rng(42)
|
||||||
|
return ((rng.standard_normal(n) + 1j * rng.standard_normal(n)) * amplitude).astype(np.complex64)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_QA = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# estimate_snr_db
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEstimateSnrDb:
|
||||||
|
def test_high_snr_tone(self):
|
||||||
|
sr = 1e6
|
||||||
|
samples = _tone(int(sr * 1), sr)
|
||||||
|
snr = estimate_snr_db(samples)
|
||||||
|
assert snr > 20.0, f"Expected high SNR for clean tone, got {snr:.1f} dB"
|
||||||
|
|
||||||
|
def test_pure_noise_low_snr(self):
|
||||||
|
sr = 1e6
|
||||||
|
rng = np.random.default_rng(0)
|
||||||
|
samples = (rng.standard_normal(int(sr)) + 1j * rng.standard_normal(int(sr))).astype(np.complex64)
|
||||||
|
snr = estimate_snr_db(samples)
|
||||||
|
# Pure noise should yield a low (possibly negative) SNR
|
||||||
|
assert snr < 15.0, f"Expected low SNR for noise, got {snr:.1f} dB"
|
||||||
|
|
||||||
|
def test_snr_increases_with_amplitude(self):
|
||||||
|
sr = 1e6
|
||||||
|
n = int(sr)
|
||||||
|
rng = np.random.default_rng(1)
|
||||||
|
noise = (rng.standard_normal(n) + 1j * rng.standard_normal(n)).astype(np.complex64) * 0.01
|
||||||
|
t = np.arange(n) / sr
|
||||||
|
tone = np.exp(1j * 2 * np.pi * 100e3 * t).astype(np.complex64)
|
||||||
|
|
||||||
|
low_snr = estimate_snr_db(noise + tone * 0.1)
|
||||||
|
high_snr = estimate_snr_db(noise + tone * 1.0)
|
||||||
|
assert high_snr > low_snr
|
||||||
|
|
||||||
|
def test_short_input_still_works(self):
|
||||||
|
# Input shorter than n_fft=4096 should not raise
|
||||||
|
samples = _tone(512, 1e6)
|
||||||
|
snr = estimate_snr_db(samples)
|
||||||
|
assert np.isfinite(snr)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# check_recording — pass cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckRecordingPass:
|
||||||
|
def test_clean_tone_passes(self):
|
||||||
|
sr = 1e6
|
||||||
|
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
|
||||||
|
result = check_recording(rec, DEFAULT_QA)
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.flagged is False
|
||||||
|
assert result.snr_db > 10.0
|
||||||
|
assert abs(result.duration_s - 30.0) < 0.1
|
||||||
|
|
||||||
|
def test_duration_exactly_at_threshold(self):
|
||||||
|
sr = 1e6
|
||||||
|
n = int(sr * 25) # exactly at min_duration_s
|
||||||
|
rec = _make_recording(n, sr, _tone(n, sr))
|
||||||
|
result = check_recording(rec, DEFAULT_QA)
|
||||||
|
assert result.flagged is False
|
||||||
|
|
||||||
|
def test_issues_empty_when_passing(self):
|
||||||
|
sr = 1e6
|
||||||
|
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
|
||||||
|
result = check_recording(rec, DEFAULT_QA)
|
||||||
|
assert result.issues == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# check_recording — flag cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckRecordingFlag:
|
||||||
|
def test_short_recording_flagged(self):
|
||||||
|
sr = 1e6
|
||||||
|
n = int(sr * 10) # shorter than 25s min
|
||||||
|
rec = _make_recording(n, sr, _tone(n, sr))
|
||||||
|
result = check_recording(rec, DEFAULT_QA)
|
||||||
|
assert result.flagged is True
|
||||||
|
assert any("Duration" in issue for issue in result.issues)
|
||||||
|
|
||||||
|
def test_low_snr_flagged(self):
|
||||||
|
sr = 1e6
|
||||||
|
n = int(sr * 30)
|
||||||
|
rec = _make_recording(n, sr, _noise(n, amplitude=0.001))
|
||||||
|
result = check_recording(rec, DEFAULT_QA)
|
||||||
|
assert result.flagged is True
|
||||||
|
assert any("SNR" in issue for issue in result.issues)
|
||||||
|
|
||||||
|
def test_flag_for_review_still_passes(self):
|
||||||
|
"""With flag_for_review=True, flagged recordings are still marked passed."""
|
||||||
|
sr = 1e6
|
||||||
|
n = int(sr * 10) # short → will be flagged
|
||||||
|
rec = _make_recording(n, sr, _tone(n, sr))
|
||||||
|
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
|
||||||
|
result = check_recording(rec, qa)
|
||||||
|
assert result.flagged is True
|
||||||
|
assert result.passed is True # human review, not auto-reject
|
||||||
|
|
||||||
|
def test_flag_for_review_false_fails(self):
|
||||||
|
"""With flag_for_review=False, a flagged recording is also marked failed."""
|
||||||
|
sr = 1e6
|
||||||
|
n = int(sr * 10)
|
||||||
|
rec = _make_recording(n, sr, _tone(n, sr))
|
||||||
|
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=False)
|
||||||
|
result = check_recording(rec, qa)
|
||||||
|
assert result.flagged is True
|
||||||
|
assert result.passed is False
|
||||||
|
|
||||||
|
def test_multiple_issues_reported(self):
|
||||||
|
"""Both short duration AND low SNR should both appear in issues list."""
|
||||||
|
sr = 1e6
|
||||||
|
n = int(sr * 5) # very short
|
||||||
|
rec = _make_recording(n, sr, _noise(n, amplitude=0.0001))
|
||||||
|
result = check_recording(rec, DEFAULT_QA)
|
||||||
|
assert result.flagged is True
|
||||||
|
assert len(result.issues) >= 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# check_recording — multichannel input
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckRecordingMultichannel:
|
||||||
|
def test_multichannel_recording(self):
|
||||||
|
"""2-channel recording should evaluate channel 0 without error."""
|
||||||
|
sr = 1e6
|
||||||
|
n = int(sr * 30)
|
||||||
|
ch0 = _tone(n, sr)
|
||||||
|
ch1 = _tone(n, sr, freq_hz=200e3)
|
||||||
|
data = np.stack([ch0, ch1]) # shape (2, N)
|
||||||
|
rec = Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9})
|
||||||
|
result = check_recording(rec, DEFAULT_QA)
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.flagged is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# QAResult.to_dict
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestQAResultToDict:
|
||||||
|
def test_to_dict_keys(self):
|
||||||
|
r = QAResult(passed=True, flagged=False, snr_db=18.3, duration_s=30.0)
|
||||||
|
d = r.to_dict()
|
||||||
|
assert set(d.keys()) == {"passed", "flagged", "snr_db", "duration_s", "issues"}
|
||||||
|
|
||||||
|
def test_to_dict_values(self):
|
||||||
|
r = QAResult(passed=False, flagged=True, snr_db=7.5, duration_s=10.2, issues=["SNR below threshold"])
|
||||||
|
d = r.to_dict()
|
||||||
|
assert d["passed"] is False
|
||||||
|
assert d["flagged"] is True
|
||||||
|
assert d["snr_db"] == pytest.approx(7.5, abs=0.01)
|
||||||
|
assert d["duration_s"] == pytest.approx(10.2, abs=0.01)
|
||||||
|
assert d["issues"] == ["SNR below threshold"]
|
||||||
Loading…
Reference in New Issue
Block a user