463 lines
16 KiB
Python
463 lines
16 KiB
Python
|
|
"""RT-OSS Node Agent — connects to RIA Hub and dispatches work to local hardware.
|
|||
|
|
|
|||
|
|
The agent runs on any machine with an SDR attached and connects **outbound** to
|
|||
|
|
RIA Hub. No inbound ports need to be opened on the user's machine, and the
|
|||
|
|
connection works identically through NAT, corporate firewalls, or a Pi on a
|
|||
|
|
cellular link.
|
|||
|
|
|
|||
|
|
Usage::
|
|||
|
|
|
|||
|
|
ria-agent \\
|
|||
|
|
--hub https://riahub.company.com \\
|
|||
|
|
--key <api-key> \\
|
|||
|
|
--name lab-bench-1 \\
|
|||
|
|
[--device plutosdr] \\
|
|||
|
|
[--insecure]
|
|||
|
|
|
|||
|
|
The agent:
|
|||
|
|
1. Registers with RIA Hub and receives a ``node_id``.
|
|||
|
|
2. Sends a heartbeat every 30 s so the hub knows it is online.
|
|||
|
|
3. Long-polls ``GET /orchestrator/nodes/{id}/commands`` (30 s timeout).
|
|||
|
|
4. Executes received campaigns via :class:`ria_toolkit_oss.orchestration.executor.CampaignExecutor`.
|
|||
|
|
5. Uploads recordings to the hub via chunked POST, keeping each request
|
|||
|
|
under 50 MB so it passes through Cloudflare without needing the bypass
|
|||
|
|
subdomain.
|
|||
|
|
6. Deregisters cleanly on SIGINT / SIGTERM.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
import math
|
|||
|
|
import os
|
|||
|
|
import signal
|
|||
|
|
import sys
|
|||
|
|
import threading
|
|||
|
|
import time
|
|||
|
|
import uuid
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
logger = logging.getLogger("ria_agent")
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Tuneable constants
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
_HEARTBEAT_INTERVAL = 30 # seconds between heartbeats
|
|||
|
|
_POLL_TIMEOUT = 30 # server-side long-poll duration
|
|||
|
|
_POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server
|
|||
|
|
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
|
|||
|
|
_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB — well below Cloudflare's 100 MB limit
|
|||
|
|
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Agent
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
|
|||
|
|
class NodeAgent:
|
|||
|
|
"""Outbound-connecting agent that bridges RIA Hub to local SDR hardware.
|
|||
|
|
|
|||
|
|
All network I/O is initiated by the agent (outbound). RIA Hub never opens
|
|||
|
|
a connection back to the agent's machine.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
hub_url: str,
|
|||
|
|
api_key: str,
|
|||
|
|
name: str,
|
|||
|
|
sdr_device: str = "unknown",
|
|||
|
|
insecure: bool = False,
|
|||
|
|
) -> None:
|
|||
|
|
self.hub_url = hub_url.rstrip("/")
|
|||
|
|
self.api_key = api_key
|
|||
|
|
self.name = name
|
|||
|
|
self.sdr_device = sdr_device
|
|||
|
|
self.insecure = insecure
|
|||
|
|
|
|||
|
|
self.node_id: str | None = None
|
|||
|
|
self._stop = threading.Event()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import ria_toolkit_oss
|
|||
|
|
|
|||
|
|
self._ria_version: str = getattr(ria_toolkit_oss, "__version__", "unknown")
|
|||
|
|
except Exception:
|
|||
|
|
self._ria_version = "unknown"
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Public entry point
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def run(self) -> None:
|
|||
|
|
"""Register, start the heartbeat thread, and enter the command loop.
|
|||
|
|
|
|||
|
|
Blocks until SIGINT or SIGTERM is received.
|
|||
|
|
"""
|
|||
|
|
self._register()
|
|||
|
|
|
|||
|
|
def _shutdown(sig: int, _frame: Any) -> None:
|
|||
|
|
logger.info("Shutdown signal received — stopping agent")
|
|||
|
|
self._stop.set()
|
|||
|
|
|
|||
|
|
signal.signal(signal.SIGINT, _shutdown)
|
|||
|
|
signal.signal(signal.SIGTERM, _shutdown)
|
|||
|
|
|
|||
|
|
hb = threading.Thread(target=self._heartbeat_loop, daemon=True, name="ria-agent-heartbeat")
|
|||
|
|
hb.start()
|
|||
|
|
|
|||
|
|
logger.info("Agent %r online (node_id=%s, hub=%s)", self.name, self.node_id, self.hub_url)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
self._command_loop()
|
|||
|
|
finally:
|
|||
|
|
self._stop.set()
|
|||
|
|
self._deregister()
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Registration
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _register(self) -> None:
|
|||
|
|
resp = self._post(
|
|||
|
|
"/orchestrator/nodes/register",
|
|||
|
|
json={
|
|||
|
|
"name": self.name,
|
|||
|
|
"sdr_device": self.sdr_device,
|
|||
|
|
"ria_toolkit_version": self._ria_version,
|
|||
|
|
"capabilities": ["inference", "campaign"],
|
|||
|
|
},
|
|||
|
|
timeout=15,
|
|||
|
|
)
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
self.node_id = resp.json()["node_id"]
|
|||
|
|
logger.info("Registered as %r (node_id=%s)", self.name, self.node_id)
|
|||
|
|
|
|||
|
|
def _deregister(self) -> None:
|
|||
|
|
if not self.node_id:
|
|||
|
|
return
|
|||
|
|
try:
|
|||
|
|
self._delete(f"/orchestrator/nodes/{self.node_id}", timeout=10)
|
|||
|
|
logger.info("Deregistered %s", self.node_id)
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.debug("Deregister failed (ignored on shutdown): %s", exc)
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Heartbeat thread
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _heartbeat_loop(self) -> None:
|
|||
|
|
while not self._stop.wait(_HEARTBEAT_INTERVAL):
|
|||
|
|
try:
|
|||
|
|
resp = self._post(f"/orchestrator/nodes/{self.node_id}/heartbeat", timeout=10)
|
|||
|
|
if resp.status_code == 404:
|
|||
|
|
logger.warning("Heartbeat got 404 — hub lost registration, re-registering")
|
|||
|
|
self._register()
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.warning("Heartbeat failed: %s", exc)
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Command poll loop
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _command_loop(self) -> None:
|
|||
|
|
while not self._stop.is_set():
|
|||
|
|
try:
|
|||
|
|
resp = self._get(
|
|||
|
|
f"/orchestrator/nodes/{self.node_id}/commands",
|
|||
|
|
timeout=_POLL_CLIENT_TIMEOUT,
|
|||
|
|
)
|
|||
|
|
if resp.status_code == 204:
|
|||
|
|
# No command within the timeout window — loop immediately.
|
|||
|
|
continue
|
|||
|
|
if resp.status_code == 404:
|
|||
|
|
logger.warning("Command poll got 404 — re-registering")
|
|||
|
|
self._register()
|
|||
|
|
continue
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
cmd = resp.json()
|
|||
|
|
logger.info("Received command: %s", cmd.get("command"))
|
|||
|
|
self._dispatch(cmd)
|
|||
|
|
except Exception as exc:
|
|||
|
|
if not self._stop.is_set():
|
|||
|
|
logger.warning("Command poll error: %s — retrying in %ds", exc, _RECONNECT_PAUSE)
|
|||
|
|
time.sleep(_RECONNECT_PAUSE)
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Command dispatch
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _dispatch(self, cmd: dict) -> None:
|
|||
|
|
command = cmd.get("command")
|
|||
|
|
if command == "run_campaign":
|
|||
|
|
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
|
|||
|
|
config_dict: dict = cmd.get("payload") or {}
|
|||
|
|
threading.Thread(
|
|||
|
|
target=self._run_campaign,
|
|||
|
|
args=(campaign_id, config_dict),
|
|||
|
|
daemon=True,
|
|||
|
|
name=f"campaign-{campaign_id[:8]}",
|
|||
|
|
).start()
|
|||
|
|
else:
|
|||
|
|
logger.warning("Unknown command %r — ignored", command)
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Campaign execution
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _run_campaign(self, campaign_id: str, config_dict: dict) -> None:
|
|||
|
|
try:
|
|||
|
|
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
|||
|
|
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
|
|||
|
|
except ImportError as exc:
|
|||
|
|
logger.error(
|
|||
|
|
"Campaign %s cannot start — ria_toolkit_oss not fully installed: %s",
|
|||
|
|
campaign_id[:8],
|
|||
|
|
exc,
|
|||
|
|
)
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
logger.info("Campaign %s starting", campaign_id[:8])
|
|||
|
|
try:
|
|||
|
|
config = CampaignConfig.from_dict(config_dict)
|
|||
|
|
executor = CampaignExecutor(config)
|
|||
|
|
result = executor.run()
|
|||
|
|
logger.info("Campaign %s completed — uploading recordings", campaign_id[:8])
|
|||
|
|
self._upload_recordings(campaign_id, config, result)
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Recording upload (chunked for large files)
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _upload_recordings(self, campaign_id: str, config: Any, result: Any) -> None:
|
|||
|
|
output_repo: str | None = getattr(getattr(config, "output", None), "repo", None)
|
|||
|
|
if not output_repo or "/" not in output_repo:
|
|||
|
|
logger.warning("Campaign %s: no output.repo — skipping upload", campaign_id[:8])
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
repo_owner, repo_name = output_repo.split("/", 1)
|
|||
|
|
base_url = f"{self.hub_url}/datasets/upload"
|
|||
|
|
steps = getattr(result, "steps", None) or []
|
|||
|
|
|
|||
|
|
for step in steps:
|
|||
|
|
output_path: str | None = getattr(step, "output_path", None)
|
|||
|
|
if not output_path:
|
|||
|
|
continue
|
|||
|
|
device_id: str = getattr(step, "transmitter_id", "") or ""
|
|||
|
|
for fpath in _sigmf_files(output_path):
|
|||
|
|
filename = os.path.basename(fpath)
|
|||
|
|
metadata = {
|
|||
|
|
"filename": filename,
|
|||
|
|
"repo_owner": repo_owner,
|
|||
|
|
"repo_name": repo_name,
|
|||
|
|
"device_id": device_id,
|
|||
|
|
"campaign_id": campaign_id,
|
|||
|
|
}
|
|||
|
|
try:
|
|||
|
|
resp_data = self._upload_file(base_url, fpath, metadata)
|
|||
|
|
logger.info(
|
|||
|
|
"Campaign %s: uploaded %s (oid=%s)",
|
|||
|
|
campaign_id[:8],
|
|||
|
|
filename,
|
|||
|
|
resp_data.get("oid", "?"),
|
|||
|
|
)
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.warning("Campaign %s: upload of %s failed: %s", campaign_id[:8], filename, exc)
|
|||
|
|
|
|||
|
|
def _upload_file(self, base_url: str, file_path: str, metadata: dict) -> dict:
|
|||
|
|
"""Upload *file_path*, choosing chunked or direct path based on file size."""
|
|||
|
|
import requests as _requests
|
|||
|
|
|
|||
|
|
size = os.path.getsize(file_path)
|
|||
|
|
filename = os.path.basename(file_path)
|
|||
|
|
headers = {"X-API-Key": self.api_key}
|
|||
|
|
verify = not self.insecure
|
|||
|
|
|
|||
|
|
# Small files: single POST (unchanged endpoint, no assembly needed server-side).
|
|||
|
|
if size <= _DIRECT_THRESHOLD:
|
|||
|
|
with open(file_path, "rb") as fh:
|
|||
|
|
resp = _requests.post(
|
|||
|
|
base_url,
|
|||
|
|
headers=headers,
|
|||
|
|
files={"file": (filename, fh)},
|
|||
|
|
data=metadata,
|
|||
|
|
timeout=300,
|
|||
|
|
verify=verify,
|
|||
|
|
)
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
return resp.json()
|
|||
|
|
|
|||
|
|
# Large files: chunked upload — each request is ≤ 50 MB.
|
|||
|
|
total_chunks = math.ceil(size / _CHUNK_SIZE)
|
|||
|
|
upload_id = str(uuid.uuid4())
|
|||
|
|
chunk_url = base_url + "/chunk"
|
|||
|
|
|
|||
|
|
logger.info(
|
|||
|
|
"Chunked upload: %s (%d bytes, %d × %d MB chunks)",
|
|||
|
|
filename,
|
|||
|
|
size,
|
|||
|
|
total_chunks,
|
|||
|
|
_CHUNK_SIZE // (1024 * 1024),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
resp_data: dict = {}
|
|||
|
|
with open(file_path, "rb") as fh:
|
|||
|
|
for i in range(total_chunks):
|
|||
|
|
chunk = fh.read(_CHUNK_SIZE)
|
|||
|
|
resp = _requests.post(
|
|||
|
|
chunk_url,
|
|||
|
|
headers=headers,
|
|||
|
|
files={"file": (filename, chunk, "application/octet-stream")},
|
|||
|
|
data={
|
|||
|
|
**metadata,
|
|||
|
|
"upload_id": upload_id,
|
|||
|
|
"chunk_index": i,
|
|||
|
|
"total_chunks": total_chunks,
|
|||
|
|
},
|
|||
|
|
timeout=120,
|
|||
|
|
verify=verify,
|
|||
|
|
)
|
|||
|
|
if not resp.ok:
|
|||
|
|
raise RuntimeError(
|
|||
|
|
f"Chunk {i + 1}/{total_chunks} failed: " f"HTTP {resp.status_code}: {resp.text[:300]}"
|
|||
|
|
)
|
|||
|
|
resp_data = resp.json()
|
|||
|
|
logger.debug("Chunk %d/%d uploaded", i + 1, total_chunks)
|
|||
|
|
|
|||
|
|
return resp_data
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# HTTP helpers
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _get(self, path: str, **kwargs: Any):
|
|||
|
|
import requests as _requests
|
|||
|
|
|
|||
|
|
return _requests.get(
|
|||
|
|
f"{self.hub_url}{path}",
|
|||
|
|
headers={"X-API-Key": self.api_key},
|
|||
|
|
verify=not self.insecure,
|
|||
|
|
**kwargs,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _post(self, path: str, **kwargs: Any):
|
|||
|
|
import requests as _requests
|
|||
|
|
|
|||
|
|
return _requests.post(
|
|||
|
|
f"{self.hub_url}{path}",
|
|||
|
|
headers={"X-API-Key": self.api_key},
|
|||
|
|
verify=not self.insecure,
|
|||
|
|
**kwargs,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _delete(self, path: str, **kwargs: Any):
|
|||
|
|
import requests as _requests
|
|||
|
|
|
|||
|
|
return _requests.delete(
|
|||
|
|
f"{self.hub_url}{path}",
|
|||
|
|
headers={"X-API-Key": self.api_key},
|
|||
|
|
verify=not self.insecure,
|
|||
|
|
**kwargs,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Helpers
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _sigmf_files(data_path: str) -> list[str]:
|
|||
|
|
"""Return paths to both SigMF files (.sigmf-data and .sigmf-meta) for a recording."""
|
|||
|
|
candidates = [data_path]
|
|||
|
|
if data_path.endswith(".sigmf-data"):
|
|||
|
|
candidates.append(data_path[: -len(".sigmf-data")] + ".sigmf-meta")
|
|||
|
|
return [p for p in candidates if os.path.exists(p)]
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# CLI entry point
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main() -> None:
|
|||
|
|
import argparse
|
|||
|
|
|
|||
|
|
parser = argparse.ArgumentParser(
|
|||
|
|
prog="ria-agent",
|
|||
|
|
description=(
|
|||
|
|
"RT-OSS Node Agent — connects outbound to RIA Hub and executes "
|
|||
|
|
"campaigns / inference on local SDR hardware."
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--hub",
|
|||
|
|
required=True,
|
|||
|
|
metavar="URL",
|
|||
|
|
help="RIA Hub base URL, e.g. https://riahub.company.com",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--key",
|
|||
|
|
required=True,
|
|||
|
|
metavar="API_KEY",
|
|||
|
|
help="Shared API key (must match [wac] API_KEY in the hub's app.ini)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--name",
|
|||
|
|
required=True,
|
|||
|
|
metavar="NAME",
|
|||
|
|
help='Human-readable name shown in the Target Node dropdown, e.g. "lab-bench-1"',
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--device",
|
|||
|
|
default="unknown",
|
|||
|
|
metavar="SDR",
|
|||
|
|
help=(
|
|||
|
|
"SDR device type reported to the hub (informational only). "
|
|||
|
|
"Examples: plutosdr, usrp_b210, rtlsdr, mock. Default: unknown"
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--insecure",
|
|||
|
|
action="store_true",
|
|||
|
|
help="Disable TLS certificate verification (dev/self-signed certs only)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--log-level",
|
|||
|
|
default="INFO",
|
|||
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|||
|
|
help="Logging verbosity (default: INFO)",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
logging.basicConfig(
|
|||
|
|
level=getattr(logging, args.log_level),
|
|||
|
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|||
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|||
|
|
stream=sys.stderr,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Warn loudly if --insecure is used outside of development.
|
|||
|
|
if args.insecure:
|
|||
|
|
logger.warning(
|
|||
|
|
"--insecure disables TLS certificate verification. "
|
|||
|
|
"Only use this for local development with self-signed certs."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
agent = NodeAgent(
|
|||
|
|
hub_url=args.hub,
|
|||
|
|
api_key=args.key,
|
|||
|
|
name=args.name,
|
|||
|
|
sdr_device=args.device,
|
|||
|
|
insecure=args.insecure,
|
|||
|
|
)
|
|||
|
|
agent.run()
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|