ria-toolkit-oss/src/ria_toolkit_oss/agent/legacy_executor.py
ben c9b19949ad
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 19m57s
Build Project / Build Project (3.10) (pull_request) Successful in 19m59s
Test with tox / Test with tox (3.10) (pull_request) Successful in 19m46s
Build Project / Build Project (3.11) (pull_request) Successful in 20m19s
Build Project / Build Project (3.12) (pull_request) Successful in 20m21s
Test with tox / Test with tox (3.11) (pull_request) Successful in 18m48s
Test with tox / Test with tox (3.12) (pull_request) Successful in 1m25s
timeout chunk improvements
2026-04-21 17:11:16 -04:00

1001 lines
38 KiB
Python
Raw RIA Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""RT-OSS Node Agent — connects to RIA Hub and dispatches work to local hardware.
The agent runs on any machine with an SDR attached and connects **outbound** to
RIA Hub. No inbound ports need to be opened on the user's machine, and the
connection works identically through NAT, corporate firewalls, or a Pi on a
cellular link.
Usage::
ria-agent \\
--hub https://riahub.company.com \\
--key <api-key> \\
--name lab-bench-1 \\
[--device plutosdr] \\
[--insecure]
# Or store credentials in a config file and omit them from the command line:
ria-agent --config ~/.config/ria-agent/config.json --name lab-bench-1
The agent:
1. Registers with RIA Hub and receives a ``node_id``.
2. Sends a heartbeat every 30 s so the hub knows it is online.
3. Long-polls ``GET /composer/nodes/{id}/commands`` (30 s timeout).
4. Dispatches received commands:
- ``run_campaign``: executes via CampaignExecutor, uploads recordings.
- ``load_model``: loads an ONNX fingerprint or detector model.
- ``start_inference``: opens the SDR, runs the inference loop, posts
detection events to the hub for SSE fan-out to browsers.
- ``stop_inference``: gracefully stops the inference loop.
- ``configure_inference``: queues an SDR parameter update (applied at the
next capture boundary without restarting the loop).
5. Deregisters cleanly on SIGINT / SIGTERM.
Config file (JSON, optional)::
{
"hub": "https://riahub.company.com",
"key": "secret",
"name": "lab-bench-1",
"device": "plutosdr",
"insecure": false,
"log_level": "INFO"
}
CLI arguments always override config file values.
"""
from __future__ import annotations
import json
import logging
import math
import os
import signal
import sys
import threading
import time
import uuid
from typing import Any
logger = logging.getLogger("ria_agent")
# ---------------------------------------------------------------------------
# Tuneable constants
# ---------------------------------------------------------------------------
_HEARTBEAT_INTERVAL = 30 # seconds between heartbeats
_POLL_TIMEOUT = 30 # server-side long-poll duration
_POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
_CHUNK_SIZE = 10 * 1024 * 1024 # 10 MB per chunk — fast enough for git-LFS to process within timeout
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
_CAPTURE_SAMPLES = 4096 # IQ samples per inference window
_IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"})
# ---------------------------------------------------------------------------
# Agent
# ---------------------------------------------------------------------------
class NodeAgent:
"""Outbound-connecting agent that bridges RIA Hub to local SDR hardware.
All network I/O is initiated by the agent (outbound). RIA Hub never opens
a connection back to the agent's machine.
"""
def __init__(
self,
hub_url: str,
api_key: str,
name: str,
sdr_device: str = "unknown",
insecure: bool = False,
role: str = "general",
session_code: str | None = None,
) -> None:
self.hub_url = hub_url.rstrip("/")
self.api_key = api_key
self.name = name
self.sdr_device = sdr_device
self.insecure = insecure
self.role = role
self.session_code = session_code
self.node_id: str | None = None
self._stop = threading.Event()
# ── TX state ────────────────────────────────────────────────────────
self._tx_stop = threading.Event()
self._tx_thread: threading.Thread | None = None
# ── Inference state ─────────────────────────────────────────────────
# Protected by _inf_lock for cross-thread model swaps.
self._inf_lock = threading.Lock()
self._inf_session: Any = None # primary fingerprint ONNX session
self._inf_index_to_label: dict[int, str] = {}
self._inf_detector_session: Any = None # optional protocol-detector session
self._inf_detector_index_to_label: dict[int, str] = {}
self._inf_detector_threshold: float = 0.7
self._inf_pending_config: dict = {} # queued SDR attribute updates
self._inf_stop = threading.Event()
self._inf_thread: threading.Thread | None = None
# Detect optional dependencies once at startup so capability
# advertising is accurate from the first registration.
try:
import onnxruntime as _ort_mod
self._ort: Any = _ort_mod
self._ort_available = True
except ImportError:
self._ort = None
self._ort_available = False
try:
import ria_toolkit_oss
self._ria_version: str = getattr(ria_toolkit_oss, "__version__", "unknown")
except Exception:
self._ria_version = "unknown"
# ------------------------------------------------------------------
# Public entry point
# ------------------------------------------------------------------
def run(self) -> None:
"""Register, start the heartbeat thread, and enter the command loop.
Blocks until SIGINT or SIGTERM is received.
"""
self._register()
def _shutdown(sig: int, _frame: Any) -> None:
logger.info("Shutdown signal received — stopping agent")
self._stop.set()
signal.signal(signal.SIGINT, _shutdown)
signal.signal(signal.SIGTERM, _shutdown)
hb = threading.Thread(target=self._heartbeat_loop, daemon=True, name="ria-agent-heartbeat")
hb.start()
logger.info("Agent %r online (node_id=%s, hub=%s)", self.name, self.node_id, self.hub_url)
try:
self._command_loop()
finally:
self._stop.set()
self._stop_inference()
self._deregister()
# ------------------------------------------------------------------
# Registration
# ------------------------------------------------------------------
def _register(self) -> None:
capabilities = ["campaign"]
if self._ort_available:
capabilities.append("inference")
if self.role == "tx":
capabilities.append("transmit")
payload: dict = {
"name": self.name,
"sdr_device": self.sdr_device,
"ria_toolkit_version": self._ria_version,
"capabilities": capabilities,
"role": self.role,
}
if self.session_code:
payload["session_code"] = self.session_code
resp = self._post("/composer/nodes/register", json=payload, timeout=15)
resp.raise_for_status()
self.node_id = resp.json()["node_id"]
logger.info(
"Registered as %r (node_id=%s, role=%s%s)",
self.name,
self.node_id,
self.role,
f", session_code={self.session_code!r}" if self.session_code else "",
)
def _deregister(self) -> None:
if not self.node_id:
return
try:
self._delete(f"/composer/nodes/{self.node_id}", timeout=10)
logger.info("Deregistered %s", self.node_id)
except Exception as exc:
logger.debug("Deregister failed (ignored on shutdown): %s", exc)
# ------------------------------------------------------------------
# Heartbeat thread
# ------------------------------------------------------------------
def _heartbeat_loop(self) -> None:
while not self._stop.wait(_HEARTBEAT_INTERVAL):
try:
resp = self._post(f"/composer/nodes/{self.node_id}/heartbeat", timeout=10)
if resp.status_code == 404:
logger.warning("Heartbeat got 404 — hub lost registration, re-registering")
self._register()
except Exception as exc:
logger.warning("Heartbeat failed: %s", exc)
# ------------------------------------------------------------------
# Command poll loop
# ------------------------------------------------------------------
def _command_loop(self) -> None:
while not self._stop.is_set():
try:
resp = self._get(
f"/composer/nodes/{self.node_id}/commands",
timeout=_POLL_CLIENT_TIMEOUT,
)
if resp.status_code == 204:
# No command within the timeout window — loop immediately.
continue
if resp.status_code == 404:
logger.warning("Command poll got 404 — re-registering")
self._register()
continue
resp.raise_for_status()
cmd = resp.json()
logger.info("Received command: %s", cmd.get("command"))
self._dispatch(cmd)
except Exception as exc:
if not self._stop.is_set():
logger.warning("Command poll error: %s — retrying in %ds", exc, _RECONNECT_PAUSE)
time.sleep(_RECONNECT_PAUSE)
# ------------------------------------------------------------------
# Command dispatch
# ------------------------------------------------------------------
def _dispatch(self, cmd: dict) -> None:
command = cmd.get("command")
if command == "run_campaign":
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
config_dict: dict = cmd.get("payload") or {}
skip_local_tx: bool = bool(cmd.get("skip_local_tx", False))
threading.Thread(
target=self._run_campaign,
args=(campaign_id, config_dict, skip_local_tx),
daemon=True,
name=f"campaign-{campaign_id[:8]}",
).start()
elif command == "load_model":
threading.Thread(
target=self._load_model,
args=(cmd,),
daemon=True,
name="ria-load-model",
).start()
elif command == "start_inference":
threading.Thread(
target=self._start_inference,
args=(cmd,),
daemon=True,
name="ria-start-inf",
).start()
elif command == "stop_inference":
self._stop_inference()
elif command == "configure_inference":
self._queue_sdr_config(cmd)
elif command == "start_transmit":
threading.Thread(
target=self._start_transmit,
args=(cmd,),
daemon=True,
name="ria-start-tx",
).start()
elif command == "stop_transmit":
self._stop_transmit()
elif command == "configure_transmit":
logger.info("configure_transmit received — will apply on next step boundary")
else:
logger.warning("Unknown command %r — ignored", command)
# ------------------------------------------------------------------
# Campaign execution
# ------------------------------------------------------------------
def _run_campaign(self, campaign_id: str, config_dict: dict, skip_local_tx: bool = False) -> None:
try:
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
except ImportError as exc:
logger.error(
"Campaign %s cannot start — ria_toolkit_oss not fully installed: %s",
campaign_id[:8],
exc,
)
return
logger.info("Campaign %s starting (skip_local_tx=%s)", campaign_id[:8], skip_local_tx)
try:
config = CampaignConfig.from_dict(config_dict)
executor = CampaignExecutor(config, skip_local_tx=skip_local_tx)
result = executor.run()
logger.info("Campaign %s completed — uploading recordings", campaign_id[:8])
self._upload_recordings(campaign_id, config, result)
result_dict = result.to_dict() if hasattr(result, "to_dict") else None
self._report_campaign_status(campaign_id, "completed", result=result_dict)
except Exception as exc:
logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
self._report_campaign_status(campaign_id, "failed", error=str(exc))
# ------------------------------------------------------------------
# TX execution
# ------------------------------------------------------------------
def _start_transmit(self, cmd: dict) -> None:
"""Execute a synthetic transmit campaign using TxExecutor.
The command payload mirrors a TransmitterConfig dict with an optional
``schedule`` of steps. Each step synthesises a signal and transmits it
via the local SDR in TX mode.
"""
try:
from ria_toolkit_oss.orchestration.tx_executor import TxExecutor
except ImportError as exc:
logger.error("start_transmit: TxExecutor not available: %s", exc)
return
if self._tx_thread and self._tx_thread.is_alive():
logger.warning("start_transmit: TX already running — ignoring duplicate command")
return
self._tx_stop.clear()
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
executor = TxExecutor(
config=cmd,
sdr_device=self.sdr_device,
stop_event=self._tx_stop,
)
self._tx_thread = threading.Thread(
target=self._run_tx_campaign,
args=(executor, campaign_id),
daemon=True,
name=f"tx-campaign-{campaign_id[:8]}",
)
self._tx_thread.start()
def _run_tx_campaign(self, executor: Any, campaign_id: str) -> None:
try:
executor.run()
logger.info("TX campaign %s completed", campaign_id[:8])
self._report_campaign_status(campaign_id, "completed")
except Exception as exc:
logger.error("TX campaign %s failed: %s", campaign_id[:8], exc)
self._report_campaign_status(campaign_id, "failed", error=str(exc))
def _stop_transmit(self) -> None:
"""Signal the TX loop to stop gracefully."""
self._tx_stop.set()
if self._tx_thread and self._tx_thread.is_alive():
self._tx_thread.join(timeout=5.0)
logger.info("TX stopped")
# ------------------------------------------------------------------
# Inference — model loading
# ------------------------------------------------------------------
def _load_model(self, cmd: dict) -> None:
"""Load an ONNX model into the fingerprint or detector slot.
The ``model_path`` field may be either a local filesystem path or an
``http(s)://`` URL; in the latter case the file is downloaded first.
"""
if not self._ort_available:
logger.error("load_model: onnxruntime is not installed — cannot load model")
return
model_path: str = cmd.get("model_path", "")
label_map: dict[str, int] = cmd.get("label_map") or {}
stage: str = cmd.get("stage", "fingerprint")
detector_threshold: float = float(cmd.get("detector_threshold") or 0.7)
if model_path.startswith(("http://", "https://")):
model_path = self._download_model(model_path)
if model_path is None:
return
try:
session = self._ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
except Exception as exc:
logger.error("Failed to load model %r: %s", model_path, exc)
return
index_to_label = {v: k for k, v in label_map.items()}
with self._inf_lock:
if stage == "detector":
self._inf_detector_session = session
self._inf_detector_index_to_label = index_to_label
self._inf_detector_threshold = detector_threshold
logger.info(
"Detector model loaded: path=%s classes=%d threshold=%.2f",
model_path,
len(label_map),
detector_threshold,
)
else:
self._inf_session = session
self._inf_index_to_label = index_to_label
logger.info(
"Fingerprint model loaded: path=%s classes=%d",
model_path,
len(label_map),
)
def _download_model(self, url: str) -> str | None:
"""Download a model from *url* to a temp file and return the local path."""
import tempfile
import requests as _requests
try:
logger.info("Downloading model from %s", url)
resp = _requests.get(
url,
headers={"X-API-Key": self.api_key},
verify=not self.insecure,
timeout=120,
)
resp.raise_for_status()
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as fh:
fh.write(resp.content)
path = fh.name
logger.info("Model downloaded to %s (%d bytes)", path, len(resp.content))
return path
except Exception as exc:
logger.error("Model download from %s failed: %s", url, exc)
return None
# ------------------------------------------------------------------
# Inference — loop lifecycle
# ------------------------------------------------------------------
def _start_inference(self, cmd: dict) -> None:
"""Start the SDR capture + ONNX inference loop."""
if not self._ort_available:
logger.error("start_inference: onnxruntime is not installed")
return
with self._inf_lock:
if self._inf_session is None:
logger.error("start_inference: no fingerprint model loaded — call load_model first")
return
if self._inf_thread is not None and self._inf_thread.is_alive():
logger.warning("start_inference: inference loop is already running — ignoring")
return
center_freq: float = float(cmd.get("center_freq", 2.4e9))
sample_rate: float = float(cmd.get("sample_rate", 10e6))
gain: float | str = cmd.get("gain", "auto")
device_type: str = cmd.get("device") or self.sdr_device
self._inf_stop.clear()
self._inf_thread = threading.Thread(
target=self._inference_loop,
args=(device_type, center_freq, sample_rate, gain),
daemon=True,
name="ria-agent-inference",
)
self._inf_thread.start()
logger.info(
"Inference started (device=%s freq=%.3f MHz rate=%.1f MHz)",
device_type,
center_freq / 1e6,
sample_rate / 1e6,
)
def _stop_inference(self) -> None:
"""Signal the inference loop to stop and wait up to 5 s for it to exit."""
self._inf_stop.set()
if self._inf_thread is not None and self._inf_thread.is_alive():
self._inf_thread.join(timeout=5.0)
if self._inf_thread.is_alive():
logger.warning("Inference thread did not exit within 5 s")
logger.info("Inference stopped")
def _queue_sdr_config(self, cmd: dict) -> None:
"""Merge SDR parameter updates into the pending-config dict.
The inference loop checks this at each capture boundary and applies
the updates without restarting.
"""
cfg = {k: v for k, v in cmd.items() if k != "command" and v is not None}
with self._inf_lock:
self._inf_pending_config.update(cfg)
logger.debug("SDR reconfiguration queued: %s", cfg)
# ------------------------------------------------------------------
# Inference — main loop
# ------------------------------------------------------------------
def _inference_loop(
self,
device_type: str,
center_freq: float,
sample_rate: float,
gain: float | str,
) -> None:
"""Continuous SDR capture → ONNX inference → POST events to hub.
Mirrors the two-stage pipeline in the hub's ``_inference_loop``:
an optional protocol-detector gates the fingerprint model so the
fingerprint model only runs when an active transmission is detected.
"""
try:
from ria_toolkit_oss.sdr import get_sdr_device
except ImportError as exc:
logger.error("inference_loop: ria_toolkit_oss not installed: %s", exc)
return
try:
sdr = get_sdr_device(device_type)
_apply_sdr_config(sdr, {"center_freq": center_freq, "sample_rate": sample_rate, "gain": gain})
except Exception as exc:
logger.error("SDR initialisation failed: %s", exc)
return
try:
import numpy as np
try:
from ria_toolkit_oss.orchestration.qa import estimate_snr_db
except ImportError:
estimate_snr_db = None
# Snapshot model state once at loop start. If the hub sends a
# new load_model command while the loop is running, the new session
# will be picked up on the next loop restart (stop + start).
with self._inf_lock:
session = self._inf_session
index_to_label = dict(self._inf_index_to_label)
det_session = self._inf_detector_session
det_threshold = self._inf_detector_threshold
input_name = session.get_inputs()[0].name
det_input_name = det_session.get_inputs()[0].name if det_session else None
while not self._inf_stop.is_set():
# Apply any queued SDR configuration changes.
with self._inf_lock:
pending = self._inf_pending_config.copy()
self._inf_pending_config.clear()
if pending:
_apply_sdr_config(sdr, pending)
try:
samples = sdr.rx(_CAPTURE_SAMPLES)
except Exception as exc:
logger.warning("SDR capture error: %s", exc)
# Avoid a tight spin when the SDR is in a persistent error
# state (e.g. physically disconnected).
self._inf_stop.wait(timeout=0.5)
continue
samples = np.array(samples, dtype=np.complex64)
snr_db = float(estimate_snr_db(samples)) if estimate_snr_db is not None else 0.0
iq = np.stack([samples.real, samples.imag], axis=0).astype(np.float32)
# Stage 1: protocol detector gate (optional).
if det_session is not None:
det_out = _run_onnx_session(det_session, det_input_name, iq)
det_probs = _softmax(det_out[0][0])
det_confidence = float(det_probs.max())
if det_confidence < det_threshold:
# No active protocol detected — report idle and skip
# the fingerprint model for this window.
self._post_event(device_id=None, confidence=det_confidence, snr_db=snr_db)
continue
# Stage 2: fingerprint model.
out = _run_onnx_session(session, input_name, iq)
probs = _softmax(out[0][0])
pred_idx = int(probs.argmax())
confidence = float(probs[pred_idx])
device_id = index_to_label.get(pred_idx)
idle = (device_id in _IDLE_LABELS) if device_id else True
self._post_event(
device_id=None if idle else device_id,
confidence=confidence,
snr_db=snr_db,
)
except Exception as exc:
logger.exception("Inference loop terminated unexpectedly: %s", exc)
finally:
try:
sdr.close()
except Exception:
pass
logger.info("Inference loop exited")
def _post_event(self, device_id: str | None, confidence: float, snr_db: float) -> None:
"""POST a single detection event to ``POST /composer/nodes/{id}/events``.
Failures are logged at DEBUG level and silently swallowed so that a
transient network blip does not crash the inference loop.
"""
from datetime import datetime, timezone
payload = {
"type": "detection",
"device_id": device_id,
"confidence": round(confidence, 6),
"snr_db": round(snr_db, 2),
"timestamp": datetime.now(timezone.utc).isoformat(),
}
try:
resp = self._post(
f"/composer/nodes/{self.node_id}/events",
json=payload,
timeout=5,
)
if resp.status_code not in (200, 204):
logger.debug("Event POST returned HTTP %d", resp.status_code)
except Exception as exc:
logger.debug("Event POST failed (will retry next inference cycle): %s", exc)
# ------------------------------------------------------------------
# Recording upload (chunked for large files)
# ------------------------------------------------------------------
def _upload_recordings(self, campaign_id: str, config: Any, result: Any) -> None:
output_repo: str | None = getattr(getattr(config, "output", None), "repo", None)
if not output_repo or "/" not in output_repo:
logger.warning("Campaign %s: no output.repo — skipping upload", campaign_id[:8])
return
repo_owner, repo_name = output_repo.split("/", 1)
base_url = f"{self.hub_url}/datasets/upload"
steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or []
output_obj = getattr(config, "output", None)
folder = getattr(output_obj, "folder", None)
campaign_name: str = folder if folder is not None else (getattr(config, "name", None) or "")
for step in steps:
output_path: str | None = getattr(step, "output_path", None)
if not output_path:
continue
device_id: str = getattr(step, "transmitter_id", "") or ""
for fpath in _sigmf_files(output_path):
basename = os.path.basename(fpath)
path_parts = [p for p in (campaign_name, device_id) if p]
filename = "/".join(path_parts + [basename])
metadata = {
"filename": filename,
"repo_owner": repo_owner,
"repo_name": repo_name,
"device_id": device_id,
"campaign_id": campaign_id,
}
try:
resp_data = self._upload_file(base_url, fpath, metadata)
logger.info(
"Campaign %s: uploaded %s (oid=%s)",
campaign_id[:8],
filename,
resp_data.get("oid", "?"),
)
except Exception as exc:
logger.warning("Campaign %s: upload of %s failed: %s", campaign_id[:8], filename, exc)
def _report_campaign_status(
self,
campaign_id: str,
status: str,
result: "dict | None" = None,
error: "str | None" = None,
) -> None:
"""POST campaign completion/failure back to the hub so GET /status/{id} resolves."""
payload: dict = {"campaign_id": campaign_id, "status": status}
if result is not None:
payload["result"] = result
if error is not None:
payload["error"] = error
try:
resp = self._post(
f"/composer/nodes/{self.node_id}/campaign-status",
json=payload,
timeout=15,
)
resp.raise_for_status()
logger.info("Campaign %s: reported status=%s to hub", campaign_id[:8], status)
except Exception as exc:
logger.warning("Campaign %s: failed to report status to hub: %s", campaign_id[:8], exc)
def _upload_file(self, base_url: str, file_path: str, metadata: dict) -> dict:
"""Upload *file_path*, choosing chunked or direct path based on file size."""
import requests as _requests
size = os.path.getsize(file_path)
filename = os.path.basename(file_path)
headers = {"X-API-Key": self.api_key}
verify = not self.insecure
if size <= _DIRECT_THRESHOLD:
with open(file_path, "rb") as fh:
resp = _requests.post(
base_url,
headers=headers,
files={"file": (filename, fh)},
data=metadata,
timeout=300,
verify=verify,
)
resp.raise_for_status()
return resp.json()
total_chunks = math.ceil(size / _CHUNK_SIZE)
upload_id = str(uuid.uuid4())
chunk_url = base_url + "/chunk"
logger.info(
"Chunked upload: %s (%d bytes, %d × %d MB chunks)",
filename,
size,
total_chunks,
_CHUNK_SIZE // (1024 * 1024),
)
resp_data: dict = {}
with open(file_path, "rb") as fh:
for i in range(total_chunks):
chunk = fh.read(_CHUNK_SIZE)
resp = _requests.post(
chunk_url,
headers=headers,
files={"file": (filename, chunk, "application/octet-stream")},
data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks},
timeout=(30, None), # 30s connect, no read timeout — server may take minutes on final chunk
verify=verify,
)
if not resp.ok:
raise RuntimeError(
f"Chunk {i + 1}/{total_chunks} failed: HTTP {resp.status_code}: {resp.text[:300]}"
)
resp_data = resp.json()
logger.debug("Chunk %d/%d uploaded", i + 1, total_chunks)
return resp_data
# ------------------------------------------------------------------
# HTTP helpers
# ------------------------------------------------------------------
def _get(self, path: str, **kwargs: Any):
import requests as _requests
return _requests.get(
f"{self.hub_url}{path}",
headers={"X-API-Key": self.api_key},
verify=not self.insecure,
**kwargs,
)
def _post(self, path: str, **kwargs: Any):
import requests as _requests
return _requests.post(
f"{self.hub_url}{path}",
headers={"X-API-Key": self.api_key},
verify=not self.insecure,
**kwargs,
)
def _delete(self, path: str, **kwargs: Any):
import requests as _requests
return _requests.delete(
f"{self.hub_url}{path}",
headers={"X-API-Key": self.api_key},
verify=not self.insecure,
**kwargs,
)
# ---------------------------------------------------------------------------
# Module-level helpers (shared by NodeAgent._inference_loop)
# ---------------------------------------------------------------------------
def _run_onnx_session(session: Any, input_name: str, iq: Any) -> list:
"""Run an ONNX session on an IQ array (2, N).
Tries channel-first layout (1, 2, N) first; falls back to interleaved flat
(1, 2*N) when the model expects a flattened input.
"""
import numpy as np
x = iq[np.newaxis] # (1, 2, N)
try:
return session.run(None, {input_name: x})
except Exception:
return session.run(None, {input_name: iq.flatten()[np.newaxis]})
def _softmax(x: Any) -> Any:
import numpy as np
e = np.exp(x - x.max())
return e / e.sum()
def _apply_sdr_config(sdr: Any, cfg: dict) -> None:
for attr in ("center_freq", "sample_rate", "gain"):
if attr in cfg:
try:
setattr(sdr, attr, cfg[attr])
except Exception as exc:
logger.warning("SDR config %s=%r failed: %s", attr, cfg[attr], exc)
def _sigmf_files(data_path: str) -> list[str]:
"""Return paths to both SigMF files (.sigmf-data and .sigmf-meta) for a recording."""
candidates = [data_path]
if data_path.endswith(".sigmf-data"):
candidates.append(data_path[: -len(".sigmf-data")] + ".sigmf-meta")
return [p for p in candidates if os.path.exists(p)]
# ---------------------------------------------------------------------------
# Config file helpers
# ---------------------------------------------------------------------------
_DEFAULT_CONFIG_PATH = os.path.join(
os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")),
"ria-agent",
"config.json",
)
def _load_config(path: str) -> dict:
"""Load a JSON config file, returning an empty dict if it does not exist."""
try:
with open(path) as fh:
return json.load(fh)
except FileNotFoundError:
return {}
except Exception as exc:
logger.warning("Could not read config file %s: %s", path, exc)
return {}
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def main() -> None:
import argparse
parser = argparse.ArgumentParser(
prog="ria-agent",
description=(
"RT-OSS Node Agent — connects outbound to RIA Hub and executes "
"campaigns / inference on local SDR hardware."
),
)
parser.add_argument(
"--config",
default=None,
metavar="PATH",
help=(
f"Path to a JSON config file (default: {_DEFAULT_CONFIG_PATH}). "
"CLI arguments override config file values."
),
)
parser.add_argument(
"--hub",
default=None,
metavar="URL",
help="RIA Hub base URL, e.g. https://riahub.company.com",
)
parser.add_argument(
"--key",
default=None,
metavar="API_KEY",
help="Shared API key (must match [wac] API_KEY in the hub's app.ini)",
)
parser.add_argument(
"--name",
default=None,
metavar="NAME",
help='Human-readable name shown in the Target Node dropdown, e.g. "lab-bench-1"',
)
parser.add_argument(
"--device",
default=None,
metavar="SDR",
help=(
"SDR device type reported to the hub and used for inference. "
"Examples: plutosdr, usrp_b210, rtlsdr, mock. Default: unknown"
),
)
parser.add_argument(
"--insecure",
action="store_true",
default=None,
help="Disable TLS certificate verification (dev/self-signed certs only)",
)
parser.add_argument(
"--log-level",
default=None,
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging verbosity (default: INFO)",
)
parser.add_argument(
"--role",
default=None,
choices=["general", "rx", "tx"],
help=("Node role reported to the hub. " "'tx' enables synthetic transmission commands. " "Default: general"),
)
parser.add_argument(
"--session-code",
default=None,
metavar="CODE",
help=(
"3-word session code to pair this TX agent with a waiting campaign, "
"e.g. 'amber-peak-transmit'. Supplied by the campaign UI."
),
)
args = parser.parse_args()
# Merge: config file → CLI args (CLI wins).
config_path = args.config or _DEFAULT_CONFIG_PATH
cfg = _load_config(config_path)
hub = args.hub or cfg.get("hub")
key = args.key or cfg.get("key")
name = args.name or cfg.get("name")
device = args.device or cfg.get("device", "unknown")
insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False)
log_level = args.log_level or cfg.get("log_level", "INFO")
role = args.role or cfg.get("role", "general")
session_code = args.session_code or cfg.get("session_code")
if not hub:
parser.error("--hub is required (or set 'hub' in the config file)")
if not key:
parser.error("--key is required (or set 'key' in the config file)")
if not name:
parser.error("--name is required (or set 'name' in the config file)")
logging.basicConfig(
level=getattr(logging, log_level),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
stream=sys.stderr,
)
if insecure:
logger.warning(
"--insecure disables TLS certificate verification. "
"Only use this for local development with self-signed certs."
)
agent = NodeAgent(
hub_url=hub,
api_key=key,
name=name,
sdr_device=device,
insecure=insecure,
role=role,
session_code=session_code,
)
agent.run()
if __name__ == "__main__":
main()