From 8c247f9f7a7838d46ff1dff96485fbcbf1938430 Mon Sep 17 00:00:00 2001 From: jonny Date: Thu, 16 Apr 2026 15:12:56 -0400 Subject: [PATCH] transmit further updates --- scripts/pluto_tx_smoke.py | 225 +++++++++++++++++++++ scripts/pluto_tx_ws_smoke.py | 236 ++++++++++++++++++++++ src/ria_toolkit_oss/agent/hardware.py | 12 ++ src/ria_toolkit_oss/agent/streamer.py | 58 +++++- src/ria_toolkit_oss/sdr/pluto.py | 128 ++++++------ tests/agent/test_hardware.py | 27 +++ tests/agent/test_param_lock_contention.py | 210 +++++++++++++++++++ tests/agent/test_streamer.py | 15 ++ tests/agent/test_ws_client.py | 110 +--------- tests/agent/test_ws_client_binary.py | 186 +++++++++++++++++ 10 files changed, 1042 insertions(+), 165 deletions(-) create mode 100755 scripts/pluto_tx_smoke.py create mode 100755 scripts/pluto_tx_ws_smoke.py create mode 100644 tests/agent/test_param_lock_contention.py create mode 100644 tests/agent/test_ws_client_binary.py diff --git a/scripts/pluto_tx_smoke.py b/scripts/pluto_tx_smoke.py new file mode 100755 index 0000000..64adbb9 --- /dev/null +++ b/scripts/pluto_tx_smoke.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +"""Transmit a continuous tone through the agent's TX pipeline on a real Pluto. + +End-to-end smoke test for the Pluto + Streamer TX path. Drives the same +``Streamer`` the hub talks to, but in-process with a logging ``FakeWs`` so +the script is self-contained — no hub required. + +Default: 100 kHz baseband tone × 2 450 MHz LO → carrier at 2 450.1 MHz, +continuous until you Ctrl-C (or the ``--duration`` timer fires). A spectrum +analyzer tuned to 2 450.1 MHz should show a clean CW spike as long as +``tx_status: transmitting`` prints. + +Usage:: + + python3 scripts/pluto_tx_smoke.py # auto-discover Pluto + python3 scripts/pluto_tx_smoke.py --identifier 192.168.3.1 + python3 scripts/pluto_tx_smoke.py --frequency 2.4e9 --gain -20 --duration 60 + +Flags map 1:1 onto the agent's ``radio_config``: + + --identifier Pluto IP or hostname (omitted → ip:pluto.local). + --frequency TX LO in Hz. Default 2 450 MHz. + --gain Pluto TX gain in dB. Pluto range is ``[-89, 0]``; more negative + = more attenuation = less power. Default -30. + --sample-rate Baseband sample rate. Default 1 MHz. + --tone Baseband tone offset in Hz. Default 100 kHz; set 0 for DC + (unmodulated carrier at exactly --frequency, but Pluto's + LO leakage will dominate). + --buffer-size Complex samples per WS frame. Default 4096. + --duration Stop after this many seconds (0 = run until Ctrl-C). +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import signal +import sys + +import numpy as np + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer + + +class LoggingFakeWs: + """In-process stand-in for the hub's WebSocket. + + Prints every ``tx_status`` + ``error`` frame the Streamer emits so the + operator can watch the lifecycle (armed → transmitting → done) on stdout. + """ + + async def send_json(self, payload: dict) -> None: + t = payload.get("type") + if t == "tx_status": + state = payload.get("state") + msg = payload.get("message") + tail = f" — {msg}" if msg else "" + print(f"[tx_status] {state}{tail}") + elif t == "error": + print(f"[error] {payload.get('message')}") + + async def send_bytes(self, data: bytes) -> None: + # Agent side won't send RX bytes in this script (no RX session). + pass + + +def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float, + phase_offset: float = 0.0) -> tuple[bytes, float]: + """Return ``(interleaved_float32_bytes, next_phase)`` for a sine tone. + + Emitting one continuous phase-coherent tone requires threading the phase + across frames; the returned ``next_phase`` should be fed back as + ``phase_offset`` on the next call so the sinusoid doesn't glitch at frame + boundaries. Amplitude is 0.7 to leave some headroom below the [-1, 1] cap + that ``_verify_sample_format`` polices elsewhere in the toolkit. + """ + n = np.arange(buffer_size, dtype=np.float64) + phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset + amp = 0.7 + iq = amp * (np.cos(phase) + 1j * np.sin(phase)) + iq = iq.astype(np.complex64) + interleaved = np.empty(buffer_size * 2, dtype=np.float32) + interleaved[0::2] = iq.real + interleaved[1::2] = iq.imag + next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi) + return interleaved.tobytes(), next_phase + + +def _make_pluto_factory(identifier: str | None): + def factory(device: str, _ident: str | None): + if device != "pluto": + raise ValueError(f"this script only drives pluto; got device={device!r}") + from ria_toolkit_oss.sdr.pluto import Pluto + return Pluto(identifier=identifier) + return factory + + +async def _run(args: argparse.Namespace) -> int: + ws = LoggingFakeWs() + cfg = AgentConfig( + tx_enabled=True, + # Pluto's TX gain range is [-89, 0]. Cap at 0 so a fat-fingered + # --gain=+5 still gets rejected at the agent boundary rather than + # turned into mystery attenuation by Pluto's setter. + tx_max_gain_db=0.0, + tx_max_duration_s=float(args.duration) if args.duration > 0 else None, + ) + streamer = Streamer(ws=ws, sdr_factory=_make_pluto_factory(args.identifier), cfg=cfg) + + await streamer.on_message( + { + "type": "tx_start", + "app_id": "smoke", + "radio_config": { + "device": "pluto", + "identifier": args.identifier, + "tx_sample_rate": int(args.sample_rate), + "tx_center_frequency": int(args.frequency), + "tx_gain": int(args.gain), + "buffer_size": int(args.buffer_size), + # "repeat" keeps the last buffer on the air if we ever stall, + # so a continuous carrier stays up even when Python GC or + # asyncio scheduling briefly pauses the producer. + "underrun_policy": "repeat", + }, + } + ) + + # Abort if tx_start was rejected by an interlock (no session → nothing to do). + if streamer._tx is None: + print("tx_start rejected — see [tx_status] line above for the reason.", + file=sys.stderr) + return 2 + + print(f"Transmitting at {args.frequency/1e6:.3f} MHz with " + f"{args.tone/1e3:.1f} kHz baseband tone at gain {args.gain} dB. " + f"{'Running for ' + str(args.duration) + 's' if args.duration > 0 else 'Run until Ctrl-C'}.") + + # Arrange a clean shutdown on Ctrl-C. + stop = asyncio.Event() + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, stop.set) + except NotImplementedError: + # add_signal_handler is not available on Windows event loops. + pass + + # Produce buffers at the nominal sample-rate pace. We deliberately stay + # slightly ahead of the radio — queue is bounded at 8, so backpressure + # flows naturally. + phase = 0.0 + buffer_dt = args.buffer_size / args.sample_rate + # Aim for one buffer every ``buffer_dt * 0.5`` seconds so the queue stays + # topped up. The queue's own backpressure keeps us from spinning. + produce_interval = buffer_dt * 0.5 + try: + async def producer(): + nonlocal phase + while not stop.is_set(): + frame, phase = _make_iq_frame( + args.buffer_size, args.tone, args.sample_rate, phase + ) + await streamer.on_binary(frame) + await asyncio.sleep(produce_interval) + + producer_task = asyncio.create_task(producer()) + + if args.duration > 0: + try: + await asyncio.wait_for(stop.wait(), timeout=args.duration) + except asyncio.TimeoutError: + pass + else: + await stop.wait() + + stop.set() + producer_task.cancel() + try: + await producer_task + except (asyncio.CancelledError, Exception): + pass + finally: + await streamer.on_message({"type": "tx_stop", "app_id": "smoke"}) + + print("TX session closed.") + return 0 + + +def main() -> int: + p = argparse.ArgumentParser( + description="End-to-end TX smoke test: agent → Pluto continuous tone.", + ) + p.add_argument("--identifier", default=None, + help="Pluto IP/hostname (default: auto-discover pluto.local)") + p.add_argument("--frequency", type=float, default=3_410_000_000.0, + help="TX LO in Hz (default 2.45 GHz)") + p.add_argument("--gain", type=float, default=-0.0, + help="TX gain in dB; Pluto range [-89, 0] (default -30)") + p.add_argument("--sample-rate", type=float, default=1_000_000.0, + help="Baseband sample rate (default 1 Msps)") + p.add_argument("--tone", type=float, default=100_000.0, + help="Baseband tone offset in Hz; 0 = DC (default 100 kHz)") + p.add_argument("--buffer-size", type=int, default=4096, + help="Complex samples per frame (default 4096)") + p.add_argument("--duration", type=float, default=60.0, + help="Seconds to transmit; 0 = run until Ctrl-C (default 30)") + p.add_argument("--log-level", default="INFO") + args = p.parse_args() + + logging.basicConfig( + level=getattr(logging, args.log_level.upper(), logging.INFO), + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + ) + + try: + return asyncio.run(_run(args)) + except KeyboardInterrupt: + return 130 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/pluto_tx_ws_smoke.py b/scripts/pluto_tx_ws_smoke.py new file mode 100755 index 0000000..d4c8344 --- /dev/null +++ b/scripts/pluto_tx_ws_smoke.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +"""Full-stack TX smoke test: localhost mock-hub → WS → agent → real Pluto. + +Same radio output as ``pluto_tx_smoke.py`` (continuous tone at 2 450.1 MHz), +but drives the agent through the *real* WebSocket path instead of calling +handlers in-process. Proves that the hub-driven path behaves identically: + + mock hub ── ws:// ──▶ WsClient.run() ──▶ Streamer.on_message + └▶ Streamer.on_binary + │ + ▼ + real Pluto + +This is the most rigorous check short of pointing the real ``ria-agent stream`` +at a live ria-hub. If a tone appears on the spectrum analyzer here but *not* +when ria-hub drives it, the fault is above the WS decoder (registration, +capability gate, TX operator, hub's binary-frame publisher); everything +downstream of ``ws.recv()`` is this script's code path. + +Usage:: + + python3 scripts/pluto_tx_ws_smoke.py # default 30s tone + python3 scripts/pluto_tx_ws_smoke.py --identifier 192.168.3.1 + python3 scripts/pluto_tx_ws_smoke.py --duration 0 # until Ctrl-C +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import signal +import sys + +import numpy as np +import websockets + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.agent.ws_client import WsClient + + +def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float, + phase_offset: float) -> tuple[bytes, float]: + n = np.arange(buffer_size, dtype=np.float64) + phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset + amp = 0.7 + iq = (amp * (np.cos(phase) + 1j * np.sin(phase))).astype(np.complex64) + interleaved = np.empty(buffer_size * 2, dtype=np.float32) + interleaved[0::2] = iq.real + interleaved[1::2] = iq.imag + next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi) + return interleaved.tobytes(), next_phase + + +def _make_pluto_factory(identifier: str | None): + def factory(device: str, _ident: str | None): + if device != "pluto": + raise ValueError(f"this script only drives pluto; got device={device!r}") + from ria_toolkit_oss.sdr.pluto import Pluto + return Pluto(identifier=identifier) + return factory + + +async def _mock_hub_handler(ws, args, stop: asyncio.Event): + """Server side of the WS. Sends tx_start, streams IQ, then tx_stop.""" + # Drain the first heartbeat so the log is clean; we don't need to gate on + # it for a localhost smoke test. + try: + first = await asyncio.wait_for(ws.recv(), timeout=2.0) + if isinstance(first, str): + payload = json.loads(first) + if payload.get("type") == "heartbeat": + caps = payload.get("capabilities") + print(f"[mock-hub] agent heartbeat: capabilities={caps} " + f"tx_enabled={payload.get('tx_enabled')}") + except asyncio.TimeoutError: + print("[mock-hub] warning: no heartbeat received in first 2s") + + # Arm the agent's TX path. + await ws.send(json.dumps({ + "type": "tx_start", + "app_id": "ws-smoke", + "radio_config": { + "device": "pluto", + "identifier": args.identifier, + "tx_sample_rate": int(args.sample_rate), + "tx_center_frequency": int(args.frequency), + "tx_gain": int(args.gain), + "buffer_size": int(args.buffer_size), + "underrun_policy": "repeat", + }, + })) + print(f"[mock-hub] sent tx_start at {args.frequency/1e6:.3f} MHz, " + f"gain={args.gain} dB") + + # Producer: push IQ frames at a steady clip. Use a concurrent receiver so + # tx_status frames show up in real time rather than being queued behind + # the sends. + phase = 0.0 + buffer_dt = args.buffer_size / args.sample_rate + + async def receiver(): + try: + while True: + msg = await ws.recv() + if isinstance(msg, str): + print(f"[mock-hub] ← {msg}") + except (websockets.ConnectionClosed, asyncio.CancelledError): + pass + + recv_task = asyncio.create_task(receiver()) + try: + deadline = None if args.duration <= 0 else ( + asyncio.get_event_loop().time() + args.duration + ) + while not stop.is_set(): + if deadline is not None and asyncio.get_event_loop().time() >= deadline: + break + frame, phase = _make_iq_frame( + args.buffer_size, args.tone, args.sample_rate, phase + ) + try: + await ws.send(frame) + except websockets.ConnectionClosed: + break + # Slightly ahead of real-time; WS backpressure handles the rest. + await asyncio.sleep(buffer_dt * 0.5) + finally: + try: + await ws.send(json.dumps({"type": "tx_stop", "app_id": "ws-smoke"})) + print("[mock-hub] sent tx_stop") + except websockets.ConnectionClosed: + pass + # Give the agent a moment to emit `tx_status: done` before we tear down. + await asyncio.sleep(0.3) + recv_task.cancel() + try: + await recv_task + except (asyncio.CancelledError, Exception): + pass + + +async def _run(args: argparse.Namespace) -> int: + stop = asyncio.Event() + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, stop.set) + except NotImplementedError: + pass + + # Start the mock hub on a local port. + async def handler(ws): + try: + await _mock_hub_handler(ws, args, stop) + finally: + stop.set() + + server = await websockets.serve(handler, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + print(f"[mock-hub] listening on ws://127.0.0.1:{port}") + + # Run the agent — exactly as ``ria-agent stream`` would, just with a + # different URL and an in-memory AgentConfig instead of one loaded from + # ``~/.ria/agent.json``. + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=5.0, + reconnect_pause=0.5, + ) + streamer = Streamer( + ws=client, + sdr_factory=_make_pluto_factory(args.identifier), + cfg=AgentConfig(tx_enabled=True, tx_max_gain_db=0.0), + ) + client_task = asyncio.create_task( + client.run( + on_message=streamer.on_message, + heartbeat=streamer.build_heartbeat, + on_binary=streamer.on_binary, + ) + ) + + try: + await stop.wait() + finally: + client.stop() + client_task.cancel() + try: + await client_task + except (asyncio.CancelledError, Exception): + pass + server.close() + await server.wait_closed() + + print("Done.") + return 0 + + +def main() -> int: + p = argparse.ArgumentParser( + description="Full-stack TX smoke: localhost mock-hub → WS → agent → Pluto.", + ) + p.add_argument("--identifier", default=None, + help="Pluto IP/hostname (default: auto-discover pluto.local)") + p.add_argument("--frequency", type=float, default=2_450_000_000.0, + help="TX LO in Hz (default 2.45 GHz)") + p.add_argument("--gain", type=float, default=0.0, + help="TX gain in dB; Pluto range [-89, 0] (default 0)") + p.add_argument("--sample-rate", type=float, default=1_000_000.0, + help="Baseband sample rate (default 1 Msps)") + p.add_argument("--tone", type=float, default=100_000.0, + help="Baseband tone offset in Hz (default 100 kHz)") + p.add_argument("--buffer-size", type=int, default=4096, + help="Complex samples per frame (default 4096)") + p.add_argument("--duration", type=float, default=30.0, + help="Seconds to transmit; 0 = run until Ctrl-C (default 30)") + p.add_argument("--log-level", default="INFO") + args = p.parse_args() + + logging.basicConfig( + level=getattr(logging, args.log_level.upper(), logging.INFO), + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + ) + + try: + return asyncio.run(_run(args)) + except KeyboardInterrupt: + return 130 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/ria_toolkit_oss/agent/hardware.py b/src/ria_toolkit_oss/agent/hardware.py index 32a65e5..d585e8f 100644 --- a/src/ria_toolkit_oss/agent/hardware.py +++ b/src/ria_toolkit_oss/agent/hardware.py @@ -37,6 +37,18 @@ def heartbeat_payload( "capabilities": capabilities, "tx_enabled": bool(c.tx_enabled), } + # Surface configured interlock values so the hub can pre-filter UI controls + # before sending a tx_start that would be rejected. Only included when TX + # is opted in AND the operator set a cap. + if c.tx_enabled: + if c.tx_max_gain_db is not None: + payload["tx_max_gain_db"] = float(c.tx_max_gain_db) + if c.tx_max_duration_s is not None: + payload["tx_max_duration_s"] = float(c.tx_max_duration_s) + if c.tx_allowed_freq_ranges: + payload["tx_allowed_freq_ranges"] = [ + [float(lo), float(hi)] for lo, hi in c.tx_allowed_freq_ranges + ] if app_id: payload["app_id"] = app_id if sessions: diff --git a/src/ria_toolkit_oss/agent/streamer.py b/src/ria_toolkit_oss/agent/streamer.py index 8570a73..6cf73e6 100644 --- a/src/ria_toolkit_oss/agent/streamer.py +++ b/src/ria_toolkit_oss/agent/streamer.py @@ -197,8 +197,14 @@ class Streamer: sessions=sessions or None, ) + # Advisory / keepalive message types we accept and ignore without warning. + _IGNORED_MESSAGE_TYPES = frozenset({"tx_data_available"}) + async def on_message(self, msg: dict) -> None: t = msg.get("type") + if t in self._IGNORED_MESSAGE_TYPES: + logger.debug("Ignoring advisory message: %r", t) + return handler = { "start": self._handle_rx_start, "stop": self._handle_rx_stop, @@ -469,9 +475,12 @@ class Streamer: def _tx_executor_body(self, session: TxSession) -> None: try: session.sdr._stream_tx(lambda n: self._tx_callback(session, n)) - except Exception: + except Exception as exc: logger.exception("TX stream crashed") - self._schedule(self._send_tx_status(session.app_id, "error", "tx stream crashed")) + # Schedule both the error frame and session teardown on the loop + # so ``self._tx`` clears, subsequent binary frames are rejected, + # and the SDR handle is released. + self._schedule(self._tx_crash_teardown(session, str(exc))) def _tx_callback(self, session: TxSession, num_samples) -> np.ndarray: n = int(num_samples) @@ -561,6 +570,18 @@ class Streamer: return await asyncio.sleep(0.05) + async def _tx_crash_teardown(self, session: TxSession, message: str) -> None: + # Called from the executor thread via _schedule when _stream_tx raises. + # Emit the error, mark stopped, drain the queue, release the SDR. + await self._send_tx_status(session.app_id, "error", f"tx stream crashed: {message}") + if self._tx is not session: + return + session.stop_event.set() + self._drain_tx_queue(session) + self._close_session_sdr(session) + if self._tx is session: + self._tx = None + async def _teardown_tx_after_underrun(self, session: TxSession) -> None: if self._tx is not session: return @@ -643,13 +664,44 @@ _CONFIG_ATTR_MAP = { } +def _is_stub_setter(method: Any) -> bool: + """True when *method* is an unimplemented base-class stub. + + The ``SDR`` abstract base defines ``set_rx_sample_rate`` / ``set_tx_gain`` + etc. as zero-argument ``NotImplementedError`` stubs. A driver (Pluto) that + actually transmits overrides them with a real ``(value, ...)`` signature. + Comparing ``__qualname__`` against ``SDR.`` lets us skip the stubs cheaply. + """ + return getattr(method, "__qualname__", "").startswith("SDR.") + + def _apply_sdr_config(sdr: Any, cfg: dict) -> None: - """Apply a radio_config dict to an SDR, trying multiple attribute aliases.""" + """Apply a radio_config dict to an SDR. + + Prefers ``sdr.set_(value)`` when the driver implements it — Pluto's + setters take ``_param_lock``, so routing through them keeps concurrent + RX + TX reconfigures from racing on shared native attributes. Falls back + to ``setattr`` for drivers (MockSDR, tests) that don't override the + base-class stubs. + """ for key, value in cfg.items(): if value is None: continue attrs = _CONFIG_ATTR_MAP.get(key, (key,)) applied = False + for attr in attrs: + setter = getattr(sdr, f"set_{attr}", None) + if callable(setter) and not _is_stub_setter(setter): + try: + setter(value) + applied = True + break + except Exception as exc: + logger.debug("set_%s(%r) failed: %s", attr, value, exc) + # Fall through to setattr; some drivers may partially + # implement setters. + if applied: + continue for attr in attrs: if hasattr(sdr, attr): try: diff --git a/src/ria_toolkit_oss/sdr/pluto.py b/src/ria_toolkit_oss/sdr/pluto.py index 7ed3be0..88243b1 100644 --- a/src/ria_toolkit_oss/sdr/pluto.py +++ b/src/ria_toolkit_oss/sdr/pluto.py @@ -384,7 +384,10 @@ class Pluto(SDR): self._enable_tx = True while self._enable_tx is True: buffer = self._convert_tx_samples(callback(self.tx_buffer_size)) - self.radio.tx(buffer[0]) + # pyadi-iio's ``radio.tx`` auto-wraps single-channel 1-D input. + # Indexing ``buffer[0]`` was a latent bug for callbacks that + # returned 1-D samples (scalar → TypeError inside pyadi). + self.radio.tx(buffer) def set_rx_center_frequency(self, center_frequency): """ @@ -514,74 +517,85 @@ class Pluto(SDR): raise SDRError(e) def set_tx_center_frequency(self, center_frequency): - if center_frequency < 70e6 or center_frequency > 6e9: - raise SDRParameterError( - f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz " - f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t" - f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]" - ) + # ``adi.Pluto`` exposes one radio handle shared between RX and TX; concurrent + # RX + TX sessions (see the agent ``_SdrRegistry``) may call RX and TX + # setters at the same time. Serialize with ``_param_lock`` — RX setters hold + # the same reentrant lock — so native attribute writes don't interleave. + with self._param_lock: + if center_frequency < 70e6 or center_frequency > 6e9: + raise SDRParameterError( + f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz " + f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t" + f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]" + ) - try: - self.radio.tx_lo = int(center_frequency) - self.tx_center_frequency = center_frequency - except OSError as e: - raise SDRError(e) - except ValueError: - raise SDRParameterError( - f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz " - f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t" - f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]" - ) + try: + self.radio.tx_lo = int(center_frequency) + self.tx_center_frequency = center_frequency + except OSError as e: + raise SDRError(e) + except ValueError: + raise SDRParameterError( + f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz " + f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t" + f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]" + ) def set_tx_sample_rate(self, sample_rate): - min_rate, max_rate = 65.1e3, 61.44e6 - if sample_rate < min_rate or sample_rate > max_rate: - raise SDRParameterError( - f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps " - f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]" - ) + # ``self.radio.sample_rate`` is shared between RX and TX on Pluto — RX's + # ``set_rx_sample_rate`` writes the same native attribute. Hold ``_param_lock`` + # so full-duplex sessions can't interleave writes. + with self._param_lock: + min_rate, max_rate = 65.1e3, 61.44e6 + if sample_rate < min_rate or sample_rate > max_rate: + raise SDRParameterError( + f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps " + f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]" + ) - try: - self.radio.sample_rate = sample_rate - self.tx_sample_rate = sample_rate - except OSError as e: - raise SDRError(e) - except ValueError: - raise SDRParameterError( - f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps " - f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]" - ) + try: + self.radio.sample_rate = sample_rate + self.tx_sample_rate = sample_rate + except OSError as e: + raise SDRError(e) + except ValueError: + raise SDRParameterError( + f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps " + f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]" + ) def set_tx_gain(self, gain, channel=0, gain_mode="absolute"): - tx_gain_min = -89 - tx_gain_max = 0 + # Serialize with RX setters: see ``set_tx_sample_rate`` above. + with self._param_lock: + tx_gain_min = -89 + tx_gain_max = 0 - if gain_mode == "relative": - if gain > 0: - raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\ - the gain relative to the maximum possible gain.") + if gain_mode == "relative": + if gain > 0: + raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\ + the gain relative to the maximum possible gain.") + else: + abs_gain = tx_gain_max + gain else: - abs_gain = tx_gain_max + gain - else: - abs_gain = gain + abs_gain = gain - if abs_gain < tx_gain_min or abs_gain > tx_gain_max: - abs_gain = min(max(gain, tx_gain_min), tx_gain_max) - print(f"Gain {gain} out of range for Pluto.") - print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB") + if abs_gain < tx_gain_min or abs_gain > tx_gain_max: + abs_gain = min(max(gain, tx_gain_min), tx_gain_max) + print(f"Gain {gain} out of range for Pluto.") + print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB") - try: - self.tx_gain = abs_gain + try: + self.tx_gain = abs_gain - if channel == 0: - self.radio.tx_hardwaregain_chan0 = int(abs_gain) - elif channel == 1: - self.radio.tx_hardwaregain_chan1 = int(abs_gain) - else: - raise SDRParameterError(f"Pluto channel must be 0 or 1 but was {channel}.") + if channel == 0: + self.radio.tx_hardwaregain_chan0 = int(abs_gain) + elif channel == 1: + self.radio.tx_hardwaregain_chan1 = int(abs_gain) + else: + raise SDRParameterError(f"Pluto channel must be 0 or 1 but was {channel}.") - except Exception as e: - raise SDRError(e) + except Exception as e: + raise SDRError(e) def set_tx_channel(self, channel): if channel == 0: diff --git a/tests/agent/test_hardware.py b/tests/agent/test_hardware.py index 51b2e45..6a9cdf3 100644 --- a/tests/agent/test_hardware.py +++ b/tests/agent/test_hardware.py @@ -44,3 +44,30 @@ def test_heartbeat_payload_sessions_field(): sessions = {"rx": {"app_id": "a", "state": "streaming"}} p = hardware.heartbeat_payload(status="streaming", app_id="a", sessions=sessions) assert p["sessions"] == sessions + + +def test_heartbeat_payload_surfaces_tx_caps_when_enabled(): + from ria_toolkit_oss.agent.config import AgentConfig + + cfg = AgentConfig( + tx_enabled=True, + tx_max_gain_db=-10.0, + tx_max_duration_s=60.0, + tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]], + ) + p = hardware.heartbeat_payload(cfg=cfg) + assert p["tx_max_gain_db"] == -10.0 + assert p["tx_max_duration_s"] == 60.0 + assert p["tx_allowed_freq_ranges"] == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]] + + +def test_heartbeat_payload_omits_caps_when_tx_disabled(): + from ria_toolkit_oss.agent.config import AgentConfig + + # Caps set but tx_enabled=False — don't leak them; they're only meaningful + # when the hub can attempt a tx_start. + cfg = AgentConfig(tx_enabled=False, tx_max_gain_db=-10.0) + p = hardware.heartbeat_payload(cfg=cfg) + assert "tx_max_gain_db" not in p + assert "tx_max_duration_s" not in p + assert "tx_allowed_freq_ranges" not in p diff --git a/tests/agent/test_param_lock_contention.py b/tests/agent/test_param_lock_contention.py new file mode 100644 index 0000000..e3d84fc --- /dev/null +++ b/tests/agent/test_param_lock_contention.py @@ -0,0 +1,210 @@ +"""Step-A6 (Pluto lock audit) coverage. + +Verifies the two invariants the handoff doc calls for when RX and TX run +concurrently on one shared SDR handle: + +1. ``_param_lock`` actually serializes concurrent RX + TX setter calls — the + spec's §A6 acceptance criterion is *"``_param_lock`` instrumented for + contention"*. We drive parallel ``set_{rx,tx}_sample_rate`` calls through + the lock and assert it's hit often enough to prove both paths fight for it. +2. Under a sustained full-duplex session (RX capturing + TX transmitting on + one ``(device, identifier)``), no setter write is dropped and no exception + escapes the executor — i.e., the shared-handle assumption holds. Runs + against ``MockSDR`` per the spec; the real Pluto driver now takes the + same lock on its TX setters so the production code path is isomorphic. + +The stress window is 2 seconds by default — the handoff mentions 30 s but +that's impractical in CI. Set ``RIA_LOCK_STRESS_S`` to override. +""" + +from __future__ import annotations + +import asyncio +import os +import threading +import time + +import numpy as np + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.sdr.mock import MockSDR + + +_STRESS_S = float(os.environ.get("RIA_LOCK_STRESS_S", "2.0")) + + +class InstrumentedMockSDR(MockSDR): + """MockSDR that counts lock acquisitions and exposes a real ``_param_lock``. + + ``_param_lock`` is inherited from ``SDR`` as a reentrant lock; we wrap it + with a counter that records every time RX or TX setters grab it, so the + test can assert real contention rather than just "the code compiles". + """ + + def __init__(self, buffer_size: int): + super().__init__(buffer_size=buffer_size) + self.rx_lock_hits = 0 + self.tx_lock_hits = 0 + self.param_lock_hits = 0 + # Shadow lock that increments a counter each time __enter__ fires. + real_lock = self._param_lock + + test = self + + class CountingLock: + def __enter__(self_inner): + test.param_lock_hits += 1 + real_lock.acquire() + return self_inner + + def __exit__(self_inner, *a): + real_lock.release() + return False + + # ``threading.RLock`` interop for any code that calls acquire/release directly. + def acquire(self_inner, *a, **k): + test.param_lock_hits += 1 + return real_lock.acquire(*a, **k) + + def release(self_inner): + return real_lock.release() + + self._param_lock = CountingLock() + + # The MockSDR doesn't ship RX setter methods that hit the lock — override + # ``sample_rate`` / ``center_freq`` / ``gain`` writes to route through the + # same lock the real Pluto driver uses, so this test faithfully models the + # production contention path. + def set_rx_sample_rate(self, sample_rate): + with self._param_lock: + self.rx_lock_hits += 1 + self.rx_sample_rate = float(sample_rate) + self.sample_rate = self.rx_sample_rate + + def set_tx_sample_rate(self, sample_rate): + with self._param_lock: + self.tx_lock_hits += 1 + self.tx_sample_rate = float(sample_rate) + # Mirror Pluto: both RX and TX write the same native attribute. + self.sample_rate = self.tx_sample_rate + + +class FakeWs: + def __init__(self): + self.json_sent: list[dict] = [] + self.bytes_sent: list[bytes] = [] + + async def send_json(self, p): + self.json_sent.append(p) + + async def send_bytes(self, b): + self.bytes_sent.append(b) + + +def _iq_frame(samples: np.ndarray) -> bytes: + interleaved = np.empty(samples.size * 2, dtype=np.float32) + interleaved[0::2] = samples.real + interleaved[1::2] = samples.imag + return interleaved.tobytes() + + +def test_param_lock_contended_under_concurrent_setters(): + """Run two threads that hammer RX + TX sample-rate setters and assert both + lock paths fire. This proves the lock is doing work — if either setter + bypassed ``_param_lock``, one of the counters would stay at zero.""" + sdr = InstrumentedMockSDR(buffer_size=16) + stop = threading.Event() + + def rx_setter(): + i = 0 + while not stop.is_set(): + sdr.set_rx_sample_rate(1_000_000 + (i % 1000)) + i += 1 + + def tx_setter(): + i = 0 + while not stop.is_set(): + sdr.set_tx_sample_rate(2_000_000 + (i % 1000)) + i += 1 + + t1 = threading.Thread(target=rx_setter) + t2 = threading.Thread(target=tx_setter) + t1.start() + t2.start() + time.sleep(min(_STRESS_S, 2.0)) + stop.set() + t1.join() + t2.join() + + assert sdr.rx_lock_hits > 100, f"RX setter barely ran: {sdr.rx_lock_hits}" + assert sdr.tx_lock_hits > 100, f"TX setter barely ran: {sdr.tx_lock_hits}" + # Every setter call should have passed through _param_lock exactly once. + assert sdr.param_lock_hits >= sdr.rx_lock_hits + sdr.tx_lock_hits + + +def test_full_duplex_stays_healthy_over_stress_window(): + """Start RX + TX on one shared SDR and drive both paths for ``_STRESS_S`` + seconds, pushing binary frames and emitting ``tx_configure`` mid-stream. + The session must survive, deliver buffers in both directions, and leave + the registry clean on shutdown.""" + BUF = 32 + sdr = InstrumentedMockSDR(buffer_size=BUF) + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) + + await s.on_message( + {"type": "start", "app_id": "app-1", + "radio_config": {"device": "mock", "buffer_size": BUF}} + ) + await s.on_message( + {"type": "tx_start", "app_id": "app-1", + "radio_config": { + "device": "mock", "buffer_size": BUF, + "tx_sample_rate": 1_000_000, + "tx_center_frequency": 2.45e9, + "tx_gain": -20, + "underrun_policy": "zero", + }} + ) + + marker = np.arange(BUF, dtype=np.complex64) + 1 + deadline = time.monotonic() + _STRESS_S + i = 0 + while time.monotonic() < deadline: + await s.on_binary(_iq_frame(marker)) + if i % 8 == 0: + # Mid-stream parameter reconfiguration touches _apply_sdr_config, + # which routes through the same setters the stress test above + # verifies. + await s.on_message( + {"type": "tx_configure", "app_id": "app-1", + "radio_config": {"tx_sample_rate": 1_000_000 + i}} + ) + await s.on_message( + {"type": "configure", "app_id": "app-1", + "radio_config": {"sample_rate": 2_000_000 + i}} + ) + i += 1 + await asyncio.sleep(0.005) + + await s.on_message({"type": "tx_stop", "app_id": "app-1"}) + await s.on_message({"type": "stop", "app_id": "app-1"}) + return ws, s + + ws, s = asyncio.run(scenario()) + + # No error frame leaked out. + errors = [m for m in ws.json_sent + if m.get("type") in ("error", "tx_status") and m.get("state") == "error"] + assert errors == [], f"Unexpected error frames: {errors}" + # RX produced IQ frames and TX's callback ran — heartbeat-level contention + # check: both setter paths were hit at least once during configure dispatch. + assert ws.bytes_sent, "RX produced no IQ frames" + assert sdr.param_lock_hits > 0 + # Sessions cleaned up; registry drained. + assert s._tx is None + assert s._rx is None + assert s._registry.refcount(("mock", None)) == 0 diff --git a/tests/agent/test_streamer.py b/tests/agent/test_streamer.py index 2aa842e..da2956c 100644 --- a/tests/agent/test_streamer.py +++ b/tests/agent/test_streamer.py @@ -139,6 +139,21 @@ def test_unknown_message_type_is_ignored(): asyncio.run(scenario()) +def test_tx_data_available_is_a_silent_noop(): + # Hub sends this as a keepalive; we should accept and ignore without + # emitting a WARNING or treating it as an error. + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=_factory) + await s.on_message({"type": "tx_data_available", "app_id": "x"}) + return ws + + ws = asyncio.run(scenario()) + # No outbound frames emitted. + assert ws.json_sent == [] + assert ws.bytes_sent == [] + + def test_registry_shares_sdr_across_start_stop_cycles(): # Two sequential start/stop cycles with the same (device, identifier) # should hit the registry's cache path rather than constructing a new SDR. diff --git a/tests/agent/test_ws_client.py b/tests/agent/test_ws_client.py index 4061f32..c113b64 100644 --- a/tests/agent/test_ws_client.py +++ b/tests/agent/test_ws_client.py @@ -1,11 +1,14 @@ -"""Reconnect + heartbeat timing against a real local websockets server.""" +"""Reconnect + heartbeat + malformed-control-frame behavior. + +Binary-frame delivery lives in ``test_ws_client_binary.py`` to match the +test matrix spelled out in ``Agent TX Streaming Handoff.md`` §A7. +""" from __future__ import annotations import asyncio import json -import pytest import websockets from ria_toolkit_oss.agent.ws_client import WsClient @@ -113,109 +116,6 @@ def test_reconnects_after_server_drop(): assert n >= 2 -def test_binary_frame_forwarded_to_handler(): - payload = bytes(range(128)) - - async def scenario(): - received: list[bytes] = [] - done = asyncio.Event() - - async def handler(ws): - await ws.send(payload) - done.set() - try: - await ws.wait_closed() - except Exception: - pass - - server, port = await _open_server(handler) - try: - client = WsClient( - f"ws://127.0.0.1:{port}", - token="", - heartbeat_interval=10.0, - reconnect_pause=0.05, - ) - - async def on_bin(data): - received.append(data) - - task = asyncio.create_task( - client.run( - on_message=lambda _m: asyncio.sleep(0), - heartbeat=lambda: {"type": "heartbeat"}, - on_binary=on_bin, - ) - ) - for _ in range(50): - if received: - break - await asyncio.sleep(0.02) - client.stop() - task.cancel() - try: - await task - except (asyncio.CancelledError, Exception): - pass - finally: - server.close() - await server.wait_closed() - return received - - received = asyncio.run(scenario()) - assert received == [payload] - - -def test_binary_frame_dropped_when_no_handler(): - # Regression guard: existing behavior (drop server-sent binary) preserved when - # on_binary is not supplied. - async def scenario(): - crashes: list[Exception] = [] - - async def handler(ws): - await ws.send(b"\x00\x01\x02\x03") - await ws.send(json.dumps({"type": "ping"})) - try: - await ws.wait_closed() - except Exception: - pass - - messages: list[dict] = [] - server, port = await _open_server(handler) - try: - client = WsClient( - f"ws://127.0.0.1:{port}", - token="", - heartbeat_interval=10.0, - reconnect_pause=0.05, - ) - - async def on_msg(m): - messages.append(m) - - task = asyncio.create_task( - client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}) - ) - for _ in range(50): - if messages: - break - await asyncio.sleep(0.02) - client.stop() - task.cancel() - try: - await task - except (asyncio.CancelledError, Exception) as exc: - crashes.append(exc) - finally: - server.close() - await server.wait_closed() - return messages, crashes - - messages, crashes = asyncio.run(scenario()) - # JSON still delivered; binary silently dropped; no uncaught crash. - assert messages and messages[0] == {"type": "ping"} - - def test_malformed_control_frame_does_not_crash(): async def scenario(): handled: list[dict] = [] diff --git a/tests/agent/test_ws_client_binary.py b/tests/agent/test_ws_client_binary.py new file mode 100644 index 0000000..4d9ddc1 --- /dev/null +++ b/tests/agent/test_ws_client_binary.py @@ -0,0 +1,186 @@ +"""Binary-frame delivery on the hub → agent WebSocket. + +Named to match the test matrix in ``Agent TX Streaming Handoff.md`` §A7. +Exercises: + +- Binary frames are forwarded to an ``on_binary`` coroutine when supplied. +- Binary frames are silently dropped (no crash) when ``on_binary`` is omitted, + preserving the pre-TX behavior for RX-only deployments. +""" + +from __future__ import annotations + +import asyncio +import json + +import websockets + +from ria_toolkit_oss.agent.ws_client import WsClient + + +async def _open_server(handler): + server = await websockets.serve(handler, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + return server, port + + +def test_binary_frame_forwarded_to_handler(): + payload = bytes(range(128)) + + async def scenario(): + received: list[bytes] = [] + done = asyncio.Event() + + async def handler(ws): + await ws.send(payload) + done.set() + try: + await ws.wait_closed() + except Exception: + pass + + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + + async def on_bin(data): + received.append(data) + + task = asyncio.create_task( + client.run( + on_message=lambda _m: asyncio.sleep(0), + heartbeat=lambda: {"type": "heartbeat"}, + on_binary=on_bin, + ) + ) + for _ in range(50): + if received: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return received + + received = asyncio.run(scenario()) + assert received == [payload] + + +def test_binary_frame_dropped_when_no_handler(): + async def scenario(): + crashes: list[Exception] = [] + + async def handler(ws): + await ws.send(b"\x00\x01\x02\x03") + await ws.send(json.dumps({"type": "ping"})) + try: + await ws.wait_closed() + except Exception: + pass + + messages: list[dict] = [] + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + + async def on_msg(m): + messages.append(m) + + task = asyncio.create_task( + client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}) + ) + for _ in range(50): + if messages: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception) as exc: + crashes.append(exc) + finally: + server.close() + await server.wait_closed() + return messages, crashes + + messages, _ = asyncio.run(scenario()) + assert messages and messages[0] == {"type": "ping"} + + +def test_on_binary_exception_does_not_kill_connection(): + """A buggy ``on_binary`` raises mid-stream; the WS loop keeps accepting frames.""" + + async def scenario(): + delivered_binary = 0 + delivered_control: list[dict] = [] + + async def handler(ws): + await ws.send(b"\x10\x20\x30") + await ws.send(b"\x40\x50\x60") + await ws.send(json.dumps({"type": "ping"})) + try: + await ws.wait_closed() + except Exception: + pass + + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + + async def on_bin(data): + nonlocal delivered_binary + delivered_binary += 1 + raise RuntimeError("handler broke") + + async def on_msg(m): + delivered_control.append(m) + + task = asyncio.create_task( + client.run( + on_message=on_msg, + heartbeat=lambda: {"type": "heartbeat"}, + on_binary=on_bin, + ) + ) + for _ in range(60): + if delivered_control: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return delivered_binary, delivered_control + + bins, ctrls = asyncio.run(scenario()) + # Both binary frames were delivered to the (crashing) handler. + assert bins == 2 + # The subsequent JSON frame still arrived — loop didn't die on the exceptions. + assert ctrls and ctrls[0] == {"type": "ping"}