ria-toolkit-oss/src/ria_toolkit_oss/agent/legacy_executor.py

1001 lines
38 KiB
Python
Raw Permalink Normal View History

2026-03-31 13:51:10 -04:00
"""RT-OSS Node Agent — connects to RIA Hub and dispatches work to local hardware.
The agent runs on any machine with an SDR attached and connects **outbound** to
RIA Hub. No inbound ports need to be opened on the user's machine, and the
connection works identically through NAT, corporate firewalls, or a Pi on a
cellular link.
Usage::
ria-agent \\
--hub https://riahub.company.com \\
--key <api-key> \\
--name lab-bench-1 \\
[--device plutosdr] \\
[--insecure]
2026-04-10 16:39:11 -04:00
# Or store credentials in a config file and omit them from the command line:
ria-agent --config ~/.config/ria-agent/config.json --name lab-bench-1
2026-03-31 13:51:10 -04:00
The agent:
1. Registers with RIA Hub and receives a ``node_id``.
2. Sends a heartbeat every 30 s so the hub knows it is online.
2026-04-17 11:49:44 -04:00
3. Long-polls ``GET /composer/nodes/{id}/commands`` (30 s timeout).
2026-04-10 16:39:11 -04:00
4. Dispatches received commands:
- ``run_campaign``: executes via CampaignExecutor, uploads recordings.
- ``load_model``: loads an ONNX fingerprint or detector model.
- ``start_inference``: opens the SDR, runs the inference loop, posts
detection events to the hub for SSE fan-out to browsers.
- ``stop_inference``: gracefully stops the inference loop.
- ``configure_inference``: queues an SDR parameter update (applied at the
next capture boundary without restarting the loop).
5. Deregisters cleanly on SIGINT / SIGTERM.
Config file (JSON, optional)::
{
"hub": "https://riahub.company.com",
"key": "secret",
"name": "lab-bench-1",
"device": "plutosdr",
"insecure": false,
"log_level": "INFO"
}
CLI arguments always override config file values.
2026-03-31 13:51:10 -04:00
"""
from __future__ import annotations
2026-04-10 16:39:11 -04:00
import json
2026-03-31 13:51:10 -04:00
import logging
import math
import os
import signal
import sys
import threading
import time
import uuid
from typing import Any
logger = logging.getLogger("ria_agent")
# ---------------------------------------------------------------------------
# Tuneable constants
# ---------------------------------------------------------------------------
_HEARTBEAT_INTERVAL = 30 # seconds between heartbeats
_POLL_TIMEOUT = 30 # server-side long-poll duration
_POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
2026-04-21 16:40:49 -04:00
_CHUNK_SIZE = 10 * 1024 * 1024 # 10 MB per chunk — fast enough for git-LFS to process within timeout
2026-03-31 13:51:10 -04:00
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
2026-04-10 16:39:11 -04:00
_CAPTURE_SAMPLES = 4096 # IQ samples per inference window
_IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"})
2026-03-31 13:51:10 -04:00
# ---------------------------------------------------------------------------
# Agent
# ---------------------------------------------------------------------------
class NodeAgent:
"""Outbound-connecting agent that bridges RIA Hub to local SDR hardware.
All network I/O is initiated by the agent (outbound). RIA Hub never opens
a connection back to the agent's machine.
"""
def __init__(
self,
hub_url: str,
api_key: str,
name: str,
sdr_device: str = "unknown",
insecure: bool = False,
2026-04-20 12:33:14 -04:00
role: str = "general",
session_code: str | None = None,
2026-03-31 13:51:10 -04:00
) -> None:
self.hub_url = hub_url.rstrip("/")
self.api_key = api_key
self.name = name
self.sdr_device = sdr_device
self.insecure = insecure
2026-04-20 12:33:14 -04:00
self.role = role
self.session_code = session_code
2026-03-31 13:51:10 -04:00
self.node_id: str | None = None
self._stop = threading.Event()
2026-04-20 12:33:14 -04:00
# ── TX state ────────────────────────────────────────────────────────
self._tx_stop = threading.Event()
self._tx_thread: threading.Thread | None = None
2026-04-10 16:39:11 -04:00
# ── Inference state ─────────────────────────────────────────────────
# Protected by _inf_lock for cross-thread model swaps.
self._inf_lock = threading.Lock()
self._inf_session: Any = None # primary fingerprint ONNX session
self._inf_index_to_label: dict[int, str] = {}
self._inf_detector_session: Any = None # optional protocol-detector session
self._inf_detector_index_to_label: dict[int, str] = {}
self._inf_detector_threshold: float = 0.7
self._inf_pending_config: dict = {} # queued SDR attribute updates
self._inf_stop = threading.Event()
self._inf_thread: threading.Thread | None = None
# Detect optional dependencies once at startup so capability
# advertising is accurate from the first registration.
try:
import onnxruntime as _ort_mod
self._ort: Any = _ort_mod
self._ort_available = True
except ImportError:
self._ort = None
self._ort_available = False
2026-03-31 13:51:10 -04:00
try:
import ria_toolkit_oss
self._ria_version: str = getattr(ria_toolkit_oss, "__version__", "unknown")
except Exception:
self._ria_version = "unknown"
# ------------------------------------------------------------------
# Public entry point
# ------------------------------------------------------------------
def run(self) -> None:
"""Register, start the heartbeat thread, and enter the command loop.
Blocks until SIGINT or SIGTERM is received.
"""
self._register()
def _shutdown(sig: int, _frame: Any) -> None:
logger.info("Shutdown signal received — stopping agent")
self._stop.set()
signal.signal(signal.SIGINT, _shutdown)
signal.signal(signal.SIGTERM, _shutdown)
hb = threading.Thread(target=self._heartbeat_loop, daemon=True, name="ria-agent-heartbeat")
hb.start()
logger.info("Agent %r online (node_id=%s, hub=%s)", self.name, self.node_id, self.hub_url)
try:
self._command_loop()
finally:
self._stop.set()
2026-04-10 16:39:11 -04:00
self._stop_inference()
2026-03-31 13:51:10 -04:00
self._deregister()
# ------------------------------------------------------------------
# Registration
# ------------------------------------------------------------------
def _register(self) -> None:
2026-04-10 16:39:11 -04:00
capabilities = ["campaign"]
if self._ort_available:
capabilities.append("inference")
2026-04-20 12:33:14 -04:00
if self.role == "tx":
capabilities.append("transmit")
payload: dict = {
"name": self.name,
"sdr_device": self.sdr_device,
"ria_toolkit_version": self._ria_version,
"capabilities": capabilities,
"role": self.role,
}
if self.session_code:
payload["session_code"] = self.session_code
resp = self._post("/composer/nodes/register", json=payload, timeout=15)
2026-03-31 13:51:10 -04:00
resp.raise_for_status()
self.node_id = resp.json()["node_id"]
2026-04-20 12:33:14 -04:00
logger.info(
"Registered as %r (node_id=%s, role=%s%s)",
self.name,
self.node_id,
self.role,
f", session_code={self.session_code!r}" if self.session_code else "",
)
2026-03-31 13:51:10 -04:00
def _deregister(self) -> None:
if not self.node_id:
return
try:
2026-04-17 11:49:44 -04:00
self._delete(f"/composer/nodes/{self.node_id}", timeout=10)
2026-03-31 13:51:10 -04:00
logger.info("Deregistered %s", self.node_id)
except Exception as exc:
logger.debug("Deregister failed (ignored on shutdown): %s", exc)
# ------------------------------------------------------------------
# Heartbeat thread
# ------------------------------------------------------------------
def _heartbeat_loop(self) -> None:
while not self._stop.wait(_HEARTBEAT_INTERVAL):
try:
2026-04-17 11:49:44 -04:00
resp = self._post(f"/composer/nodes/{self.node_id}/heartbeat", timeout=10)
2026-03-31 13:51:10 -04:00
if resp.status_code == 404:
logger.warning("Heartbeat got 404 — hub lost registration, re-registering")
self._register()
except Exception as exc:
logger.warning("Heartbeat failed: %s", exc)
# ------------------------------------------------------------------
# Command poll loop
# ------------------------------------------------------------------
def _command_loop(self) -> None:
while not self._stop.is_set():
try:
resp = self._get(
2026-04-17 11:49:44 -04:00
f"/composer/nodes/{self.node_id}/commands",
2026-03-31 13:51:10 -04:00
timeout=_POLL_CLIENT_TIMEOUT,
)
if resp.status_code == 204:
# No command within the timeout window — loop immediately.
continue
if resp.status_code == 404:
logger.warning("Command poll got 404 — re-registering")
self._register()
continue
resp.raise_for_status()
cmd = resp.json()
logger.info("Received command: %s", cmd.get("command"))
self._dispatch(cmd)
except Exception as exc:
if not self._stop.is_set():
logger.warning("Command poll error: %s — retrying in %ds", exc, _RECONNECT_PAUSE)
time.sleep(_RECONNECT_PAUSE)
# ------------------------------------------------------------------
# Command dispatch
# ------------------------------------------------------------------
def _dispatch(self, cmd: dict) -> None:
command = cmd.get("command")
if command == "run_campaign":
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
config_dict: dict = cmd.get("payload") or {}
2026-04-21 14:09:36 -04:00
skip_local_tx: bool = bool(cmd.get("skip_local_tx", False))
2026-03-31 13:51:10 -04:00
threading.Thread(
target=self._run_campaign,
2026-04-21 14:09:36 -04:00
args=(campaign_id, config_dict, skip_local_tx),
2026-03-31 13:51:10 -04:00
daemon=True,
name=f"campaign-{campaign_id[:8]}",
).start()
2026-04-10 16:39:11 -04:00
elif command == "load_model":
threading.Thread(
target=self._load_model,
args=(cmd,),
daemon=True,
name="ria-load-model",
).start()
elif command == "start_inference":
threading.Thread(
target=self._start_inference,
args=(cmd,),
daemon=True,
name="ria-start-inf",
).start()
elif command == "stop_inference":
self._stop_inference()
elif command == "configure_inference":
self._queue_sdr_config(cmd)
2026-04-20 12:33:14 -04:00
elif command == "start_transmit":
threading.Thread(
target=self._start_transmit,
args=(cmd,),
daemon=True,
name="ria-start-tx",
).start()
elif command == "stop_transmit":
self._stop_transmit()
elif command == "configure_transmit":
logger.info("configure_transmit received — will apply on next step boundary")
2026-03-31 13:51:10 -04:00
else:
logger.warning("Unknown command %r — ignored", command)
# ------------------------------------------------------------------
# Campaign execution
# ------------------------------------------------------------------
2026-04-21 14:09:36 -04:00
def _run_campaign(self, campaign_id: str, config_dict: dict, skip_local_tx: bool = False) -> None:
2026-03-31 13:51:10 -04:00
try:
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
except ImportError as exc:
logger.error(
"Campaign %s cannot start — ria_toolkit_oss not fully installed: %s",
campaign_id[:8],
exc,
)
return
2026-04-21 14:09:36 -04:00
logger.info("Campaign %s starting (skip_local_tx=%s)", campaign_id[:8], skip_local_tx)
2026-03-31 13:51:10 -04:00
try:
config = CampaignConfig.from_dict(config_dict)
2026-04-21 14:09:36 -04:00
executor = CampaignExecutor(config, skip_local_tx=skip_local_tx)
2026-03-31 13:51:10 -04:00
result = executor.run()
logger.info("Campaign %s completed — uploading recordings", campaign_id[:8])
self._upload_recordings(campaign_id, config, result)
2026-04-01 14:08:13 -04:00
result_dict = result.to_dict() if hasattr(result, "to_dict") else None
self._report_campaign_status(campaign_id, "completed", result=result_dict)
2026-03-31 13:51:10 -04:00
except Exception as exc:
logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
2026-04-01 14:08:13 -04:00
self._report_campaign_status(campaign_id, "failed", error=str(exc))
2026-03-31 13:51:10 -04:00
2026-04-20 12:33:14 -04:00
# ------------------------------------------------------------------
# TX execution
# ------------------------------------------------------------------
def _start_transmit(self, cmd: dict) -> None:
"""Execute a synthetic transmit campaign using TxExecutor.
The command payload mirrors a TransmitterConfig dict with an optional
``schedule`` of steps. Each step synthesises a signal and transmits it
via the local SDR in TX mode.
"""
try:
from ria_toolkit_oss.orchestration.tx_executor import TxExecutor
except ImportError as exc:
logger.error("start_transmit: TxExecutor not available: %s", exc)
return
if self._tx_thread and self._tx_thread.is_alive():
logger.warning("start_transmit: TX already running — ignoring duplicate command")
return
self._tx_stop.clear()
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
executor = TxExecutor(
config=cmd,
sdr_device=self.sdr_device,
stop_event=self._tx_stop,
)
self._tx_thread = threading.Thread(
target=self._run_tx_campaign,
args=(executor, campaign_id),
daemon=True,
name=f"tx-campaign-{campaign_id[:8]}",
)
self._tx_thread.start()
def _run_tx_campaign(self, executor: Any, campaign_id: str) -> None:
try:
executor.run()
logger.info("TX campaign %s completed", campaign_id[:8])
self._report_campaign_status(campaign_id, "completed")
except Exception as exc:
logger.error("TX campaign %s failed: %s", campaign_id[:8], exc)
self._report_campaign_status(campaign_id, "failed", error=str(exc))
def _stop_transmit(self) -> None:
"""Signal the TX loop to stop gracefully."""
self._tx_stop.set()
if self._tx_thread and self._tx_thread.is_alive():
self._tx_thread.join(timeout=5.0)
logger.info("TX stopped")
2026-04-10 16:39:11 -04:00
# ------------------------------------------------------------------
# Inference — model loading
# ------------------------------------------------------------------
def _load_model(self, cmd: dict) -> None:
"""Load an ONNX model into the fingerprint or detector slot.
The ``model_path`` field may be either a local filesystem path or an
``http(s)://`` URL; in the latter case the file is downloaded first.
"""
if not self._ort_available:
logger.error("load_model: onnxruntime is not installed — cannot load model")
return
model_path: str = cmd.get("model_path", "")
label_map: dict[str, int] = cmd.get("label_map") or {}
stage: str = cmd.get("stage", "fingerprint")
detector_threshold: float = float(cmd.get("detector_threshold") or 0.7)
if model_path.startswith(("http://", "https://")):
model_path = self._download_model(model_path)
if model_path is None:
return
try:
session = self._ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
except Exception as exc:
logger.error("Failed to load model %r: %s", model_path, exc)
return
index_to_label = {v: k for k, v in label_map.items()}
with self._inf_lock:
if stage == "detector":
self._inf_detector_session = session
self._inf_detector_index_to_label = index_to_label
self._inf_detector_threshold = detector_threshold
logger.info(
"Detector model loaded: path=%s classes=%d threshold=%.2f",
model_path,
len(label_map),
detector_threshold,
)
else:
self._inf_session = session
self._inf_index_to_label = index_to_label
logger.info(
"Fingerprint model loaded: path=%s classes=%d",
model_path,
len(label_map),
)
def _download_model(self, url: str) -> str | None:
"""Download a model from *url* to a temp file and return the local path."""
import tempfile
import requests as _requests
try:
logger.info("Downloading model from %s", url)
resp = _requests.get(
url,
headers={"X-API-Key": self.api_key},
verify=not self.insecure,
timeout=120,
)
resp.raise_for_status()
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as fh:
fh.write(resp.content)
path = fh.name
logger.info("Model downloaded to %s (%d bytes)", path, len(resp.content))
return path
except Exception as exc:
logger.error("Model download from %s failed: %s", url, exc)
return None
# ------------------------------------------------------------------
# Inference — loop lifecycle
# ------------------------------------------------------------------
def _start_inference(self, cmd: dict) -> None:
"""Start the SDR capture + ONNX inference loop."""
if not self._ort_available:
logger.error("start_inference: onnxruntime is not installed")
return
with self._inf_lock:
if self._inf_session is None:
logger.error("start_inference: no fingerprint model loaded — call load_model first")
return
if self._inf_thread is not None and self._inf_thread.is_alive():
logger.warning("start_inference: inference loop is already running — ignoring")
return
center_freq: float = float(cmd.get("center_freq", 2.4e9))
sample_rate: float = float(cmd.get("sample_rate", 10e6))
gain: float | str = cmd.get("gain", "auto")
device_type: str = cmd.get("device") or self.sdr_device
self._inf_stop.clear()
self._inf_thread = threading.Thread(
target=self._inference_loop,
args=(device_type, center_freq, sample_rate, gain),
daemon=True,
name="ria-agent-inference",
)
self._inf_thread.start()
logger.info(
"Inference started (device=%s freq=%.3f MHz rate=%.1f MHz)",
device_type,
center_freq / 1e6,
sample_rate / 1e6,
)
def _stop_inference(self) -> None:
"""Signal the inference loop to stop and wait up to 5 s for it to exit."""
self._inf_stop.set()
if self._inf_thread is not None and self._inf_thread.is_alive():
self._inf_thread.join(timeout=5.0)
if self._inf_thread.is_alive():
logger.warning("Inference thread did not exit within 5 s")
logger.info("Inference stopped")
def _queue_sdr_config(self, cmd: dict) -> None:
"""Merge SDR parameter updates into the pending-config dict.
The inference loop checks this at each capture boundary and applies
the updates without restarting.
"""
cfg = {k: v for k, v in cmd.items() if k != "command" and v is not None}
with self._inf_lock:
self._inf_pending_config.update(cfg)
logger.debug("SDR reconfiguration queued: %s", cfg)
# ------------------------------------------------------------------
# Inference — main loop
# ------------------------------------------------------------------
def _inference_loop(
self,
device_type: str,
center_freq: float,
sample_rate: float,
gain: float | str,
) -> None:
"""Continuous SDR capture → ONNX inference → POST events to hub.
Mirrors the two-stage pipeline in the hub's ``_inference_loop``:
an optional protocol-detector gates the fingerprint model so the
fingerprint model only runs when an active transmission is detected.
"""
try:
from ria_toolkit_oss.sdr import get_sdr_device
except ImportError as exc:
logger.error("inference_loop: ria_toolkit_oss not installed: %s", exc)
return
try:
sdr = get_sdr_device(device_type)
_apply_sdr_config(sdr, {"center_freq": center_freq, "sample_rate": sample_rate, "gain": gain})
except Exception as exc:
logger.error("SDR initialisation failed: %s", exc)
return
try:
import numpy as np
try:
from ria_toolkit_oss.orchestration.qa import estimate_snr_db
except ImportError:
estimate_snr_db = None
# Snapshot model state once at loop start. If the hub sends a
# new load_model command while the loop is running, the new session
# will be picked up on the next loop restart (stop + start).
with self._inf_lock:
session = self._inf_session
index_to_label = dict(self._inf_index_to_label)
det_session = self._inf_detector_session
det_threshold = self._inf_detector_threshold
input_name = session.get_inputs()[0].name
det_input_name = det_session.get_inputs()[0].name if det_session else None
while not self._inf_stop.is_set():
# Apply any queued SDR configuration changes.
with self._inf_lock:
pending = self._inf_pending_config.copy()
self._inf_pending_config.clear()
if pending:
_apply_sdr_config(sdr, pending)
try:
samples = sdr.rx(_CAPTURE_SAMPLES)
except Exception as exc:
logger.warning("SDR capture error: %s", exc)
# Avoid a tight spin when the SDR is in a persistent error
# state (e.g. physically disconnected).
self._inf_stop.wait(timeout=0.5)
continue
samples = np.array(samples, dtype=np.complex64)
snr_db = float(estimate_snr_db(samples)) if estimate_snr_db is not None else 0.0
iq = np.stack([samples.real, samples.imag], axis=0).astype(np.float32)
# Stage 1: protocol detector gate (optional).
if det_session is not None:
det_out = _run_onnx_session(det_session, det_input_name, iq)
det_probs = _softmax(det_out[0][0])
det_confidence = float(det_probs.max())
if det_confidence < det_threshold:
# No active protocol detected — report idle and skip
# the fingerprint model for this window.
self._post_event(device_id=None, confidence=det_confidence, snr_db=snr_db)
continue
# Stage 2: fingerprint model.
out = _run_onnx_session(session, input_name, iq)
probs = _softmax(out[0][0])
pred_idx = int(probs.argmax())
confidence = float(probs[pred_idx])
device_id = index_to_label.get(pred_idx)
idle = (device_id in _IDLE_LABELS) if device_id else True
self._post_event(
device_id=None if idle else device_id,
confidence=confidence,
snr_db=snr_db,
)
except Exception as exc:
logger.exception("Inference loop terminated unexpectedly: %s", exc)
finally:
try:
sdr.close()
except Exception:
pass
logger.info("Inference loop exited")
def _post_event(self, device_id: str | None, confidence: float, snr_db: float) -> None:
2026-04-17 11:49:44 -04:00
"""POST a single detection event to ``POST /composer/nodes/{id}/events``.
2026-04-10 16:39:11 -04:00
Failures are logged at DEBUG level and silently swallowed so that a
transient network blip does not crash the inference loop.
"""
from datetime import datetime, timezone
payload = {
"type": "detection",
"device_id": device_id,
"confidence": round(confidence, 6),
"snr_db": round(snr_db, 2),
"timestamp": datetime.now(timezone.utc).isoformat(),
}
try:
resp = self._post(
2026-04-17 11:49:44 -04:00
f"/composer/nodes/{self.node_id}/events",
2026-04-10 16:39:11 -04:00
json=payload,
timeout=5,
)
if resp.status_code not in (200, 204):
logger.debug("Event POST returned HTTP %d", resp.status_code)
except Exception as exc:
logger.debug("Event POST failed (will retry next inference cycle): %s", exc)
2026-03-31 13:51:10 -04:00
# ------------------------------------------------------------------
# Recording upload (chunked for large files)
# ------------------------------------------------------------------
def _upload_recordings(self, campaign_id: str, config: Any, result: Any) -> None:
output_repo: str | None = getattr(getattr(config, "output", None), "repo", None)
if not output_repo or "/" not in output_repo:
logger.warning("Campaign %s: no output.repo — skipping upload", campaign_id[:8])
return
repo_owner, repo_name = output_repo.split("/", 1)
base_url = f"{self.hub_url}/datasets/upload"
2026-04-10 16:39:11 -04:00
steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or []
2026-03-31 13:51:10 -04:00
2026-04-21 17:11:16 -04:00
output_obj = getattr(config, "output", None)
folder = getattr(output_obj, "folder", None)
campaign_name: str = folder if folder is not None else (getattr(config, "name", None) or "")
2026-03-31 13:51:10 -04:00
for step in steps:
output_path: str | None = getattr(step, "output_path", None)
if not output_path:
continue
device_id: str = getattr(step, "transmitter_id", "") or ""
for fpath in _sigmf_files(output_path):
2026-04-21 16:40:49 -04:00
basename = os.path.basename(fpath)
path_parts = [p for p in (campaign_name, device_id) if p]
filename = "/".join(path_parts + [basename])
2026-03-31 13:51:10 -04:00
metadata = {
"filename": filename,
"repo_owner": repo_owner,
"repo_name": repo_name,
"device_id": device_id,
"campaign_id": campaign_id,
}
try:
resp_data = self._upload_file(base_url, fpath, metadata)
logger.info(
"Campaign %s: uploaded %s (oid=%s)",
campaign_id[:8],
filename,
resp_data.get("oid", "?"),
)
except Exception as exc:
logger.warning("Campaign %s: upload of %s failed: %s", campaign_id[:8], filename, exc)
2026-04-01 14:08:13 -04:00
def _report_campaign_status(
self,
campaign_id: str,
status: str,
result: "dict | None" = None,
error: "str | None" = None,
) -> None:
"""POST campaign completion/failure back to the hub so GET /status/{id} resolves."""
payload: dict = {"campaign_id": campaign_id, "status": status}
if result is not None:
payload["result"] = result
if error is not None:
payload["error"] = error
try:
resp = self._post(
2026-04-17 11:49:44 -04:00
f"/composer/nodes/{self.node_id}/campaign-status",
2026-04-01 14:08:13 -04:00
json=payload,
timeout=15,
)
resp.raise_for_status()
logger.info("Campaign %s: reported status=%s to hub", campaign_id[:8], status)
except Exception as exc:
logger.warning("Campaign %s: failed to report status to hub: %s", campaign_id[:8], exc)
2026-03-31 13:51:10 -04:00
def _upload_file(self, base_url: str, file_path: str, metadata: dict) -> dict:
"""Upload *file_path*, choosing chunked or direct path based on file size."""
import requests as _requests
size = os.path.getsize(file_path)
filename = os.path.basename(file_path)
headers = {"X-API-Key": self.api_key}
verify = not self.insecure
if size <= _DIRECT_THRESHOLD:
with open(file_path, "rb") as fh:
resp = _requests.post(
base_url,
headers=headers,
files={"file": (filename, fh)},
data=metadata,
timeout=300,
verify=verify,
)
resp.raise_for_status()
return resp.json()
total_chunks = math.ceil(size / _CHUNK_SIZE)
upload_id = str(uuid.uuid4())
chunk_url = base_url + "/chunk"
logger.info(
"Chunked upload: %s (%d bytes, %d × %d MB chunks)",
filename,
size,
total_chunks,
_CHUNK_SIZE // (1024 * 1024),
)
resp_data: dict = {}
with open(file_path, "rb") as fh:
for i in range(total_chunks):
chunk = fh.read(_CHUNK_SIZE)
resp = _requests.post(
chunk_url,
headers=headers,
files={"file": (filename, chunk, "application/octet-stream")},
2026-04-10 16:39:11 -04:00
data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks},
2026-04-21 17:11:16 -04:00
timeout=(30, None), # 30s connect, no read timeout — server may take minutes on final chunk
2026-03-31 13:51:10 -04:00
verify=verify,
)
if not resp.ok:
raise RuntimeError(
2026-04-10 16:39:11 -04:00
f"Chunk {i + 1}/{total_chunks} failed: HTTP {resp.status_code}: {resp.text[:300]}"
2026-03-31 13:51:10 -04:00
)
resp_data = resp.json()
logger.debug("Chunk %d/%d uploaded", i + 1, total_chunks)
return resp_data
# ------------------------------------------------------------------
# HTTP helpers
# ------------------------------------------------------------------
def _get(self, path: str, **kwargs: Any):
import requests as _requests
return _requests.get(
f"{self.hub_url}{path}",
headers={"X-API-Key": self.api_key},
verify=not self.insecure,
**kwargs,
)
def _post(self, path: str, **kwargs: Any):
import requests as _requests
return _requests.post(
f"{self.hub_url}{path}",
headers={"X-API-Key": self.api_key},
verify=not self.insecure,
**kwargs,
)
def _delete(self, path: str, **kwargs: Any):
import requests as _requests
return _requests.delete(
f"{self.hub_url}{path}",
headers={"X-API-Key": self.api_key},
verify=not self.insecure,
**kwargs,
)
# ---------------------------------------------------------------------------
2026-04-10 16:39:11 -04:00
# Module-level helpers (shared by NodeAgent._inference_loop)
2026-03-31 13:51:10 -04:00
# ---------------------------------------------------------------------------
2026-04-10 16:39:11 -04:00
def _run_onnx_session(session: Any, input_name: str, iq: Any) -> list:
"""Run an ONNX session on an IQ array (2, N).
Tries channel-first layout (1, 2, N) first; falls back to interleaved flat
(1, 2*N) when the model expects a flattened input.
"""
import numpy as np
x = iq[np.newaxis] # (1, 2, N)
try:
return session.run(None, {input_name: x})
except Exception:
return session.run(None, {input_name: iq.flatten()[np.newaxis]})
def _softmax(x: Any) -> Any:
import numpy as np
e = np.exp(x - x.max())
return e / e.sum()
def _apply_sdr_config(sdr: Any, cfg: dict) -> None:
for attr in ("center_freq", "sample_rate", "gain"):
if attr in cfg:
try:
setattr(sdr, attr, cfg[attr])
except Exception as exc:
logger.warning("SDR config %s=%r failed: %s", attr, cfg[attr], exc)
2026-03-31 13:51:10 -04:00
def _sigmf_files(data_path: str) -> list[str]:
"""Return paths to both SigMF files (.sigmf-data and .sigmf-meta) for a recording."""
candidates = [data_path]
if data_path.endswith(".sigmf-data"):
candidates.append(data_path[: -len(".sigmf-data")] + ".sigmf-meta")
return [p for p in candidates if os.path.exists(p)]
2026-04-10 16:39:11 -04:00
# ---------------------------------------------------------------------------
# Config file helpers
# ---------------------------------------------------------------------------
_DEFAULT_CONFIG_PATH = os.path.join(
os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")),
"ria-agent",
"config.json",
)
def _load_config(path: str) -> dict:
"""Load a JSON config file, returning an empty dict if it does not exist."""
try:
with open(path) as fh:
return json.load(fh)
except FileNotFoundError:
return {}
except Exception as exc:
logger.warning("Could not read config file %s: %s", path, exc)
return {}
2026-03-31 13:51:10 -04:00
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def main() -> None:
import argparse
parser = argparse.ArgumentParser(
prog="ria-agent",
description=(
"RT-OSS Node Agent — connects outbound to RIA Hub and executes "
"campaigns / inference on local SDR hardware."
),
)
2026-04-10 16:39:11 -04:00
parser.add_argument(
"--config",
default=None,
metavar="PATH",
help=(
f"Path to a JSON config file (default: {_DEFAULT_CONFIG_PATH}). "
"CLI arguments override config file values."
),
)
2026-03-31 13:51:10 -04:00
parser.add_argument(
"--hub",
2026-04-10 16:39:11 -04:00
default=None,
2026-03-31 13:51:10 -04:00
metavar="URL",
help="RIA Hub base URL, e.g. https://riahub.company.com",
)
parser.add_argument(
"--key",
2026-04-10 16:39:11 -04:00
default=None,
2026-03-31 13:51:10 -04:00
metavar="API_KEY",
help="Shared API key (must match [wac] API_KEY in the hub's app.ini)",
)
parser.add_argument(
"--name",
2026-04-10 16:39:11 -04:00
default=None,
2026-03-31 13:51:10 -04:00
metavar="NAME",
help='Human-readable name shown in the Target Node dropdown, e.g. "lab-bench-1"',
)
parser.add_argument(
"--device",
2026-04-10 16:39:11 -04:00
default=None,
2026-03-31 13:51:10 -04:00
metavar="SDR",
help=(
2026-04-10 16:39:11 -04:00
"SDR device type reported to the hub and used for inference. "
2026-03-31 13:51:10 -04:00
"Examples: plutosdr, usrp_b210, rtlsdr, mock. Default: unknown"
),
)
parser.add_argument(
"--insecure",
action="store_true",
2026-04-10 16:39:11 -04:00
default=None,
2026-03-31 13:51:10 -04:00
help="Disable TLS certificate verification (dev/self-signed certs only)",
)
parser.add_argument(
"--log-level",
2026-04-10 16:39:11 -04:00
default=None,
2026-03-31 13:51:10 -04:00
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging verbosity (default: INFO)",
)
2026-04-20 12:33:14 -04:00
parser.add_argument(
"--role",
default=None,
choices=["general", "rx", "tx"],
2026-04-20 16:49:52 -04:00
help=("Node role reported to the hub. " "'tx' enables synthetic transmission commands. " "Default: general"),
2026-04-20 12:33:14 -04:00
)
parser.add_argument(
"--session-code",
default=None,
metavar="CODE",
help=(
"3-word session code to pair this TX agent with a waiting campaign, "
"e.g. 'amber-peak-transmit'. Supplied by the campaign UI."
),
)
2026-03-31 13:51:10 -04:00
args = parser.parse_args()
2026-04-10 16:39:11 -04:00
# Merge: config file → CLI args (CLI wins).
config_path = args.config or _DEFAULT_CONFIG_PATH
cfg = _load_config(config_path)
hub = args.hub or cfg.get("hub")
key = args.key or cfg.get("key")
name = args.name or cfg.get("name")
device = args.device or cfg.get("device", "unknown")
insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False)
log_level = args.log_level or cfg.get("log_level", "INFO")
2026-04-20 12:33:14 -04:00
role = args.role or cfg.get("role", "general")
session_code = args.session_code or cfg.get("session_code")
2026-04-10 16:39:11 -04:00
if not hub:
parser.error("--hub is required (or set 'hub' in the config file)")
if not key:
parser.error("--key is required (or set 'key' in the config file)")
if not name:
parser.error("--name is required (or set 'name' in the config file)")
2026-03-31 13:51:10 -04:00
logging.basicConfig(
2026-04-10 16:39:11 -04:00
level=getattr(logging, log_level),
2026-03-31 13:51:10 -04:00
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
stream=sys.stderr,
)
2026-04-10 16:39:11 -04:00
if insecure:
2026-03-31 13:51:10 -04:00
logger.warning(
"--insecure disables TLS certificate verification. "
"Only use this for local development with self-signed certs."
)
agent = NodeAgent(
2026-04-10 16:39:11 -04:00
hub_url=hub,
api_key=key,
name=name,
sdr_device=device,
insecure=insecure,
2026-04-20 12:33:14 -04:00
role=role,
session_code=session_code,
2026-03-31 13:51:10 -04:00
)
agent.run()
if __name__ == "__main__":
main()