ria-toolkit-oss/src/ria_toolkit_oss/agent.py
ben 9a960e2f29
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Failing after 1s
Build Project / Build Project (3.10) (pull_request) Successful in 57s
Build Project / Build Project (3.11) (pull_request) Successful in 1m7s
Build Project / Build Project (3.12) (pull_request) Successful in 56s
Test with tox / Test with tox (3.12) (pull_request) Failing after 5m13s
Test with tox / Test with tox (3.11) (pull_request) Failing after 5m48s
Test with tox / Test with tox (3.10) (pull_request) Failing after 8m46s
zfp functionality and servers
2026-03-31 13:51:10 -04:00

463 lines
16 KiB
Python
Raw RIA Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""RT-OSS Node Agent — connects to RIA Hub and dispatches work to local hardware.
The agent runs on any machine with an SDR attached and connects **outbound** to
RIA Hub. No inbound ports need to be opened on the user's machine, and the
connection works identically through NAT, corporate firewalls, or a Pi on a
cellular link.
Usage::
ria-agent \\
--hub https://riahub.company.com \\
--key <api-key> \\
--name lab-bench-1 \\
[--device plutosdr] \\
[--insecure]
The agent:
1. Registers with RIA Hub and receives a ``node_id``.
2. Sends a heartbeat every 30 s so the hub knows it is online.
3. Long-polls ``GET /orchestrator/nodes/{id}/commands`` (30 s timeout).
4. Executes received campaigns via :class:`ria_toolkit_oss.orchestration.executor.CampaignExecutor`.
5. Uploads recordings to the hub via chunked POST, keeping each request
under 50 MB so it passes through Cloudflare without needing the bypass
subdomain.
6. Deregisters cleanly on SIGINT / SIGTERM.
"""
from __future__ import annotations
import logging
import math
import os
import signal
import sys
import threading
import time
import uuid
from typing import Any
logger = logging.getLogger("ria_agent")
# ---------------------------------------------------------------------------
# Tuneable constants
# ---------------------------------------------------------------------------
_HEARTBEAT_INTERVAL = 30 # seconds between heartbeats
_POLL_TIMEOUT = 30 # server-side long-poll duration
_POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB — well below Cloudflare's 100 MB limit
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
# ---------------------------------------------------------------------------
# Agent
# ---------------------------------------------------------------------------
class NodeAgent:
"""Outbound-connecting agent that bridges RIA Hub to local SDR hardware.
All network I/O is initiated by the agent (outbound). RIA Hub never opens
a connection back to the agent's machine.
"""
def __init__(
self,
hub_url: str,
api_key: str,
name: str,
sdr_device: str = "unknown",
insecure: bool = False,
) -> None:
self.hub_url = hub_url.rstrip("/")
self.api_key = api_key
self.name = name
self.sdr_device = sdr_device
self.insecure = insecure
self.node_id: str | None = None
self._stop = threading.Event()
try:
import ria_toolkit_oss
self._ria_version: str = getattr(ria_toolkit_oss, "__version__", "unknown")
except Exception:
self._ria_version = "unknown"
# ------------------------------------------------------------------
# Public entry point
# ------------------------------------------------------------------
def run(self) -> None:
"""Register, start the heartbeat thread, and enter the command loop.
Blocks until SIGINT or SIGTERM is received.
"""
self._register()
def _shutdown(sig: int, _frame: Any) -> None:
logger.info("Shutdown signal received — stopping agent")
self._stop.set()
signal.signal(signal.SIGINT, _shutdown)
signal.signal(signal.SIGTERM, _shutdown)
hb = threading.Thread(target=self._heartbeat_loop, daemon=True, name="ria-agent-heartbeat")
hb.start()
logger.info("Agent %r online (node_id=%s, hub=%s)", self.name, self.node_id, self.hub_url)
try:
self._command_loop()
finally:
self._stop.set()
self._deregister()
# ------------------------------------------------------------------
# Registration
# ------------------------------------------------------------------
def _register(self) -> None:
resp = self._post(
"/orchestrator/nodes/register",
json={
"name": self.name,
"sdr_device": self.sdr_device,
"ria_toolkit_version": self._ria_version,
"capabilities": ["inference", "campaign"],
},
timeout=15,
)
resp.raise_for_status()
self.node_id = resp.json()["node_id"]
logger.info("Registered as %r (node_id=%s)", self.name, self.node_id)
def _deregister(self) -> None:
if not self.node_id:
return
try:
self._delete(f"/orchestrator/nodes/{self.node_id}", timeout=10)
logger.info("Deregistered %s", self.node_id)
except Exception as exc:
logger.debug("Deregister failed (ignored on shutdown): %s", exc)
# ------------------------------------------------------------------
# Heartbeat thread
# ------------------------------------------------------------------
def _heartbeat_loop(self) -> None:
while not self._stop.wait(_HEARTBEAT_INTERVAL):
try:
resp = self._post(f"/orchestrator/nodes/{self.node_id}/heartbeat", timeout=10)
if resp.status_code == 404:
logger.warning("Heartbeat got 404 — hub lost registration, re-registering")
self._register()
except Exception as exc:
logger.warning("Heartbeat failed: %s", exc)
# ------------------------------------------------------------------
# Command poll loop
# ------------------------------------------------------------------
def _command_loop(self) -> None:
while not self._stop.is_set():
try:
resp = self._get(
f"/orchestrator/nodes/{self.node_id}/commands",
timeout=_POLL_CLIENT_TIMEOUT,
)
if resp.status_code == 204:
# No command within the timeout window — loop immediately.
continue
if resp.status_code == 404:
logger.warning("Command poll got 404 — re-registering")
self._register()
continue
resp.raise_for_status()
cmd = resp.json()
logger.info("Received command: %s", cmd.get("command"))
self._dispatch(cmd)
except Exception as exc:
if not self._stop.is_set():
logger.warning("Command poll error: %s — retrying in %ds", exc, _RECONNECT_PAUSE)
time.sleep(_RECONNECT_PAUSE)
# ------------------------------------------------------------------
# Command dispatch
# ------------------------------------------------------------------
def _dispatch(self, cmd: dict) -> None:
command = cmd.get("command")
if command == "run_campaign":
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
config_dict: dict = cmd.get("payload") or {}
threading.Thread(
target=self._run_campaign,
args=(campaign_id, config_dict),
daemon=True,
name=f"campaign-{campaign_id[:8]}",
).start()
else:
logger.warning("Unknown command %r — ignored", command)
# ------------------------------------------------------------------
# Campaign execution
# ------------------------------------------------------------------
def _run_campaign(self, campaign_id: str, config_dict: dict) -> None:
try:
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
except ImportError as exc:
logger.error(
"Campaign %s cannot start — ria_toolkit_oss not fully installed: %s",
campaign_id[:8],
exc,
)
return
logger.info("Campaign %s starting", campaign_id[:8])
try:
config = CampaignConfig.from_dict(config_dict)
executor = CampaignExecutor(config)
result = executor.run()
logger.info("Campaign %s completed — uploading recordings", campaign_id[:8])
self._upload_recordings(campaign_id, config, result)
except Exception as exc:
logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
# ------------------------------------------------------------------
# Recording upload (chunked for large files)
# ------------------------------------------------------------------
def _upload_recordings(self, campaign_id: str, config: Any, result: Any) -> None:
output_repo: str | None = getattr(getattr(config, "output", None), "repo", None)
if not output_repo or "/" not in output_repo:
logger.warning("Campaign %s: no output.repo — skipping upload", campaign_id[:8])
return
repo_owner, repo_name = output_repo.split("/", 1)
base_url = f"{self.hub_url}/datasets/upload"
steps = getattr(result, "steps", None) or []
for step in steps:
output_path: str | None = getattr(step, "output_path", None)
if not output_path:
continue
device_id: str = getattr(step, "transmitter_id", "") or ""
for fpath in _sigmf_files(output_path):
filename = os.path.basename(fpath)
metadata = {
"filename": filename,
"repo_owner": repo_owner,
"repo_name": repo_name,
"device_id": device_id,
"campaign_id": campaign_id,
}
try:
resp_data = self._upload_file(base_url, fpath, metadata)
logger.info(
"Campaign %s: uploaded %s (oid=%s)",
campaign_id[:8],
filename,
resp_data.get("oid", "?"),
)
except Exception as exc:
logger.warning("Campaign %s: upload of %s failed: %s", campaign_id[:8], filename, exc)
def _upload_file(self, base_url: str, file_path: str, metadata: dict) -> dict:
"""Upload *file_path*, choosing chunked or direct path based on file size."""
import requests as _requests
size = os.path.getsize(file_path)
filename = os.path.basename(file_path)
headers = {"X-API-Key": self.api_key}
verify = not self.insecure
# Small files: single POST (unchanged endpoint, no assembly needed server-side).
if size <= _DIRECT_THRESHOLD:
with open(file_path, "rb") as fh:
resp = _requests.post(
base_url,
headers=headers,
files={"file": (filename, fh)},
data=metadata,
timeout=300,
verify=verify,
)
resp.raise_for_status()
return resp.json()
# Large files: chunked upload — each request is ≤ 50 MB.
total_chunks = math.ceil(size / _CHUNK_SIZE)
upload_id = str(uuid.uuid4())
chunk_url = base_url + "/chunk"
logger.info(
"Chunked upload: %s (%d bytes, %d × %d MB chunks)",
filename,
size,
total_chunks,
_CHUNK_SIZE // (1024 * 1024),
)
resp_data: dict = {}
with open(file_path, "rb") as fh:
for i in range(total_chunks):
chunk = fh.read(_CHUNK_SIZE)
resp = _requests.post(
chunk_url,
headers=headers,
files={"file": (filename, chunk, "application/octet-stream")},
data={
**metadata,
"upload_id": upload_id,
"chunk_index": i,
"total_chunks": total_chunks,
},
timeout=120,
verify=verify,
)
if not resp.ok:
raise RuntimeError(
f"Chunk {i + 1}/{total_chunks} failed: " f"HTTP {resp.status_code}: {resp.text[:300]}"
)
resp_data = resp.json()
logger.debug("Chunk %d/%d uploaded", i + 1, total_chunks)
return resp_data
# ------------------------------------------------------------------
# HTTP helpers
# ------------------------------------------------------------------
def _get(self, path: str, **kwargs: Any):
import requests as _requests
return _requests.get(
f"{self.hub_url}{path}",
headers={"X-API-Key": self.api_key},
verify=not self.insecure,
**kwargs,
)
def _post(self, path: str, **kwargs: Any):
import requests as _requests
return _requests.post(
f"{self.hub_url}{path}",
headers={"X-API-Key": self.api_key},
verify=not self.insecure,
**kwargs,
)
def _delete(self, path: str, **kwargs: Any):
import requests as _requests
return _requests.delete(
f"{self.hub_url}{path}",
headers={"X-API-Key": self.api_key},
verify=not self.insecure,
**kwargs,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _sigmf_files(data_path: str) -> list[str]:
"""Return paths to both SigMF files (.sigmf-data and .sigmf-meta) for a recording."""
candidates = [data_path]
if data_path.endswith(".sigmf-data"):
candidates.append(data_path[: -len(".sigmf-data")] + ".sigmf-meta")
return [p for p in candidates if os.path.exists(p)]
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def main() -> None:
import argparse
parser = argparse.ArgumentParser(
prog="ria-agent",
description=(
"RT-OSS Node Agent — connects outbound to RIA Hub and executes "
"campaigns / inference on local SDR hardware."
),
)
parser.add_argument(
"--hub",
required=True,
metavar="URL",
help="RIA Hub base URL, e.g. https://riahub.company.com",
)
parser.add_argument(
"--key",
required=True,
metavar="API_KEY",
help="Shared API key (must match [wac] API_KEY in the hub's app.ini)",
)
parser.add_argument(
"--name",
required=True,
metavar="NAME",
help='Human-readable name shown in the Target Node dropdown, e.g. "lab-bench-1"',
)
parser.add_argument(
"--device",
default="unknown",
metavar="SDR",
help=(
"SDR device type reported to the hub (informational only). "
"Examples: plutosdr, usrp_b210, rtlsdr, mock. Default: unknown"
),
)
parser.add_argument(
"--insecure",
action="store_true",
help="Disable TLS certificate verification (dev/self-signed certs only)",
)
parser.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging verbosity (default: INFO)",
)
args = parser.parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
stream=sys.stderr,
)
# Warn loudly if --insecure is used outside of development.
if args.insecure:
logger.warning(
"--insecure disables TLS certificate verification. "
"Only use this for local development with self-signed certs."
)
agent = NodeAgent(
hub_url=args.hub,
api_key=args.key,
name=args.name,
sdr_device=args.device,
insecure=args.insecure,
)
agent.run()
if __name__ == "__main__":
main()