From 54b9bd4fc899b64f70d58d9eb7dd9159d47e3b17 Mon Sep 17 00:00:00 2001 From: ben Date: Fri, 10 Apr 2026 16:39:11 -0400 Subject: [PATCH 01/13] Agent Error fix --- src/ria_toolkit_oss/agent.py | 469 ++++++++++++++++++++++++++++++++--- 1 file changed, 438 insertions(+), 31 deletions(-) diff --git a/src/ria_toolkit_oss/agent.py b/src/ria_toolkit_oss/agent.py index bd4b3fc..274ff55 100644 --- a/src/ria_toolkit_oss/agent.py +++ b/src/ria_toolkit_oss/agent.py @@ -14,19 +14,40 @@ Usage:: [--device plutosdr] \\ [--insecure] + # Or store credentials in a config file and omit them from the command line: + ria-agent --config ~/.config/ria-agent/config.json --name lab-bench-1 + The agent: 1. Registers with RIA Hub and receives a ``node_id``. 2. Sends a heartbeat every 30 s so the hub knows it is online. 3. Long-polls ``GET /orchestrator/nodes/{id}/commands`` (30 s timeout). - 4. Executes received campaigns via :class:`ria_toolkit_oss.orchestration.executor.CampaignExecutor`. - 5. Uploads recordings to the hub via chunked POST, keeping each request - under 50 MB so it passes through Cloudflare without needing the bypass - subdomain. - 6. Deregisters cleanly on SIGINT / SIGTERM. + 4. Dispatches received commands: + - ``run_campaign``: executes via CampaignExecutor, uploads recordings. + - ``load_model``: loads an ONNX fingerprint or detector model. + - ``start_inference``: opens the SDR, runs the inference loop, posts + detection events to the hub for SSE fan-out to browsers. + - ``stop_inference``: gracefully stops the inference loop. + - ``configure_inference``: queues an SDR parameter update (applied at the + next capture boundary without restarting the loop). + 5. Deregisters cleanly on SIGINT / SIGTERM. + +Config file (JSON, optional):: + + { + "hub": "https://riahub.company.com", + "key": "secret", + "name": "lab-bench-1", + "device": "plutosdr", + "insecure": false, + "log_level": "INFO" + } + +CLI arguments always override config file values. """ from __future__ import annotations +import json import logging import math import os @@ -49,6 +70,8 @@ _POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server _RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying _CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB — well below Cloudflare's 100 MB limit _DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload +_CAPTURE_SAMPLES = 4096 # IQ samples per inference window +_IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"}) # --------------------------------------------------------------------------- @@ -80,6 +103,30 @@ class NodeAgent: self.node_id: str | None = None self._stop = threading.Event() + # ── Inference state ───────────────────────────────────────────────── + # Protected by _inf_lock for cross-thread model swaps. + self._inf_lock = threading.Lock() + self._inf_session: Any = None # primary fingerprint ONNX session + self._inf_index_to_label: dict[int, str] = {} + self._inf_detector_session: Any = None # optional protocol-detector session + self._inf_detector_index_to_label: dict[int, str] = {} + self._inf_detector_threshold: float = 0.7 + self._inf_pending_config: dict = {} # queued SDR attribute updates + + self._inf_stop = threading.Event() + self._inf_thread: threading.Thread | None = None + + # Detect optional dependencies once at startup so capability + # advertising is accurate from the first registration. + try: + import onnxruntime as _ort_mod + + self._ort: Any = _ort_mod + self._ort_available = True + except ImportError: + self._ort = None + self._ort_available = False + try: import ria_toolkit_oss @@ -114,6 +161,7 @@ class NodeAgent: self._command_loop() finally: self._stop.set() + self._stop_inference() self._deregister() # ------------------------------------------------------------------ @@ -121,13 +169,16 @@ class NodeAgent: # ------------------------------------------------------------------ def _register(self) -> None: + capabilities = ["campaign"] + if self._ort_available: + capabilities.append("inference") resp = self._post( "/orchestrator/nodes/register", json={ "name": self.name, "sdr_device": self.sdr_device, "ria_toolkit_version": self._ria_version, - "capabilities": ["inference", "campaign"], + "capabilities": capabilities, }, timeout=15, ) @@ -200,6 +251,24 @@ class NodeAgent: daemon=True, name=f"campaign-{campaign_id[:8]}", ).start() + elif command == "load_model": + threading.Thread( + target=self._load_model, + args=(cmd,), + daemon=True, + name="ria-load-model", + ).start() + elif command == "start_inference": + threading.Thread( + target=self._start_inference, + args=(cmd,), + daemon=True, + name="ria-start-inf", + ).start() + elif command == "stop_inference": + self._stop_inference() + elif command == "configure_inference": + self._queue_sdr_config(cmd) else: logger.warning("Unknown command %r — ignored", command) @@ -232,6 +301,270 @@ class NodeAgent: logger.error("Campaign %s failed: %s", campaign_id[:8], exc) self._report_campaign_status(campaign_id, "failed", error=str(exc)) + # ------------------------------------------------------------------ + # Inference — model loading + # ------------------------------------------------------------------ + + def _load_model(self, cmd: dict) -> None: + """Load an ONNX model into the fingerprint or detector slot. + + The ``model_path`` field may be either a local filesystem path or an + ``http(s)://`` URL; in the latter case the file is downloaded first. + """ + if not self._ort_available: + logger.error("load_model: onnxruntime is not installed — cannot load model") + return + + model_path: str = cmd.get("model_path", "") + label_map: dict[str, int] = cmd.get("label_map") or {} + stage: str = cmd.get("stage", "fingerprint") + detector_threshold: float = float(cmd.get("detector_threshold") or 0.7) + + if model_path.startswith(("http://", "https://")): + model_path = self._download_model(model_path) + if model_path is None: + return + + try: + session = self._ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + except Exception as exc: + logger.error("Failed to load model %r: %s", model_path, exc) + return + + index_to_label = {v: k for k, v in label_map.items()} + with self._inf_lock: + if stage == "detector": + self._inf_detector_session = session + self._inf_detector_index_to_label = index_to_label + self._inf_detector_threshold = detector_threshold + logger.info( + "Detector model loaded: path=%s classes=%d threshold=%.2f", + model_path, + len(label_map), + detector_threshold, + ) + else: + self._inf_session = session + self._inf_index_to_label = index_to_label + logger.info( + "Fingerprint model loaded: path=%s classes=%d", + model_path, + len(label_map), + ) + + def _download_model(self, url: str) -> str | None: + """Download a model from *url* to a temp file and return the local path.""" + import tempfile + + import requests as _requests + + try: + logger.info("Downloading model from %s", url) + resp = _requests.get( + url, + headers={"X-API-Key": self.api_key}, + verify=not self.insecure, + timeout=120, + ) + resp.raise_for_status() + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as fh: + fh.write(resp.content) + path = fh.name + logger.info("Model downloaded to %s (%d bytes)", path, len(resp.content)) + return path + except Exception as exc: + logger.error("Model download from %s failed: %s", url, exc) + return None + + # ------------------------------------------------------------------ + # Inference — loop lifecycle + # ------------------------------------------------------------------ + + def _start_inference(self, cmd: dict) -> None: + """Start the SDR capture + ONNX inference loop.""" + if not self._ort_available: + logger.error("start_inference: onnxruntime is not installed") + return + + with self._inf_lock: + if self._inf_session is None: + logger.error("start_inference: no fingerprint model loaded — call load_model first") + return + + if self._inf_thread is not None and self._inf_thread.is_alive(): + logger.warning("start_inference: inference loop is already running — ignoring") + return + + center_freq: float = float(cmd.get("center_freq", 2.4e9)) + sample_rate: float = float(cmd.get("sample_rate", 10e6)) + gain: float | str = cmd.get("gain", "auto") + device_type: str = cmd.get("device") or self.sdr_device + + self._inf_stop.clear() + self._inf_thread = threading.Thread( + target=self._inference_loop, + args=(device_type, center_freq, sample_rate, gain), + daemon=True, + name="ria-agent-inference", + ) + self._inf_thread.start() + logger.info( + "Inference started (device=%s freq=%.3f MHz rate=%.1f MHz)", + device_type, + center_freq / 1e6, + sample_rate / 1e6, + ) + + def _stop_inference(self) -> None: + """Signal the inference loop to stop and wait up to 5 s for it to exit.""" + self._inf_stop.set() + if self._inf_thread is not None and self._inf_thread.is_alive(): + self._inf_thread.join(timeout=5.0) + if self._inf_thread.is_alive(): + logger.warning("Inference thread did not exit within 5 s") + logger.info("Inference stopped") + + def _queue_sdr_config(self, cmd: dict) -> None: + """Merge SDR parameter updates into the pending-config dict. + + The inference loop checks this at each capture boundary and applies + the updates without restarting. + """ + cfg = {k: v for k, v in cmd.items() if k != "command" and v is not None} + with self._inf_lock: + self._inf_pending_config.update(cfg) + logger.debug("SDR reconfiguration queued: %s", cfg) + + # ------------------------------------------------------------------ + # Inference — main loop + # ------------------------------------------------------------------ + + def _inference_loop( + self, + device_type: str, + center_freq: float, + sample_rate: float, + gain: float | str, + ) -> None: + """Continuous SDR capture → ONNX inference → POST events to hub. + + Mirrors the two-stage pipeline in the hub's ``_inference_loop``: + an optional protocol-detector gates the fingerprint model so the + fingerprint model only runs when an active transmission is detected. + """ + try: + from ria_toolkit_oss.sdr import get_sdr_device + except ImportError as exc: + logger.error("inference_loop: ria_toolkit_oss not installed: %s", exc) + return + + try: + sdr = get_sdr_device(device_type) + _apply_sdr_config(sdr, {"center_freq": center_freq, "sample_rate": sample_rate, "gain": gain}) + except Exception as exc: + logger.error("SDR initialisation failed: %s", exc) + return + + try: + import numpy as np + + try: + from ria_toolkit_oss.orchestration.qa import estimate_snr_db + except ImportError: + estimate_snr_db = None + + # Snapshot model state once at loop start. If the hub sends a + # new load_model command while the loop is running, the new session + # will be picked up on the next loop restart (stop + start). + with self._inf_lock: + session = self._inf_session + index_to_label = dict(self._inf_index_to_label) + det_session = self._inf_detector_session + det_threshold = self._inf_detector_threshold + + input_name = session.get_inputs()[0].name + det_input_name = det_session.get_inputs()[0].name if det_session else None + + while not self._inf_stop.is_set(): + # Apply any queued SDR configuration changes. + with self._inf_lock: + pending = self._inf_pending_config.copy() + self._inf_pending_config.clear() + if pending: + _apply_sdr_config(sdr, pending) + + try: + samples = sdr.rx(_CAPTURE_SAMPLES) + except Exception as exc: + logger.warning("SDR capture error: %s", exc) + # Avoid a tight spin when the SDR is in a persistent error + # state (e.g. physically disconnected). + self._inf_stop.wait(timeout=0.5) + continue + + samples = np.array(samples, dtype=np.complex64) + snr_db = float(estimate_snr_db(samples)) if estimate_snr_db is not None else 0.0 + iq = np.stack([samples.real, samples.imag], axis=0).astype(np.float32) + + # Stage 1: protocol detector gate (optional). + if det_session is not None: + det_out = _run_onnx_session(det_session, det_input_name, iq) + det_probs = _softmax(det_out[0][0]) + det_confidence = float(det_probs.max()) + if det_confidence < det_threshold: + # No active protocol detected — report idle and skip + # the fingerprint model for this window. + self._post_event(device_id=None, confidence=det_confidence, snr_db=snr_db) + continue + + # Stage 2: fingerprint model. + out = _run_onnx_session(session, input_name, iq) + probs = _softmax(out[0][0]) + pred_idx = int(probs.argmax()) + confidence = float(probs[pred_idx]) + device_id = index_to_label.get(pred_idx) + idle = (device_id in _IDLE_LABELS) if device_id else True + self._post_event( + device_id=None if idle else device_id, + confidence=confidence, + snr_db=snr_db, + ) + + except Exception as exc: + logger.exception("Inference loop terminated unexpectedly: %s", exc) + finally: + try: + sdr.close() + except Exception: + pass + logger.info("Inference loop exited") + + def _post_event(self, device_id: str | None, confidence: float, snr_db: float) -> None: + """POST a single detection event to ``POST /orchestrator/nodes/{id}/events``. + + Failures are logged at DEBUG level and silently swallowed so that a + transient network blip does not crash the inference loop. + """ + from datetime import datetime, timezone + + payload = { + "type": "detection", + "device_id": device_id, + "confidence": round(confidence, 6), + "snr_db": round(snr_db, 2), + "timestamp": datetime.now(timezone.utc).isoformat(), + } + try: + resp = self._post( + f"/orchestrator/nodes/{self.node_id}/events", + json=payload, + timeout=5, + ) + if resp.status_code not in (200, 204): + logger.debug("Event POST returned HTTP %d", resp.status_code) + except Exception as exc: + logger.debug("Event POST failed (will retry next inference cycle): %s", exc) + # ------------------------------------------------------------------ # Recording upload (chunked for large files) # ------------------------------------------------------------------ @@ -244,7 +577,7 @@ class NodeAgent: repo_owner, repo_name = output_repo.split("/", 1) base_url = f"{self.hub_url}/datasets/upload" - steps = getattr(result, "steps", None) or [] + steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or [] for step in steps: output_path: str | None = getattr(step, "output_path", None) @@ -304,7 +637,6 @@ class NodeAgent: headers = {"X-API-Key": self.api_key} verify = not self.insecure - # Small files: single POST (unchanged endpoint, no assembly needed server-side). if size <= _DIRECT_THRESHOLD: with open(file_path, "rb") as fh: resp = _requests.post( @@ -318,7 +650,6 @@ class NodeAgent: resp.raise_for_status() return resp.json() - # Large files: chunked upload — each request is ≤ 50 MB. total_chunks = math.ceil(size / _CHUNK_SIZE) upload_id = str(uuid.uuid4()) chunk_url = base_url + "/chunk" @@ -339,18 +670,13 @@ class NodeAgent: chunk_url, headers=headers, files={"file": (filename, chunk, "application/octet-stream")}, - data={ - **metadata, - "upload_id": upload_id, - "chunk_index": i, - "total_chunks": total_chunks, - }, + data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks}, timeout=120, verify=verify, ) if not resp.ok: raise RuntimeError( - f"Chunk {i + 1}/{total_chunks} failed: " f"HTTP {resp.status_code}: {resp.text[:300]}" + f"Chunk {i + 1}/{total_chunks} failed: HTTP {resp.status_code}: {resp.text[:300]}" ) resp_data = resp.json() logger.debug("Chunk %d/%d uploaded", i + 1, total_chunks) @@ -393,10 +719,41 @@ class NodeAgent: # --------------------------------------------------------------------------- -# Helpers +# Module-level helpers (shared by NodeAgent._inference_loop) # --------------------------------------------------------------------------- +def _run_onnx_session(session: Any, input_name: str, iq: Any) -> list: + """Run an ONNX session on an IQ array (2, N). + + Tries channel-first layout (1, 2, N) first; falls back to interleaved flat + (1, 2*N) when the model expects a flattened input. + """ + import numpy as np + + x = iq[np.newaxis] # (1, 2, N) + try: + return session.run(None, {input_name: x}) + except Exception: + return session.run(None, {input_name: iq.flatten()[np.newaxis]}) + + +def _softmax(x: Any) -> Any: + import numpy as np + + e = np.exp(x - x.max()) + return e / e.sum() + + +def _apply_sdr_config(sdr: Any, cfg: dict) -> None: + for attr in ("center_freq", "sample_rate", "gain"): + if attr in cfg: + try: + setattr(sdr, attr, cfg[attr]) + except Exception as exc: + logger.warning("SDR config %s=%r failed: %s", attr, cfg[attr], exc) + + def _sigmf_files(data_path: str) -> list[str]: """Return paths to both SigMF files (.sigmf-data and .sigmf-meta) for a recording.""" candidates = [data_path] @@ -405,6 +762,29 @@ def _sigmf_files(data_path: str) -> list[str]: return [p for p in candidates if os.path.exists(p)] +# --------------------------------------------------------------------------- +# Config file helpers +# --------------------------------------------------------------------------- + +_DEFAULT_CONFIG_PATH = os.path.join( + os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")), + "ria-agent", + "config.json", +) + + +def _load_config(path: str) -> dict: + """Load a JSON config file, returning an empty dict if it does not exist.""" + try: + with open(path) as fh: + return json.load(fh) + except FileNotFoundError: + return {} + except Exception as exc: + logger.warning("Could not read config file %s: %s", path, exc) + return {} + + # --------------------------------------------------------------------------- # CLI entry point # --------------------------------------------------------------------------- @@ -420,67 +800,94 @@ def main() -> None: "campaigns / inference on local SDR hardware." ), ) + parser.add_argument( + "--config", + default=None, + metavar="PATH", + help=( + f"Path to a JSON config file (default: {_DEFAULT_CONFIG_PATH}). " + "CLI arguments override config file values." + ), + ) parser.add_argument( "--hub", - required=True, + default=None, metavar="URL", help="RIA Hub base URL, e.g. https://riahub.company.com", ) parser.add_argument( "--key", - required=True, + default=None, metavar="API_KEY", help="Shared API key (must match [wac] API_KEY in the hub's app.ini)", ) parser.add_argument( "--name", - required=True, + default=None, metavar="NAME", help='Human-readable name shown in the Target Node dropdown, e.g. "lab-bench-1"', ) parser.add_argument( "--device", - default="unknown", + default=None, metavar="SDR", help=( - "SDR device type reported to the hub (informational only). " + "SDR device type reported to the hub and used for inference. " "Examples: plutosdr, usrp_b210, rtlsdr, mock. Default: unknown" ), ) parser.add_argument( "--insecure", action="store_true", + default=None, help="Disable TLS certificate verification (dev/self-signed certs only)", ) parser.add_argument( "--log-level", - default="INFO", + default=None, choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Logging verbosity (default: INFO)", ) args = parser.parse_args() + # Merge: config file → CLI args (CLI wins). + config_path = args.config or _DEFAULT_CONFIG_PATH + cfg = _load_config(config_path) + + hub = args.hub or cfg.get("hub") + key = args.key or cfg.get("key") + name = args.name or cfg.get("name") + device = args.device or cfg.get("device", "unknown") + insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False) + log_level = args.log_level or cfg.get("log_level", "INFO") + + if not hub: + parser.error("--hub is required (or set 'hub' in the config file)") + if not key: + parser.error("--key is required (or set 'key' in the config file)") + if not name: + parser.error("--name is required (or set 'name' in the config file)") + logging.basicConfig( - level=getattr(logging, args.log_level), + level=getattr(logging, log_level), format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", stream=sys.stderr, ) - # Warn loudly if --insecure is used outside of development. - if args.insecure: + if insecure: logger.warning( "--insecure disables TLS certificate verification. " "Only use this for local development with self-signed certs." ) agent = NodeAgent( - hub_url=args.hub, - api_key=args.key, - name=args.name, - sdr_device=args.device, - insecure=args.insecure, + hub_url=hub, + api_key=key, + name=name, + sdr_device=device, + insecure=insecure, ) agent.run() From 195db4a27db894f14a70c5fc36afa8e007441026 Mon Sep 17 00:00:00 2001 From: ben Date: Tue, 14 Apr 2026 10:45:54 -0400 Subject: [PATCH 02/13] quick fix --- poetry.toml | 2 ++ src/ria_toolkit_oss/sdr/sdr.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 poetry.toml diff --git a/poetry.toml b/poetry.toml new file mode 100644 index 0000000..25758d2 --- /dev/null +++ b/poetry.toml @@ -0,0 +1,2 @@ +[virtualenvs.options] +system-site-packages = true diff --git a/src/ria_toolkit_oss/sdr/sdr.py b/src/ria_toolkit_oss/sdr/sdr.py index 36e26f7..f2ea9f4 100644 --- a/src/ria_toolkit_oss/sdr/sdr.py +++ b/src/ria_toolkit_oss/sdr/sdr.py @@ -43,6 +43,13 @@ class SDR(ABC): self.tx_gain = None self._param_lock = threading.RLock() # Reentrant lock + # Pending config consumed by rx() on first call and by _apply_sdr_config + # in the agent inference loop. Subclasses that need different defaults + # (e.g. MockSDR) can overwrite these in their own __init__. + self.center_freq: float = 2.4e9 + self.sample_rate: float = 10e6 + self.gain: float = 40.0 + def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording: """ Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided. @@ -100,6 +107,32 @@ class SDR(ABC): self._num_buffers_processed = 0 return recording + def rx(self, num_samples: int) -> "np.ndarray": + """Return *num_samples* complex IQ samples as a 1-D complex64 array. + + This is the interface used by the agent inference loop. On first call, + ``init_rx()`` is invoked automatically using the values stored in + ``center_freq``, ``sample_rate``, and ``gain`` (set beforehand by + ``_apply_sdr_config``). Subsequent calls stream directly. + + Subclasses may override this for hardware-native capture APIs (e.g. + ``MockSDR`` uses AWGN generation; ``PlutoSDR`` could use + ``self.radio.rx()``). + """ + if not self._rx_initialized: + gain = self.gain if isinstance(self.gain, (int, float)) else 40.0 + self.init_rx( + sample_rate=self.sample_rate, + center_frequency=self.center_freq, + gain=gain, + channel=0, + ) + recording = self.record(num_samples=num_samples) + # Recording.data is either a list of 1-D arrays (one per channel) or a + # 2-D ndarray (channels × samples). Either way, index 0 is channel 0. + data = recording.data + return data[0] if hasattr(data, "__getitem__") else data + def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000): """ Stream iq samples as interleaved bytes via zmq. From efc09481104e9c1fddccbab27a8bc1ce9fce7ed8 Mon Sep 17 00:00:00 2001 From: ben Date: Fri, 17 Apr 2026 09:43:59 -0400 Subject: [PATCH 03/13] ria composer support --- src/ria_toolkit_oss/orchestration/campaign.py | 6 +- src/ria_toolkit_oss/orchestration/executor.py | 61 ++- .../remote_control/__init__.py | 6 + .../remote_control/remote_transmitter.py | 147 +++++++ .../remote_transmitter_controller.py | 210 ++++++++++ tests/remote_control/__init__.py | 0 .../remote_control/test_remote_transmitter.py | 266 ++++++++++++ .../test_remote_transmitter_controller.py | 294 +++++++++++++ .../test_sdr_remote_integration.py | 391 ++++++++++++++++++ 9 files changed, 1379 insertions(+), 2 deletions(-) create mode 100644 src/ria_toolkit_oss/remote_control/__init__.py create mode 100644 src/ria_toolkit_oss/remote_control/remote_transmitter.py create mode 100644 src/ria_toolkit_oss/remote_control/remote_transmitter_controller.py create mode 100644 tests/remote_control/__init__.py create mode 100644 tests/remote_control/test_remote_transmitter.py create mode 100644 tests/remote_control/test_remote_transmitter_controller.py create mode 100644 tests/remote_control/test_sdr_remote_integration.py diff --git a/src/ria_toolkit_oss/orchestration/campaign.py b/src/ria_toolkit_oss/orchestration/campaign.py index 9d96c96..027c33f 100644 --- a/src/ria_toolkit_oss/orchestration/campaign.py +++ b/src/ria_toolkit_oss/orchestration/campaign.py @@ -223,13 +223,16 @@ class TransmitterConfig: id: str type: str # "wifi", "bluetooth", "sdr", "external" - control_method: str # "external_script" | "sdr" + control_method: str # "external_script" | "sdr" | "sdr_remote" schedule: list[CaptureStep] # For external_script control script: Optional[str] = None # path to control script device: Optional[str] = None # e.g. "/dev/wlan0" + # For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port + sdr_remote: Optional[dict] = None + @classmethod def from_dict(cls, d: dict) -> "TransmitterConfig": schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])] @@ -240,6 +243,7 @@ class TransmitterConfig: schedule=schedule, script=d.get("script"), device=d.get("device"), + sdr_remote=d.get("sdr_remote"), ) diff --git a/src/ria_toolkit_oss/orchestration/executor.py b/src/ria_toolkit_oss/orchestration/executor.py index 629c0d8..1bdd4d8 100644 --- a/src/ria_toolkit_oss/orchestration/executor.py +++ b/src/ria_toolkit_oss/orchestration/executor.py @@ -196,6 +196,7 @@ class CampaignExecutor: self.config = config self.progress_cb = progress_cb self._sdr = None + self._remote_tx_controllers: dict = {} if verbose: logging.basicConfig(level=logging.DEBUG) @@ -222,6 +223,7 @@ class CampaignExecutor: ) self._init_sdr() + self._init_remote_tx_controllers() try: total = self.config.total_steps() step_index = 0 @@ -248,6 +250,7 @@ class CampaignExecutor: ) finally: self._close_sdr() + self._close_remote_tx_controllers() result.end_time = time.time() logger.info( @@ -287,6 +290,41 @@ class CampaignExecutor: logger.warning(f"SDR close error: {e}") self._sdr = None + # ------------------------------------------------------------------ + # Remote Tx controller management + # ------------------------------------------------------------------ + + def _init_remote_tx_controllers(self) -> None: + """Open SSH+ZMQ connections for all sdr_remote transmitters.""" + from ria_toolkit_oss.remote_control import RemoteTransmitterController + + for tx in self.config.transmitters: + if tx.control_method != "sdr_remote": + continue + cfg = tx.sdr_remote + if not cfg: + raise RuntimeError(f"Transmitter '{tx.id}' uses sdr_remote but has no sdr_remote config") + logger.info(f"Connecting remote Tx controller for {tx.id} → {cfg['host']}") + ctrl = RemoteTransmitterController( + host=cfg["host"], + ssh_user=cfg["ssh_user"], + ssh_key_path=cfg["ssh_key_path"], + zmq_port=int(cfg.get("zmq_port", 5556)), + ) + ctrl.set_radio( + device_type=cfg["device_type"], + device_id=cfg.get("device_id", ""), + ) + self._remote_tx_controllers[tx.id] = ctrl + + def _close_remote_tx_controllers(self) -> None: + for tx_id, ctrl in list(self._remote_tx_controllers.items()): + try: + ctrl.close() + except Exception as exc: + logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}") + self._remote_tx_controllers.clear() + def _record(self, duration_s: float) -> Recording: """Capture ``duration_s`` seconds of IQ samples.""" num_samples = int(duration_s * self.config.recorder.sample_rate) @@ -372,7 +410,8 @@ class CampaignExecutor: traffic, etc. The script is responsible for applying the configuration and returning promptly (i.e. not blocking for the capture duration). - For SDR transmitters this is a no-op placeholder (TX not yet implemented). + For ``sdr_remote`` the remote ZMQ controller calls ``init_tx`` then + starts a background transmit thread that runs for the step duration. """ if transmitter.control_method == "external_script": if not transmitter.script: @@ -384,6 +423,20 @@ class CampaignExecutor: elif transmitter.control_method == "sdr": logger.debug("SDR TX not yet implemented — skipping start") + elif transmitter.control_method == "sdr_remote": + ctrl = self._remote_tx_controllers.get(transmitter.id) + if ctrl is None: + raise RuntimeError(f"No remote Tx controller found for transmitter '{transmitter.id}'") + gain = step.power_dbm if step.power_dbm is not None else 0.0 + ctrl.init_tx( + center_frequency=self.config.recorder.center_freq, + sample_rate=self.config.recorder.sample_rate, + gain=gain, + channel=step.channel or 0, + ) + # Start transmission in background; _record() runs concurrently + ctrl.transmit_async(step.duration + 1.0) + else: logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping") @@ -391,6 +444,7 @@ class CampaignExecutor: """Signal the transmitter to stop. Calls ``