From b955256479f21bebd4594b373a19b31f359bffee Mon Sep 17 00:00:00 2001 From: jonny Date: Thu, 16 Apr 2026 11:13:43 -0400 Subject: [PATCH] Pluto TX streaming functionality base --- src/ria_toolkit_oss/agent/cli.py | 62 ++- src/ria_toolkit_oss/agent/config.py | 33 +- src/ria_toolkit_oss/agent/hardware.py | 26 +- src/ria_toolkit_oss/agent/streamer.py | 607 ++++++++++++++++++++++--- src/ria_toolkit_oss/agent/ws_client.py | 17 +- src/ria_toolkit_oss/app/cli.py | 37 +- tests/agent/test_cli_tx.py | 111 +++++ tests/agent/test_config.py | 30 ++ tests/agent/test_disconnect.py | 8 +- tests/agent/test_full_duplex.py | 133 ++++++ tests/agent/test_hardware.py | 17 + tests/agent/test_integration_tx.py | 144 ++++++ tests/agent/test_streamer.py | 88 +++- tests/agent/test_streamer_tx.py | 133 ++++++ tests/agent/test_tx_safety.py | 167 +++++++ tests/agent/test_tx_underrun.py | 136 ++++++ tests/agent/test_ws_client.py | 103 +++++ 17 files changed, 1752 insertions(+), 100 deletions(-) create mode 100644 tests/agent/test_cli_tx.py create mode 100644 tests/agent/test_full_duplex.py create mode 100644 tests/agent/test_integration_tx.py create mode 100644 tests/agent/test_streamer_tx.py create mode 100644 tests/agent/test_tx_safety.py create mode 100644 tests/agent/test_tx_underrun.py diff --git a/src/ria_toolkit_oss/agent/cli.py b/src/ria_toolkit_oss/agent/cli.py index 0b06e72..6b88473 100644 --- a/src/ria_toolkit_oss/agent/cli.py +++ b/src/ria_toolkit_oss/agent/cli.py @@ -5,8 +5,8 @@ Subcommands: - ``ria-agent run [legacy args]`` — legacy long-poll NodeAgent (unchanged). - ``ria-agent stream`` — new WebSocket-based IQ streamer. - ``ria-agent detect`` — print SDR drivers whose modules import cleanly. -- ``ria-agent register --url URL --token TOKEN`` — save credentials to - ``~/.ria/agent.json``. +- ``ria-agent register --hub URL --api-key KEY`` — register with the hub and + save credentials (and optional TX interlocks) to ``~/.ria/agent.json``. Invoking ``ria-agent`` with no subcommand falls through to the legacy long-poll behavior for back-compatibility with existing deployments. @@ -69,9 +69,27 @@ def _cmd_register(args: argparse.Namespace) -> int: if args.name: cfg.name = args.name cfg.insecure = bool(args.insecure) + cfg.tx_enabled = bool(getattr(args, "allow_tx", False)) + if (v := getattr(args, "tx_max_gain_db", None)) is not None: + cfg.tx_max_gain_db = float(v) + if (v := getattr(args, "tx_max_duration_s", None)) is not None: + cfg.tx_max_duration_s = float(v) + freq_ranges = getattr(args, "tx_freq_range", None) or [] + if freq_ranges: + cfg.tx_allowed_freq_ranges = [[float(lo), float(hi)] for lo, hi in freq_ranges] path = _config.save(cfg) print(f"Registered agent: {agent_id}") + if cfg.tx_enabled: + caps: list[str] = [] + if cfg.tx_max_gain_db is not None: + caps.append(f"gain<={cfg.tx_max_gain_db} dB") + if cfg.tx_max_duration_s is not None: + caps.append(f"duration<={cfg.tx_max_duration_s} s") + if cfg.tx_allowed_freq_ranges: + caps.append(f"freq in {cfg.tx_allowed_freq_ranges}") + tail = f" ({', '.join(caps)})" if caps else "" + print(f"TX enabled{tail}") print(f"Credentials saved to {path}") return 0 @@ -85,8 +103,10 @@ def _cmd_stream(args: argparse.Namespace) -> int: if not url: print("error: --url is required (or run `ria-agent register` first)", file=sys.stderr) return 2 + if getattr(args, "allow_tx", False): + cfg.tx_enabled = True try: - asyncio.run(run_streamer(url, token)) + asyncio.run(run_streamer(url, token, cfg=cfg)) except KeyboardInterrupt: pass return 0 @@ -123,11 +143,47 @@ def main() -> None: p_reg.add_argument("--api-key", dest="api_key", required=True, help="Hub API key") p_reg.add_argument("--name", default=None, help="Human-friendly agent name") p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification") + p_reg.add_argument( + "--allow-tx", + dest="allow_tx", + action="store_true", + help="Opt this agent in to TX (required for any transmission from the hub)", + ) + p_reg.add_argument( + "--tx-max-gain-db", + dest="tx_max_gain_db", + type=float, + default=None, + help="Reject tx_start frames whose tx_gain exceeds this cap (dB)", + ) + p_reg.add_argument( + "--tx-max-duration-s", + dest="tx_max_duration_s", + type=float, + default=None, + help="Auto-stop any TX session after this many seconds", + ) + p_reg.add_argument( + "--tx-freq-range", + dest="tx_freq_range", + type=float, + nargs=2, + action="append", + metavar=("LO", "HI"), + default=None, + help="Allowed TX center-frequency range in Hz (repeat for multiple bands)", + ) p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer") p_stream.add_argument("--url", default=None, help="Override WebSocket URL") p_stream.add_argument("--token", default=None, help="Override bearer token") p_stream.add_argument("--log-level", default="INFO") + p_stream.add_argument( + "--allow-tx", + dest="allow_tx", + action="store_true", + help="Runtime override: enable TX for this process without writing config", + ) # Unknown extras are forwarded to the legacy CLI when command == "run". args, extras = parser.parse_known_args(argv) diff --git a/src/ria_toolkit_oss/agent/config.py b/src/ria_toolkit_oss/agent/config.py index d1f0e00..431094a 100644 --- a/src/ria_toolkit_oss/agent/config.py +++ b/src/ria_toolkit_oss/agent/config.py @@ -7,7 +7,11 @@ Schema:: "agent_id": "agent-abc123", "token": "rha_xxxx", "name": "lab-bench-1", - "insecure": false + "insecure": false, + "tx_enabled": false, + "tx_max_gain_db": null, + "tx_max_duration_s": null, + "tx_allowed_freq_ranges": null } """ @@ -18,7 +22,8 @@ import os from dataclasses import asdict, dataclass, field from pathlib import Path -_DEFAULT_PATH = Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json"))) +def _resolve_default_path() -> Path: + return Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json"))) @dataclass @@ -29,15 +34,29 @@ class AgentConfig: name: str = "" insecure: bool = False api_key: str = "" + tx_enabled: bool = False + tx_max_gain_db: float | None = None + tx_max_duration_s: float | None = None + tx_allowed_freq_ranges: list[list[float]] | None = None extra: dict = field(default_factory=dict) def default_path() -> Path: - return _DEFAULT_PATH + return _resolve_default_path() + + +def _coerce_ranges(raw) -> list[list[float]] | None: + if raw is None: + return None + out: list[list[float]] = [] + for pair in raw: + lo, hi = pair + out.append([float(lo), float(hi)]) + return out def load(path: Path | None = None) -> AgentConfig: - p = path or _DEFAULT_PATH + p = path or _resolve_default_path() if not p.exists(): return AgentConfig() data = json.loads(p.read_text()) @@ -50,12 +69,16 @@ def load(path: Path | None = None) -> AgentConfig: name=data.get("name", ""), insecure=bool(data.get("insecure", False)), api_key=data.get("api_key", ""), + tx_enabled=bool(data.get("tx_enabled", False)), + tx_max_gain_db=(float(v) if (v := data.get("tx_max_gain_db")) is not None else None), + tx_max_duration_s=(float(v) if (v := data.get("tx_max_duration_s")) is not None else None), + tx_allowed_freq_ranges=_coerce_ranges(data.get("tx_allowed_freq_ranges")), extra=extra, ) def save(cfg: AgentConfig, path: Path | None = None) -> Path: - p = path or _DEFAULT_PATH + p = path or _resolve_default_path() p.parent.mkdir(parents=True, exist_ok=True) data = asdict(cfg) extra = data.pop("extra", {}) or {} diff --git a/src/ria_toolkit_oss/agent/hardware.py b/src/ria_toolkit_oss/agent/hardware.py index 417bf1c..32a65e5 100644 --- a/src/ria_toolkit_oss/agent/hardware.py +++ b/src/ria_toolkit_oss/agent/hardware.py @@ -4,19 +4,41 @@ from __future__ import annotations from ria_toolkit_oss.sdr import detect_available +from .config import AgentConfig + def available_devices() -> list[str]: """Return a sorted list of device names whose driver modules import cleanly.""" return sorted(detect_available().keys()) -def heartbeat_payload(status: str = "idle", app_id: str | None = None) -> dict: - """Build the JSON body of a periodic heartbeat frame.""" +def heartbeat_payload( + status: str = "idle", + app_id: str | None = None, + *, + cfg: AgentConfig | None = None, + sessions: dict | None = None, +) -> dict: + """Build the JSON body of a periodic heartbeat frame. + + *cfg* drives the ``capabilities`` list and the ``tx_enabled`` flag. If not + supplied, the heartbeat advertises RX-only with ``tx_enabled=False`` — + matching the pre-TX shape. + """ + c = cfg or AgentConfig() + capabilities = ["rx"] + if c.tx_enabled: + capabilities.append("tx") + payload: dict = { "type": "heartbeat", "hardware": available_devices(), "status": status, + "capabilities": capabilities, + "tx_enabled": bool(c.tx_enabled), } if app_id: payload["app_id"] = app_id + if sessions: + payload["sessions"] = sessions return payload diff --git a/src/ria_toolkit_oss/agent/streamer.py b/src/ria_toolkit_oss/agent/streamer.py index 4d89743..8570a73 100644 --- a/src/ria_toolkit_oss/agent/streamer.py +++ b/src/ria_toolkit_oss/agent/streamer.py @@ -1,20 +1,33 @@ -"""Thin IQ-streaming agent. +"""IQ-streaming agent. Listens for control messages from the RIA Hub over a persistent WebSocket. -When the server sends ``start``, opens the SDR described in ``radio_config``, -loops over ``sdr.rx(buffer_size)``, and sends each buffer as raw -interleaved float32 bytes. ``stop`` closes the SDR; ``configure`` applies -parameter updates at the next capture boundary. +Supports: + +- An **RX session** (hub sends ``start``/``stop``/``configure``; agent opens + the SDR, loops ``sdr.rx()`` and ships raw interleaved float32 IQ). +- A **TX session** (hub sends ``tx_start``/``tx_stop``/``tx_configure`` plus + binary IQ frames; agent feeds them into ``sdr._stream_tx``). Phase 3 wires + up the session plumbing and rejects TX when ``cfg.tx_enabled`` is False; + Phase 4 implements the full TX loop. + +Both sessions can run concurrently on the same physical SDR (FDD) — a +ref-counted SDR registry shares one driver instance when RX and TX name the +same ``(device, identifier)``. """ from __future__ import annotations import asyncio import logging +import queue +import threading +import time +from dataclasses import dataclass, field from typing import Any import numpy as np +from .config import AgentConfig from .hardware import heartbeat_payload from .ws_client import WsClient @@ -23,6 +36,98 @@ logger = logging.getLogger("ria_agent.streamer") _DEFAULT_BUFFER_SIZE = 1024 +# --------------------------------------------------------------------------- +# Session dataclasses + + +@dataclass +class RxSession: + app_id: str + sdr: Any + device_key: tuple[str, str | None] + buffer_size: int + task: asyncio.Task | None = None + pending_config: dict = field(default_factory=dict) + + +@dataclass +class TxSession: + app_id: str + sdr: Any + device_key: tuple[str, str | None] + buffer_size: int + task: Any = None # concurrent.futures.Future from run_in_executor + pending_config: dict = field(default_factory=dict) + underrun_policy: str = "pause" + last_buffer: np.ndarray | None = None + stop_event: threading.Event = field(default_factory=threading.Event) + started_at: float = 0.0 + max_duration_s: float | None = None + state: str = "armed" + # Thread-safe queue of inbound interleaved-float32 IQ frames. Bounded so + # hub-side over-production triggers WS backpressure rather than memory + # growth in the agent. + in_queue: "queue.Queue[bytes]" = field(default_factory=lambda: queue.Queue(maxsize=8)) + # Set by the TX callback when it hits an underrun while policy=="pause"; + # asyncio side flips the session state and emits tx_status. + underrun_flag: threading.Event = field(default_factory=threading.Event) + + +# --------------------------------------------------------------------------- +# SDR registry (ref-counted so one Pluto handle serves RX + TX simultaneously) + + +class _SdrRegistry: + def __init__(self, factory): + self._factory = factory + self._instances: dict[tuple[str, str | None], tuple[Any, int]] = {} + self._lock = threading.Lock() + + def acquire(self, device: str, identifier: str | None) -> tuple[Any, tuple[str, str | None]]: + key = (device, identifier) + with self._lock: + if key in self._instances: + sdr, rc = self._instances[key] + self._instances[key] = (sdr, rc + 1) + return sdr, key + # Build outside the lock: driver init can be slow and we don't want to + # block concurrent releases on unrelated devices. + sdr = self._factory(device, identifier) + with self._lock: + if key in self._instances: + # Raced another acquirer; discard our duplicate and share theirs. + other_sdr, rc = self._instances[key] + try: + sdr.close() + except Exception: + pass + self._instances[key] = (other_sdr, rc + 1) + return other_sdr, key + self._instances[key] = (sdr, 1) + return sdr, key + + def release(self, key: tuple[str, str | None]) -> bool: + """Decrement refcount. Returns True if the caller owns the last reference + and should close the SDR.""" + with self._lock: + sdr, rc = self._instances.get(key, (None, 0)) + if sdr is None: + return False + if rc <= 1: + del self._instances[key] + return True + self._instances[key] = (sdr, rc - 1) + return False + + def refcount(self, key: tuple[str, str | None]) -> int: + with self._lock: + return self._instances.get(key, (None, 0))[1] + + +# --------------------------------------------------------------------------- +# Streamer + + class Streamer: """Main streamer loop. @@ -31,103 +136,186 @@ class Streamer: ws: Connected :class:`WsClient`. sdr_factory: - Callable ``(device, identifier) -> SDR``. Defaults to - :func:`ria_toolkit_oss.sdr.get_sdr_device`. Injectable for tests. + Callable ``(device, identifier) -> SDR``. Defaults to the helper in + :mod:`ria_toolkit_oss.sdr`. Injectable for tests. + cfg: + :class:`AgentConfig` for interlocks (``tx_enabled`` and caps) and + heartbeat capabilities. Defaults to an empty ``AgentConfig()`` which + leaves TX disabled. """ - def __init__(self, ws: WsClient, sdr_factory=None) -> None: + def __init__( + self, + ws, + sdr_factory=None, + cfg: AgentConfig | None = None, + ) -> None: self.ws = ws - self._sdr_factory = sdr_factory - self._app_id: str | None = None - self._sdr: Any = None - self._pending_config: dict = {} - self._capture_task: asyncio.Task | None = None - self._status = "idle" + self._cfg = cfg or AgentConfig() + self._registry = _SdrRegistry(sdr_factory or _default_sdr_factory) + self._rx: RxSession | None = None + self._tx: TxSession | None = None + # Pending radio_config accepted via ``configure`` before ``start``. + self._standalone_pending_config: dict = {} + # Cached asyncio event loop, set the first time a handler runs. Used + # to schedule async callbacks from the TX executor thread. + self._loop: asyncio.AbstractEventLoop | None = None + + # ------------------------------------------------------------------ + # Back-compat read-only shims for callers that check ``._sdr`` etc. + # Writes to these attributes are not supported — use the session objects. + + @property + def _sdr(self): + return self._rx.sdr if self._rx is not None else None + + @property + def _pending_config(self) -> dict: + return self._rx.pending_config if self._rx is not None else self._standalone_pending_config # ------------------------------------------------------------------ # WsClient wiring def build_heartbeat(self) -> dict: - return heartbeat_payload(status=self._status, app_id=self._app_id) + status = "streaming" if (self._rx is not None or self._tx is not None) else "idle" + app_id: str | None = None + if self._rx is not None: + app_id = self._rx.app_id + elif self._tx is not None: + app_id = self._tx.app_id + + sessions: dict[str, dict] = {} + if self._rx is not None: + sessions["rx"] = {"app_id": self._rx.app_id, "state": "streaming"} + if self._tx is not None: + sessions["tx"] = {"app_id": self._tx.app_id, "state": self._tx.state} + + return heartbeat_payload( + status=status, + app_id=app_id, + cfg=self._cfg, + sessions=sessions or None, + ) async def on_message(self, msg: dict) -> None: t = msg.get("type") - if t == "start": - await self._handle_start(msg) - elif t == "stop": - await self._handle_stop(msg) - elif t == "configure": - self._pending_config.update(msg.get("radio_config") or {}) - logger.debug("Queued configure: %s", self._pending_config) - else: + handler = { + "start": self._handle_rx_start, + "stop": self._handle_rx_stop, + "configure": self._handle_rx_configure, + "tx_start": self._handle_tx_start, + "tx_stop": self._handle_tx_stop, + "tx_configure": self._handle_tx_configure, + }.get(t) + if handler is None: logger.warning("Unknown server message type: %r", t) + return + await handler(msg) - # ------------------------------------------------------------------ - async def _handle_start(self, msg: dict) -> None: - if self._capture_task is not None and not self._capture_task.done(): + async def on_binary(self, data: bytes) -> None: + tx = self._tx + if tx is None: + logger.debug("Dropping %d-byte binary frame: no TX session", len(data)) + return + # Backpressure: if the TX queue is full, await briefly so the hub's + # ``await ws.send`` throttles naturally via TCP. We don't block + # indefinitely — a 2s stall means something else is wrong. + loop = asyncio.get_running_loop() + try: + await loop.run_in_executor(None, lambda: tx.in_queue.put(data, timeout=2.0)) + except queue.Full: + logger.warning("TX queue stalled; dropping frame") + + # ================================================================== + # RX + + async def _handle_rx_start(self, msg: dict) -> None: + if self._rx is not None: logger.warning("start received while already streaming — ignoring") return - self._app_id = msg.get("app_id") + app_id = msg.get("app_id") or "" radio_config = dict(msg.get("radio_config") or {}) device = radio_config.pop("device", None) identifier = radio_config.pop("identifier", None) buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE)) if not device: - await self._send_error("start missing radio_config.device") + await self._send_error(app_id, "start missing radio_config.device") return try: - factory = self._sdr_factory or _default_sdr_factory - self._sdr = factory(device, identifier) - _apply_sdr_config(self._sdr, radio_config) + sdr, device_key = self._registry.acquire(device, identifier) + _apply_sdr_config(sdr, radio_config) except Exception as exc: logger.exception("Failed to open SDR %r", device) - await self._send_error(f"SDR init failed: {exc}") + await self._send_error(app_id, f"SDR init failed: {exc}") return - self._status = "streaming" - await self._send_status("streaming") - self._capture_task = asyncio.create_task( - self._capture_loop(buffer_size), name="ria-streamer-capture" + # Inherit any pending config that was queued before start. + pending = dict(self._standalone_pending_config) + self._standalone_pending_config = {} + + session = RxSession( + app_id=app_id, + sdr=sdr, + device_key=device_key, + buffer_size=buffer_size, + pending_config=pending, + ) + self._rx = session + await self._send_status("streaming", app_id) + session.task = asyncio.create_task( + self._capture_loop(session), name="ria-streamer-capture" ) - async def _handle_stop(self, msg: dict) -> None: - if self._capture_task is not None: - self._capture_task.cancel() + async def _handle_rx_stop(self, msg: dict) -> None: + session = self._rx + if session is None: + return + if session.task is not None: + session.task.cancel() try: - await self._capture_task + await session.task except (asyncio.CancelledError, Exception): pass - self._capture_task = None - self._close_sdr() - self._app_id = None - self._status = "idle" - await self._send_status("idle") + self._close_session_sdr(session) + app_id = session.app_id + self._rx = None + await self._send_status("idle", app_id) - async def _capture_loop(self, buffer_size: int) -> None: + async def _handle_rx_configure(self, msg: dict) -> None: + cfg = dict(msg.get("radio_config") or {}) + if self._rx is not None: + self._rx.pending_config.update(cfg) + else: + self._standalone_pending_config.update(cfg) + logger.debug("Queued configure: %s", cfg) + + async def _capture_loop(self, session: RxSession) -> None: loop = asyncio.get_running_loop() try: while True: - if self._pending_config: - cfg = self._pending_config - self._pending_config = {} + if session.pending_config: + cfg = session.pending_config + session.pending_config = {} try: - _apply_sdr_config(self._sdr, cfg) + _apply_sdr_config(session.sdr, cfg) except Exception as exc: logger.warning("Applying configure failed: %s", exc) try: - samples = await loop.run_in_executor(None, self._sdr.rx, buffer_size) + samples = await loop.run_in_executor( + None, session.sdr.rx, session.buffer_size + ) except Exception as exc: from ria_toolkit_oss.sdr import SdrDisconnectedError if isinstance(exc, SdrDisconnectedError): logger.warning("SDR disconnected: %s", exc) - await self._send_error(f"SDR disconnected: {exc}") + await self._send_error(session.app_id, f"SDR disconnected: {exc}") else: logger.exception("SDR rx error") - await self._send_error(f"SDR capture failed: {exc}") + await self._send_error(session.app_id, f"SDR capture failed: {exc}") break payload = _samples_to_interleaved_float32(samples) @@ -139,29 +327,305 @@ class Streamer: except asyncio.CancelledError: raise finally: - self._close_sdr() + self._close_session_sdr(session) + # If the loop died on its own (e.g. SDR disconnect), clear the + # session handle so future ``start`` messages can proceed. + if self._rx is session: + self._rx = None - def _close_sdr(self) -> None: - if self._sdr is None: + # ================================================================== + # TX + + async def _handle_tx_start(self, msg: dict) -> None: + app_id = msg.get("app_id") or "" + radio_config = dict(msg.get("radio_config") or {}) + + # --- interlocks (agent-enforced; never trust the hub alone) --- + if not self._cfg.tx_enabled: + await self._send_tx_status(app_id, "error", "tx disabled on this agent") return + tx_gain = radio_config.get("tx_gain") + if ( + self._cfg.tx_max_gain_db is not None + and tx_gain is not None + and float(tx_gain) > float(self._cfg.tx_max_gain_db) + ): + await self._send_tx_status( + app_id, + "error", + f"tx_gain {tx_gain} exceeds cap {self._cfg.tx_max_gain_db}", + ) + return + tx_freq = radio_config.get("tx_center_frequency") + if self._cfg.tx_allowed_freq_ranges and tx_freq is not None: + f = float(tx_freq) + if not any(float(lo) <= f <= float(hi) for lo, hi in self._cfg.tx_allowed_freq_ranges): + await self._send_tx_status( + app_id, + "error", + f"tx_center_frequency {tx_freq} outside allowed ranges", + ) + return + + if self._tx is not None: + await self._send_tx_status(app_id, "error", "tx already active on this agent") + return + + # --- device --- + device = radio_config.pop("device", None) + identifier = radio_config.pop("identifier", None) + buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE)) + underrun_policy = str(radio_config.pop("underrun_policy", "pause")) + if underrun_policy not in ("pause", "zero", "repeat"): + await self._send_tx_status( + app_id, "error", f"invalid underrun_policy {underrun_policy!r}" + ) + return + if not device: + await self._send_tx_status(app_id, "error", "tx_start missing radio_config.device") + return + + device_key: tuple[str, str | None] | None = None + sdr: Any = None try: - self._sdr.close() + sdr, device_key = self._registry.acquire(device, identifier) + _apply_sdr_config(sdr, radio_config) + # Only call init_tx when the hub supplied the three required + # parameters. Drivers that gate _stream_tx on _tx_initialized + # (e.g. Pluto) need this; drivers that don't (e.g. Mock) tolerate + # its absence. + init_args = { + k: radio_config.get(f"tx_{k}") + for k in ("sample_rate", "center_frequency", "gain") + } + if hasattr(sdr, "init_tx") and all(v is not None for v in init_args.values()): + sdr.init_tx( + sample_rate=init_args["sample_rate"], + center_frequency=init_args["center_frequency"], + gain=init_args["gain"], + channel=radio_config.get("tx_channel", 0), + gain_mode=radio_config.get("tx_gain_mode", "manual"), + ) + except Exception as exc: + if device_key is not None: + if self._registry.release(device_key): + try: + sdr.close() + except Exception: + pass + logger.exception("Failed to init TX on %r", device) + await self._send_tx_status(app_id, "error", f"tx init failed: {exc}") + return + + self._loop = asyncio.get_running_loop() + session = TxSession( + app_id=app_id, + sdr=sdr, + device_key=device_key, + buffer_size=buffer_size, + underrun_policy=underrun_policy, + started_at=time.monotonic(), + max_duration_s=self._cfg.tx_max_duration_s, + ) + self._tx = session + await self._send_tx_status(app_id, "armed") + session.task = self._loop.run_in_executor(None, self._tx_executor_body, session) + # Spawn a small watchdog that transitions armed → transmitting when + # the first buffer has been consumed, and surfaces underrun / max- + # duration terminations back to the hub. + asyncio.create_task(self._tx_watchdog(session)) + + async def _handle_tx_stop(self, msg: dict) -> None: + session = self._tx + if session is None: + return + app_id = session.app_id + session.stop_event.set() + try: + session.sdr.pause_tx() + except Exception: + logger.debug("pause_tx raised during stop", exc_info=True) + # Wake the executor thread if it's blocked on ``queue.get``. + self._drain_tx_queue(session) + if session.task is not None: + try: + await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.5) + except asyncio.TimeoutError: + logger.warning("TX executor did not exit within 1.5s after stop") + except Exception: + logger.debug("TX executor raised on shutdown", exc_info=True) + self._close_session_sdr(session) + self._tx = None + await self._send_tx_status(app_id, "done") + + async def _handle_tx_configure(self, msg: dict) -> None: + if self._tx is None: + return + self._tx.pending_config.update(msg.get("radio_config") or {}) + + # ------------------------------------------------------------------ + # TX executor & watchdog + + def _tx_executor_body(self, session: TxSession) -> None: + try: + session.sdr._stream_tx(lambda n: self._tx_callback(session, n)) + except Exception: + logger.exception("TX stream crashed") + self._schedule(self._send_tx_status(session.app_id, "error", "tx stream crashed")) + + def _tx_callback(self, session: TxSession, num_samples) -> np.ndarray: + n = int(num_samples) + # Honor stop requests: return silence one last time and let the driver + # exit its loop on the next iteration (pause_tx flips _enable_tx). + if session.stop_event.is_set(): + return _silence(n) + + # Max-duration watchdog. + if ( + session.max_duration_s is not None + and (time.monotonic() - session.started_at) >= float(session.max_duration_s) + ): + session.stop_event.set() + try: + session.sdr.pause_tx() + except Exception: + pass + self._schedule(self._send_tx_status(session.app_id, "done", "max duration reached")) + return _silence(n) + + # Apply queued configure at buffer boundary. + if session.pending_config: + cfg = session.pending_config + session.pending_config = {} + try: + _apply_sdr_config(session.sdr, cfg) + except Exception as exc: + logger.debug("tx_configure apply failed: %s", exc) + + try: + raw = session.in_queue.get(timeout=0.1) + except queue.Empty: + return self._underrun_fill(session, n) + + arr = np.frombuffer(raw, dtype=np.float32) + if arr.size < 2 or arr.size % 2 != 0: + logger.warning("Malformed TX frame: %d floats (must be non-zero even count)", arr.size) + return self._underrun_fill(session, n) + samples = (arr[0::2].astype(np.complex64) + 1j * arr[1::2].astype(np.complex64)) + if samples.size < n: + out = np.zeros(n, dtype=np.complex64) + out[: samples.size] = samples + session.last_buffer = out + return out + if samples.size > n: + samples = samples[:n] + session.last_buffer = samples + if session.state == "armed": + session.state = "transmitting" + self._schedule(self._send_tx_status(session.app_id, "transmitting")) + return samples + + def _underrun_fill(self, session: TxSession, n: int) -> np.ndarray: + policy = session.underrun_policy + if policy == "zero": + return _silence(n) + if policy == "repeat" and session.last_buffer is not None: + buf = session.last_buffer + if buf.size == n: + return buf + if buf.size > n: + return buf[:n].copy() + out = np.zeros(n, dtype=np.complex64) + out[: buf.size] = buf + return out + # "pause" policy (default) or "repeat" before any buffer arrived. + if not session.underrun_flag.is_set(): + session.underrun_flag.set() + session.stop_event.set() + try: + session.sdr.pause_tx() except Exception: pass - self._sdr = None + return _silence(n) - async def _send_status(self, status: str) -> None: + async def _tx_watchdog(self, session: TxSession) -> None: + # Poll the underrun flag so we can emit status + tear down cleanly + # when the callback flips the flag from the executor thread. Check + # underrun_flag before stop_event, since the "pause" path sets both. + while session is self._tx: + if session.underrun_flag.is_set(): + await self._send_tx_status(session.app_id, "underrun") + await self._teardown_tx_after_underrun(session) + return + if session.stop_event.is_set(): + return + await asyncio.sleep(0.05) + + async def _teardown_tx_after_underrun(self, session: TxSession) -> None: + if self._tx is not session: + return + self._drain_tx_queue(session) + if session.task is not None: + try: + await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.0) + except asyncio.TimeoutError: + logger.warning("TX executor did not exit within 1s after underrun") + except Exception: + logger.debug("TX executor raised during underrun teardown", exc_info=True) + self._close_session_sdr(session) + if self._tx is session: + self._tx = None + + def _drain_tx_queue(self, session: TxSession) -> None: try: - await self.ws.send_json({"type": "status", "status": status, "app_id": self._app_id}) + while True: + session.in_queue.get_nowait() + except queue.Empty: + pass + + def _schedule(self, coro) -> None: + loop = self._loop + if loop is None: + return + try: + asyncio.run_coroutine_threadsafe(coro, loop) + except Exception: + logger.debug("_schedule failed", exc_info=True) + + # ================================================================== + # Helpers + + def _close_session_sdr(self, session) -> None: + if session.sdr is None: + return + should_close = self._registry.release(session.device_key) + if should_close: + try: + session.sdr.close() + except Exception: + logger.debug("SDR close raised", exc_info=True) + + async def _send_status(self, status: str, app_id: str) -> None: + try: + await self.ws.send_json({"type": "status", "status": status, "app_id": app_id}) except Exception as exc: logger.debug("Status send failed: %s", exc) - async def _send_error(self, message: str) -> None: + async def _send_error(self, app_id: str, message: str) -> None: try: - await self.ws.send_json({"type": "error", "app_id": self._app_id, "message": message}) + await self.ws.send_json({"type": "error", "app_id": app_id, "message": message}) except Exception as exc: logger.debug("Error-frame send failed: %s", exc) + async def _send_tx_status(self, app_id: str, state: str, message: str | None = None) -> None: + payload: dict = {"type": "tx_status", "app_id": app_id, "state": state} + if message is not None: + payload["message"] = message + try: + await self.ws.send_json(payload) + except Exception as exc: + logger.debug("tx_status send failed: %s", exc) + # --------------------------------------------------------------------------- # Helpers @@ -172,6 +636,10 @@ _CONFIG_ATTR_MAP = { "center_freq": ("center_freq", "rx_center_frequency"), "gain": ("gain", "rx_gain"), "bandwidth": ("bandwidth", "rx_bandwidth"), + "tx_sample_rate": ("tx_sample_rate",), + "tx_center_frequency": ("tx_center_frequency", "tx_lo"), + "tx_gain": ("tx_gain",), + "tx_bandwidth": ("tx_bandwidth",), } @@ -194,6 +662,11 @@ def _apply_sdr_config(sdr: Any, cfg: dict) -> None: logger.debug("radio_config key %r ignored (no matching attr)", key) +def _silence(num_samples: int) -> np.ndarray: + """Return a ``num_samples``-length zero-filled complex64 buffer.""" + return np.zeros(int(num_samples), dtype=np.complex64) + + def _samples_to_interleaved_float32(samples: Any) -> bytes: """Convert complex IQ samples (any numeric dtype) to interleaved float32 bytes.""" arr = np.asarray(samples) @@ -214,8 +687,12 @@ def _default_sdr_factory(device: str, identifier: str | None): # --------------------------------------------------------------------------- # Top-level entry -async def run_streamer(ws_url: str, token: str) -> None: +async def run_streamer(ws_url: str, token: str, *, cfg: AgentConfig | None = None) -> None: """Connect to *ws_url* and run the streamer loop until cancelled.""" ws = WsClient(ws_url, token) - streamer = Streamer(ws) - await ws.run(streamer.on_message, streamer.build_heartbeat) + streamer = Streamer(ws, cfg=cfg) + await ws.run( + streamer.on_message, + streamer.build_heartbeat, + on_binary=streamer.on_binary, + ) diff --git a/src/ria_toolkit_oss/agent/ws_client.py b/src/ria_toolkit_oss/agent/ws_client.py index 1bc66f6..a33991d 100644 --- a/src/ria_toolkit_oss/agent/ws_client.py +++ b/src/ria_toolkit_oss/agent/ws_client.py @@ -15,6 +15,7 @@ logger = logging.getLogger("ria_agent.ws") MessageHandler = Callable[[dict], Awaitable[None]] HeartbeatBuilder = Callable[[], dict] +BinaryHandler = Callable[[bytes], Awaitable[None]] class WsClient: @@ -65,7 +66,12 @@ class WsClient: self._stop.set() # ------------------------------------------------------------------ - async def run(self, on_message: MessageHandler, heartbeat: HeartbeatBuilder) -> None: + async def run( + self, + on_message: MessageHandler, + heartbeat: HeartbeatBuilder, + on_binary: BinaryHandler | None = None, + ) -> None: """Main loop: connect, heartbeat, dispatch messages, reconnect on drop.""" while not self._stop.is_set(): try: @@ -75,8 +81,13 @@ class WsClient: try: async for raw in self._ws: if isinstance(raw, bytes): - # Server shouldn't send binary to the agent; log and drop. - logger.debug("Discarding unexpected %d-byte binary frame", len(raw)) + if on_binary is None: + logger.debug("Discarding unexpected %d-byte binary frame", len(raw)) + continue + try: + await on_binary(raw) + except Exception: + logger.exception("on_binary handler raised; dropping frame") continue try: msg = json.loads(raw) diff --git a/src/ria_toolkit_oss/app/cli.py b/src/ria_toolkit_oss/app/cli.py index 6cd0c1c..9bfb479 100644 --- a/src/ria_toolkit_oss/app/cli.py +++ b/src/ria_toolkit_oss/app/cli.py @@ -21,6 +21,7 @@ from __future__ import annotations import argparse import json +import os import shutil import subprocess import sys @@ -77,24 +78,33 @@ def _inspect_labels(engine: list[str], ref: str) -> dict: return {} -def _hardware_flags(labels: dict) -> list[str]: +def _gpu_available() -> bool: + if os.path.exists("/dev/nvidia0"): + return True + return shutil.which("nvidia-smi") is not None + + +def _hardware_flags(labels: dict, no_gpu: bool, no_usb: bool, no_host_net: bool) -> tuple[list[str], list[str]]: flags: list[str] = [] + notes: list[str] = [] profile = (labels.get(_LABEL_PROFILE) or "").lower() hardware = (labels.get(_LABEL_HARDWARE) or "").lower() hw_items = {h.strip() for h in hardware.split(",") if h.strip()} - if "nvidia" in profile or "holoscan" in profile or "cuda" in profile: - flags += ["--gpus", "all"] + wants_gpu = any(k in profile for k in ("nvidia", "holoscan", "cuda")) + if wants_gpu and not no_gpu: + if _gpu_available(): + flags += ["--gpus", "all"] + else: + notes.append("image wants GPU but no NVIDIA runtime detected — skipping --gpus (use --force-gpu to override)") - needs_usb = hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} - if needs_usb: + if hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} and not no_usb: flags += ["--device", "/dev/bus/usb"] - needs_net = hw_items & {"usrp", "thinkrf", "pluto"} - if needs_net: + if hw_items & {"usrp", "thinkrf", "pluto"} and not no_host_net: flags += ["--net", "host"] - return flags + return flags, notes def _cmd_configure(args: argparse.Namespace) -> int: @@ -132,7 +142,10 @@ def _cmd_run(args: argparse.Namespace) -> int: return rc labels = _inspect_labels(engine, ref) - hw_flags = _hardware_flags(labels) + no_gpu = args.no_gpu and not args.force_gpu + hw_flags, notes = _hardware_flags(labels, no_gpu=no_gpu, no_usb=args.no_usb, no_host_net=args.no_host_net) + if args.force_gpu and "--gpus" not in hw_flags: + hw_flags = ["--gpus", "all", *hw_flags] cmd = [*engine, "run", "--rm"] if not args.foreground: @@ -162,6 +175,8 @@ def _cmd_run(args: argparse.Namespace) -> int: print(f"Running {ref} [{label_str}]") if hw_flags: print(f" auto flags: {' '.join(hw_flags)}") + for note in notes: + print(f" note: {note}") return subprocess.call(cmd) @@ -225,6 +240,10 @@ def main() -> None: p_run.add_argument("-v", "--volume", action="append", help="Extra volume mount") p_run.add_argument("-p", "--publish", action="append", help="Publish port") p_run.add_argument("--foreground", "-F", action="store_true", help="Run in foreground (no -d)") + p_run.add_argument("--no-gpu", action="store_true", help="Skip --gpus flag even if image wants GPU") + p_run.add_argument("--force-gpu", action="store_true", help="Force --gpus all even if no NVIDIA runtime detected") + p_run.add_argument("--no-usb", action="store_true", help="Skip --device /dev/bus/usb") + p_run.add_argument("--no-host-net", action="store_true", help="Skip --net host") p_run.add_argument("--dry-run", action="store_true", help="Print the container command and exit") p_run.add_argument("--docker-args", nargs=argparse.REMAINDER, help="Pass remaining args to docker/podman run") p_run.add_argument("--app-args", nargs=argparse.REMAINDER, help="Pass remaining args to the app entrypoint") diff --git a/tests/agent/test_cli_tx.py b/tests/agent/test_cli_tx.py new file mode 100644 index 0000000..1543d4c --- /dev/null +++ b/tests/agent/test_cli_tx.py @@ -0,0 +1,111 @@ +"""CLI flags for TX opt-in and interlocks.""" + +from __future__ import annotations + +import json +import sys +from unittest.mock import patch + +from ria_toolkit_oss.agent import cli as agent_cli +from ria_toolkit_oss.agent import config as agent_config + + +class _FakeResp: + def __init__(self, payload: dict): + self._payload = payload + + def read(self) -> bytes: + return json.dumps(self._payload).encode() + + def __enter__(self): + return self + + def __exit__(self, *_a): + return False + + +def _run_register(argv: list[str], cfg_path) -> int: + fake_resp = _FakeResp({"agent_id": "agent-1", "token": "tok-abc"}) + with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \ + patch("urllib.request.urlopen", return_value=fake_resp), \ + patch.object(sys, "argv", ["ria-agent", *argv]): + try: + agent_cli.main() + except SystemExit as exc: + return int(exc.code or 0) + return 0 + + +def test_register_without_allow_tx_keeps_tx_disabled(tmp_path): + cfg_path = tmp_path / "agent.json" + _run_register( + ["register", "--hub", "http://hub:3005", "--api-key", "K"], + cfg_path, + ) + cfg = agent_config.load(path=cfg_path) + assert cfg.agent_id == "agent-1" + assert cfg.tx_enabled is False + assert cfg.tx_max_gain_db is None + + +def test_register_with_allow_tx_and_caps(tmp_path): + cfg_path = tmp_path / "agent.json" + _run_register( + [ + "register", + "--hub", + "http://hub:3005", + "--api-key", + "K", + "--allow-tx", + "--tx-max-gain-db", + "-10", + "--tx-max-duration-s", + "60", + "--tx-freq-range", + "2.4e9", + "2.5e9", + "--tx-freq-range", + "5.7e9", + "5.8e9", + ], + cfg_path, + ) + cfg = agent_config.load(path=cfg_path) + assert cfg.tx_enabled is True + assert cfg.tx_max_gain_db == -10.0 + assert cfg.tx_max_duration_s == 60.0 + assert cfg.tx_allowed_freq_ranges == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]] + + +def test_stream_allow_tx_does_not_persist(tmp_path): + # Pre-register with tx_enabled=False, then simulate `stream --allow-tx`. + # The on-disk config must remain unchanged; the runtime flag is process-local. + cfg_path = tmp_path / "agent.json" + base = agent_config.AgentConfig( + hub_url="http://hub:3005", + agent_id="agent-1", + token="tok-abc", + tx_enabled=False, + ) + agent_config.save(base, path=cfg_path) + + captured: dict = {} + + async def _fake_run_streamer(url, token, *, cfg): + captured["cfg"] = cfg + return None + + with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \ + patch("ria_toolkit_oss.agent.streamer.run_streamer", new=_fake_run_streamer), \ + patch.object(sys, "argv", ["ria-agent", "stream", "--allow-tx"]): + try: + agent_cli.main() + except SystemExit: + pass + + # Runtime cfg had TX flipped on + assert captured["cfg"].tx_enabled is True + # But the persisted file is untouched + on_disk = agent_config.load(path=cfg_path) + assert on_disk.tx_enabled is False diff --git a/tests/agent/test_config.py b/tests/agent/test_config.py index 2532abd..7d2a6b4 100644 --- a/tests/agent/test_config.py +++ b/tests/agent/test_config.py @@ -20,6 +20,36 @@ def test_load_missing_returns_empty(tmp_path): assert loaded == agent_config.AgentConfig() +def test_tx_fields_round_trip(tmp_path): + p = tmp_path / "agent.json" + cfg = agent_config.AgentConfig( + hub_url="https://hub.example.com", + agent_id="agent-1", + token="t", + tx_enabled=True, + tx_max_gain_db=-10.0, + tx_max_duration_s=60.0, + tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]], + ) + agent_config.save(cfg, path=p) + loaded = agent_config.load(path=p) + assert loaded.tx_enabled is True + assert loaded.tx_max_gain_db == -10.0 + assert loaded.tx_max_duration_s == 60.0 + assert loaded.tx_allowed_freq_ranges == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]] + + +def test_tx_fields_default_when_absent(tmp_path): + # Old configs written before TX existed should load cleanly with safe defaults. + p = tmp_path / "agent.json" + p.write_text('{"hub_url": "x", "agent_id": "a", "token": "t"}') + cfg = agent_config.load(path=p) + assert cfg.tx_enabled is False + assert cfg.tx_max_gain_db is None + assert cfg.tx_max_duration_s is None + assert cfg.tx_allowed_freq_ranges is None + + def test_extra_keys_preserved(tmp_path): p = tmp_path / "agent.json" p.write_text('{"hub_url": "x", "custom": 42}') diff --git a/tests/agent/test_disconnect.py b/tests/agent/test_disconnect.py index f063e3a..3063613 100644 --- a/tests/agent/test_disconnect.py +++ b/tests/agent/test_disconnect.py @@ -67,9 +67,9 @@ def test_streamer_reports_disconnected_and_ends_capture(): "radio_config": {"device": "fake", "buffer_size": 8}, } ) - # Wait for the capture task to fail out. - for _ in range(50): - if streamer._capture_task and streamer._capture_task.done(): + # Wait for the capture loop to emit its error frame and tear down the session. + for _ in range(100): + if any(m.get("type") == "error" for m in ws.json_sent) and streamer._rx is None: break await asyncio.sleep(0.01) return ws, sdr, streamer @@ -79,3 +79,5 @@ def test_streamer_reports_disconnected_and_ends_capture(): errors = [m for m in ws.json_sent if m.get("type") == "error"] assert errors, "expected an error frame" assert "disconnected" in errors[-1]["message"].lower() + # Session handle cleared so future starts can proceed. + assert streamer._rx is None diff --git a/tests/agent/test_full_duplex.py b/tests/agent/test_full_duplex.py new file mode 100644 index 0000000..6ad2f62 --- /dev/null +++ b/tests/agent/test_full_duplex.py @@ -0,0 +1,133 @@ +"""Concurrent RX + TX sessions on the same agent — shared SDR via registry.""" + +from __future__ import annotations + +import asyncio +import time + +import numpy as np + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.sdr.mock import MockSDR + + +class FullDuplexMockSDR(MockSDR): + """MockSDR with a recording TX path so the test can assert both directions.""" + + def __init__(self, buffer_size: int): + super().__init__(buffer_size=buffer_size) + self.tx_produced: list[np.ndarray] = [] + + def _stream_tx(self, callback): + self._enable_tx = True + self._tx_initialized = True + while self._enable_tx: + result = callback(self.rx_buffer_size) + self.tx_produced.append(np.asarray(result).copy()) + time.sleep(0.005) + + +class FakeWs: + def __init__(self): + self.json_sent = [] + self.bytes_sent = [] + + async def send_json(self, p): + self.json_sent.append(p) + + async def send_bytes(self, b): + self.bytes_sent.append(b) + + +def _iq_frame(samples: np.ndarray) -> bytes: + interleaved = np.empty(samples.size * 2, dtype=np.float32) + interleaved[0::2] = samples.real + interleaved[1::2] = samples.imag + return interleaved.tobytes() + + +def test_rx_and_tx_share_one_sdr_instance(): + built: list[FullDuplexMockSDR] = [] + + def factory(device, identifier): + sdr = FullDuplexMockSDR(buffer_size=16) + built.append(sdr) + return sdr + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=factory, cfg=AgentConfig(tx_enabled=True)) + + # Start RX first. + await s.on_message( + { + "type": "start", + "app_id": "app-1", + "radio_config": {"device": "mock", "buffer_size": 16}, + } + ) + # Then start TX on the same device — should share the SDR handle. + await s.on_message( + { + "type": "tx_start", + "app_id": "app-1", + "radio_config": { + "device": "mock", + "buffer_size": 16, + "tx_gain": -20, + "tx_center_frequency": 2.45e9, + "underrun_policy": "zero", + }, + } + ) + + # Push a known TX buffer. + marker = np.arange(16, dtype=np.complex64) + 7 + await s.on_binary(_iq_frame(marker)) + + # Let both directions produce output. + for _ in range(80): + rx_ok = len(ws.bytes_sent) >= 2 + tx_ok = any(np.array_equal(b, marker) for b in built[0].tx_produced) if built else False + if rx_ok and tx_ok: + break + await asyncio.sleep(0.01) + + # Heartbeat should show both sessions. + hb = s.build_heartbeat() + + # Stop TX first, RX keeps running. + await s.on_message({"type": "tx_stop", "app_id": "app-1"}) + tx_after_stop = s._tx is None + rx_still_active = s._rx is not None + + # Now stop RX. + await s.on_message({"type": "stop", "app_id": "app-1"}) + + return ws, s, built, hb, tx_after_stop, rx_still_active + + ws, s, built, hb, tx_after_stop, rx_still_active = asyncio.run(scenario()) + + # One SDR was built and shared. + assert len(built) == 1, f"expected exactly one SDR instance, got {len(built)}" + + # Both directions produced output. + assert len(ws.bytes_sent) >= 1, "RX produced no IQ frames" + marker = np.arange(16, dtype=np.complex64) + 7 + assert any( + np.array_equal(b, marker) for b in built[0].tx_produced + ), "TX callback never saw the pushed marker buffer" + + # Heartbeat reflected both sessions while they were active. + assert hb["sessions"]["rx"]["app_id"] == "app-1" + assert hb["sessions"]["tx"]["app_id"] == "app-1" + + # Stopping TX does not tear down RX. + assert tx_after_stop + assert rx_still_active + + # After both stops, registry is empty. + assert s._registry.refcount(("mock", None)) == 0 + assert s._rx is None + assert s._tx is None diff --git a/tests/agent/test_hardware.py b/tests/agent/test_hardware.py index ab9fcdf..51b2e45 100644 --- a/tests/agent/test_hardware.py +++ b/tests/agent/test_hardware.py @@ -23,7 +23,24 @@ def test_heartbeat_payload_shape(): assert p["status"] == "idle" assert "mock" in p["hardware"] assert "app_id" not in p + # New fields, default shape + assert p["capabilities"] == ["rx"] + assert p["tx_enabled"] is False p2 = hardware.heartbeat_payload(status="streaming", app_id="abc") assert p2["status"] == "streaming" assert p2["app_id"] == "abc" + + +def test_heartbeat_payload_tx_capability_from_cfg(): + from ria_toolkit_oss.agent.config import AgentConfig + + p = hardware.heartbeat_payload(cfg=AgentConfig(tx_enabled=True)) + assert p["capabilities"] == ["rx", "tx"] + assert p["tx_enabled"] is True + + +def test_heartbeat_payload_sessions_field(): + sessions = {"rx": {"app_id": "a", "state": "streaming"}} + p = hardware.heartbeat_payload(status="streaming", app_id="a", sessions=sessions) + assert p["sessions"] == sessions diff --git a/tests/agent/test_integration_tx.py b/tests/agent/test_integration_tx.py new file mode 100644 index 0000000..4fc13af --- /dev/null +++ b/tests/agent/test_integration_tx.py @@ -0,0 +1,144 @@ +"""End-to-end: local websockets server drives a Streamer's TX path.""" + +from __future__ import annotations + +import asyncio +import json +import time + +import numpy as np +import websockets + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.agent.ws_client import WsClient +from ria_toolkit_oss.sdr.mock import MockSDR + + +class RecordingMockSDR(MockSDR): + def __init__(self, buffer_size: int): + super().__init__(buffer_size=buffer_size) + self.tx_produced: list[np.ndarray] = [] + + def _stream_tx(self, callback): + self._enable_tx = True + self._tx_initialized = True + while self._enable_tx: + result = callback(self.rx_buffer_size) + self.tx_produced.append(np.asarray(result).copy()) + time.sleep(0.005) + + +def _iq_frame(samples: np.ndarray) -> bytes: + interleaved = np.empty(samples.size * 2, dtype=np.float32) + interleaved[0::2] = samples.real + interleaved[1::2] = samples.imag + return interleaved.tobytes() + + +def test_server_tx_start_binary_stop_cycle_over_real_ws(): + BUF = 16 + sdr = RecordingMockSDR(buffer_size=BUF) + marker = np.arange(BUF, dtype=np.complex64) + 1 + + async def scenario(): + control_frames: list[dict] = [] + done = asyncio.Event() + + async def server_handler(ws): + try: + # Drain initial heartbeat. + first = await asyncio.wait_for(ws.recv(), timeout=2.0) + control_frames.append(json.loads(first)) + + await ws.send( + json.dumps( + { + "type": "tx_start", + "app_id": "tx-app", + "radio_config": { + "device": "mock", + "buffer_size": BUF, + "tx_sample_rate": 1_000_000, + "tx_center_frequency": 2.45e9, + "tx_gain": -20, + "underrun_policy": "zero", + }, + } + ) + ) + + # Push a few binary IQ frames. + for _ in range(3): + await ws.send(_iq_frame(marker)) + + # Wait for at least "armed" + "transmitting" statuses. + for _ in range(100): + msg = await asyncio.wait_for(ws.recv(), timeout=2.0) + if isinstance(msg, str): + control_frames.append(json.loads(msg)) + if any( + f.get("type") == "tx_status" and f.get("state") == "transmitting" + for f in control_frames + ): + break + + await ws.send(json.dumps({"type": "tx_stop", "app_id": "tx-app"})) + + # Drain trailing statuses. + try: + while True: + msg = await asyncio.wait_for(ws.recv(), timeout=0.5) + if isinstance(msg, str): + control_frames.append(json.loads(msg)) + except (asyncio.TimeoutError, Exception): + pass + finally: + done.set() + + server = await websockets.serve(server_handler, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + streamer = Streamer( + ws=client, + sdr_factory=lambda d, i: sdr, + cfg=AgentConfig(tx_enabled=True), + ) + task = asyncio.create_task( + client.run( + on_message=streamer.on_message, + heartbeat=streamer.build_heartbeat, + on_binary=streamer.on_binary, + ) + ) + await asyncio.wait_for(done.wait(), timeout=5.0) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return control_frames, streamer + + controls, streamer = asyncio.run(scenario()) + + # Heartbeat reached the server. + assert any(f.get("type") == "heartbeat" for f in controls) + # tx_status lifecycle: armed → transmitting → done. + tx_states = [f["state"] for f in controls if f.get("type") == "tx_status"] + assert tx_states[0] == "armed" + assert "transmitting" in tx_states + assert tx_states[-1] == "done" + # TX callback saw our marker buffer at least once. + assert any(np.array_equal(b, marker) for b in sdr.tx_produced) + # Session cleared. + assert streamer._tx is None diff --git a/tests/agent/test_streamer.py b/tests/agent/test_streamer.py index 1bb2081..2aa842e 100644 --- a/tests/agent/test_streamer.py +++ b/tests/agent/test_streamer.py @@ -46,15 +46,29 @@ def test_apply_sdr_config_sets_attributes(): def test_heartbeat_reflects_status_and_app(): - s = Streamer(ws=FakeWs(), sdr_factory=_factory) - hb = s.build_heartbeat() - assert hb["type"] == "heartbeat" - assert hb["status"] == "idle" - s._status = "streaming" - s._app_id = "app-42" - hb2 = s.build_heartbeat() - assert hb2["status"] == "streaming" - assert hb2["app_id"] == "app-42" + async def scenario(): + s = Streamer(ws=FakeWs(), sdr_factory=_factory) + hb = s.build_heartbeat() + assert hb["type"] == "heartbeat" + assert hb["status"] == "idle" + # capabilities default to rx-only + assert hb["capabilities"] == ["rx"] + assert hb["tx_enabled"] is False + + await s.on_message( + { + "type": "start", + "app_id": "app-42", + "radio_config": {"device": "mock", "buffer_size": 32}, + } + ) + hb2 = s.build_heartbeat() + assert hb2["status"] == "streaming" + assert hb2["app_id"] == "app-42" + assert hb2["sessions"]["rx"]["app_id"] == "app-42" + await s.on_message({"type": "stop", "app_id": "app-42"}) + + asyncio.run(scenario()) def test_full_start_stream_stop_cycle(): @@ -89,7 +103,7 @@ def test_full_start_stream_stop_cycle(): statuses = [m for m in ws.json_sent if m.get("type") == "status"] assert statuses[0]["status"] == "streaming" assert statuses[-1]["status"] == "idle" - assert streamer._sdr is None + assert streamer._rx is None def test_start_without_device_emits_error(): @@ -110,6 +124,7 @@ def test_configure_queues_update(): await streamer.on_message( {"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}} ) + # Before start(), pending config lives on the standalone dict exposed via the _pending_config shim. return streamer._pending_config pending = asyncio.run(scenario()) @@ -122,3 +137,56 @@ def test_unknown_message_type_is_ignored(): await s.on_message({"type": "nope"}) asyncio.run(scenario()) + + +def test_registry_shares_sdr_across_start_stop_cycles(): + # Two sequential start/stop cycles with the same (device, identifier) + # should hit the registry's cache path rather than constructing a new SDR. + built: list[MockSDR] = [] + + def counting_factory(device: str, identifier): + sdr = MockSDR(buffer_size=16, seed=0) + built.append(sdr) + return sdr + + async def scenario(): + s = Streamer(ws=FakeWs(), sdr_factory=counting_factory) + for _ in range(2): + await s.on_message( + { + "type": "start", + "app_id": "a", + "radio_config": {"device": "mock", "buffer_size": 16}, + } + ) + # Let one capture buffer flow before stopping so the loop is engaged. + await asyncio.sleep(0.02) + await s.on_message({"type": "stop", "app_id": "a"}) + + asyncio.run(scenario()) + # A new SDR per cycle (we fully close between starts) — registry refcount + # drops to zero on each stop. This test confirms close-and-rebuild works; + # the ref-counting share-while-open case is covered in the full-duplex tests. + assert len(built) == 2 + + +def test_tx_start_rejected_when_tx_disabled(): + from ria_toolkit_oss.agent.config import AgentConfig + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=_factory, cfg=AgentConfig(tx_enabled=False)) + await s.on_message( + { + "type": "tx_start", + "app_id": "a", + "radio_config": {"device": "mock", "tx_center_frequency": 2.45e9, "tx_gain": -20}, + } + ) + return ws + + ws = asyncio.run(scenario()) + tx_statuses = [m for m in ws.json_sent if m.get("type") == "tx_status"] + assert tx_statuses, "expected a tx_status frame" + assert tx_statuses[-1]["state"] == "error" + assert "disabled" in tx_statuses[-1]["message"].lower() diff --git a/tests/agent/test_streamer_tx.py b/tests/agent/test_streamer_tx.py new file mode 100644 index 0000000..6cb2bb4 --- /dev/null +++ b/tests/agent/test_streamer_tx.py @@ -0,0 +1,133 @@ +"""TX streaming happy path + shutdown semantics.""" + +from __future__ import annotations + +import asyncio +import time + +import numpy as np + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.sdr.mock import MockSDR + + +class RecordingMockSDR(MockSDR): + """MockSDR that records each TX callback's returned buffer.""" + + def __init__(self, buffer_size: int): + super().__init__(buffer_size=buffer_size) + self.tx_produced: list[np.ndarray] = [] + + def _stream_tx(self, callback) -> None: + self._enable_tx = True + self._tx_initialized = True + while self._enable_tx: + result = callback(self.rx_buffer_size) + self.tx_produced.append(np.asarray(result)) + time.sleep(0.005) + + +class FakeWs: + def __init__(self): + self.json_sent: list[dict] = [] + self.bytes_sent: list[bytes] = [] + + async def send_json(self, payload): + self.json_sent.append(payload) + + async def send_bytes(self, data): + self.bytes_sent.append(data) + + +def _iq_frame(samples: np.ndarray) -> bytes: + interleaved = np.empty(samples.size * 2, dtype=np.float32) + interleaved[0::2] = samples.real + interleaved[1::2] = samples.imag + return interleaved.tobytes() + + +def test_tx_start_streams_binary_to_callback(): + BUF = 16 + sdr = RecordingMockSDR(buffer_size=BUF) + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) + + # Frames of distinct content so we can assert ordering. + frame_a = np.arange(BUF, dtype=np.complex64) * (1 + 0j) + frame_b = (np.arange(BUF, dtype=np.complex64) + BUF) * (1 + 0j) + frame_c = (np.arange(BUF, dtype=np.complex64) + 2 * BUF) * (1 + 0j) + + await s.on_message( + { + "type": "tx_start", + "app_id": "app-1", + "radio_config": { + "device": "mock", + "buffer_size": BUF, + "tx_sample_rate": 1_000_000, + "tx_center_frequency": 2.45e9, + "tx_gain": -20, + "underrun_policy": "zero", + }, + } + ) + # Push three IQ frames. + await s.on_binary(_iq_frame(frame_a)) + await s.on_binary(_iq_frame(frame_b)) + await s.on_binary(_iq_frame(frame_c)) + + # Let the executor thread consume them. + for _ in range(100): + # At least the 3 real frames, plus any zero-fill from before they + # arrived. We stop once 3 non-trivial buffers are recorded. + nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)] + if len(nontrivial) >= 3: + break + await asyncio.sleep(0.01) + + await s.on_message({"type": "tx_stop", "app_id": "app-1"}) + return ws, sdr, s + + ws, sdr, streamer = asyncio.run(scenario()) + + nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)] + assert len(nontrivial) >= 3, "expected ≥3 nontrivial TX buffers" + + # First three nontrivial buffers match the order we pushed them. + np.testing.assert_array_equal(nontrivial[0], np.arange(BUF, dtype=np.complex64)) + np.testing.assert_array_equal(nontrivial[1], np.arange(BUF, 2 * BUF, dtype=np.complex64)) + np.testing.assert_array_equal(nontrivial[2], np.arange(2 * BUF, 3 * BUF, dtype=np.complex64)) + + # Lifecycle: armed → transmitting → done. + states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"] + assert states[0] == "armed" + assert "transmitting" in states + assert states[-1] == "done" + # Session cleared. + assert streamer._tx is None + + +def test_tx_stop_releases_sdr(): + sdr = RecordingMockSDR(buffer_size=8) + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) + await s.on_message( + { + "type": "tx_start", + "app_id": "a", + "radio_config": {"device": "mock", "buffer_size": 8, "underrun_policy": "zero"}, + } + ) + await asyncio.sleep(0.03) + await s.on_message({"type": "tx_stop", "app_id": "a"}) + return s + + s = asyncio.run(scenario()) + # After stop, the registry has no outstanding references to ("mock", None). + assert s._registry.refcount(("mock", None)) == 0 + assert s._tx is None diff --git a/tests/agent/test_tx_safety.py b/tests/agent/test_tx_safety.py new file mode 100644 index 0000000..5307917 --- /dev/null +++ b/tests/agent/test_tx_safety.py @@ -0,0 +1,167 @@ +"""Agent-side TX interlocks: gain cap, freq ranges, duplicate sessions, disabled.""" + +from __future__ import annotations + +import asyncio + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.sdr.mock import MockSDR + + +class FakeWs: + def __init__(self): + self.json_sent = [] + self.bytes_sent = [] + + async def send_json(self, p): + self.json_sent.append(p) + + async def send_bytes(self, b): + self.bytes_sent.append(b) + + +def _last_tx_status(ws): + frames = [m for m in ws.json_sent if m.get("type") == "tx_status"] + return frames[-1] if frames else None + + +def _tx_start(app_id="a", **radio): + rc = {"device": "mock", "buffer_size": 16, "underrun_policy": "zero"} + rc.update(radio) + return {"type": "tx_start", "app_id": app_id, "radio_config": rc} + + +def _make_streamer(cfg): + built: list = [] + + def factory(device, identifier): + sdr = MockSDR(buffer_size=16) + built.append(sdr) + return sdr + + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=factory, cfg=cfg) + return s, ws, built + + +def test_rejects_when_tx_disabled(): + async def scenario(): + s, ws, built = _make_streamer(AgentConfig(tx_enabled=False)) + await s.on_message(_tx_start(tx_gain=-20, tx_center_frequency=2.45e9)) + return s, ws, built + + s, ws, built = asyncio.run(scenario()) + status = _last_tx_status(ws) + assert status and status["state"] == "error" + assert "disabled" in status["message"].lower() + assert not built, "SDR should never have been constructed" + assert s._tx is None + + +def test_rejects_when_tx_gain_exceeds_cap(): + async def scenario(): + s, ws, built = _make_streamer(AgentConfig(tx_enabled=True, tx_max_gain_db=-15.0)) + await s.on_message(_tx_start(tx_gain=-5, tx_center_frequency=2.45e9)) + return ws, built + + ws, built = asyncio.run(scenario()) + status = _last_tx_status(ws) + assert status and status["state"] == "error" + assert "exceeds cap" in status["message"] + assert not built + + +def test_allows_gain_at_cap_boundary(): + async def scenario(): + s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True, tx_max_gain_db=-10.0)) + await s.on_message(_tx_start(tx_gain=-10, tx_center_frequency=2.45e9)) + # Stop promptly to avoid keeping an executor thread around. + await asyncio.sleep(0.02) + await s.on_message({"type": "tx_stop", "app_id": "a"}) + return ws + + ws = asyncio.run(scenario()) + states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"] + assert "armed" in states + assert states[-1] == "done" + + +def test_rejects_when_freq_outside_ranges(): + async def scenario(): + s, ws, built = _make_streamer( + AgentConfig( + tx_enabled=True, + tx_allowed_freq_ranges=[[2.4e9, 2.5e9]], + ) + ) + await s.on_message(_tx_start(tx_center_frequency=5.8e9, tx_gain=-20)) + return ws, built + + ws, built = asyncio.run(scenario()) + status = _last_tx_status(ws) + assert status and status["state"] == "error" + assert "outside allowed ranges" in status["message"] + assert not built + + +def test_allows_freq_inside_a_range(): + async def scenario(): + s, ws, _ = _make_streamer( + AgentConfig( + tx_enabled=True, + tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]], + ) + ) + await s.on_message(_tx_start(tx_center_frequency=5.75e9, tx_gain=-20)) + await asyncio.sleep(0.02) + await s.on_message({"type": "tx_stop", "app_id": "a"}) + return ws + + ws = asyncio.run(scenario()) + states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"] + assert "armed" in states + assert states[-1] == "done" + + +def test_rejects_duplicate_tx_session(): + async def scenario(): + s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True)) + await s.on_message(_tx_start(app_id="a", tx_gain=-20, tx_center_frequency=2.45e9)) + await asyncio.sleep(0.01) + await s.on_message(_tx_start(app_id="b", tx_gain=-20, tx_center_frequency=2.45e9)) + # Let the second request process, then stop cleanly. + await asyncio.sleep(0.01) + await s.on_message({"type": "tx_stop", "app_id": "a"}) + return ws + + ws = asyncio.run(scenario()) + errors = [ + m for m in ws.json_sent + if m.get("type") == "tx_status" and m.get("state") == "error" + ] + assert any("already active" in e.get("message", "") for e in errors) + + +def test_rejects_invalid_underrun_policy(): + async def scenario(): + s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True)) + await s.on_message( + { + "type": "tx_start", + "app_id": "a", + "radio_config": { + "device": "mock", + "buffer_size": 8, + "tx_gain": -20, + "tx_center_frequency": 2.45e9, + "underrun_policy": "teleport", + }, + } + ) + return ws + + ws = asyncio.run(scenario()) + status = _last_tx_status(ws) + assert status and status["state"] == "error" + assert "underrun_policy" in status["message"] diff --git a/tests/agent/test_tx_underrun.py b/tests/agent/test_tx_underrun.py new file mode 100644 index 0000000..e95feec --- /dev/null +++ b/tests/agent/test_tx_underrun.py @@ -0,0 +1,136 @@ +"""Underrun policies: pause, zero, repeat.""" + +from __future__ import annotations + +import asyncio +import time + +import numpy as np + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.sdr.mock import MockSDR + + +class RecordingMockSDR(MockSDR): + def __init__(self, buffer_size: int): + super().__init__(buffer_size=buffer_size) + self.tx_produced: list[np.ndarray] = [] + + def _stream_tx(self, callback): + self._enable_tx = True + self._tx_initialized = True + while self._enable_tx: + result = callback(self.rx_buffer_size) + self.tx_produced.append(np.asarray(result).copy()) + time.sleep(0.005) + + +class FakeWs: + def __init__(self): + self.json_sent = [] + self.bytes_sent = [] + + async def send_json(self, p): + self.json_sent.append(p) + + async def send_bytes(self, b): + self.bytes_sent.append(b) + + +def _iq_frame(samples: np.ndarray) -> bytes: + interleaved = np.empty(samples.size * 2, dtype=np.float32) + interleaved[0::2] = samples.real + interleaved[1::2] = samples.imag + return interleaved.tobytes() + + +def _start_cfg(policy: str, buf: int = 8) -> dict: + return { + "type": "tx_start", + "app_id": "a", + "radio_config": { + "device": "mock", + "buffer_size": buf, + "tx_gain": -20, + "tx_center_frequency": 2.45e9, + "underrun_policy": policy, + }, + } + + +def test_underrun_pause_stops_session_and_emits_status(): + sdr = RecordingMockSDR(buffer_size=8) + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) + await s.on_message(_start_cfg("pause")) + # Do not push any buffers. The callback underruns on first tick and + # the watchdog should emit "underrun" and tear down. + for _ in range(100): + if any( + m.get("type") == "tx_status" and m.get("state") == "underrun" + for m in ws.json_sent + ): + break + await asyncio.sleep(0.01) + for _ in range(50): + if s._tx is None: + break + await asyncio.sleep(0.01) + return ws, s + + ws, s = asyncio.run(scenario()) + states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"] + assert "underrun" in states + assert s._tx is None + + +def test_underrun_zero_keeps_session_alive(): + sdr = RecordingMockSDR(buffer_size=8) + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) + await s.on_message(_start_cfg("zero")) + # Let it produce several underrun-filled buffers. + await asyncio.sleep(0.08) + still_alive = s._tx is not None + await s.on_message({"type": "tx_stop", "app_id": "a"}) + return ws, still_alive + + ws, still_alive = asyncio.run(scenario()) + # No underrun status emitted (policy absorbs it silently). + assert not any( + m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent + ) + assert still_alive + # All produced buffers are zero (no real data was pushed). + assert sdr.tx_produced, "expected at least one TX callback invocation" + assert all(not np.any(b != 0) for b in sdr.tx_produced) + + +def test_underrun_repeat_replays_last_buffer(): + BUF = 8 + sdr = RecordingMockSDR(buffer_size=BUF) + marker = np.arange(BUF, dtype=np.complex64) + 1 # distinct non-zero buffer + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) + await s.on_message(_start_cfg("repeat", buf=BUF)) + await s.on_binary(_iq_frame(marker)) + # Give the executor time to consume the real frame + several repeats. + await asyncio.sleep(0.08) + await s.on_message({"type": "tx_stop", "app_id": "a"}) + return ws, sdr + + ws, sdr = asyncio.run(scenario()) + # No underrun status emitted. + assert not any( + m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent + ) + # At least two buffers equal to the marker — the real one and ≥1 repeat. + matching = [b for b in sdr.tx_produced if np.array_equal(b, marker)] + assert len(matching) >= 2, f"expected ≥2 buffers matching marker, got {len(matching)}" diff --git a/tests/agent/test_ws_client.py b/tests/agent/test_ws_client.py index 0994a5b..4061f32 100644 --- a/tests/agent/test_ws_client.py +++ b/tests/agent/test_ws_client.py @@ -113,6 +113,109 @@ def test_reconnects_after_server_drop(): assert n >= 2 +def test_binary_frame_forwarded_to_handler(): + payload = bytes(range(128)) + + async def scenario(): + received: list[bytes] = [] + done = asyncio.Event() + + async def handler(ws): + await ws.send(payload) + done.set() + try: + await ws.wait_closed() + except Exception: + pass + + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + + async def on_bin(data): + received.append(data) + + task = asyncio.create_task( + client.run( + on_message=lambda _m: asyncio.sleep(0), + heartbeat=lambda: {"type": "heartbeat"}, + on_binary=on_bin, + ) + ) + for _ in range(50): + if received: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return received + + received = asyncio.run(scenario()) + assert received == [payload] + + +def test_binary_frame_dropped_when_no_handler(): + # Regression guard: existing behavior (drop server-sent binary) preserved when + # on_binary is not supplied. + async def scenario(): + crashes: list[Exception] = [] + + async def handler(ws): + await ws.send(b"\x00\x01\x02\x03") + await ws.send(json.dumps({"type": "ping"})) + try: + await ws.wait_closed() + except Exception: + pass + + messages: list[dict] = [] + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + + async def on_msg(m): + messages.append(m) + + task = asyncio.create_task( + client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}) + ) + for _ in range(50): + if messages: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception) as exc: + crashes.append(exc) + finally: + server.close() + await server.wait_closed() + return messages, crashes + + messages, crashes = asyncio.run(scenario()) + # JSON still delivered; binary silently dropped; no uncaught crash. + assert messages and messages[0] == {"type": "ping"} + + def test_malformed_control_frame_does_not_crash(): async def scenario(): handled: list[dict] = []