ria-toolkit-oss/src/ria_toolkit_oss/orchestration/executor.py
ben 50d04161b7
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 35s
Build Project / Build Project (3.10) (pull_request) Successful in 8m8s
Test with tox / Test with tox (3.11) (pull_request) Successful in 8m0s
Build Project / Build Project (3.11) (pull_request) Successful in 8m6s
Build Project / Build Project (3.12) (pull_request) Successful in 8m6s
Test with tox / Test with tox (3.12) (pull_request) Successful in 9m8s
Test with tox / Test with tox (3.10) (pull_request) Successful in 13m58s
Merge remote-tracking branch 'origin/main' into zfp-oss
2026-04-22 15:44:12 -04:00

571 lines
21 KiB
Python
Raw RIA Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Campaign executor: runs a capture campaign end-to-end."""
from __future__ import annotations
import json
import logging
import subprocess
import threading
import time
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import Callable, Optional
from ria_toolkit_oss.data.recording import Recording
from ria_toolkit_oss.io.recording import to_sigmf
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
from .labeler import build_output_filename, label_recording
from .qa import QAResult, check_recording
from .tx_executor import TxExecutor
logger = logging.getLogger(__name__)
# Device name aliases: campaign YAML names → get_sdr_device() names
_DEVICE_ALIASES = {
"usrp_b210": "usrp",
"usrp_b200": "usrp",
"usrp": "usrp",
"plutosdr": "pluto",
"pluto": "pluto",
"hackrf": "hackrf",
"hackrf_one": "hackrf",
"bladerf": "bladerf",
"rtlsdr": "rtlsdr",
"rtl_sdr": "rtlsdr",
"thinkrf": "thinkrf",
# Simulated device — no hardware required
"mock": "mock",
"sim": "mock",
}
@dataclass
class StepResult:
"""Outcome of a single capture step."""
transmitter_id: str
step_label: str
output_path: Optional[str]
qa: QAResult
capture_timestamp: float
error: Optional[str] = None
@property
def ok(self) -> bool:
return self.error is None and self.qa.passed
def to_dict(self) -> dict:
return {
"transmitter_id": self.transmitter_id,
"step_label": self.step_label,
"output_path": self.output_path,
"capture_timestamp": self.capture_timestamp,
"qa": self.qa.to_dict(),
"error": self.error,
}
@dataclass
class CampaignResult:
"""Aggregate outcome of a full campaign."""
campaign_name: str
steps: list[StepResult] = field(default_factory=list)
start_time: float = field(default_factory=time.time)
end_time: Optional[float] = None
@property
def total_steps(self) -> int:
return len(self.steps)
@property
def passed(self) -> int:
return sum(1 for s in self.steps if s.ok)
@property
def flagged(self) -> int:
return sum(1 for s in self.steps if not s.error and s.qa.flagged)
@property
def failed(self) -> int:
return sum(1 for s in self.steps if s.error or not s.qa.passed)
@property
def duration_s(self) -> float:
if self.end_time:
return self.end_time - self.start_time
return time.time() - self.start_time
def to_dict(self) -> dict:
return {
"campaign_name": self.campaign_name,
"total_steps": self.total_steps,
"passed": self.passed,
"flagged": self.flagged,
"failed": self.failed,
"duration_s": round(self.duration_s, 1),
"steps": [s.to_dict() for s in self.steps],
}
def write_report(self, path: str | Path) -> None:
"""Write a JSON QA report to disk."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
json.dump(self.to_dict(), f, indent=2)
logger.info(f"QA report written to {path}")
# ---------------------------------------------------------------------------
# External script interface
# ---------------------------------------------------------------------------
def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
"""Run an external control script and return stdout.
The script is called as::
<script> <arg1> <arg2> ...
A non-zero return code raises RuntimeError.
Args:
script: Path to executable script. Must be an absolute path to an
existing regular file. Relative paths are rejected to prevent
accidentally executing files that are not the intended script.
*args: Positional arguments forwarded to the script.
timeout: Maximum seconds to wait.
Returns:
Script stdout as a string.
"""
if not Path(script).is_absolute():
raise RuntimeError(f"Script path must be absolute: {script}")
script_path = Path(script).resolve()
if not script_path.is_file():
raise RuntimeError(f"Script not found or is not a regular file: {script}")
cmd = [str(script_path), *args]
logger.debug(f"Running script: {' '.join(cmd)}")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=timeout,
)
except subprocess.TimeoutExpired:
raise RuntimeError(f"Script timed out after {timeout}s: {script}")
except FileNotFoundError:
raise RuntimeError(f"Script not found: {script}")
if result.returncode != 0:
raise RuntimeError(f"Script exited {result.returncode}: {result.stderr.strip() or result.stdout.strip()}")
return result.stdout.strip()
# ---------------------------------------------------------------------------
# Campaign executor
# ---------------------------------------------------------------------------
def _extract_tx_params(transmitter: TransmitterConfig) -> dict | None:
"""Build a tx_params dict from a transmitter's signal config for SigMF labeling.
For sdr_agent transmitters, returns the synthetic generation parameters
(modulation, order, symbol_rate, etc.) so recordings capture what was
transmitted. Returns None for control methods without signal-level params.
"""
sdr_agent_cfg = getattr(transmitter, "sdr_agent", None)
if not sdr_agent_cfg:
return None
# Extract known signal-level fields; ignore infra fields
_INFRA_KEYS = {"node_id", "session_code"}
return {k: v for k, v in sdr_agent_cfg.items() if k not in _INFRA_KEYS and v is not None}
class CampaignExecutor:
"""Executes a :class:`CampaignConfig` end-to-end.
Initialises the SDR recorder once, then for each (transmitter, step):
1. Configures the transmitter (via external script or SDR TX)
2. Records IQ samples
3. Labels the recording with device/config metadata
4. Runs QA checks
5. Saves the recording to disk
6. Stops/resets the transmitter
Args:
config: Parsed campaign configuration.
progress_cb: Optional callback ``(step_index, total_steps, step_result)``
called after each step completes. Useful for status reporting.
verbose: Enable debug logging.
"""
def __init__(
self,
config: CampaignConfig,
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
verbose: bool = False,
skip_local_tx: bool = False,
):
self.config = config
self.progress_cb = progress_cb
self.skip_local_tx = skip_local_tx
self._sdr = None
self._remote_tx_controllers: dict = {}
self._tx_executors: dict[str, tuple] = {} # tx_id → (TxExecutor, stop_event, thread)
if verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def run(self) -> CampaignResult:
"""Execute the full campaign and return a :class:`CampaignResult`.
Initialises the SDR, runs all steps across all transmitters,
then closes the SDR. If SDR initialisation fails the exception
propagates immediately (nothing is captured).
"""
result = CampaignResult(campaign_name=self.config.name)
loops = self.config.loops
logger.info(
f"Starting campaign '{self.config.name}': "
f"{self.config.total_steps()} steps"
+ (f" ({self.config.total_steps() // loops} × {loops} loops)" if loops > 1 else "")
+ f", ~{self.config.total_capture_time_s():.0f}s capture time"
)
self._init_sdr()
self._init_remote_tx_controllers()
try:
total = self.config.total_steps()
step_index = 0
for loop_idx in range(loops):
if loops > 1:
logger.info(f"Loop {loop_idx + 1}/{loops}")
for transmitter in self.config.transmitters:
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
for step in transmitter.schedule:
looped_step = replace(step, label=f"{step.label}_run{loop_idx + 1:02d}") if loops > 1 else step
step_result = self._execute_step(transmitter, looped_step)
result.steps.append(step_result)
step_index += 1
if self.progress_cb:
self.progress_cb(step_index, total, step_result)
if step_result.error:
logger.warning(f"Step '{looped_step.label}' error: {step_result.error}")
elif step_result.qa.flagged:
logger.warning(
f"Step '{looped_step.label}' flagged for review: " + "; ".join(step_result.qa.issues)
)
else:
logger.info(
f"Step '{looped_step.label}' OK "
f"(SNR {step_result.qa.snr_db:.1f} dB, "
f"{step_result.qa.duration_s:.1f}s)"
)
finally:
self._close_sdr()
self._close_remote_tx_controllers()
self._close_tx_executors()
result.end_time = time.time()
logger.info(
f"Campaign complete: {result.passed}/{result.total_steps} passed, "
f"{result.flagged} flagged, {result.failed} failed"
)
return result
# ------------------------------------------------------------------
# SDR management
# ------------------------------------------------------------------
def _init_sdr(self) -> None:
"""Initialise and configure the SDR recorder."""
from ria_toolkit_oss.sdr import get_sdr_device
rec = self.config.recorder
device_name = _DEVICE_ALIASES.get(rec.device.lower(), rec.device.lower())
logger.info(f"Initialising SDR: {device_name} @ {rec.center_freq/1e6:.2f} MHz")
self._sdr = get_sdr_device(device_name)
gain = None if rec.gain == "auto" else float(rec.gain)
self._sdr.init_rx(
sample_rate=rec.sample_rate,
center_frequency=rec.center_freq,
gain=gain,
channel=0,
)
if rec.bandwidth and hasattr(self._sdr, "set_rx_bandwidth"):
self._sdr.set_rx_bandwidth(rec.bandwidth)
def _close_sdr(self) -> None:
if self._sdr is not None:
try:
self._sdr.close()
except Exception as e:
logger.warning(f"SDR close error: {e}")
self._sdr = None
# ------------------------------------------------------------------
# Remote Tx controller management
# ------------------------------------------------------------------
def _init_remote_tx_controllers(self) -> None:
"""Open SSH+ZMQ connections for all sdr_remote transmitters."""
from ria_toolkit_oss.remote_control import RemoteTransmitterController
for tx in self.config.transmitters:
if tx.control_method != "sdr_remote":
continue
cfg = tx.sdr_remote
if not cfg:
raise RuntimeError(f"Transmitter '{tx.id}' uses sdr_remote but has no sdr_remote config")
logger.info(f"Connecting remote Tx controller for {tx.id}{cfg['host']}")
ctrl = RemoteTransmitterController(
host=cfg["host"],
ssh_user=cfg["ssh_user"],
ssh_key_path=cfg["ssh_key_path"],
zmq_port=int(cfg.get("zmq_port", 5556)),
)
ctrl.set_radio(
device_type=cfg["device_type"],
device_id=cfg.get("device_id", ""),
)
self._remote_tx_controllers[tx.id] = ctrl
def _close_remote_tx_controllers(self) -> None:
for tx_id, ctrl in list(self._remote_tx_controllers.items()):
try:
ctrl.close()
except Exception as exc:
logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}")
self._remote_tx_controllers.clear()
def _close_tx_executors(self) -> None:
for tx_id, (_, stop_event, t) in list(self._tx_executors.items()):
stop_event.set()
t.join(timeout=5.0)
self._tx_executors.clear()
def _record(self, duration_s: float) -> Recording:
"""Capture ``duration_s`` seconds of IQ samples."""
num_samples = int(duration_s * self.config.recorder.sample_rate)
return self._sdr.record(num_samples=num_samples)
# ------------------------------------------------------------------
# Step execution
# ------------------------------------------------------------------
def _execute_step(self, transmitter: TransmitterConfig, step: CaptureStep) -> StepResult:
"""Run a single capture step.
Returns:
StepResult with QA outcome and output path (or error string).
"""
capture_timestamp = time.time()
output_path: Optional[str] = None
try:
self._start_transmitter(transmitter, step)
recording = self._record(step.duration)
self._stop_transmitter(transmitter, step)
except Exception as e:
# Best-effort stop on error
try:
self._stop_transmitter(transmitter, step)
except Exception:
pass
return StepResult(
transmitter_id=transmitter.id,
step_label=step.label,
output_path=None,
qa=QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=[f"Capture error: {e}"]),
capture_timestamp=capture_timestamp,
error=str(e),
)
# Label recording
recording = label_recording(
recording=recording,
device_id=transmitter.id,
step=step,
capture_timestamp=capture_timestamp,
campaign_name=self.config.name,
tx_params=_extract_tx_params(transmitter),
)
# QA
qa_result = check_recording(recording, self.config.qa)
# Save
try:
output_path = self._save(recording, transmitter.id, step)
except Exception as e:
return StepResult(
transmitter_id=transmitter.id,
step_label=step.label,
output_path=None,
qa=qa_result,
capture_timestamp=capture_timestamp,
error=f"Save failed: {e}",
)
return StepResult(
transmitter_id=transmitter.id,
step_label=step.label,
output_path=output_path,
qa=qa_result,
capture_timestamp=capture_timestamp,
)
# ------------------------------------------------------------------
# Transmitter control (external script interface)
# ------------------------------------------------------------------
def _start_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
"""Configure the transmitter for this step.
For ``external_script`` control method the script is called as::
<script> configure <step_params_json>
where ``step_params_json`` is a JSON object with channel, bandwidth,
traffic, etc. The script is responsible for applying the configuration
and returning promptly (i.e. not blocking for the capture duration).
For ``sdr_remote`` the remote ZMQ controller calls ``init_tx`` then
starts a background transmit thread that runs for the step duration.
"""
if transmitter.control_method == "external_script":
if not transmitter.script:
logger.debug(f"No script configured for {transmitter.id}, skipping configure")
return
params = self._step_params_json(transmitter, step)
_run_script(transmitter.script, "configure", params)
elif transmitter.control_method == "sdr":
logger.debug("SDR TX not yet implemented — skipping start")
elif transmitter.control_method == "sdr_remote":
ctrl = self._remote_tx_controllers.get(transmitter.id)
if ctrl is None:
raise RuntimeError(f"No remote Tx controller found for transmitter '{transmitter.id}'")
gain = step.power_dbm if step.power_dbm is not None else 0.0
ctrl.init_tx(
center_frequency=self.config.recorder.center_freq,
sample_rate=self.config.recorder.sample_rate,
gain=gain,
channel=step.channel or 0,
)
# Start transmission in background; _record() runs concurrently
ctrl.transmit_async(step.duration + 1.0)
elif transmitter.control_method == "sdr_agent":
if self.skip_local_tx:
logger.debug(f"skip_local_tx — TX for '{transmitter.id}' delegated to TX agent node")
return
if not transmitter.sdr_agent:
logger.warning(f"Transmitter '{transmitter.id}' has no sdr_agent config — skipping")
return
step_dict: dict = {"label": step.label, "duration": step.duration + 1.0}
if step.power_dbm is not None:
step_dict["power_dbm"] = step.power_dbm
tx_config = {
"id": transmitter.id,
"sdr_agent": transmitter.sdr_agent,
"schedule": [step_dict],
}
rec = self.config.recorder
tx_device = transmitter.device or rec.device
sdr_device = _DEVICE_ALIASES.get(tx_device.lower(), tx_device.lower())
stop_event = threading.Event()
executor = TxExecutor(tx_config, sdr_device=sdr_device, stop_event=stop_event)
t = threading.Thread(target=executor.run, daemon=True, name=f"tx-{transmitter.id}")
self._tx_executors[transmitter.id] = (executor, stop_event, t)
t.start()
else:
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
def _stop_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
"""Signal the transmitter to stop.
Calls ``<script> stop`` for external_script transmitters.
For ``sdr_remote``, waits for the background transmit thread to finish.
"""
if transmitter.control_method == "external_script":
if not transmitter.script:
return
try:
_run_script(transmitter.script, "stop")
except Exception as e:
logger.warning(f"Script stop failed for {transmitter.id}: {e}")
elif transmitter.control_method == "sdr_remote":
ctrl = self._remote_tx_controllers.get(transmitter.id)
if ctrl is not None:
ctrl.wait_transmit(timeout=step.duration + 10.0)
elif transmitter.control_method == "sdr_agent":
entry = self._tx_executors.pop(transmitter.id, None)
if entry is not None:
_, stop_event, t = entry
stop_event.set()
t.join(timeout=step.duration + 10.0)
@staticmethod
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
"""Serialise step parameters to a JSON string for the control script."""
params: dict = {"device": transmitter.device or ""}
if step.channel is not None:
params["channel"] = step.channel
if step.bandwidth_mhz is not None:
params["bandwidth_mhz"] = step.bandwidth_mhz
if step.traffic is not None:
params["traffic"] = step.traffic
if step.power_dbm is not None:
params["power_dbm"] = step.power_dbm
return json.dumps(params)
# ------------------------------------------------------------------
# Output
# ------------------------------------------------------------------
def _save(self, recording: Recording, device_id: str, step: CaptureStep) -> str:
"""Save a recording to disk and return the data file path."""
out = self.config.output
rel_filename = build_output_filename(device_id, step)
out_dir = Path(out.path).resolve()
# build_output_filename returns "<device_id>/<label>"
# to_sigmf needs filename (base) and path (dir) separately
parts = Path(rel_filename)
subdir = (out_dir / parts.parent).resolve()
# Prevent path traversal: the resolved subdir must stay within the configured output directory.
try:
subdir.relative_to(out_dir)
except ValueError:
raise RuntimeError(
f"Output path escape detected: '{subdir}' is outside configured output directory '{out_dir}'"
)
subdir.mkdir(parents=True, exist_ok=True)
base = parts.name
to_sigmf(recording, filename=base, path=str(subdir), overwrite=True)
return str(subdir / f"{base}.sigmf-data")