reformats and campaign additions
This commit is contained in:
parent
b1e3ebf74f
commit
019b0c6f4b
2143
poetry.lock
generated
2143
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
|
|
@ -127,5 +127,8 @@ exclude = '''
|
|||
)/
|
||||
'''
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["src"]
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
|
|
|||
|
|
@ -956,8 +956,10 @@ def get_result_sizes( # noqa: C901 # TODO: Simplify function
|
|||
# Check that each class that will be augmented does not already suffice target_size
|
||||
for cls_name, target_size_value in zip(classes_to_augment, target_size):
|
||||
if class_sizes[cls_name] >= target_size_value:
|
||||
raise ValueError(f"""target_size of {target_size_value} is already sufficed for current size of
|
||||
{class_sizes[cls_name]} for class: {cls_name}""")
|
||||
raise ValueError(
|
||||
f"""target_size of {target_size_value} is already sufficed for current size of
|
||||
{class_sizes[cls_name]} for class: {cls_name}"""
|
||||
)
|
||||
|
||||
for index, class_name in enumerate(classes_to_augment):
|
||||
result_sizes[class_name] = target_size[index]
|
||||
|
|
|
|||
|
|
@ -316,6 +316,8 @@ def to_sigmf(
|
|||
meta_dict = sigMF_metafile.ordered_metadata()
|
||||
meta_dict["ria"] = metadata
|
||||
|
||||
if overwrite and os.path.isfile(meta_file_path):
|
||||
os.remove(meta_file_path)
|
||||
sigMF_metafile.tofile(meta_file_path)
|
||||
|
||||
|
||||
|
|
|
|||
26
src/ria_toolkit_oss/orchestration/__init__.py
Normal file
26
src/ria_toolkit_oss/orchestration/__init__.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
"""Orchestration layer for automated RF capture campaigns."""
|
||||
|
||||
from .campaign import (
|
||||
CampaignConfig,
|
||||
CaptureStep,
|
||||
QAConfig,
|
||||
RecorderConfig,
|
||||
TransmitterConfig,
|
||||
)
|
||||
from .executor import CampaignExecutor, CampaignResult, StepResult
|
||||
from .labeler import label_recording
|
||||
from .qa import QAResult, check_recording
|
||||
|
||||
__all__ = [
|
||||
"CampaignConfig",
|
||||
"CaptureStep",
|
||||
"QAConfig",
|
||||
"RecorderConfig",
|
||||
"TransmitterConfig",
|
||||
"CampaignExecutor",
|
||||
"CampaignResult",
|
||||
"StepResult",
|
||||
"label_recording",
|
||||
"QAResult",
|
||||
"check_recording",
|
||||
]
|
||||
446
src/ria_toolkit_oss/orchestration/campaign.py
Normal file
446
src/ria_toolkit_oss/orchestration/campaign.py
Normal file
|
|
@ -0,0 +1,446 @@
|
|||
"""Campaign configuration schema and YAML parser for orchestrated RF captures."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parsing helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_duration(value: str | float | int) -> float:
|
||||
"""Parse a duration string to seconds.
|
||||
|
||||
Accepts:
|
||||
"30s" → 30.0
|
||||
"1.5m" or "1.5min" → 90.0
|
||||
"2h" → 7200.0
|
||||
30 (numeric) → 30.0
|
||||
"""
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
value = str(value).strip()
|
||||
match = re.fullmatch(r"([\d.]+)\s*(s|sec|m|min|h|hr)?", value, re.IGNORECASE)
|
||||
if not match:
|
||||
raise ValueError(f"Cannot parse duration: '{value}'")
|
||||
amount = float(match.group(1))
|
||||
unit = (match.group(2) or "s").lower()
|
||||
if unit in ("h", "hr"):
|
||||
return amount * 3600
|
||||
if unit in ("m", "min"):
|
||||
return amount * 60
|
||||
return amount
|
||||
|
||||
|
||||
def parse_frequency(value: str | float | int) -> float:
|
||||
"""Parse a frequency string to Hz.
|
||||
|
||||
Accepts:
|
||||
"2.45GHz" → 2_450_000_000.0
|
||||
"40MHz" → 40_000_000.0
|
||||
"915e6" → 915_000_000.0
|
||||
2.45e9 (numeric) → 2_450_000_000.0
|
||||
"""
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
value = str(value).strip()
|
||||
|
||||
# Try bare numeric first (handles scientific notation like "915e6")
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Handle suffix notation: "2.45GHz", "40MHz", "40M", "433k"
|
||||
match = re.fullmatch(r"([\d.]+)\s*(k|M|G)(?:\s*Hz?)?", value, re.IGNORECASE)
|
||||
if match:
|
||||
amount = float(match.group(1))
|
||||
suffix = match.group(2).upper()
|
||||
return amount * {"K": 1e3, "M": 1e6, "G": 1e9}[suffix]
|
||||
|
||||
raise ValueError(f"Cannot parse frequency: '{value}'")
|
||||
|
||||
|
||||
def parse_gain(value: str | float | int) -> float | str:
|
||||
"""Parse a gain string.
|
||||
|
||||
Accepts:
|
||||
"40dB" or "40 dB" → 40.0
|
||||
"auto" → "auto"
|
||||
40 (numeric) → 40.0
|
||||
"""
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
value = str(value).strip()
|
||||
if value.lower() == "auto":
|
||||
return "auto"
|
||||
match = re.fullmatch(r"([\d.+\-]+)\s*dB?", value, re.IGNORECASE)
|
||||
if not match:
|
||||
raise ValueError(f"Cannot parse gain: '{value}'")
|
||||
return float(match.group(1))
|
||||
|
||||
|
||||
def parse_bandwidth_mhz(value: str | float | int | None) -> Optional[float]:
|
||||
"""Parse a bandwidth string to MHz.
|
||||
|
||||
Accepts:
|
||||
"20MHz" → 20.0
|
||||
"40MHz" → 40.0
|
||||
20 (numeric, assumed MHz) → 20.0
|
||||
None → None
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
value = str(value).strip()
|
||||
match = re.fullmatch(r"([\d.]+)\s*MHz?", value, re.IGNORECASE)
|
||||
if match:
|
||||
return float(match.group(1))
|
||||
match = re.fullmatch(r"([\d.]+)", value)
|
||||
if match:
|
||||
return float(match.group(1))
|
||||
raise ValueError(f"Cannot parse bandwidth: '{value}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecorderConfig:
|
||||
"""SDR recorder configuration."""
|
||||
|
||||
device: str
|
||||
center_freq: float # Hz
|
||||
sample_rate: float # Hz
|
||||
gain: float | str # dB float, or "auto"
|
||||
bandwidth: Optional[float] = None # Hz, None = match sample_rate
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "RecorderConfig":
|
||||
gain = parse_gain(d.get("gain", "auto"))
|
||||
bandwidth_raw = d.get("bandwidth") or d.get("bandwidth_hz")
|
||||
bandwidth = parse_frequency(bandwidth_raw) if bandwidth_raw else None
|
||||
return cls(
|
||||
device=str(d["device"]),
|
||||
center_freq=parse_frequency(d["center_freq"]),
|
||||
sample_rate=parse_frequency(d["sample_rate"]),
|
||||
gain=gain,
|
||||
bandwidth=bandwidth,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaptureStep:
|
||||
"""A single timed capture within a transmitter schedule."""
|
||||
|
||||
duration: float # seconds
|
||||
label: str # used as filename component
|
||||
|
||||
# WiFi-specific
|
||||
channel: Optional[int] = None
|
||||
bandwidth_mhz: Optional[float] = None # MHz
|
||||
traffic: Optional[str] = None
|
||||
|
||||
# Bluetooth-specific
|
||||
connection_interval_ms: Optional[float] = None
|
||||
|
||||
# Power (dBm), optional
|
||||
power_dbm: Optional[float] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict, auto_label: bool = True) -> "CaptureStep":
|
||||
duration = parse_duration(d["duration"])
|
||||
label = d.get("label", "")
|
||||
if not label and auto_label:
|
||||
parts = []
|
||||
if d.get("channel"):
|
||||
parts.append(f"ch{d['channel']:02d}")
|
||||
if d.get("bandwidth"):
|
||||
bw = parse_bandwidth_mhz(d["bandwidth"])
|
||||
parts.append(f"{int(bw)}mhz")
|
||||
if d.get("traffic"):
|
||||
parts.append(str(d["traffic"]).replace(" ", "_"))
|
||||
label = "_".join(parts) if parts else "capture"
|
||||
return cls(
|
||||
duration=duration,
|
||||
label=label,
|
||||
channel=d.get("channel"),
|
||||
bandwidth_mhz=parse_bandwidth_mhz(d.get("bandwidth")),
|
||||
traffic=d.get("traffic"),
|
||||
connection_interval_ms=d.get("connection_interval_ms"),
|
||||
power_dbm=float(d["power"].rstrip("dBm").strip()) if d.get("power") else None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransmitterConfig:
|
||||
"""Configuration for a single transmitter device in the campaign."""
|
||||
|
||||
id: str
|
||||
type: str # "wifi", "bluetooth", "sdr", "external"
|
||||
control_method: str # "external_script" | "sdr"
|
||||
schedule: list[CaptureStep]
|
||||
|
||||
# For external_script control
|
||||
script: Optional[str] = None # path to control script
|
||||
device: Optional[str] = None # e.g. "/dev/wlan0"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "TransmitterConfig":
|
||||
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
|
||||
return cls(
|
||||
id=str(d["id"]),
|
||||
type=str(d["type"]),
|
||||
control_method=str(d.get("control_method", "external_script")),
|
||||
schedule=schedule,
|
||||
script=d.get("script"),
|
||||
device=d.get("device"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QAConfig:
|
||||
"""Quality assurance thresholds."""
|
||||
|
||||
snr_threshold_db: float = 10.0
|
||||
min_duration_s: float = 25.0
|
||||
flag_for_review: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "QAConfig":
|
||||
return cls(
|
||||
snr_threshold_db=float(str(d.get("snr_threshold", "10")).rstrip("dB").strip()),
|
||||
min_duration_s=parse_duration(d.get("min_duration", "25s")),
|
||||
flag_for_review=bool(d.get("flag_for_review", True)),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputConfig:
|
||||
"""Where to save captured recordings."""
|
||||
|
||||
format: str = "sigmf"
|
||||
path: str = "recordings"
|
||||
device_id: Optional[str] = None # for device-profile campaigns
|
||||
repo: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "OutputConfig":
|
||||
return cls(
|
||||
format=str(d.get("format", "sigmf")),
|
||||
path=str(d.get("path", "recordings")),
|
||||
device_id=d.get("device_id"),
|
||||
repo=d.get("repo"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignConfig:
|
||||
"""Full campaign configuration parsed from YAML."""
|
||||
|
||||
name: str
|
||||
recorder: RecorderConfig
|
||||
transmitters: list[TransmitterConfig]
|
||||
qa: QAConfig = field(default_factory=QAConfig)
|
||||
output: OutputConfig = field(default_factory=OutputConfig)
|
||||
mode: str = "controlled_testbed"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Loaders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, raw: dict) -> "CampaignConfig":
|
||||
"""Build a CampaignConfig from a parsed dictionary.
|
||||
|
||||
Accepts the same structure as the campaign YAML, already loaded into
|
||||
a Python dict (e.g. from a JSON HTTP request body).
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing or malformed.
|
||||
KeyError: If ``recorder`` key is absent.
|
||||
"""
|
||||
campaign_meta = raw.get("campaign", {})
|
||||
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
|
||||
if not transmitters:
|
||||
raise ValueError("Campaign config must define at least one transmitter")
|
||||
return cls(
|
||||
name=str(campaign_meta.get("name", "unnamed")),
|
||||
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||
transmitters=transmitters,
|
||||
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||
output=OutputConfig.from_dict(raw.get("output", {})),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> "CampaignConfig":
|
||||
"""Load a full campaign config YAML.
|
||||
|
||||
Expected format::
|
||||
|
||||
campaign:
|
||||
name: "wifi_capture_001"
|
||||
mode: "controlled_testbed"
|
||||
|
||||
transmitters:
|
||||
- id: "laptop_wifi"
|
||||
type: "wifi"
|
||||
control_method: "external_script"
|
||||
script: "./scripts/wifi_control.sh"
|
||||
device: "/dev/wlan0"
|
||||
schedule:
|
||||
- channel: 6
|
||||
bandwidth: "20MHz"
|
||||
traffic: "iperf_udp"
|
||||
duration: "30s"
|
||||
|
||||
recorder:
|
||||
device: "usrp_b210"
|
||||
center_freq: "2.45GHz"
|
||||
sample_rate: "40MHz"
|
||||
gain: "40dB"
|
||||
|
||||
qa:
|
||||
snr_threshold: "10dB"
|
||||
min_duration: "25s"
|
||||
flag_for_review: true
|
||||
|
||||
output:
|
||||
format: "sigmf"
|
||||
path: "./recordings"
|
||||
"""
|
||||
path = Path(path)
|
||||
try:
|
||||
with open(path) as f:
|
||||
raw = yaml.safe_load(f)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Campaign config not found: {path}")
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML in {path}: {e}")
|
||||
|
||||
campaign_meta = raw.get("campaign", {})
|
||||
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
|
||||
if not transmitters:
|
||||
raise ValueError("Campaign config must define at least one transmitter")
|
||||
|
||||
return cls(
|
||||
name=str(campaign_meta.get("name", path.stem)),
|
||||
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||
transmitters=transmitters,
|
||||
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||
output=OutputConfig.from_dict(raw.get("output", {})),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_device_profile(cls, path: str | Path) -> "CampaignConfig":
|
||||
"""Build a campaign config from an App 1 device profile YAML.
|
||||
|
||||
Expected format::
|
||||
|
||||
device:
|
||||
name: "iPhone_13_WiFi"
|
||||
type: "wifi"
|
||||
protocol: "wifi_24ghz"
|
||||
|
||||
capture:
|
||||
channels: [1, 6, 11] # WiFi only
|
||||
bandwidth: "20MHz" # WiFi only
|
||||
traffic_patterns: ["idle", "ping", "iperf_udp"]
|
||||
duration_per_config: "30s"
|
||||
|
||||
recorder:
|
||||
device: "usrp_b210"
|
||||
center_freq: "2.45GHz"
|
||||
sample_rate: "40MHz"
|
||||
gain: "auto"
|
||||
|
||||
output:
|
||||
path: "./recordings"
|
||||
device_id: "iphone13_wifi_001"
|
||||
|
||||
For WiFi devices, schedule is expanded as channels × traffic_patterns.
|
||||
For Bluetooth devices (no channels), schedule is traffic_patterns only.
|
||||
"""
|
||||
path = Path(path)
|
||||
try:
|
||||
with open(path) as f:
|
||||
raw = yaml.safe_load(f)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Device profile not found: {path}")
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML in {path}: {e}")
|
||||
|
||||
device = raw.get("device", {})
|
||||
capture = raw.get("capture", {})
|
||||
device_type = str(device.get("type", "wifi")).lower()
|
||||
device_name = str(device.get("name", path.stem))
|
||||
duration = parse_duration(capture.get("duration_per_config", "30s"))
|
||||
traffic_patterns = capture.get("traffic_patterns", ["idle"])
|
||||
|
||||
# Build capture schedule
|
||||
schedule: list[CaptureStep] = []
|
||||
|
||||
if device_type in ("wifi", "wifi_24ghz", "wifi_5ghz"):
|
||||
channels = capture.get("channels", [6])
|
||||
bw_str = capture.get("bandwidth", "20MHz")
|
||||
bw_mhz = parse_bandwidth_mhz(bw_str)
|
||||
for ch in channels:
|
||||
for traffic in traffic_patterns:
|
||||
label = f"ch{ch:02d}_{int(bw_mhz)}mhz_{traffic}"
|
||||
schedule.append(
|
||||
CaptureStep(
|
||||
duration=duration,
|
||||
label=label,
|
||||
channel=ch,
|
||||
bandwidth_mhz=bw_mhz,
|
||||
traffic=traffic,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Bluetooth / generic — no channels
|
||||
for traffic in traffic_patterns:
|
||||
schedule.append(
|
||||
CaptureStep(
|
||||
duration=duration,
|
||||
label=traffic,
|
||||
traffic=traffic,
|
||||
)
|
||||
)
|
||||
|
||||
device_id = raw.get("output", {}).get("device_id", device_name.lower().replace(" ", "_"))
|
||||
transmitter = TransmitterConfig(
|
||||
id=device_id,
|
||||
type=device_type,
|
||||
control_method=str(capture.get("control_method", "external_script")),
|
||||
schedule=schedule,
|
||||
script=capture.get("script"),
|
||||
device=capture.get("device"),
|
||||
)
|
||||
|
||||
return cls(
|
||||
name=f"enroll_{device_id}",
|
||||
mode="controlled_testbed",
|
||||
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||
transmitters=[transmitter],
|
||||
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||
output=OutputConfig.from_dict(raw.get("output", {})),
|
||||
)
|
||||
|
||||
def total_capture_time_s(self) -> float:
|
||||
"""Sum of all step durations across all transmitters."""
|
||||
return sum(step.duration for tx in self.transmitters for step in tx.schedule)
|
||||
|
||||
def total_steps(self) -> int:
|
||||
"""Total number of capture steps across all transmitters."""
|
||||
return sum(len(tx.schedule) for tx in self.transmitters)
|
||||
423
src/ria_toolkit_oss/orchestration/executor.py
Normal file
423
src/ria_toolkit_oss/orchestration/executor.py
Normal file
|
|
@ -0,0 +1,423 @@
|
|||
"""Campaign executor: runs a capture campaign end-to-end."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
from ria_toolkit_oss.io.recording import to_sigmf
|
||||
|
||||
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
|
||||
from .labeler import build_output_filename, label_recording
|
||||
from .qa import QAResult, check_recording
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Device name aliases: campaign YAML names → get_sdr_device() names
|
||||
_DEVICE_ALIASES = {
|
||||
"usrp_b210": "usrp",
|
||||
"usrp_b200": "usrp",
|
||||
"usrp": "usrp",
|
||||
"plutosdr": "pluto",
|
||||
"pluto": "pluto",
|
||||
"hackrf": "hackrf",
|
||||
"hackrf_one": "hackrf",
|
||||
"bladerf": "bladerf",
|
||||
"rtlsdr": "rtlsdr",
|
||||
"rtl_sdr": "rtlsdr",
|
||||
"thinkrf": "thinkrf",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepResult:
|
||||
"""Outcome of a single capture step."""
|
||||
|
||||
transmitter_id: str
|
||||
step_label: str
|
||||
output_path: Optional[str]
|
||||
qa: QAResult
|
||||
capture_timestamp: float
|
||||
error: Optional[str] = None
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return self.error is None and self.qa.passed
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"transmitter_id": self.transmitter_id,
|
||||
"step_label": self.step_label,
|
||||
"output_path": self.output_path,
|
||||
"capture_timestamp": self.capture_timestamp,
|
||||
"qa": self.qa.to_dict(),
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignResult:
|
||||
"""Aggregate outcome of a full campaign."""
|
||||
|
||||
campaign_name: str
|
||||
steps: list[StepResult] = field(default_factory=list)
|
||||
start_time: float = field(default_factory=time.time)
|
||||
end_time: Optional[float] = None
|
||||
|
||||
@property
|
||||
def total_steps(self) -> int:
|
||||
return len(self.steps)
|
||||
|
||||
@property
|
||||
def passed(self) -> int:
|
||||
return sum(1 for s in self.steps if s.ok)
|
||||
|
||||
@property
|
||||
def flagged(self) -> int:
|
||||
return sum(1 for s in self.steps if not s.error and s.qa.flagged)
|
||||
|
||||
@property
|
||||
def failed(self) -> int:
|
||||
return sum(1 for s in self.steps if s.error or not s.qa.passed)
|
||||
|
||||
@property
|
||||
def duration_s(self) -> float:
|
||||
if self.end_time:
|
||||
return self.end_time - self.start_time
|
||||
return time.time() - self.start_time
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"campaign_name": self.campaign_name,
|
||||
"total_steps": self.total_steps,
|
||||
"passed": self.passed,
|
||||
"flagged": self.flagged,
|
||||
"failed": self.failed,
|
||||
"duration_s": round(self.duration_s, 1),
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
}
|
||||
|
||||
def write_report(self, path: str | Path) -> None:
|
||||
"""Write a JSON QA report to disk."""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
json.dump(self.to_dict(), f, indent=2)
|
||||
logger.info(f"QA report written to {path}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# External script interface
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
|
||||
"""Run an external control script and return stdout.
|
||||
|
||||
The script is called as::
|
||||
|
||||
<script> <arg1> <arg2> ...
|
||||
|
||||
A non-zero return code raises RuntimeError.
|
||||
|
||||
Args:
|
||||
script: Path to executable script.
|
||||
*args: Positional arguments forwarded to the script.
|
||||
timeout: Maximum seconds to wait.
|
||||
|
||||
Returns:
|
||||
Script stdout as a string.
|
||||
"""
|
||||
cmd = [script, *args]
|
||||
logger.debug(f"Running script: {' '.join(cmd)}")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"Script timed out after {timeout}s: {script}")
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(f"Script not found: {script}")
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Script exited {result.returncode}: {result.stderr.strip() or result.stdout.strip()}")
|
||||
return result.stdout.strip()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Campaign executor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CampaignExecutor:
|
||||
"""Executes a :class:`CampaignConfig` end-to-end.
|
||||
|
||||
Initialises the SDR recorder once, then for each (transmitter, step):
|
||||
1. Configures the transmitter (via external script or SDR TX)
|
||||
2. Records IQ samples
|
||||
3. Labels the recording with device/config metadata
|
||||
4. Runs QA checks
|
||||
5. Saves the recording to disk
|
||||
6. Stops/resets the transmitter
|
||||
|
||||
Args:
|
||||
config: Parsed campaign configuration.
|
||||
progress_cb: Optional callback ``(step_index, total_steps, step_result)``
|
||||
called after each step completes. Useful for status reporting.
|
||||
verbose: Enable debug logging.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CampaignConfig,
|
||||
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
self.config = config
|
||||
self.progress_cb = progress_cb
|
||||
self._sdr = None
|
||||
|
||||
if verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run(self) -> CampaignResult:
|
||||
"""Execute the full campaign and return a :class:`CampaignResult`.
|
||||
|
||||
Initialises the SDR, runs all steps across all transmitters,
|
||||
then closes the SDR. If SDR initialisation fails the exception
|
||||
propagates immediately (nothing is captured).
|
||||
"""
|
||||
result = CampaignResult(campaign_name=self.config.name)
|
||||
|
||||
logger.info(
|
||||
f"Starting campaign '{self.config.name}': "
|
||||
f"{self.config.total_steps()} steps, "
|
||||
f"~{self.config.total_capture_time_s():.0f}s capture time"
|
||||
)
|
||||
|
||||
self._init_sdr()
|
||||
try:
|
||||
total = self.config.total_steps()
|
||||
step_index = 0
|
||||
|
||||
for transmitter in self.config.transmitters:
|
||||
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
|
||||
for step in transmitter.schedule:
|
||||
step_result = self._execute_step(transmitter, step)
|
||||
result.steps.append(step_result)
|
||||
step_index += 1
|
||||
|
||||
if self.progress_cb:
|
||||
self.progress_cb(step_index, total, step_result)
|
||||
|
||||
if step_result.error:
|
||||
logger.warning(f"Step '{step.label}' error: {step_result.error}")
|
||||
elif step_result.qa.flagged:
|
||||
logger.warning(f"Step '{step.label}' flagged for review: " + "; ".join(step_result.qa.issues))
|
||||
else:
|
||||
logger.info(
|
||||
f"Step '{step.label}' OK "
|
||||
f"(SNR {step_result.qa.snr_db:.1f} dB, "
|
||||
f"{step_result.qa.duration_s:.1f}s)"
|
||||
)
|
||||
finally:
|
||||
self._close_sdr()
|
||||
|
||||
result.end_time = time.time()
|
||||
logger.info(
|
||||
f"Campaign complete: {result.passed}/{result.total_steps} passed, "
|
||||
f"{result.flagged} flagged, {result.failed} failed"
|
||||
)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SDR management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_sdr(self) -> None:
|
||||
"""Initialise and configure the SDR recorder."""
|
||||
from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
|
||||
|
||||
rec = self.config.recorder
|
||||
device_name = _DEVICE_ALIASES.get(rec.device.lower(), rec.device.lower())
|
||||
logger.info(f"Initialising SDR: {device_name} @ {rec.center_freq/1e6:.2f} MHz")
|
||||
|
||||
self._sdr = get_sdr_device(device_name)
|
||||
gain = None if rec.gain == "auto" else float(rec.gain)
|
||||
self._sdr.init_rx(
|
||||
sample_rate=rec.sample_rate,
|
||||
center_frequency=rec.center_freq,
|
||||
gain=gain,
|
||||
channel=0,
|
||||
)
|
||||
if rec.bandwidth and hasattr(self._sdr, "set_rx_bandwidth"):
|
||||
self._sdr.set_rx_bandwidth(rec.bandwidth)
|
||||
|
||||
def _close_sdr(self) -> None:
|
||||
if self._sdr is not None:
|
||||
try:
|
||||
self._sdr.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"SDR close error: {e}")
|
||||
self._sdr = None
|
||||
|
||||
def _record(self, duration_s: float) -> Recording:
|
||||
"""Capture ``duration_s`` seconds of IQ samples."""
|
||||
num_samples = int(duration_s * self.config.recorder.sample_rate)
|
||||
return self._sdr.record(num_samples=num_samples)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _execute_step(self, transmitter: TransmitterConfig, step: CaptureStep) -> StepResult:
|
||||
"""Run a single capture step.
|
||||
|
||||
Returns:
|
||||
StepResult with QA outcome and output path (or error string).
|
||||
"""
|
||||
capture_timestamp = time.time()
|
||||
output_path: Optional[str] = None
|
||||
|
||||
try:
|
||||
self._start_transmitter(transmitter, step)
|
||||
recording = self._record(step.duration)
|
||||
self._stop_transmitter(transmitter, step)
|
||||
except Exception as e:
|
||||
# Best-effort stop on error
|
||||
try:
|
||||
self._stop_transmitter(transmitter, step)
|
||||
except Exception:
|
||||
pass
|
||||
return StepResult(
|
||||
transmitter_id=transmitter.id,
|
||||
step_label=step.label,
|
||||
output_path=None,
|
||||
qa=QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=[f"Capture error: {e}"]),
|
||||
capture_timestamp=capture_timestamp,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Label recording
|
||||
recording = label_recording(
|
||||
recording=recording,
|
||||
device_id=transmitter.id,
|
||||
step=step,
|
||||
capture_timestamp=capture_timestamp,
|
||||
campaign_name=self.config.name,
|
||||
)
|
||||
|
||||
# QA
|
||||
qa_result = check_recording(recording, self.config.qa)
|
||||
|
||||
# Save
|
||||
try:
|
||||
output_path = self._save(recording, transmitter.id, step)
|
||||
except Exception as e:
|
||||
return StepResult(
|
||||
transmitter_id=transmitter.id,
|
||||
step_label=step.label,
|
||||
output_path=None,
|
||||
qa=qa_result,
|
||||
capture_timestamp=capture_timestamp,
|
||||
error=f"Save failed: {e}",
|
||||
)
|
||||
|
||||
return StepResult(
|
||||
transmitter_id=transmitter.id,
|
||||
step_label=step.label,
|
||||
output_path=output_path,
|
||||
qa=qa_result,
|
||||
capture_timestamp=capture_timestamp,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Transmitter control (external script interface)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _start_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
|
||||
"""Configure the transmitter for this step.
|
||||
|
||||
For ``external_script`` control method the script is called as::
|
||||
|
||||
<script> configure <step_params_json>
|
||||
|
||||
where ``step_params_json`` is a JSON object with channel, bandwidth,
|
||||
traffic, etc. The script is responsible for applying the configuration
|
||||
and returning promptly (i.e. not blocking for the capture duration).
|
||||
|
||||
For SDR transmitters this is a no-op placeholder (TX not yet implemented).
|
||||
"""
|
||||
if transmitter.control_method == "external_script":
|
||||
if not transmitter.script:
|
||||
logger.debug(f"No script configured for {transmitter.id}, skipping configure")
|
||||
return
|
||||
params = self._step_params_json(transmitter, step)
|
||||
_run_script(transmitter.script, "configure", params)
|
||||
|
||||
elif transmitter.control_method == "sdr":
|
||||
logger.debug("SDR TX not yet implemented — skipping start")
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
|
||||
|
||||
def _stop_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
|
||||
"""Signal the transmitter to stop.
|
||||
|
||||
Calls ``<script> stop`` for external_script transmitters.
|
||||
"""
|
||||
if transmitter.control_method == "external_script":
|
||||
if not transmitter.script:
|
||||
return
|
||||
try:
|
||||
_run_script(transmitter.script, "stop")
|
||||
except Exception as e:
|
||||
logger.warning(f"Script stop failed for {transmitter.id}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
|
||||
"""Serialise step parameters to a JSON string for the control script."""
|
||||
params: dict = {"device": transmitter.device or ""}
|
||||
if step.channel is not None:
|
||||
params["channel"] = step.channel
|
||||
if step.bandwidth_mhz is not None:
|
||||
params["bandwidth_mhz"] = step.bandwidth_mhz
|
||||
if step.traffic is not None:
|
||||
params["traffic"] = step.traffic
|
||||
if step.power_dbm is not None:
|
||||
params["power_dbm"] = step.power_dbm
|
||||
return json.dumps(params)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Output
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _save(self, recording: Recording, device_id: str, step: CaptureStep) -> str:
|
||||
"""Save a recording to disk and return the data file path."""
|
||||
out = self.config.output
|
||||
rel_filename = build_output_filename(device_id, step)
|
||||
out_dir = Path(out.path)
|
||||
# build_output_filename returns "<device_id>/<label>"
|
||||
# to_sigmf needs filename (base) and path (dir) separately
|
||||
parts = Path(rel_filename)
|
||||
subdir = out_dir / parts.parent
|
||||
subdir.mkdir(parents=True, exist_ok=True)
|
||||
base = parts.name
|
||||
|
||||
to_sigmf(recording, filename=base, path=str(subdir), overwrite=True)
|
||||
return str(subdir / f"{base}.sigmf-data")
|
||||
77
src/ria_toolkit_oss/orchestration/labeler.py
Normal file
77
src/ria_toolkit_oss/orchestration/labeler.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""Timestamp-based labeling for captured recordings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
|
||||
from .campaign import CaptureStep
|
||||
|
||||
|
||||
def label_recording(
|
||||
recording: Recording,
|
||||
device_id: str,
|
||||
step: CaptureStep,
|
||||
capture_timestamp: float,
|
||||
campaign_name: Optional[str] = None,
|
||||
) -> Recording:
|
||||
"""Apply device identity and capture configuration labels to a recording's metadata.
|
||||
|
||||
Labels are stored in the ``ria:*`` namespace when the recording is saved
|
||||
as SigMF, via the existing ``update_metadata`` mechanism.
|
||||
|
||||
Args:
|
||||
recording: The recording to label.
|
||||
device_id: Identifier for the transmitting device (e.g. "iphone13_wifi_001").
|
||||
step: The capture step that was active during this recording.
|
||||
capture_timestamp: Unix timestamp (float) of when capture started.
|
||||
campaign_name: Optional campaign name for cross-recording reference.
|
||||
|
||||
Returns:
|
||||
The same recording with updated metadata.
|
||||
"""
|
||||
recording.update_metadata("device_id", device_id)
|
||||
recording.update_metadata("capture_timestamp", capture_timestamp)
|
||||
recording.update_metadata("step_label", step.label)
|
||||
recording.update_metadata("step_duration_s", step.duration)
|
||||
|
||||
if campaign_name:
|
||||
recording.update_metadata("campaign", campaign_name)
|
||||
|
||||
# WiFi-specific labels
|
||||
if step.channel is not None:
|
||||
recording.update_metadata("wifi_channel", step.channel)
|
||||
if step.bandwidth_mhz is not None:
|
||||
recording.update_metadata("wifi_bandwidth_mhz", step.bandwidth_mhz)
|
||||
|
||||
# Bluetooth-specific labels
|
||||
if step.connection_interval_ms is not None:
|
||||
recording.update_metadata("bt_connection_interval_ms", step.connection_interval_ms)
|
||||
|
||||
# Traffic pattern (WiFi + BT)
|
||||
if step.traffic is not None:
|
||||
recording.update_metadata("traffic_pattern", step.traffic)
|
||||
|
||||
# TX power
|
||||
if step.power_dbm is not None:
|
||||
recording.update_metadata("tx_power_dbm", step.power_dbm)
|
||||
|
||||
return recording
|
||||
|
||||
|
||||
def build_output_filename(device_id: str, step: CaptureStep) -> str:
|
||||
"""Generate a deterministic filename for a labeled recording.
|
||||
|
||||
Format: ``<device_id>/<step_label>``
|
||||
|
||||
Args:
|
||||
device_id: Device identifier string.
|
||||
step: Capture step.
|
||||
|
||||
Returns:
|
||||
Relative path string (no extension) to use as ``filename`` in ``to_sigmf()``.
|
||||
"""
|
||||
safe_id = device_id.replace("/", "_").replace(" ", "_")
|
||||
safe_label = step.label.replace("/", "_").replace(" ", "_")
|
||||
return f"{safe_id}/{safe_label}"
|
||||
109
src/ria_toolkit_oss/orchestration/qa.py
Normal file
109
src/ria_toolkit_oss/orchestration/qa.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
"""QA metrics for captured RF recordings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
|
||||
from .campaign import QAConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class QAResult:
|
||||
"""Result of QA checks on a single recording."""
|
||||
|
||||
passed: bool
|
||||
flagged: bool # True if any metric is below threshold (but not hard-failed)
|
||||
snr_db: float
|
||||
duration_s: float
|
||||
issues: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"passed": self.passed,
|
||||
"flagged": self.flagged,
|
||||
"snr_db": round(self.snr_db, 2),
|
||||
"duration_s": round(self.duration_s, 3),
|
||||
"issues": self.issues,
|
||||
}
|
||||
|
||||
|
||||
def estimate_snr_db(samples: np.ndarray, signal_fraction: float = 0.7) -> float:
|
||||
"""Estimate SNR from IQ samples using PSD-based signal/noise separation.
|
||||
|
||||
Computes an FFT of the samples and assumes the top ``signal_fraction``
|
||||
of power bins are signal and the remainder are noise. This is a
|
||||
heuristic appropriate for a controlled testbed where a single dominant
|
||||
signal is expected.
|
||||
|
||||
Args:
|
||||
samples: 1-D complex array of IQ samples.
|
||||
signal_fraction: Fraction of PSD bins to treat as signal (0–1).
|
||||
|
||||
Returns:
|
||||
Estimated SNR in dB, or 0.0 if the noise floor is zero.
|
||||
"""
|
||||
n_fft = min(4096, len(samples))
|
||||
window = np.hanning(n_fft)
|
||||
psd = np.abs(np.fft.fft(samples[:n_fft] * window)) ** 2
|
||||
|
||||
psd_sorted = np.sort(psd)[::-1]
|
||||
n_signal = max(1, int(n_fft * signal_fraction))
|
||||
signal_power = psd_sorted[:n_signal].mean()
|
||||
noise_power = psd_sorted[n_signal:].mean()
|
||||
|
||||
if noise_power <= 0.0:
|
||||
return 0.0
|
||||
return float(10.0 * np.log10(signal_power / noise_power))
|
||||
|
||||
|
||||
def check_recording(recording: Recording, config: QAConfig) -> QAResult:
|
||||
"""Run QA checks on a recording against the campaign QA config.
|
||||
|
||||
Checks performed:
|
||||
- Duration: number of samples / sample_rate >= min_duration_s
|
||||
- SNR: estimated SNR >= snr_threshold_db
|
||||
|
||||
Args:
|
||||
recording: Recording to evaluate.
|
||||
config: QA thresholds from the campaign config.
|
||||
|
||||
Returns:
|
||||
QAResult with pass/flag status and per-metric details.
|
||||
"""
|
||||
issues: list[str] = []
|
||||
flagged = False
|
||||
|
||||
# --- Duration check ---
|
||||
sample_rate = recording.metadata.get("sample_rate", 1.0)
|
||||
n_samples = recording.data.shape[-1]
|
||||
duration_s = n_samples / sample_rate if sample_rate else 0.0
|
||||
|
||||
if duration_s < config.min_duration_s:
|
||||
issues.append(f"Duration too short: {duration_s:.1f}s < {config.min_duration_s:.1f}s threshold")
|
||||
flagged = True
|
||||
|
||||
# --- SNR check ---
|
||||
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
|
||||
snr_db = estimate_snr_db(samples)
|
||||
|
||||
if snr_db < config.snr_threshold_db:
|
||||
issues.append(f"SNR below threshold: {snr_db:.1f} dB < {config.snr_threshold_db:.1f} dB")
|
||||
flagged = True
|
||||
|
||||
# In flag_for_review mode: flag but don't hard-fail
|
||||
if config.flag_for_review:
|
||||
passed = True # always accept; human reviews flagged recordings
|
||||
else:
|
||||
passed = not flagged
|
||||
|
||||
return QAResult(
|
||||
passed=passed,
|
||||
flagged=flagged,
|
||||
snr_db=snr_db,
|
||||
duration_s=duration_s,
|
||||
issues=issues,
|
||||
)
|
||||
|
|
@ -474,8 +474,10 @@ class Blade(SDR):
|
|||
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets \
|
||||
the gain relative to the maximum possible gain.")
|
||||
raise SDRParameterError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This sets \
|
||||
the gain relative to the maximum possible gain."
|
||||
)
|
||||
else:
|
||||
abs_gain = rx_gain_max + gain
|
||||
else:
|
||||
|
|
@ -548,8 +550,10 @@ class Blade(SDR):
|
|||
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain.")
|
||||
raise SDRParameterError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain."
|
||||
)
|
||||
else:
|
||||
abs_gain = tx_gain_max + gain
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -172,8 +172,10 @@ class HackRF(SDR):
|
|||
tx_gain_max = 47
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This \
|
||||
sets the gain relative to the maximum possible gain.")
|
||||
raise SDRParameterError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This \
|
||||
sets the gain relative to the maximum possible gain."
|
||||
)
|
||||
else:
|
||||
abs_gain = tx_gain_max + gain
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -274,16 +274,20 @@ class Pluto(SDR):
|
|||
data = [self._convert_tx_samples(samples), self._convert_tx_samples(samples)]
|
||||
else:
|
||||
if len(recording) > 2:
|
||||
warnings.warn("More recordings were provided than channels in the Pluto. \
|
||||
Only the first two recordings will be used")
|
||||
warnings.warn(
|
||||
"More recordings were provided than channels in the Pluto. \
|
||||
Only the first two recordings will be used"
|
||||
)
|
||||
sample0 = self._convert_tx_samples(recording.data[0])
|
||||
sample1 = self._convert_tx_samples(recording.data[1])
|
||||
data = [sample0, sample1]
|
||||
|
||||
elif isinstance(recording, list):
|
||||
if len(recording) > 2:
|
||||
warnings.warn("More recordings were provided than channels in the Pluto. \
|
||||
Only the first two recordings will be used")
|
||||
warnings.warn(
|
||||
"More recordings were provided than channels in the Pluto. \
|
||||
Only the first two recordings will be used"
|
||||
)
|
||||
|
||||
if isinstance(recording[0], np.ndarray):
|
||||
data = [self._convert_tx_samples(recording[0]), self._convert_tx_samples(recording[1])]
|
||||
|
|
@ -423,8 +427,10 @@ class Pluto(SDR):
|
|||
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets \
|
||||
the gain relative to the maximum possible gain.")
|
||||
raise SDRParameterError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This sets \
|
||||
the gain relative to the maximum possible gain."
|
||||
)
|
||||
else:
|
||||
abs_gain = rx_gain_max + gain
|
||||
else:
|
||||
|
|
@ -534,8 +540,10 @@ class Pluto(SDR):
|
|||
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain.")
|
||||
raise SDRParameterError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain."
|
||||
)
|
||||
else:
|
||||
abs_gain = tx_gain_max + gain
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -131,15 +131,19 @@ class RTLSDR(SDR):
|
|||
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain.")
|
||||
raise SDRParameterError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain."
|
||||
)
|
||||
target_gain = max_gain + gain
|
||||
else:
|
||||
target_gain = gain
|
||||
|
||||
if target_gain < min_gain or target_gain > max_gain:
|
||||
print(f"Requested gain {target_gain} dB out of range;\
|
||||
clamping to valid span {min_gain}-{max_gain} dB.")
|
||||
print(
|
||||
f"Requested gain {target_gain} dB out of range;\
|
||||
clamping to valid span {min_gain}-{max_gain} dB."
|
||||
)
|
||||
target_gain = min(max(target_gain, min_gain), max_gain)
|
||||
|
||||
target_gain = min(available_gains, key=lambda g: abs(g - target_gain))
|
||||
|
|
|
|||
|
|
@ -392,8 +392,10 @@ class ThinkRF(SDR):
|
|||
actual_sample_rate = self.BASE_SAMPLE_RATE / decimation
|
||||
|
||||
if abs(actual_sample_rate - requested_sample_rate) > 1e3: # More than 1 kHz difference
|
||||
print(f"ThinkRF: Requested {requested_sample_rate/1e6:.2f} MS/s → \
|
||||
Using decimation={decimation} ({actual_sample_rate/1e6:.2f} MS/s)")
|
||||
print(
|
||||
f"ThinkRF: Requested {requested_sample_rate/1e6:.2f} MS/s → \
|
||||
Using decimation={decimation} ({actual_sample_rate/1e6:.2f} MS/s)"
|
||||
)
|
||||
|
||||
return decimation, actual_sample_rate
|
||||
|
||||
|
|
|
|||
|
|
@ -148,8 +148,10 @@ class USRP(SDR):
|
|||
gain_range = self.usrp.get_rx_gain_range()
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain.")
|
||||
raise SDRParameterError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain."
|
||||
)
|
||||
else:
|
||||
# set gain relative to max
|
||||
abs_gain = gain_range.stop() + gain
|
||||
|
|
@ -354,8 +356,10 @@ class USRP(SDR):
|
|||
gain_range = self.usrp.get_tx_gain_range()
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain.")
|
||||
raise SDRParameterError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain."
|
||||
)
|
||||
else:
|
||||
# set gain relative to max
|
||||
abs_gain = gain_range.stop() + gain
|
||||
|
|
|
|||
5
src/ria_toolkit_oss/server/__init__.py
Normal file
5
src/ria_toolkit_oss/server/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""RT-OSS HTTP server for RIA Hub integration."""
|
||||
|
||||
from .app import create_app
|
||||
|
||||
__all__ = ["create_app"]
|
||||
48
src/ria_toolkit_oss/server/app.py
Normal file
48
src/ria_toolkit_oss/server/app.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""FastAPI application factory for the RT-OSS HTTP server."""
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
from .auth import require_api_key
|
||||
from .routers import inference, orchestrator
|
||||
|
||||
|
||||
def create_app(api_key: str = "") -> FastAPI:
|
||||
"""Create and configure the RT-OSS FastAPI application.
|
||||
|
||||
Args:
|
||||
api_key: Secret key required in the ``X-API-Key`` request header.
|
||||
Pass an empty string to disable authentication (development only).
|
||||
|
||||
Returns:
|
||||
Configured FastAPI application instance.
|
||||
"""
|
||||
app = FastAPI(
|
||||
title="RIA Toolkit OSS Server",
|
||||
version="0.1.0",
|
||||
description=(
|
||||
"HTTP API for RT-OSS campaign orchestration and RF zone inference. "
|
||||
"All endpoints (except /health) require the X-API-Key header when "
|
||||
"an API key is configured."
|
||||
),
|
||||
)
|
||||
app.state.api_key = api_key
|
||||
|
||||
app.include_router(
|
||||
orchestrator.router,
|
||||
prefix="/orchestrator",
|
||||
tags=["Orchestrator"],
|
||||
dependencies=[Depends(require_api_key)],
|
||||
)
|
||||
app.include_router(
|
||||
inference.router,
|
||||
prefix="/inference",
|
||||
tags=["Inference"],
|
||||
dependencies=[Depends(require_api_key)],
|
||||
)
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
async def health():
|
||||
"""Health check — always returns 200."""
|
||||
return {"status": "ok"}
|
||||
|
||||
return app
|
||||
25
src/ria_toolkit_oss/server/auth.py
Normal file
25
src/ria_toolkit_oss/server/auth.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
"""API key authentication dependency."""
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
async def require_api_key(
|
||||
request: Request,
|
||||
api_key: str | None = Depends(_api_key_header),
|
||||
) -> None:
|
||||
"""FastAPI dependency that enforces X-API-Key header authentication.
|
||||
|
||||
If no API key is configured on the server (empty string), all requests
|
||||
are allowed — this is intended for local development only.
|
||||
"""
|
||||
expected: str = request.app.state.api_key
|
||||
if not expected:
|
||||
return # dev mode: no key set, allow all
|
||||
if api_key != expected:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid or missing API key",
|
||||
)
|
||||
77
src/ria_toolkit_oss/server/models.py
Normal file
77
src/ria_toolkit_oss/server/models.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""Pydantic request and response models for the RT-OSS HTTP server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DeployRequest(BaseModel):
|
||||
config: dict
|
||||
|
||||
|
||||
class DeployResponse(BaseModel):
|
||||
campaign_id: str
|
||||
|
||||
|
||||
class CampaignStatusResponse(BaseModel):
|
||||
campaign_id: str
|
||||
status: str
|
||||
config_name: str
|
||||
progress: int
|
||||
total_steps: int
|
||||
started_at: float
|
||||
ended_at: float | None = None
|
||||
result: dict | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class CancelResponse(BaseModel):
|
||||
campaign_id: str
|
||||
cancelled: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inference
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SdrConfig(BaseModel):
|
||||
device: str
|
||||
center_freq: float
|
||||
sample_rate: float
|
||||
gain: float | str = "auto"
|
||||
|
||||
|
||||
class LoadModelRequest(BaseModel):
|
||||
model_path: str
|
||||
label_map: dict[str, int] # class_name -> class_index
|
||||
|
||||
|
||||
class LoadModelResponse(BaseModel):
|
||||
loaded: bool
|
||||
model_path: str
|
||||
num_classes: int
|
||||
|
||||
|
||||
class StartInferenceRequest(BaseModel):
|
||||
sdr_config: SdrConfig
|
||||
|
||||
|
||||
class StartInferenceResponse(BaseModel):
|
||||
running: bool
|
||||
|
||||
|
||||
class StopInferenceResponse(BaseModel):
|
||||
stopped: bool
|
||||
|
||||
|
||||
class InferenceStatusResponse(BaseModel):
|
||||
timestamp: float
|
||||
prediction: str
|
||||
confidence: float
|
||||
snr_db: float
|
||||
zone: str | None = None
|
||||
0
src/ria_toolkit_oss/server/routers/__init__.py
Normal file
0
src/ria_toolkit_oss/server/routers/__init__.py
Normal file
183
src/ria_toolkit_oss/server/routers/inference.py
Normal file
183
src/ria_toolkit_oss/server/routers/inference.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
"""Inference routes: model loading, inference loop control, and status polling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from scipy.special import softmax
|
||||
|
||||
from ..models import (
|
||||
InferenceStatusResponse,
|
||||
LoadModelRequest,
|
||||
LoadModelResponse,
|
||||
StartInferenceRequest,
|
||||
StartInferenceResponse,
|
||||
StopInferenceResponse,
|
||||
)
|
||||
from ..state import InferenceState, get_inference, set_inference
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_INFERENCE_NUM_SAMPLES = 4096
|
||||
|
||||
|
||||
def _load_onnx_session(model_path: str):
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="onnxruntime not installed. Install with: pip install ria-toolkit-oss[server]",
|
||||
)
|
||||
try:
|
||||
return ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Failed to load ONNX model: {e}")
|
||||
|
||||
|
||||
def _preprocess_samples(samples: np.ndarray, expected_shape: tuple) -> np.ndarray:
|
||||
"""Reshape complex IQ samples to float32 matching the model's expected input.
|
||||
|
||||
Supports ``(batch, 2*N)`` interleaved and ``(batch, 2, N)`` two-channel conventions.
|
||||
"""
|
||||
iq = samples.astype(np.complex64)
|
||||
i_ch, q_ch = iq.real, iq.imag
|
||||
|
||||
if len(expected_shape) == 2:
|
||||
n = expected_shape[1] // 2
|
||||
interleaved = np.empty(expected_shape[1], dtype=np.float32)
|
||||
interleaved[0::2] = i_ch[:n]
|
||||
interleaved[1::2] = q_ch[:n]
|
||||
return interleaved.reshape(1, -1)
|
||||
elif len(expected_shape) == 3:
|
||||
n = expected_shape[2]
|
||||
return np.stack([i_ch[:n], q_ch[:n]], axis=0).astype(np.float32).reshape(1, 2, n)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model input shape: {expected_shape}")
|
||||
|
||||
|
||||
def _stop_current_inference(state: InferenceState, timeout: float = 5.0) -> None:
|
||||
state.stop_event.set()
|
||||
if state.thread and state.thread.is_alive():
|
||||
state.thread.join(timeout=timeout)
|
||||
if state.thread.is_alive():
|
||||
logger.warning("Inference thread did not stop within %.1fs; SDR resources may not be released", timeout)
|
||||
|
||||
|
||||
def _inference_loop(state: InferenceState, sdr) -> None:
|
||||
from ria_toolkit_oss.orchestration.qa import estimate_snr_db
|
||||
|
||||
session = state.session
|
||||
input_name = session.get_inputs()[0].name
|
||||
expected_shape = tuple(
|
||||
d if isinstance(d, int) and d > 0 else _INFERENCE_NUM_SAMPLES for d in session.get_inputs()[0].shape
|
||||
)
|
||||
|
||||
try:
|
||||
while not state.stop_event.is_set():
|
||||
recording = sdr.record(num_samples=_INFERENCE_NUM_SAMPLES)
|
||||
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
|
||||
snr_db = estimate_snr_db(samples)
|
||||
|
||||
try:
|
||||
model_input = _preprocess_samples(samples, expected_shape)
|
||||
logits = session.run(None, {input_name: model_input})[0][0].astype(np.float32)
|
||||
probs = softmax(logits)
|
||||
pred_idx = int(np.argmax(probs))
|
||||
prediction = state.index_to_label.get(pred_idx, str(pred_idx))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
state.set_latest(
|
||||
{
|
||||
"timestamp": time.time(),
|
||||
"prediction": prediction,
|
||||
"confidence": round(float(probs[pred_idx]), 4),
|
||||
"snr_db": round(snr_db, 2),
|
||||
"zone": None,
|
||||
}
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
sdr.close()
|
||||
except Exception:
|
||||
pass
|
||||
state.running = False
|
||||
|
||||
|
||||
@router.post("/load", response_model=LoadModelResponse)
|
||||
async def load_model(request: LoadModelRequest):
|
||||
"""Load an ONNX model. Stops any running inference first.
|
||||
``label_map`` maps class names to integer indices (e.g. ``{"zone_a": 0}``).
|
||||
"""
|
||||
existing = get_inference()
|
||||
if existing and existing.running:
|
||||
_stop_current_inference(existing)
|
||||
|
||||
session = _load_onnx_session(request.model_path)
|
||||
set_inference(
|
||||
InferenceState(
|
||||
model_path=request.model_path,
|
||||
label_map=request.label_map,
|
||||
index_to_label={v: k for k, v in request.label_map.items()},
|
||||
session=session,
|
||||
)
|
||||
)
|
||||
return LoadModelResponse(loaded=True, model_path=request.model_path, num_classes=len(request.label_map))
|
||||
|
||||
|
||||
@router.post("/start", response_model=StartInferenceResponse)
|
||||
async def start_inference(request: StartInferenceRequest):
|
||||
"""Start continuous inference. Requires a model to be loaded first."""
|
||||
state = get_inference()
|
||||
if not state:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT, detail="No model loaded. Call POST /inference/load first."
|
||||
)
|
||||
if state.running:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Inference is already running.")
|
||||
|
||||
try:
|
||||
from ria_toolkit_oss.orchestration.executor import _DEVICE_ALIASES
|
||||
from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
|
||||
except ImportError as e:
|
||||
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"SDR import failed: {e}")
|
||||
|
||||
sdr_cfg = request.sdr_config
|
||||
try:
|
||||
sdr = get_sdr_device(_DEVICE_ALIASES.get(sdr_cfg.device.lower(), sdr_cfg.device.lower()))
|
||||
gain = None if sdr_cfg.gain == "auto" else float(sdr_cfg.gain)
|
||||
sdr.init_rx(sample_rate=sdr_cfg.sample_rate, center_frequency=sdr_cfg.center_freq, gain=gain, channel=0)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"SDR initialisation failed: {e}")
|
||||
|
||||
state.stop_event.clear()
|
||||
state.running = True
|
||||
state.thread = threading.Thread(target=_inference_loop, args=(state, sdr), daemon=True)
|
||||
state.thread.start()
|
||||
return StartInferenceResponse(running=True)
|
||||
|
||||
|
||||
@router.post("/stop", response_model=StopInferenceResponse)
|
||||
async def stop_inference():
|
||||
"""Stop the running inference loop."""
|
||||
state = get_inference()
|
||||
if not state or not state.running:
|
||||
return StopInferenceResponse(stopped=False)
|
||||
_stop_current_inference(state)
|
||||
return StopInferenceResponse(stopped=True)
|
||||
|
||||
|
||||
@router.get("/status", response_model=InferenceStatusResponse | None)
|
||||
async def inference_status():
|
||||
"""Return the latest inference result, or null if none available yet."""
|
||||
state = get_inference()
|
||||
if not state:
|
||||
return None
|
||||
latest = state.get_latest()
|
||||
return InferenceStatusResponse(**latest) if latest else None
|
||||
102
src/ria_toolkit_oss/server/routers/orchestrator.py
Normal file
102
src/ria_toolkit_oss/server/routers/orchestrator.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""Orchestrator routes: campaign deployment, status, and cancellation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
|
||||
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
||||
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
|
||||
|
||||
from ..models import (
|
||||
CampaignStatusResponse,
|
||||
CancelResponse,
|
||||
DeployRequest,
|
||||
DeployResponse,
|
||||
)
|
||||
from ..state import (
|
||||
CampaignCancelledError,
|
||||
CampaignState,
|
||||
get_campaign,
|
||||
set_campaign,
|
||||
update_campaign,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _make_progress_cb(campaign_id: str, cancel_event: threading.Event):
|
||||
def cb(step_index: int, total_steps: int, step_result: Any) -> None:
|
||||
update_campaign(campaign_id, progress=step_index)
|
||||
if cancel_event.is_set():
|
||||
raise CampaignCancelledError(f"Cancelled at step {step_index}/{total_steps}")
|
||||
|
||||
return cb
|
||||
|
||||
|
||||
def _run_campaign_thread(campaign_id: str, cfg: CampaignConfig) -> None:
|
||||
state = get_campaign(campaign_id)
|
||||
try:
|
||||
result = CampaignExecutor(
|
||||
config=cfg,
|
||||
progress_cb=_make_progress_cb(campaign_id, state.cancel_event),
|
||||
).run()
|
||||
update_campaign(
|
||||
campaign_id, status="completed", progress=cfg.total_steps(), result=result.to_dict(), ended_at=time.time()
|
||||
)
|
||||
except CampaignCancelledError:
|
||||
update_campaign(campaign_id, status="cancelled", ended_at=time.time())
|
||||
except Exception as e:
|
||||
update_campaign(campaign_id, status="failed", error=str(e), ended_at=time.time())
|
||||
|
||||
|
||||
@router.post("/deploy", response_model=DeployResponse)
|
||||
async def deploy(request: DeployRequest):
|
||||
"""Deploy a campaign config and start execution. Returns a ``campaign_id`` for polling.
|
||||
Cancellation takes effect at step boundaries, not mid-capture.
|
||||
"""
|
||||
try:
|
||||
cfg = CampaignConfig.from_dict(request.config)
|
||||
except (ValueError, KeyError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
|
||||
|
||||
campaign_id = str(uuid.uuid4())
|
||||
cancel_event = threading.Event()
|
||||
thread = threading.Thread(target=_run_campaign_thread, args=(campaign_id, cfg), daemon=True)
|
||||
set_campaign(
|
||||
CampaignState(
|
||||
campaign_id=campaign_id,
|
||||
status="running",
|
||||
config_name=cfg.name,
|
||||
cancel_event=cancel_event,
|
||||
thread=thread,
|
||||
total_steps=cfg.total_steps(),
|
||||
)
|
||||
)
|
||||
thread.start()
|
||||
return DeployResponse(campaign_id=campaign_id)
|
||||
|
||||
|
||||
@router.get("/status/{campaign_id}", response_model=CampaignStatusResponse)
|
||||
async def get_status(campaign_id: str):
|
||||
"""Get the status and progress of a deployed campaign."""
|
||||
state = get_campaign(campaign_id)
|
||||
if not state:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
|
||||
return CampaignStatusResponse(**state.to_dict())
|
||||
|
||||
|
||||
@router.post("/cancel/{campaign_id}", response_model=CancelResponse)
|
||||
async def cancel(campaign_id: str):
|
||||
"""Request cancellation. Takes effect at the next step boundary."""
|
||||
state = get_campaign(campaign_id)
|
||||
if not state:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
|
||||
if state.status != "running":
|
||||
return CancelResponse(campaign_id=campaign_id, cancelled=False)
|
||||
state.cancel_event.set()
|
||||
return CancelResponse(campaign_id=campaign_id, cancelled=True)
|
||||
101
src/ria_toolkit_oss/server/state.py
Normal file
101
src/ria_toolkit_oss/server/state.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
"""In-memory state for running campaigns and inference sessions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class CampaignCancelledError(Exception):
|
||||
"""Raised by the progress callback when a cancel is requested."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignState:
|
||||
campaign_id: str
|
||||
status: str # "running" | "completed" | "failed" | "cancelled"
|
||||
config_name: str
|
||||
cancel_event: threading.Event
|
||||
thread: threading.Thread
|
||||
total_steps: int = 0
|
||||
progress: int = 0
|
||||
result: Optional[dict] = None
|
||||
error: Optional[str] = None
|
||||
started_at: float = field(default_factory=time.time)
|
||||
ended_at: Optional[float] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"campaign_id": self.campaign_id,
|
||||
"status": self.status,
|
||||
"config_name": self.config_name,
|
||||
"progress": self.progress,
|
||||
"total_steps": self.total_steps,
|
||||
"result": self.result,
|
||||
"error": self.error,
|
||||
"started_at": self.started_at,
|
||||
"ended_at": self.ended_at,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceState:
|
||||
model_path: str
|
||||
label_map: dict[str, int] # class_name -> class_index
|
||||
index_to_label: dict[int, str] # reverse: class_index -> class_name
|
||||
session: Any # onnxruntime.InferenceSession
|
||||
stop_event: threading.Event = field(default_factory=threading.Event)
|
||||
thread: Optional[threading.Thread] = None
|
||||
running: bool = False
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
||||
_latest: Optional[dict] = field(default=None, repr=False)
|
||||
|
||||
def set_latest(self, result: dict) -> None:
|
||||
with self._lock:
|
||||
self._latest = result
|
||||
|
||||
def get_latest(self) -> Optional[dict]:
|
||||
with self._lock:
|
||||
return self._latest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level stores
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_campaigns: dict[str, CampaignState] = {}
|
||||
_campaigns_lock = threading.Lock()
|
||||
|
||||
_inference: Optional[InferenceState] = None
|
||||
_inference_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_campaign(campaign_id: str) -> Optional[CampaignState]:
|
||||
with _campaigns_lock:
|
||||
return _campaigns.get(campaign_id)
|
||||
|
||||
|
||||
def set_campaign(state: CampaignState) -> None:
|
||||
with _campaigns_lock:
|
||||
_campaigns[state.campaign_id] = state
|
||||
|
||||
|
||||
def update_campaign(campaign_id: str, **kwargs) -> None:
|
||||
with _campaigns_lock:
|
||||
state = _campaigns.get(campaign_id)
|
||||
if state:
|
||||
for k, v in kwargs.items():
|
||||
setattr(state, k, v)
|
||||
|
||||
|
||||
def get_inference() -> Optional[InferenceState]:
|
||||
with _inference_lock:
|
||||
return _inference
|
||||
|
||||
|
||||
def set_inference(state: Optional[InferenceState]) -> None:
|
||||
global _inference
|
||||
with _inference_lock:
|
||||
_inference = state
|
||||
|
|
@ -37,8 +37,10 @@ class Add(RecordableBlock, ProcessBlock):
|
|||
|
||||
samples = block.get_samples(num_samples)
|
||||
if len(samples) != num_samples:
|
||||
raise ValueError(f"Block {self.__class__.__name__} requested {num_samples} \
|
||||
from block {block.__class__.__name__} but got {len(samples)}.")
|
||||
raise ValueError(
|
||||
f"Block {self.__class__.__name__} requested {num_samples} \
|
||||
from block {block.__class__.__name__} but got {len(samples)}."
|
||||
)
|
||||
|
||||
return samples
|
||||
|
||||
|
|
|
|||
|
|
@ -23,9 +23,11 @@ class ProcessBlock(Block, ABC):
|
|||
)
|
||||
|
||||
elif not all(isinstance(item, Block) for item in input):
|
||||
raise ValueError(f"Invalid input to block '{self.__class__.__name__}'. \
|
||||
raise ValueError(
|
||||
f"Invalid input to block '{self.__class__.__name__}'. \
|
||||
Expected a list of Block objects but got \
|
||||
{'[' + ',' .join(f'{item.__class__.__name__}({repr(item)})' for item in input) + ']'}")
|
||||
{'[' + ',' .join(f'{item.__class__.__name__}({repr(item)})' for item in input) + ']'}"
|
||||
)
|
||||
|
||||
elif len(input) != len(self.input_type):
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -20,8 +20,10 @@ class RecordableBlock(Block, Recordable):
|
|||
:raises ValueError: If the number of samples is incorrect."""
|
||||
samples = self.get_samples(num_samples)
|
||||
if len(samples) != num_samples:
|
||||
raise ValueError(f"Error in block {self.__class__.__name__} record(). \
|
||||
Requested {num_samples} samples but got {len(samples)}")
|
||||
raise ValueError(
|
||||
f"Error in block {self.__class__.__name__} record(). \
|
||||
Requested {num_samples} samples but got {len(samples)}"
|
||||
)
|
||||
metadata = self._get_metadata()
|
||||
return Recording(data=samples, metadata=metadata)
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,9 @@ class RecordingSource(SourceBlock, RecordableBlock):
|
|||
:raises ValueError: If num_samples is greater than the recording length.
|
||||
"""
|
||||
if num_samples - 1 >= self.recording.data.shape[1]:
|
||||
raise ValueError(f"{num_samples} samples requested from recording source with \
|
||||
{self.recording.data.shape[1]} samples available.")
|
||||
raise ValueError(
|
||||
f"{num_samples} samples requested from recording source with \
|
||||
{self.recording.data.shape[1]} samples available."
|
||||
)
|
||||
|
||||
return self.recording.data[0, 0:num_samples]
|
||||
|
|
|
|||
|
|
@ -610,8 +610,10 @@ def cut_out( # noqa: C901 # TODO: Simplify function
|
|||
raise ValueError("signal must be CxN complex.")
|
||||
|
||||
if fill_type not in {"zeros", "ones", "low-snr", "avg-snr", "high-snr"}:
|
||||
raise UserWarning("""fill_type must be "zeros", "ones", "low-snr", "avg-snr", or "high-snr",
|
||||
"ones" has been selected by default""")
|
||||
raise UserWarning(
|
||||
"""fill_type must be "zeros", "ones", "low-snr", "avg-snr", or "high-snr",
|
||||
"ones" has been selected by default"""
|
||||
)
|
||||
|
||||
if max_section_size < 1 or max_section_size >= n:
|
||||
raise ValueError("max_section_size must be at least 1 and must be less than the length of signal.")
|
||||
|
|
|
|||
Binary file not shown.
|
Before Width: | Height: | Size: 90 KiB After Width: | Height: | Size: 130 B |
Binary file not shown.
|
Before Width: | Height: | Size: 19 KiB After Width: | Height: | Size: 130 B |
258
src/ria_toolkit_oss_cli/ria_toolkit_oss/campaign.py
Normal file
258
src/ria_toolkit_oss_cli/ria_toolkit_oss/campaign.py
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
"""Campaign orchestration CLI commands.
|
||||
|
||||
Usage examples::
|
||||
|
||||
# Enroll a single device using a device profile YAML (App 1 workflow)
|
||||
ria campaign enroll --config iphone13.yml
|
||||
|
||||
# Run a full custom campaign config
|
||||
ria campaign run --config my_campaign.yml
|
||||
|
||||
# Validate a config file without running it
|
||||
ria campaign validate --config my_campaign.yml
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import click
|
||||
|
||||
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
||||
from ria_toolkit_oss.orchestration.executor import CampaignExecutor, StepResult
|
||||
|
||||
|
||||
@click.group()
|
||||
def campaign():
|
||||
"""Orchestrate automated RF capture campaigns."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign validate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@campaign.command()
|
||||
@click.option(
|
||||
"--config",
|
||||
"-c",
|
||||
required=True,
|
||||
type=click.Path(exists=True),
|
||||
help="Campaign YAML config or device profile YAML.",
|
||||
)
|
||||
@click.option(
|
||||
"--profile",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Parse as a device profile (App 1 style) rather than a full campaign config.",
|
||||
)
|
||||
def validate(config, profile):
|
||||
"""Validate a campaign or device profile YAML without running it.
|
||||
|
||||
\b
|
||||
Examples:
|
||||
ria campaign validate --config iphone13.yml --profile
|
||||
ria campaign validate --config campaign.yml
|
||||
"""
|
||||
try:
|
||||
if profile:
|
||||
cfg = CampaignConfig.from_device_profile(config)
|
||||
else:
|
||||
cfg = CampaignConfig.from_yaml(config)
|
||||
except (FileNotFoundError, ValueError, KeyError) as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
click.echo(click.style("✓ Config valid", fg="green", bold=True))
|
||||
click.echo(f" Campaign name : {cfg.name}")
|
||||
click.echo(f" Mode : {cfg.mode}")
|
||||
click.echo(f" Transmitters : {len(cfg.transmitters)}")
|
||||
click.echo(f" Total steps : {cfg.total_steps()}")
|
||||
click.echo(f" Capture time : {cfg.total_capture_time_s():.0f}s")
|
||||
click.echo(f" Recorder : {cfg.recorder.device} @ {cfg.recorder.center_freq/1e6:.2f} MHz")
|
||||
click.echo(f" Sample rate : {cfg.recorder.sample_rate/1e6:.1f} MS/s")
|
||||
click.echo(f" Output path : {cfg.output.path}")
|
||||
click.echo()
|
||||
|
||||
for tx in cfg.transmitters:
|
||||
click.echo(f" Transmitter: {tx.id} ({tx.type}, {tx.control_method}, {len(tx.schedule)} steps)")
|
||||
for step in tx.schedule:
|
||||
extras = []
|
||||
if step.channel is not None:
|
||||
extras.append(f"ch={step.channel}")
|
||||
if step.bandwidth_mhz is not None:
|
||||
extras.append(f"{int(step.bandwidth_mhz)}MHz")
|
||||
if step.traffic:
|
||||
extras.append(step.traffic)
|
||||
suffix = f" [{', '.join(extras)}]" if extras else ""
|
||||
click.echo(f" [{step.duration:.0f}s] {step.label}{suffix}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign enroll
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@campaign.command()
|
||||
@click.option(
|
||||
"--config",
|
||||
"-c",
|
||||
required=True,
|
||||
type=click.Path(exists=True),
|
||||
help="Device profile YAML (App 1 enrollment format).",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
"-o",
|
||||
default=None,
|
||||
help="Override output directory from config.",
|
||||
)
|
||||
@click.option(
|
||||
"--report",
|
||||
default="qa_report.json",
|
||||
show_default=True,
|
||||
help="Path for the JSON QA report.",
|
||||
)
|
||||
@click.option("--verbose", "-v", is_flag=True, help="Verbose output.")
|
||||
@click.option("--dry-run", is_flag=True, help="Parse and validate config, then exit.")
|
||||
def enroll(config, output, report, verbose, dry_run):
|
||||
"""Enroll a single device by running its capture profile.
|
||||
|
||||
Parses a device profile YAML (App 1 format), generates a capture
|
||||
campaign, and runs it. Outputs labelled SigMF recordings and a
|
||||
JSON QA report.
|
||||
|
||||
\b
|
||||
Examples:
|
||||
ria campaign enroll --config iphone13.yml
|
||||
ria campaign enroll --config airpods.yml --output ./my_recordings
|
||||
ria campaign enroll --config iphone13.yml --dry-run
|
||||
"""
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(config)
|
||||
except (FileNotFoundError, ValueError, KeyError) as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
if output:
|
||||
cfg.output.path = output
|
||||
|
||||
_print_campaign_summary(cfg)
|
||||
|
||||
if dry_run:
|
||||
click.echo(click.style("Dry run — exiting before capture.", fg="yellow"))
|
||||
return
|
||||
|
||||
result = _run_campaign(cfg, verbose=verbose)
|
||||
result.write_report(report)
|
||||
|
||||
_print_result_summary(result, report)
|
||||
sys.exit(0 if result.failed == 0 else 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@campaign.command()
|
||||
@click.option(
|
||||
"--config",
|
||||
"-c",
|
||||
required=True,
|
||||
type=click.Path(exists=True),
|
||||
help="Full campaign YAML config.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
"-o",
|
||||
default=None,
|
||||
help="Override output directory from config.",
|
||||
)
|
||||
@click.option(
|
||||
"--report",
|
||||
default="qa_report.json",
|
||||
show_default=True,
|
||||
help="Path for the JSON QA report.",
|
||||
)
|
||||
@click.option("--verbose", "-v", is_flag=True, help="Verbose output.")
|
||||
@click.option("--dry-run", is_flag=True, help="Parse and validate config, then exit.")
|
||||
def run(config, output, report, verbose, dry_run):
|
||||
"""Run a full campaign from a campaign config YAML.
|
||||
|
||||
\b
|
||||
Examples:
|
||||
ria campaign run --config wifi_capture.yml
|
||||
ria campaign run --config campaign.yml --output ./data --dry-run
|
||||
"""
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(config)
|
||||
except (FileNotFoundError, ValueError, KeyError) as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
if output:
|
||||
cfg.output.path = output
|
||||
|
||||
_print_campaign_summary(cfg)
|
||||
|
||||
if dry_run:
|
||||
click.echo(click.style("Dry run — exiting before capture.", fg="yellow"))
|
||||
return
|
||||
|
||||
result = _run_campaign(cfg, verbose=verbose)
|
||||
result.write_report(report)
|
||||
|
||||
_print_result_summary(result, report)
|
||||
sys.exit(0 if result.failed == 0 else 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _print_campaign_summary(cfg: CampaignConfig) -> None:
|
||||
click.echo()
|
||||
click.echo(click.style(f"Campaign: {cfg.name}", bold=True))
|
||||
click.echo(f" Transmitters : {len(cfg.transmitters)}")
|
||||
click.echo(f" Total steps : {cfg.total_steps()}")
|
||||
click.echo(f" Capture time : ~{cfg.total_capture_time_s():.0f}s")
|
||||
click.echo(f" Output : {cfg.output.path}")
|
||||
click.echo()
|
||||
|
||||
|
||||
def _make_progress_cb(total: int):
|
||||
"""Return a progress callback that prints step results to stderr."""
|
||||
|
||||
def cb(idx: int, _total: int, step: StepResult) -> None:
|
||||
status = (
|
||||
click.style("✓", fg="green")
|
||||
if step.ok
|
||||
else (click.style("⚑", fg="yellow") if step.qa.flagged else click.style("✗", fg="red"))
|
||||
)
|
||||
snr_str = f"SNR {step.qa.snr_db:.1f} dB" if not step.error else f"ERROR: {step.error}"
|
||||
click.echo(
|
||||
f" [{idx:>3}/{_total}] {status} {step.transmitter_id}/{step.step_label} — {snr_str}",
|
||||
err=True,
|
||||
)
|
||||
|
||||
return cb
|
||||
|
||||
|
||||
def _run_campaign(cfg: CampaignConfig, verbose: bool = False):
|
||||
executor = CampaignExecutor(
|
||||
config=cfg,
|
||||
progress_cb=_make_progress_cb(cfg.total_steps()),
|
||||
verbose=verbose,
|
||||
)
|
||||
return executor.run()
|
||||
|
||||
|
||||
def _print_result_summary(result, report_path: str) -> None:
|
||||
click.echo()
|
||||
click.echo(click.style("Campaign complete", bold=True))
|
||||
click.echo(f" Steps : {result.total_steps}")
|
||||
click.echo(f" Passed : {click.style(str(result.passed), fg='green')}")
|
||||
if result.flagged:
|
||||
click.echo(f" Flagged : {click.style(str(result.flagged), fg='yellow')} (review required)")
|
||||
if result.failed:
|
||||
click.echo(f" Failed : {click.style(str(result.failed), fg='red')}")
|
||||
click.echo(f" Duration: {result.duration_s:.0f}s")
|
||||
click.echo(f" Report : {report_path}")
|
||||
click.echo()
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
This module contains all the CLI bindings for the ria package.
|
||||
"""
|
||||
|
||||
from .campaign import campaign
|
||||
from .capture import capture
|
||||
from .combine import combine
|
||||
from .convert import convert
|
||||
|
|
@ -13,6 +14,7 @@ from .generate import generate
|
|||
|
||||
# from .generate import generate
|
||||
from .init import init
|
||||
from .serve import serve
|
||||
from .split import split
|
||||
from .transform import transform
|
||||
from .transmit import transmit
|
||||
|
|
|
|||
51
src/ria_toolkit_oss_cli/ria_toolkit_oss/serve.py
Normal file
51
src/ria_toolkit_oss_cli/ria_toolkit_oss/serve.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
"""``ria serve`` — start the RT-OSS HTTP server for RIA Hub integration."""
|
||||
|
||||
import click
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--host", default="0.0.0.0", show_default=True, help="Bind host.")
|
||||
@click.option("--port", default=8080, show_default=True, type=int, help="Bind port.")
|
||||
@click.option(
|
||||
"--api-key",
|
||||
envvar="RT_OSS_API_KEY",
|
||||
default="",
|
||||
help="Required X-API-Key value. Also reads RT_OSS_API_KEY. Empty = no auth (dev only).",
|
||||
)
|
||||
@click.option(
|
||||
"--log-level",
|
||||
default="info",
|
||||
show_default=True,
|
||||
type=click.Choice(["debug", "info", "warning", "error"], case_sensitive=False),
|
||||
)
|
||||
def serve(host: str, port: int, api_key: str, log_level: str):
|
||||
"""Start the RT-OSS HTTP server.
|
||||
|
||||
\b
|
||||
Endpoints:
|
||||
POST /orchestrator/deploy
|
||||
GET /orchestrator/status/{campaign_id}
|
||||
POST /orchestrator/cancel/{campaign_id}
|
||||
POST /inference/load
|
||||
POST /inference/start
|
||||
POST /inference/stop
|
||||
GET /inference/status
|
||||
GET /health
|
||||
"""
|
||||
try:
|
||||
import uvicorn
|
||||
|
||||
from ria_toolkit_oss.server.app import create_app
|
||||
except ImportError as e:
|
||||
raise click.ClickException(
|
||||
f"Server dependencies missing: {e}\nInstall with: pip install ria-toolkit-oss[server]"
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
click.echo(
|
||||
click.style("Warning: ", fg="yellow", bold=True) + "no API key set — all requests unauthenticated.",
|
||||
err=True,
|
||||
)
|
||||
|
||||
click.echo(f"Starting RT-OSS server on http://{host}:{port}")
|
||||
uvicorn.run(create_app(api_key=api_key), host=host, port=port, log_level=log_level.lower())
|
||||
0
tests/orchestration/__init__.py
Normal file
0
tests/orchestration/__init__.py
Normal file
489
tests/orchestration/test_campaign.py
Normal file
489
tests/orchestration/test_campaign.py
Normal file
|
|
@ -0,0 +1,489 @@
|
|||
"""Tests for orchestration campaign schema and YAML parsing."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from ria_toolkit_oss.orchestration.campaign import (
|
||||
CampaignConfig,
|
||||
CaptureStep,
|
||||
QAConfig,
|
||||
RecorderConfig,
|
||||
parse_bandwidth_mhz,
|
||||
parse_duration,
|
||||
parse_frequency,
|
||||
parse_gain,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_duration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseDuration:
|
||||
def test_seconds_suffix(self):
|
||||
assert parse_duration("30s") == 30.0
|
||||
|
||||
def test_seconds_suffix_long(self):
|
||||
assert parse_duration("30sec") == 30.0
|
||||
|
||||
def test_minutes_suffix(self):
|
||||
assert parse_duration("1.5m") == 90.0
|
||||
|
||||
def test_minutes_suffix_long(self):
|
||||
assert parse_duration("2min") == 120.0
|
||||
|
||||
def test_hours_suffix(self):
|
||||
assert parse_duration("2h") == 7200.0
|
||||
|
||||
def test_hours_suffix_long(self):
|
||||
assert parse_duration("1hr") == 3600.0
|
||||
|
||||
def test_numeric_int(self):
|
||||
assert parse_duration(45) == 45.0
|
||||
|
||||
def test_numeric_float(self):
|
||||
assert parse_duration(1.5) == 1.5
|
||||
|
||||
def test_bare_number_string(self):
|
||||
# No unit → treated as seconds
|
||||
assert parse_duration("60") == 60.0
|
||||
|
||||
def test_invalid_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
parse_duration("two minutes")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_frequency
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseFrequency:
|
||||
def test_ghz(self):
|
||||
assert parse_frequency("2.45GHz") == pytest.approx(2.45e9)
|
||||
|
||||
def test_mhz(self):
|
||||
assert parse_frequency("40MHz") == pytest.approx(40e6)
|
||||
|
||||
def test_khz(self):
|
||||
assert parse_frequency("433k") == pytest.approx(433e3)
|
||||
|
||||
def test_scientific_notation_string(self):
|
||||
assert parse_frequency("915e6") == pytest.approx(915e6)
|
||||
|
||||
def test_numeric_float(self):
|
||||
assert parse_frequency(2.45e9) == pytest.approx(2.45e9)
|
||||
|
||||
def test_numeric_int(self):
|
||||
assert parse_frequency(1000000) == pytest.approx(1e6)
|
||||
|
||||
def test_hz_suffix_optional(self):
|
||||
# "40M" and "40MHz" should both work
|
||||
assert parse_frequency("40M") == pytest.approx(40e6)
|
||||
assert parse_frequency("40MHz") == pytest.approx(40e6)
|
||||
|
||||
def test_invalid_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
parse_frequency("two point four gigs")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_gain
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseGain:
|
||||
def test_db_suffix(self):
|
||||
assert parse_gain("40dB") == pytest.approx(40.0)
|
||||
|
||||
def test_db_suffix_lowercase(self):
|
||||
assert parse_gain("32db") == pytest.approx(32.0)
|
||||
|
||||
def test_auto(self):
|
||||
assert parse_gain("auto") == "auto"
|
||||
|
||||
def test_auto_case_insensitive(self):
|
||||
assert parse_gain("AUTO") == "auto"
|
||||
|
||||
def test_numeric_int(self):
|
||||
assert parse_gain(32) == pytest.approx(32.0)
|
||||
|
||||
def test_numeric_float(self):
|
||||
assert parse_gain(32.5) == pytest.approx(32.5)
|
||||
|
||||
def test_invalid_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
parse_gain("high")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_bandwidth_mhz
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseBandwidthMhz:
|
||||
def test_mhz_suffix(self):
|
||||
assert parse_bandwidth_mhz("20MHz") == pytest.approx(20.0)
|
||||
|
||||
def test_numeric(self):
|
||||
assert parse_bandwidth_mhz(40) == pytest.approx(40.0)
|
||||
|
||||
def test_none(self):
|
||||
assert parse_bandwidth_mhz(None) is None
|
||||
|
||||
def test_invalid_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
parse_bandwidth_mhz("wide")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CaptureStep.from_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCaptureStep:
|
||||
def test_wifi_step_auto_label(self):
|
||||
d = {"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_udp", "duration": "30s"}
|
||||
step = CaptureStep.from_dict(d)
|
||||
assert step.duration == 30.0
|
||||
assert step.channel == 6
|
||||
assert step.bandwidth_mhz == 20.0
|
||||
assert step.traffic == "iperf_udp"
|
||||
assert step.label == "ch06_20mhz_iperf_udp"
|
||||
|
||||
def test_explicit_label(self):
|
||||
d = {"channel": 1, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s", "label": "my_label"}
|
||||
step = CaptureStep.from_dict(d)
|
||||
assert step.label == "my_label"
|
||||
|
||||
def test_fallback_label(self):
|
||||
# No channel/bandwidth/traffic → label falls back to "capture"
|
||||
d = {"duration": "10s"}
|
||||
step = CaptureStep.from_dict(d)
|
||||
assert step.label == "capture"
|
||||
|
||||
def test_power_parsed(self):
|
||||
d = {"channel": 6, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s", "power": "15dBm"}
|
||||
step = CaptureStep.from_dict(d)
|
||||
assert step.power_dbm == pytest.approx(15.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RecorderConfig.from_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecorderConfig:
|
||||
def test_basic(self):
|
||||
d = {"device": "usrp_b210", "center_freq": "2.45GHz", "sample_rate": "40MHz", "gain": "40dB"}
|
||||
rec = RecorderConfig.from_dict(d)
|
||||
assert rec.device == "usrp_b210"
|
||||
assert rec.center_freq == pytest.approx(2.45e9)
|
||||
assert rec.sample_rate == pytest.approx(40e6)
|
||||
assert rec.gain == pytest.approx(40.0)
|
||||
assert rec.bandwidth is None
|
||||
|
||||
def test_auto_gain(self):
|
||||
d = {"device": "pluto", "center_freq": "2.45GHz", "sample_rate": "20MHz", "gain": "auto"}
|
||||
rec = RecorderConfig.from_dict(d)
|
||||
assert rec.gain == "auto"
|
||||
|
||||
def test_bandwidth_set(self):
|
||||
d = {"device": "pluto", "center_freq": "2.45GHz", "sample_rate": "20MHz", "gain": 32, "bandwidth": "20MHz"}
|
||||
rec = RecorderConfig.from_dict(d)
|
||||
assert rec.bandwidth == pytest.approx(20e6)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QAConfig.from_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQAConfig:
|
||||
def test_defaults(self):
|
||||
qa = QAConfig.from_dict({})
|
||||
assert qa.snr_threshold_db == pytest.approx(10.0)
|
||||
assert qa.min_duration_s == pytest.approx(25.0)
|
||||
assert qa.flag_for_review is True
|
||||
|
||||
def test_custom_values(self):
|
||||
d = {"snr_threshold": "15dB", "min_duration": "28s", "flag_for_review": False}
|
||||
qa = QAConfig.from_dict(d)
|
||||
assert qa.snr_threshold_db == pytest.approx(15.0)
|
||||
assert qa.min_duration_s == pytest.approx(28.0)
|
||||
assert qa.flag_for_review is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CampaignConfig.from_device_profile
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _write_device_profile(d: dict) -> str:
|
||||
"""Write a dict as YAML to a temp file and return the path."""
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False)
|
||||
yaml.dump(d, f)
|
||||
f.close()
|
||||
return f.name
|
||||
|
||||
|
||||
WIFI_PROFILE = {
|
||||
"device": {"name": "iPhone_13_WiFi", "type": "wifi"},
|
||||
"capture": {
|
||||
"channels": [1, 6, 11],
|
||||
"bandwidth": "20MHz",
|
||||
"traffic_patterns": ["idle", "ping", "iperf_udp"],
|
||||
"duration_per_config": "30s",
|
||||
"script": "./scripts/wifi_control.sh",
|
||||
},
|
||||
"recorder": {
|
||||
"device": "usrp_b210",
|
||||
"center_freq": "2.45GHz",
|
||||
"sample_rate": "40MHz",
|
||||
"gain": "auto",
|
||||
},
|
||||
"output": {"path": "/tmp/test_recordings", "device_id": "iphone13_wifi_001"},
|
||||
}
|
||||
|
||||
BT_PROFILE = {
|
||||
"device": {"name": "AirPods_Pro", "type": "bluetooth"},
|
||||
"capture": {
|
||||
"traffic_patterns": ["idle", "audio_stream", "data_transfer"],
|
||||
"duration_per_config": "30s",
|
||||
},
|
||||
"recorder": {
|
||||
"device": "usrp_b210",
|
||||
"center_freq": "2.45GHz",
|
||||
"sample_rate": "40MHz",
|
||||
"gain": "auto",
|
||||
},
|
||||
"output": {"path": "/tmp/test_recordings", "device_id": "airpods_pro_bt_001"},
|
||||
}
|
||||
|
||||
|
||||
class TestDeviceProfileParsing:
|
||||
def test_wifi_schedule_count(self):
|
||||
"""WiFi: 3 channels × 3 traffic = 9 steps."""
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert len(cfg.transmitters) == 1
|
||||
assert len(cfg.transmitters[0].schedule) == 9
|
||||
|
||||
def test_wifi_campaign_name(self):
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.name == "enroll_iphone13_wifi_001"
|
||||
|
||||
def test_wifi_step_labels(self):
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
labels = [s.label for s in cfg.transmitters[0].schedule]
|
||||
assert "ch01_20mhz_idle" in labels
|
||||
assert "ch06_20mhz_ping" in labels
|
||||
assert "ch11_20mhz_iperf_udp" in labels
|
||||
|
||||
def test_wifi_step_ordering(self):
|
||||
"""Steps iterate channels first, then traffic."""
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
steps = cfg.transmitters[0].schedule
|
||||
assert steps[0].channel == 1 and steps[0].traffic == "idle"
|
||||
assert steps[1].channel == 1 and steps[1].traffic == "ping"
|
||||
assert steps[3].channel == 6 and steps[3].traffic == "idle"
|
||||
assert steps[8].channel == 11 and steps[8].traffic == "iperf_udp"
|
||||
|
||||
def test_wifi_step_duration(self):
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
for step in cfg.transmitters[0].schedule:
|
||||
assert step.duration == pytest.approx(30.0)
|
||||
|
||||
def test_wifi_bandwidth(self):
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
for step in cfg.transmitters[0].schedule:
|
||||
assert step.bandwidth_mhz == pytest.approx(20.0)
|
||||
|
||||
def test_wifi_recorder(self):
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.recorder.device == "usrp_b210"
|
||||
assert cfg.recorder.center_freq == pytest.approx(2.45e9)
|
||||
assert cfg.recorder.sample_rate == pytest.approx(40e6)
|
||||
assert cfg.recorder.gain == "auto"
|
||||
|
||||
def test_wifi_total_capture_time(self):
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.total_capture_time_s() == pytest.approx(270.0) # 9 × 30s
|
||||
|
||||
def test_bt_schedule_count(self):
|
||||
"""BT: no channels, 3 traffic patterns = 3 steps."""
|
||||
path = _write_device_profile(BT_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert len(cfg.transmitters[0].schedule) == 3
|
||||
|
||||
def test_bt_no_channel(self):
|
||||
path = _write_device_profile(BT_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
for step in cfg.transmitters[0].schedule:
|
||||
assert step.channel is None
|
||||
|
||||
def test_bt_step_labels(self):
|
||||
path = _write_device_profile(BT_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
labels = [s.label for s in cfg.transmitters[0].schedule]
|
||||
assert labels == ["idle", "audio_stream", "data_transfer"]
|
||||
|
||||
def test_missing_file_raises(self):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
CampaignConfig.from_device_profile("/nonexistent/path/profile.yml")
|
||||
|
||||
def test_invalid_yaml_raises(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
|
||||
f.write(": bad: yaml: [\n")
|
||||
path = f.name
|
||||
try:
|
||||
with pytest.raises(ValueError, match="Invalid YAML"):
|
||||
CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CampaignConfig.from_yaml (full campaign format)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FULL_CAMPAIGN = {
|
||||
"campaign": {"name": "wifi_capture_001", "mode": "controlled_testbed"},
|
||||
"transmitters": [
|
||||
{
|
||||
"id": "laptop_wifi",
|
||||
"type": "wifi",
|
||||
"control_method": "external_script",
|
||||
"script": "./scripts/wifi_control.sh",
|
||||
"device": "/dev/wlan0",
|
||||
"schedule": [
|
||||
{"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_tcp", "duration": "30s"},
|
||||
{"channel": 36, "bandwidth": "40MHz", "traffic": "ping_flood", "duration": "30s"},
|
||||
],
|
||||
}
|
||||
],
|
||||
"recorder": {
|
||||
"device": "usrp_b210",
|
||||
"center_freq": "2.45GHz",
|
||||
"sample_rate": "20MHz",
|
||||
"gain": "40dB",
|
||||
},
|
||||
"qa": {"snr_threshold": "10dB", "min_duration": "25s", "flag_for_review": True},
|
||||
"output": {"format": "sigmf", "path": "./recordings"},
|
||||
}
|
||||
|
||||
|
||||
class TestFullCampaignParsing:
|
||||
def test_name(self):
|
||||
path = _write_device_profile(FULL_CAMPAIGN)
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.name == "wifi_capture_001"
|
||||
|
||||
def test_mode(self):
|
||||
path = _write_device_profile(FULL_CAMPAIGN)
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.mode == "controlled_testbed"
|
||||
|
||||
def test_transmitter_id(self):
|
||||
path = _write_device_profile(FULL_CAMPAIGN)
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.transmitters[0].id == "laptop_wifi"
|
||||
assert cfg.transmitters[0].control_method == "external_script"
|
||||
assert cfg.transmitters[0].script == "./scripts/wifi_control.sh"
|
||||
|
||||
def test_schedule_count(self):
|
||||
path = _write_device_profile(FULL_CAMPAIGN)
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert len(cfg.transmitters[0].schedule) == 2
|
||||
|
||||
def test_qa_config(self):
|
||||
path = _write_device_profile(FULL_CAMPAIGN)
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.qa.snr_threshold_db == pytest.approx(10.0)
|
||||
assert cfg.qa.min_duration_s == pytest.approx(25.0)
|
||||
assert cfg.qa.flag_for_review is True
|
||||
|
||||
def test_total_steps(self):
|
||||
path = _write_device_profile(FULL_CAMPAIGN)
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.total_steps() == 2
|
||||
|
||||
def test_no_transmitters_raises(self):
|
||||
bad = dict(FULL_CAMPAIGN)
|
||||
bad["transmitters"] = []
|
||||
path = _write_device_profile(bad)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="at least one transmitter"):
|
||||
CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_missing_recorder_raises(self):
|
||||
bad = {k: v for k, v in FULL_CAMPAIGN.items() if k != "recorder"}
|
||||
path = _write_device_profile(bad)
|
||||
try:
|
||||
with pytest.raises((KeyError, ValueError)):
|
||||
CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
274
tests/orchestration/test_campaign_cli.py
Normal file
274
tests/orchestration/test_campaign_cli.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
"""Tests for the `ria campaign` CLI commands."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import yaml
|
||||
from click.testing import CliRunner
|
||||
|
||||
from ria_toolkit_oss_cli.cli import cli
|
||||
|
||||
|
||||
def _write_yaml(d: dict, suffix=".yml") -> str:
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False)
|
||||
yaml.dump(d, f)
|
||||
f.close()
|
||||
return f.name
|
||||
|
||||
|
||||
WIFI_PROFILE = {
|
||||
"device": {"name": "iPhone_13_WiFi", "type": "wifi"},
|
||||
"capture": {
|
||||
"channels": [1, 6, 11],
|
||||
"bandwidth": "20MHz",
|
||||
"traffic_patterns": ["idle", "ping", "iperf_udp"],
|
||||
"duration_per_config": "30s",
|
||||
},
|
||||
"recorder": {
|
||||
"device": "usrp_b210",
|
||||
"center_freq": "2.45GHz",
|
||||
"sample_rate": "40MHz",
|
||||
"gain": "auto",
|
||||
},
|
||||
"output": {"path": "/tmp/test_enroll", "device_id": "iphone13_wifi_001"},
|
||||
}
|
||||
|
||||
FULL_CAMPAIGN = {
|
||||
"campaign": {"name": "wifi_capture_001", "mode": "controlled_testbed"},
|
||||
"transmitters": [
|
||||
{
|
||||
"id": "laptop_wifi",
|
||||
"type": "wifi",
|
||||
"control_method": "external_script",
|
||||
"schedule": [
|
||||
{"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_udp", "duration": "30s"},
|
||||
{"channel": 11, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s"},
|
||||
],
|
||||
}
|
||||
],
|
||||
"recorder": {
|
||||
"device": "usrp_b210",
|
||||
"center_freq": "2.45GHz",
|
||||
"sample_rate": "20MHz",
|
||||
"gain": "40dB",
|
||||
},
|
||||
"qa": {"snr_threshold": "10dB", "min_duration": "25s", "flag_for_review": True},
|
||||
"output": {"format": "sigmf", "path": "/tmp/test_campaign"},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign --help
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCampaignHelp:
|
||||
def test_campaign_help(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "campaign" in result.output.lower()
|
||||
|
||||
def test_subcommands_listed(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "--help"])
|
||||
assert result.exit_code == 0
|
||||
for sub in ("validate", "enroll", "run"):
|
||||
assert sub in result.output
|
||||
|
||||
def test_validate_help(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_enroll_help(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "enroll", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_run_help(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "run", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign validate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCampaignValidate:
|
||||
def test_validate_device_profile(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||
assert result.exit_code == 0
|
||||
assert "✓" in result.output or "valid" in result.output.lower()
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_validate_shows_campaign_name(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||
assert "enroll_iphone13_wifi_001" in result.output
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_validate_shows_step_count(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||
assert "9" in result.output # 9 total steps
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_validate_shows_capture_time(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||
assert "270" in result.output # 270s total
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_validate_full_campaign(self):
|
||||
path = _write_yaml(FULL_CAMPAIGN)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path])
|
||||
assert result.exit_code == 0
|
||||
assert "wifi_capture_001" in result.output
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_validate_shows_steps(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||
assert "ch01_20mhz_idle" in result.output
|
||||
assert "ch06_20mhz_ping" in result.output
|
||||
assert "ch11_20mhz_iperf_udp" in result.output
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_validate_missing_file(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", "/nonexistent/file.yml"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
def test_validate_bad_yaml(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
|
||||
f.write(": broken yaml [\n")
|
||||
path = f.name
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path])
|
||||
assert result.exit_code != 0
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign enroll --dry-run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCampaignEnrollDryRun:
|
||||
def test_dry_run_exits_cleanly(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "enroll", "--config", path, "--dry-run"])
|
||||
assert result.exit_code == 0
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_dry_run_shows_campaign_info(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "enroll", "--config", path, "--dry-run"])
|
||||
assert "enroll_iphone13_wifi_001" in result.output
|
||||
assert "9" in result.output
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_dry_run_does_not_capture(self):
|
||||
"""Dry run should not create any output files."""
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runner.invoke(
|
||||
cli,
|
||||
["campaign", "enroll", "--config", path, "--output", tmpdir, "--dry-run"],
|
||||
)
|
||||
# No .sigmf-data files should have been created
|
||||
sigmf_files = list(os.walk(tmpdir))
|
||||
all_files = [f for _, _, files in sigmf_files for f in files]
|
||||
assert not any(f.endswith(".sigmf-data") for f in all_files)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_dry_run_output_override(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
["campaign", "enroll", "--config", path, "--output", "/tmp/custom_out", "--dry-run"],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert "Dry run" in result.output
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign run --dry-run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCampaignRunDryRun:
|
||||
def test_dry_run_exits_cleanly(self):
|
||||
path = _write_yaml(FULL_CAMPAIGN)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "run", "--config", path, "--dry-run"])
|
||||
assert result.exit_code == 0
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_dry_run_shows_campaign_name(self):
|
||||
path = _write_yaml(FULL_CAMPAIGN)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "run", "--config", path, "--dry-run"])
|
||||
assert "wifi_capture_001" in result.output
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_dry_run_does_not_create_report(self):
|
||||
path = _write_yaml(FULL_CAMPAIGN)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
report_path = os.path.join(tmpdir, "qa_report.json")
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
["campaign", "run", "--config", path, "--dry-run", "--report", report_path],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert not os.path.exists(report_path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_missing_config_fails(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "run", "--config", "/nonexistent.yml"])
|
||||
assert result.exit_code != 0
|
||||
145
tests/orchestration/test_labeler.py
Normal file
145
tests/orchestration/test_labeler.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""Tests for orchestration labeler."""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
from ria_toolkit_oss.orchestration.campaign import CaptureStep
|
||||
from ria_toolkit_oss.orchestration.labeler import build_output_filename, label_recording
|
||||
|
||||
|
||||
def _simple_recording() -> Recording:
|
||||
sr = 1e6
|
||||
n = 1000
|
||||
data = np.ones(n, dtype=np.complex64)
|
||||
return Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9})
|
||||
|
||||
|
||||
def _wifi_step() -> CaptureStep:
|
||||
return CaptureStep(
|
||||
duration=30.0,
|
||||
label="ch06_20mhz_idle",
|
||||
channel=6,
|
||||
bandwidth_mhz=20.0,
|
||||
traffic="idle",
|
||||
)
|
||||
|
||||
|
||||
def _bt_step() -> CaptureStep:
|
||||
return CaptureStep(
|
||||
duration=30.0,
|
||||
label="audio_stream",
|
||||
traffic="audio_stream",
|
||||
connection_interval_ms=7.5,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# label_recording
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLabelRecording:
|
||||
def test_device_id_set(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert rec.metadata["device_id"] == "iphone13_001"
|
||||
|
||||
def test_capture_timestamp_set(self):
|
||||
ts = time.time()
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), ts)
|
||||
assert rec.metadata["capture_timestamp"] == pytest.approx(ts, abs=1.0)
|
||||
|
||||
def test_step_label_set(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert rec.metadata["step_label"] == "ch06_20mhz_idle"
|
||||
|
||||
def test_step_duration_set(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert rec.metadata["step_duration_s"] == pytest.approx(30.0)
|
||||
|
||||
def test_campaign_name_optional(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert "campaign" not in rec.metadata
|
||||
|
||||
def test_campaign_name_when_provided(self):
|
||||
rec = label_recording(
|
||||
_simple_recording(), "iphone13_001", _wifi_step(), time.time(), campaign_name="test_campaign"
|
||||
)
|
||||
assert rec.metadata["campaign"] == "test_campaign"
|
||||
|
||||
def test_wifi_channel_set(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert rec.metadata["wifi_channel"] == 6
|
||||
|
||||
def test_wifi_bandwidth_set(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert rec.metadata["wifi_bandwidth_mhz"] == pytest.approx(20.0)
|
||||
|
||||
def test_traffic_pattern_set(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert rec.metadata["traffic_pattern"] == "idle"
|
||||
|
||||
def test_bt_connection_interval_set(self):
|
||||
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
|
||||
assert rec.metadata["bt_connection_interval_ms"] == pytest.approx(7.5)
|
||||
|
||||
def test_no_channel_key_for_bt(self):
|
||||
"""BT steps with no channel should not add wifi_channel to metadata."""
|
||||
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
|
||||
assert "wifi_channel" not in rec.metadata
|
||||
|
||||
def test_no_bandwidth_key_for_bt(self):
|
||||
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
|
||||
assert "wifi_bandwidth_mhz" not in rec.metadata
|
||||
|
||||
def test_power_dbm_set(self):
|
||||
step = CaptureStep(duration=30.0, label="test", traffic="idle", power_dbm=15.0)
|
||||
rec = label_recording(_simple_recording(), "dev_001", step, time.time())
|
||||
assert rec.metadata["tx_power_dbm"] == pytest.approx(15.0)
|
||||
|
||||
def test_no_power_key_when_unset(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert "tx_power_dbm" not in rec.metadata
|
||||
|
||||
def test_returns_same_recording(self):
|
||||
"""label_recording should mutate and return the same Recording object."""
|
||||
rec = _simple_recording()
|
||||
result = label_recording(rec, "iphone13_001", _wifi_step(), time.time())
|
||||
assert result is rec
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_output_filename
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildOutputFilename:
|
||||
def test_basic_wifi(self):
|
||||
step = CaptureStep(duration=30.0, label="ch06_20mhz_idle")
|
||||
fn = build_output_filename("iphone13_wifi_001", step)
|
||||
assert fn == "iphone13_wifi_001/ch06_20mhz_idle"
|
||||
|
||||
def test_bt_step(self):
|
||||
step = CaptureStep(duration=30.0, label="audio_stream")
|
||||
fn = build_output_filename("airpods_pro_bt_001", step)
|
||||
assert fn == "airpods_pro_bt_001/audio_stream"
|
||||
|
||||
def test_spaces_in_device_id_replaced(self):
|
||||
step = CaptureStep(duration=30.0, label="idle")
|
||||
fn = build_output_filename("my device", step)
|
||||
assert " " not in fn
|
||||
assert fn == "my_device/idle"
|
||||
|
||||
def test_slashes_in_label_replaced(self):
|
||||
step = CaptureStep(duration=30.0, label="ch/6/idle")
|
||||
fn = build_output_filename("dev_001", step)
|
||||
assert "/" not in fn.split("/", 1)[1] # only the separator slash should remain
|
||||
|
||||
def test_path_structure(self):
|
||||
"""Filename should be exactly '<device_id>/<label>' (one level of nesting)."""
|
||||
step = CaptureStep(duration=30.0, label="idle")
|
||||
fn = build_output_filename("dev_001", step)
|
||||
parts = fn.split("/")
|
||||
assert len(parts) == 2
|
||||
192
tests/orchestration/test_qa.py
Normal file
192
tests/orchestration/test_qa.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
"""Tests for orchestration QA metrics."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
from ria_toolkit_oss.orchestration.campaign import QAConfig
|
||||
from ria_toolkit_oss.orchestration.qa import QAResult, check_recording, estimate_snr_db
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_recording(n_samples: int, sample_rate: float, signal: np.ndarray) -> Recording:
|
||||
return Recording(
|
||||
signal.astype(np.complex64),
|
||||
metadata={"sample_rate": sample_rate, "center_frequency": 2.45e9},
|
||||
)
|
||||
|
||||
|
||||
def _tone(n: int, sr: float, freq_hz: float = 100e3, amplitude: float = 0.5) -> np.ndarray:
|
||||
t = np.arange(n) / sr
|
||||
return (np.exp(1j * 2 * np.pi * freq_hz * t) * amplitude).astype(np.complex64)
|
||||
|
||||
|
||||
def _noise(n: int, amplitude: float = 0.001) -> np.ndarray:
|
||||
rng = np.random.default_rng(42)
|
||||
return ((rng.standard_normal(n) + 1j * rng.standard_normal(n)) * amplitude).astype(np.complex64)
|
||||
|
||||
|
||||
DEFAULT_QA = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# estimate_snr_db
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEstimateSnrDb:
|
||||
def test_high_snr_tone(self):
|
||||
sr = 1e6
|
||||
samples = _tone(int(sr * 1), sr)
|
||||
snr = estimate_snr_db(samples)
|
||||
assert snr > 20.0, f"Expected high SNR for clean tone, got {snr:.1f} dB"
|
||||
|
||||
def test_pure_noise_low_snr(self):
|
||||
sr = 1e6
|
||||
rng = np.random.default_rng(0)
|
||||
samples = (rng.standard_normal(int(sr)) + 1j * rng.standard_normal(int(sr))).astype(np.complex64)
|
||||
snr = estimate_snr_db(samples)
|
||||
# Pure noise should yield a low (possibly negative) SNR
|
||||
assert snr < 15.0, f"Expected low SNR for noise, got {snr:.1f} dB"
|
||||
|
||||
def test_snr_increases_with_amplitude(self):
|
||||
sr = 1e6
|
||||
n = int(sr)
|
||||
rng = np.random.default_rng(1)
|
||||
noise = (rng.standard_normal(n) + 1j * rng.standard_normal(n)).astype(np.complex64) * 0.01
|
||||
t = np.arange(n) / sr
|
||||
tone = np.exp(1j * 2 * np.pi * 100e3 * t).astype(np.complex64)
|
||||
|
||||
low_snr = estimate_snr_db(noise + tone * 0.1)
|
||||
high_snr = estimate_snr_db(noise + tone * 1.0)
|
||||
assert high_snr > low_snr
|
||||
|
||||
def test_short_input_still_works(self):
|
||||
# Input shorter than n_fft=4096 should not raise
|
||||
samples = _tone(512, 1e6)
|
||||
snr = estimate_snr_db(samples)
|
||||
assert np.isfinite(snr)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_recording — pass cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRecordingPass:
|
||||
def test_clean_tone_passes(self):
|
||||
sr = 1e6
|
||||
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.passed is True
|
||||
assert result.flagged is False
|
||||
assert result.snr_db > 10.0
|
||||
assert abs(result.duration_s - 30.0) < 0.1
|
||||
|
||||
def test_duration_exactly_at_threshold(self):
|
||||
sr = 1e6
|
||||
n = int(sr * 25) # exactly at min_duration_s
|
||||
rec = _make_recording(n, sr, _tone(n, sr))
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.flagged is False
|
||||
|
||||
def test_issues_empty_when_passing(self):
|
||||
sr = 1e6
|
||||
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.issues == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_recording — flag cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRecordingFlag:
|
||||
def test_short_recording_flagged(self):
|
||||
sr = 1e6
|
||||
n = int(sr * 10) # shorter than 25s min
|
||||
rec = _make_recording(n, sr, _tone(n, sr))
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.flagged is True
|
||||
assert any("Duration" in issue for issue in result.issues)
|
||||
|
||||
def test_low_snr_flagged(self):
|
||||
sr = 1e6
|
||||
n = int(sr * 30)
|
||||
rec = _make_recording(n, sr, _noise(n, amplitude=0.001))
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.flagged is True
|
||||
assert any("SNR" in issue for issue in result.issues)
|
||||
|
||||
def test_flag_for_review_still_passes(self):
|
||||
"""With flag_for_review=True, flagged recordings are still marked passed."""
|
||||
sr = 1e6
|
||||
n = int(sr * 10) # short → will be flagged
|
||||
rec = _make_recording(n, sr, _tone(n, sr))
|
||||
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
|
||||
result = check_recording(rec, qa)
|
||||
assert result.flagged is True
|
||||
assert result.passed is True # human review, not auto-reject
|
||||
|
||||
def test_flag_for_review_false_fails(self):
|
||||
"""With flag_for_review=False, a flagged recording is also marked failed."""
|
||||
sr = 1e6
|
||||
n = int(sr * 10)
|
||||
rec = _make_recording(n, sr, _tone(n, sr))
|
||||
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=False)
|
||||
result = check_recording(rec, qa)
|
||||
assert result.flagged is True
|
||||
assert result.passed is False
|
||||
|
||||
def test_multiple_issues_reported(self):
|
||||
"""Both short duration AND low SNR should both appear in issues list."""
|
||||
sr = 1e6
|
||||
n = int(sr * 5) # very short
|
||||
rec = _make_recording(n, sr, _noise(n, amplitude=0.0001))
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.flagged is True
|
||||
assert len(result.issues) >= 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_recording — multichannel input
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRecordingMultichannel:
|
||||
def test_multichannel_recording(self):
|
||||
"""2-channel recording should evaluate channel 0 without error."""
|
||||
sr = 1e6
|
||||
n = int(sr * 30)
|
||||
ch0 = _tone(n, sr)
|
||||
ch1 = _tone(n, sr, freq_hz=200e3)
|
||||
data = np.stack([ch0, ch1]) # shape (2, N)
|
||||
rec = Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9})
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.passed is True
|
||||
assert result.flagged is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QAResult.to_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQAResultToDict:
|
||||
def test_to_dict_keys(self):
|
||||
r = QAResult(passed=True, flagged=False, snr_db=18.3, duration_s=30.0)
|
||||
d = r.to_dict()
|
||||
assert set(d.keys()) == {"passed", "flagged", "snr_db", "duration_s", "issues"}
|
||||
|
||||
def test_to_dict_values(self):
|
||||
r = QAResult(passed=False, flagged=True, snr_db=7.5, duration_s=10.2, issues=["SNR below threshold"])
|
||||
d = r.to_dict()
|
||||
assert d["passed"] is False
|
||||
assert d["flagged"] is True
|
||||
assert d["snr_db"] == pytest.approx(7.5, abs=0.01)
|
||||
assert d["duration_s"] == pytest.approx(10.2, abs=0.01)
|
||||
assert d["issues"] == ["SNR below threshold"]
|
||||
Loading…
Reference in New Issue
Block a user