transmit further updates
This commit is contained in:
parent
b955256479
commit
8c247f9f7a
225
scripts/pluto_tx_smoke.py
Executable file
225
scripts/pluto_tx_smoke.py
Executable file
|
|
@ -0,0 +1,225 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Transmit a continuous tone through the agent's TX pipeline on a real Pluto.
|
||||
|
||||
End-to-end smoke test for the Pluto + Streamer TX path. Drives the same
|
||||
``Streamer`` the hub talks to, but in-process with a logging ``FakeWs`` so
|
||||
the script is self-contained — no hub required.
|
||||
|
||||
Default: 100 kHz baseband tone × 2 450 MHz LO → carrier at 2 450.1 MHz,
|
||||
continuous until you Ctrl-C (or the ``--duration`` timer fires). A spectrum
|
||||
analyzer tuned to 2 450.1 MHz should show a clean CW spike as long as
|
||||
``tx_status: transmitting`` prints.
|
||||
|
||||
Usage::
|
||||
|
||||
python3 scripts/pluto_tx_smoke.py # auto-discover Pluto
|
||||
python3 scripts/pluto_tx_smoke.py --identifier 192.168.3.1
|
||||
python3 scripts/pluto_tx_smoke.py --frequency 2.4e9 --gain -20 --duration 60
|
||||
|
||||
Flags map 1:1 onto the agent's ``radio_config``:
|
||||
|
||||
--identifier Pluto IP or hostname (omitted → ip:pluto.local).
|
||||
--frequency TX LO in Hz. Default 2 450 MHz.
|
||||
--gain Pluto TX gain in dB. Pluto range is ``[-89, 0]``; more negative
|
||||
= more attenuation = less power. Default -30.
|
||||
--sample-rate Baseband sample rate. Default 1 MHz.
|
||||
--tone Baseband tone offset in Hz. Default 100 kHz; set 0 for DC
|
||||
(unmodulated carrier at exactly --frequency, but Pluto's
|
||||
LO leakage will dominate).
|
||||
--buffer-size Complex samples per WS frame. Default 4096.
|
||||
--duration Stop after this many seconds (0 = run until Ctrl-C).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
from ria_toolkit_oss.agent.streamer import Streamer
|
||||
|
||||
|
||||
class LoggingFakeWs:
|
||||
"""In-process stand-in for the hub's WebSocket.
|
||||
|
||||
Prints every ``tx_status`` + ``error`` frame the Streamer emits so the
|
||||
operator can watch the lifecycle (armed → transmitting → done) on stdout.
|
||||
"""
|
||||
|
||||
async def send_json(self, payload: dict) -> None:
|
||||
t = payload.get("type")
|
||||
if t == "tx_status":
|
||||
state = payload.get("state")
|
||||
msg = payload.get("message")
|
||||
tail = f" — {msg}" if msg else ""
|
||||
print(f"[tx_status] {state}{tail}")
|
||||
elif t == "error":
|
||||
print(f"[error] {payload.get('message')}")
|
||||
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
# Agent side won't send RX bytes in this script (no RX session).
|
||||
pass
|
||||
|
||||
|
||||
def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float,
|
||||
phase_offset: float = 0.0) -> tuple[bytes, float]:
|
||||
"""Return ``(interleaved_float32_bytes, next_phase)`` for a sine tone.
|
||||
|
||||
Emitting one continuous phase-coherent tone requires threading the phase
|
||||
across frames; the returned ``next_phase`` should be fed back as
|
||||
``phase_offset`` on the next call so the sinusoid doesn't glitch at frame
|
||||
boundaries. Amplitude is 0.7 to leave some headroom below the [-1, 1] cap
|
||||
that ``_verify_sample_format`` polices elsewhere in the toolkit.
|
||||
"""
|
||||
n = np.arange(buffer_size, dtype=np.float64)
|
||||
phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset
|
||||
amp = 0.7
|
||||
iq = amp * (np.cos(phase) + 1j * np.sin(phase))
|
||||
iq = iq.astype(np.complex64)
|
||||
interleaved = np.empty(buffer_size * 2, dtype=np.float32)
|
||||
interleaved[0::2] = iq.real
|
||||
interleaved[1::2] = iq.imag
|
||||
next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi)
|
||||
return interleaved.tobytes(), next_phase
|
||||
|
||||
|
||||
def _make_pluto_factory(identifier: str | None):
|
||||
def factory(device: str, _ident: str | None):
|
||||
if device != "pluto":
|
||||
raise ValueError(f"this script only drives pluto; got device={device!r}")
|
||||
from ria_toolkit_oss.sdr.pluto import Pluto
|
||||
return Pluto(identifier=identifier)
|
||||
return factory
|
||||
|
||||
|
||||
async def _run(args: argparse.Namespace) -> int:
|
||||
ws = LoggingFakeWs()
|
||||
cfg = AgentConfig(
|
||||
tx_enabled=True,
|
||||
# Pluto's TX gain range is [-89, 0]. Cap at 0 so a fat-fingered
|
||||
# --gain=+5 still gets rejected at the agent boundary rather than
|
||||
# turned into mystery attenuation by Pluto's setter.
|
||||
tx_max_gain_db=0.0,
|
||||
tx_max_duration_s=float(args.duration) if args.duration > 0 else None,
|
||||
)
|
||||
streamer = Streamer(ws=ws, sdr_factory=_make_pluto_factory(args.identifier), cfg=cfg)
|
||||
|
||||
await streamer.on_message(
|
||||
{
|
||||
"type": "tx_start",
|
||||
"app_id": "smoke",
|
||||
"radio_config": {
|
||||
"device": "pluto",
|
||||
"identifier": args.identifier,
|
||||
"tx_sample_rate": int(args.sample_rate),
|
||||
"tx_center_frequency": int(args.frequency),
|
||||
"tx_gain": int(args.gain),
|
||||
"buffer_size": int(args.buffer_size),
|
||||
# "repeat" keeps the last buffer on the air if we ever stall,
|
||||
# so a continuous carrier stays up even when Python GC or
|
||||
# asyncio scheduling briefly pauses the producer.
|
||||
"underrun_policy": "repeat",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Abort if tx_start was rejected by an interlock (no session → nothing to do).
|
||||
if streamer._tx is None:
|
||||
print("tx_start rejected — see [tx_status] line above for the reason.",
|
||||
file=sys.stderr)
|
||||
return 2
|
||||
|
||||
print(f"Transmitting at {args.frequency/1e6:.3f} MHz with "
|
||||
f"{args.tone/1e3:.1f} kHz baseband tone at gain {args.gain} dB. "
|
||||
f"{'Running for ' + str(args.duration) + 's' if args.duration > 0 else 'Run until Ctrl-C'}.")
|
||||
|
||||
# Arrange a clean shutdown on Ctrl-C.
|
||||
stop = asyncio.Event()
|
||||
loop = asyncio.get_running_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
try:
|
||||
loop.add_signal_handler(sig, stop.set)
|
||||
except NotImplementedError:
|
||||
# add_signal_handler is not available on Windows event loops.
|
||||
pass
|
||||
|
||||
# Produce buffers at the nominal sample-rate pace. We deliberately stay
|
||||
# slightly ahead of the radio — queue is bounded at 8, so backpressure
|
||||
# flows naturally.
|
||||
phase = 0.0
|
||||
buffer_dt = args.buffer_size / args.sample_rate
|
||||
# Aim for one buffer every ``buffer_dt * 0.5`` seconds so the queue stays
|
||||
# topped up. The queue's own backpressure keeps us from spinning.
|
||||
produce_interval = buffer_dt * 0.5
|
||||
try:
|
||||
async def producer():
|
||||
nonlocal phase
|
||||
while not stop.is_set():
|
||||
frame, phase = _make_iq_frame(
|
||||
args.buffer_size, args.tone, args.sample_rate, phase
|
||||
)
|
||||
await streamer.on_binary(frame)
|
||||
await asyncio.sleep(produce_interval)
|
||||
|
||||
producer_task = asyncio.create_task(producer())
|
||||
|
||||
if args.duration > 0:
|
||||
try:
|
||||
await asyncio.wait_for(stop.wait(), timeout=args.duration)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
await stop.wait()
|
||||
|
||||
stop.set()
|
||||
producer_task.cancel()
|
||||
try:
|
||||
await producer_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
finally:
|
||||
await streamer.on_message({"type": "tx_stop", "app_id": "smoke"})
|
||||
|
||||
print("TX session closed.")
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
p = argparse.ArgumentParser(
|
||||
description="End-to-end TX smoke test: agent → Pluto continuous tone.",
|
||||
)
|
||||
p.add_argument("--identifier", default=None,
|
||||
help="Pluto IP/hostname (default: auto-discover pluto.local)")
|
||||
p.add_argument("--frequency", type=float, default=3_410_000_000.0,
|
||||
help="TX LO in Hz (default 2.45 GHz)")
|
||||
p.add_argument("--gain", type=float, default=-0.0,
|
||||
help="TX gain in dB; Pluto range [-89, 0] (default -30)")
|
||||
p.add_argument("--sample-rate", type=float, default=1_000_000.0,
|
||||
help="Baseband sample rate (default 1 Msps)")
|
||||
p.add_argument("--tone", type=float, default=100_000.0,
|
||||
help="Baseband tone offset in Hz; 0 = DC (default 100 kHz)")
|
||||
p.add_argument("--buffer-size", type=int, default=4096,
|
||||
help="Complex samples per frame (default 4096)")
|
||||
p.add_argument("--duration", type=float, default=60.0,
|
||||
help="Seconds to transmit; 0 = run until Ctrl-C (default 30)")
|
||||
p.add_argument("--log-level", default="INFO")
|
||||
args = p.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level.upper(), logging.INFO),
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
return asyncio.run(_run(args))
|
||||
except KeyboardInterrupt:
|
||||
return 130
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
236
scripts/pluto_tx_ws_smoke.py
Executable file
236
scripts/pluto_tx_ws_smoke.py
Executable file
|
|
@ -0,0 +1,236 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Full-stack TX smoke test: localhost mock-hub → WS → agent → real Pluto.
|
||||
|
||||
Same radio output as ``pluto_tx_smoke.py`` (continuous tone at 2 450.1 MHz),
|
||||
but drives the agent through the *real* WebSocket path instead of calling
|
||||
handlers in-process. Proves that the hub-driven path behaves identically:
|
||||
|
||||
mock hub ── ws:// ──▶ WsClient.run() ──▶ Streamer.on_message
|
||||
└▶ Streamer.on_binary
|
||||
│
|
||||
▼
|
||||
real Pluto
|
||||
|
||||
This is the most rigorous check short of pointing the real ``ria-agent stream``
|
||||
at a live ria-hub. If a tone appears on the spectrum analyzer here but *not*
|
||||
when ria-hub drives it, the fault is above the WS decoder (registration,
|
||||
capability gate, TX operator, hub's binary-frame publisher); everything
|
||||
downstream of ``ws.recv()`` is this script's code path.
|
||||
|
||||
Usage::
|
||||
|
||||
python3 scripts/pluto_tx_ws_smoke.py # default 30s tone
|
||||
python3 scripts/pluto_tx_ws_smoke.py --identifier 192.168.3.1
|
||||
python3 scripts/pluto_tx_ws_smoke.py --duration 0 # until Ctrl-C
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import websockets
|
||||
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
from ria_toolkit_oss.agent.streamer import Streamer
|
||||
from ria_toolkit_oss.agent.ws_client import WsClient
|
||||
|
||||
|
||||
def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float,
|
||||
phase_offset: float) -> tuple[bytes, float]:
|
||||
n = np.arange(buffer_size, dtype=np.float64)
|
||||
phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset
|
||||
amp = 0.7
|
||||
iq = (amp * (np.cos(phase) + 1j * np.sin(phase))).astype(np.complex64)
|
||||
interleaved = np.empty(buffer_size * 2, dtype=np.float32)
|
||||
interleaved[0::2] = iq.real
|
||||
interleaved[1::2] = iq.imag
|
||||
next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi)
|
||||
return interleaved.tobytes(), next_phase
|
||||
|
||||
|
||||
def _make_pluto_factory(identifier: str | None):
|
||||
def factory(device: str, _ident: str | None):
|
||||
if device != "pluto":
|
||||
raise ValueError(f"this script only drives pluto; got device={device!r}")
|
||||
from ria_toolkit_oss.sdr.pluto import Pluto
|
||||
return Pluto(identifier=identifier)
|
||||
return factory
|
||||
|
||||
|
||||
async def _mock_hub_handler(ws, args, stop: asyncio.Event):
|
||||
"""Server side of the WS. Sends tx_start, streams IQ, then tx_stop."""
|
||||
# Drain the first heartbeat so the log is clean; we don't need to gate on
|
||||
# it for a localhost smoke test.
|
||||
try:
|
||||
first = await asyncio.wait_for(ws.recv(), timeout=2.0)
|
||||
if isinstance(first, str):
|
||||
payload = json.loads(first)
|
||||
if payload.get("type") == "heartbeat":
|
||||
caps = payload.get("capabilities")
|
||||
print(f"[mock-hub] agent heartbeat: capabilities={caps} "
|
||||
f"tx_enabled={payload.get('tx_enabled')}")
|
||||
except asyncio.TimeoutError:
|
||||
print("[mock-hub] warning: no heartbeat received in first 2s")
|
||||
|
||||
# Arm the agent's TX path.
|
||||
await ws.send(json.dumps({
|
||||
"type": "tx_start",
|
||||
"app_id": "ws-smoke",
|
||||
"radio_config": {
|
||||
"device": "pluto",
|
||||
"identifier": args.identifier,
|
||||
"tx_sample_rate": int(args.sample_rate),
|
||||
"tx_center_frequency": int(args.frequency),
|
||||
"tx_gain": int(args.gain),
|
||||
"buffer_size": int(args.buffer_size),
|
||||
"underrun_policy": "repeat",
|
||||
},
|
||||
}))
|
||||
print(f"[mock-hub] sent tx_start at {args.frequency/1e6:.3f} MHz, "
|
||||
f"gain={args.gain} dB")
|
||||
|
||||
# Producer: push IQ frames at a steady clip. Use a concurrent receiver so
|
||||
# tx_status frames show up in real time rather than being queued behind
|
||||
# the sends.
|
||||
phase = 0.0
|
||||
buffer_dt = args.buffer_size / args.sample_rate
|
||||
|
||||
async def receiver():
|
||||
try:
|
||||
while True:
|
||||
msg = await ws.recv()
|
||||
if isinstance(msg, str):
|
||||
print(f"[mock-hub] ← {msg}")
|
||||
except (websockets.ConnectionClosed, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
recv_task = asyncio.create_task(receiver())
|
||||
try:
|
||||
deadline = None if args.duration <= 0 else (
|
||||
asyncio.get_event_loop().time() + args.duration
|
||||
)
|
||||
while not stop.is_set():
|
||||
if deadline is not None and asyncio.get_event_loop().time() >= deadline:
|
||||
break
|
||||
frame, phase = _make_iq_frame(
|
||||
args.buffer_size, args.tone, args.sample_rate, phase
|
||||
)
|
||||
try:
|
||||
await ws.send(frame)
|
||||
except websockets.ConnectionClosed:
|
||||
break
|
||||
# Slightly ahead of real-time; WS backpressure handles the rest.
|
||||
await asyncio.sleep(buffer_dt * 0.5)
|
||||
finally:
|
||||
try:
|
||||
await ws.send(json.dumps({"type": "tx_stop", "app_id": "ws-smoke"}))
|
||||
print("[mock-hub] sent tx_stop")
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
# Give the agent a moment to emit `tx_status: done` before we tear down.
|
||||
await asyncio.sleep(0.3)
|
||||
recv_task.cancel()
|
||||
try:
|
||||
await recv_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
|
||||
async def _run(args: argparse.Namespace) -> int:
|
||||
stop = asyncio.Event()
|
||||
loop = asyncio.get_running_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
try:
|
||||
loop.add_signal_handler(sig, stop.set)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
# Start the mock hub on a local port.
|
||||
async def handler(ws):
|
||||
try:
|
||||
await _mock_hub_handler(ws, args, stop)
|
||||
finally:
|
||||
stop.set()
|
||||
|
||||
server = await websockets.serve(handler, "127.0.0.1", 0)
|
||||
port = server.sockets[0].getsockname()[1]
|
||||
print(f"[mock-hub] listening on ws://127.0.0.1:{port}")
|
||||
|
||||
# Run the agent — exactly as ``ria-agent stream`` would, just with a
|
||||
# different URL and an in-memory AgentConfig instead of one loaded from
|
||||
# ``~/.ria/agent.json``.
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=5.0,
|
||||
reconnect_pause=0.5,
|
||||
)
|
||||
streamer = Streamer(
|
||||
ws=client,
|
||||
sdr_factory=_make_pluto_factory(args.identifier),
|
||||
cfg=AgentConfig(tx_enabled=True, tx_max_gain_db=0.0),
|
||||
)
|
||||
client_task = asyncio.create_task(
|
||||
client.run(
|
||||
on_message=streamer.on_message,
|
||||
heartbeat=streamer.build_heartbeat,
|
||||
on_binary=streamer.on_binary,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await stop.wait()
|
||||
finally:
|
||||
client.stop()
|
||||
client_task.cancel()
|
||||
try:
|
||||
await client_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
print("Done.")
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
p = argparse.ArgumentParser(
|
||||
description="Full-stack TX smoke: localhost mock-hub → WS → agent → Pluto.",
|
||||
)
|
||||
p.add_argument("--identifier", default=None,
|
||||
help="Pluto IP/hostname (default: auto-discover pluto.local)")
|
||||
p.add_argument("--frequency", type=float, default=2_450_000_000.0,
|
||||
help="TX LO in Hz (default 2.45 GHz)")
|
||||
p.add_argument("--gain", type=float, default=0.0,
|
||||
help="TX gain in dB; Pluto range [-89, 0] (default 0)")
|
||||
p.add_argument("--sample-rate", type=float, default=1_000_000.0,
|
||||
help="Baseband sample rate (default 1 Msps)")
|
||||
p.add_argument("--tone", type=float, default=100_000.0,
|
||||
help="Baseband tone offset in Hz (default 100 kHz)")
|
||||
p.add_argument("--buffer-size", type=int, default=4096,
|
||||
help="Complex samples per frame (default 4096)")
|
||||
p.add_argument("--duration", type=float, default=30.0,
|
||||
help="Seconds to transmit; 0 = run until Ctrl-C (default 30)")
|
||||
p.add_argument("--log-level", default="INFO")
|
||||
args = p.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level.upper(), logging.INFO),
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
return asyncio.run(_run(args))
|
||||
except KeyboardInterrupt:
|
||||
return 130
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
|
@ -37,6 +37,18 @@ def heartbeat_payload(
|
|||
"capabilities": capabilities,
|
||||
"tx_enabled": bool(c.tx_enabled),
|
||||
}
|
||||
# Surface configured interlock values so the hub can pre-filter UI controls
|
||||
# before sending a tx_start that would be rejected. Only included when TX
|
||||
# is opted in AND the operator set a cap.
|
||||
if c.tx_enabled:
|
||||
if c.tx_max_gain_db is not None:
|
||||
payload["tx_max_gain_db"] = float(c.tx_max_gain_db)
|
||||
if c.tx_max_duration_s is not None:
|
||||
payload["tx_max_duration_s"] = float(c.tx_max_duration_s)
|
||||
if c.tx_allowed_freq_ranges:
|
||||
payload["tx_allowed_freq_ranges"] = [
|
||||
[float(lo), float(hi)] for lo, hi in c.tx_allowed_freq_ranges
|
||||
]
|
||||
if app_id:
|
||||
payload["app_id"] = app_id
|
||||
if sessions:
|
||||
|
|
|
|||
|
|
@ -197,8 +197,14 @@ class Streamer:
|
|||
sessions=sessions or None,
|
||||
)
|
||||
|
||||
# Advisory / keepalive message types we accept and ignore without warning.
|
||||
_IGNORED_MESSAGE_TYPES = frozenset({"tx_data_available"})
|
||||
|
||||
async def on_message(self, msg: dict) -> None:
|
||||
t = msg.get("type")
|
||||
if t in self._IGNORED_MESSAGE_TYPES:
|
||||
logger.debug("Ignoring advisory message: %r", t)
|
||||
return
|
||||
handler = {
|
||||
"start": self._handle_rx_start,
|
||||
"stop": self._handle_rx_stop,
|
||||
|
|
@ -469,9 +475,12 @@ class Streamer:
|
|||
def _tx_executor_body(self, session: TxSession) -> None:
|
||||
try:
|
||||
session.sdr._stream_tx(lambda n: self._tx_callback(session, n))
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
logger.exception("TX stream crashed")
|
||||
self._schedule(self._send_tx_status(session.app_id, "error", "tx stream crashed"))
|
||||
# Schedule both the error frame and session teardown on the loop
|
||||
# so ``self._tx`` clears, subsequent binary frames are rejected,
|
||||
# and the SDR handle is released.
|
||||
self._schedule(self._tx_crash_teardown(session, str(exc)))
|
||||
|
||||
def _tx_callback(self, session: TxSession, num_samples) -> np.ndarray:
|
||||
n = int(num_samples)
|
||||
|
|
@ -561,6 +570,18 @@ class Streamer:
|
|||
return
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
async def _tx_crash_teardown(self, session: TxSession, message: str) -> None:
|
||||
# Called from the executor thread via _schedule when _stream_tx raises.
|
||||
# Emit the error, mark stopped, drain the queue, release the SDR.
|
||||
await self._send_tx_status(session.app_id, "error", f"tx stream crashed: {message}")
|
||||
if self._tx is not session:
|
||||
return
|
||||
session.stop_event.set()
|
||||
self._drain_tx_queue(session)
|
||||
self._close_session_sdr(session)
|
||||
if self._tx is session:
|
||||
self._tx = None
|
||||
|
||||
async def _teardown_tx_after_underrun(self, session: TxSession) -> None:
|
||||
if self._tx is not session:
|
||||
return
|
||||
|
|
@ -643,13 +664,44 @@ _CONFIG_ATTR_MAP = {
|
|||
}
|
||||
|
||||
|
||||
def _is_stub_setter(method: Any) -> bool:
|
||||
"""True when *method* is an unimplemented base-class stub.
|
||||
|
||||
The ``SDR`` abstract base defines ``set_rx_sample_rate`` / ``set_tx_gain``
|
||||
etc. as zero-argument ``NotImplementedError`` stubs. A driver (Pluto) that
|
||||
actually transmits overrides them with a real ``(value, ...)`` signature.
|
||||
Comparing ``__qualname__`` against ``SDR.`` lets us skip the stubs cheaply.
|
||||
"""
|
||||
return getattr(method, "__qualname__", "").startswith("SDR.")
|
||||
|
||||
|
||||
def _apply_sdr_config(sdr: Any, cfg: dict) -> None:
|
||||
"""Apply a radio_config dict to an SDR, trying multiple attribute aliases."""
|
||||
"""Apply a radio_config dict to an SDR.
|
||||
|
||||
Prefers ``sdr.set_<attr>(value)`` when the driver implements it — Pluto's
|
||||
setters take ``_param_lock``, so routing through them keeps concurrent
|
||||
RX + TX reconfigures from racing on shared native attributes. Falls back
|
||||
to ``setattr`` for drivers (MockSDR, tests) that don't override the
|
||||
base-class stubs.
|
||||
"""
|
||||
for key, value in cfg.items():
|
||||
if value is None:
|
||||
continue
|
||||
attrs = _CONFIG_ATTR_MAP.get(key, (key,))
|
||||
applied = False
|
||||
for attr in attrs:
|
||||
setter = getattr(sdr, f"set_{attr}", None)
|
||||
if callable(setter) and not _is_stub_setter(setter):
|
||||
try:
|
||||
setter(value)
|
||||
applied = True
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.debug("set_%s(%r) failed: %s", attr, value, exc)
|
||||
# Fall through to setattr; some drivers may partially
|
||||
# implement setters.
|
||||
if applied:
|
||||
continue
|
||||
for attr in attrs:
|
||||
if hasattr(sdr, attr):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -384,7 +384,10 @@ class Pluto(SDR):
|
|||
self._enable_tx = True
|
||||
while self._enable_tx is True:
|
||||
buffer = self._convert_tx_samples(callback(self.tx_buffer_size))
|
||||
self.radio.tx(buffer[0])
|
||||
# pyadi-iio's ``radio.tx`` auto-wraps single-channel 1-D input.
|
||||
# Indexing ``buffer[0]`` was a latent bug for callbacks that
|
||||
# returned 1-D samples (scalar → TypeError inside pyadi).
|
||||
self.radio.tx(buffer)
|
||||
|
||||
def set_rx_center_frequency(self, center_frequency):
|
||||
"""
|
||||
|
|
@ -514,6 +517,11 @@ class Pluto(SDR):
|
|||
raise SDRError(e)
|
||||
|
||||
def set_tx_center_frequency(self, center_frequency):
|
||||
# ``adi.Pluto`` exposes one radio handle shared between RX and TX; concurrent
|
||||
# RX + TX sessions (see the agent ``_SdrRegistry``) may call RX and TX
|
||||
# setters at the same time. Serialize with ``_param_lock`` — RX setters hold
|
||||
# the same reentrant lock — so native attribute writes don't interleave.
|
||||
with self._param_lock:
|
||||
if center_frequency < 70e6 or center_frequency > 6e9:
|
||||
raise SDRParameterError(
|
||||
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
|
||||
|
|
@ -534,6 +542,10 @@ class Pluto(SDR):
|
|||
)
|
||||
|
||||
def set_tx_sample_rate(self, sample_rate):
|
||||
# ``self.radio.sample_rate`` is shared between RX and TX on Pluto — RX's
|
||||
# ``set_rx_sample_rate`` writes the same native attribute. Hold ``_param_lock``
|
||||
# so full-duplex sessions can't interleave writes.
|
||||
with self._param_lock:
|
||||
min_rate, max_rate = 65.1e3, 61.44e6
|
||||
if sample_rate < min_rate or sample_rate > max_rate:
|
||||
raise SDRParameterError(
|
||||
|
|
@ -553,6 +565,8 @@ class Pluto(SDR):
|
|||
)
|
||||
|
||||
def set_tx_gain(self, gain, channel=0, gain_mode="absolute"):
|
||||
# Serialize with RX setters: see ``set_tx_sample_rate`` above.
|
||||
with self._param_lock:
|
||||
tx_gain_min = -89
|
||||
tx_gain_max = 0
|
||||
|
||||
|
|
|
|||
|
|
@ -44,3 +44,30 @@ def test_heartbeat_payload_sessions_field():
|
|||
sessions = {"rx": {"app_id": "a", "state": "streaming"}}
|
||||
p = hardware.heartbeat_payload(status="streaming", app_id="a", sessions=sessions)
|
||||
assert p["sessions"] == sessions
|
||||
|
||||
|
||||
def test_heartbeat_payload_surfaces_tx_caps_when_enabled():
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
|
||||
cfg = AgentConfig(
|
||||
tx_enabled=True,
|
||||
tx_max_gain_db=-10.0,
|
||||
tx_max_duration_s=60.0,
|
||||
tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]],
|
||||
)
|
||||
p = hardware.heartbeat_payload(cfg=cfg)
|
||||
assert p["tx_max_gain_db"] == -10.0
|
||||
assert p["tx_max_duration_s"] == 60.0
|
||||
assert p["tx_allowed_freq_ranges"] == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]]
|
||||
|
||||
|
||||
def test_heartbeat_payload_omits_caps_when_tx_disabled():
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
|
||||
# Caps set but tx_enabled=False — don't leak them; they're only meaningful
|
||||
# when the hub can attempt a tx_start.
|
||||
cfg = AgentConfig(tx_enabled=False, tx_max_gain_db=-10.0)
|
||||
p = hardware.heartbeat_payload(cfg=cfg)
|
||||
assert "tx_max_gain_db" not in p
|
||||
assert "tx_max_duration_s" not in p
|
||||
assert "tx_allowed_freq_ranges" not in p
|
||||
|
|
|
|||
210
tests/agent/test_param_lock_contention.py
Normal file
210
tests/agent/test_param_lock_contention.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
"""Step-A6 (Pluto lock audit) coverage.
|
||||
|
||||
Verifies the two invariants the handoff doc calls for when RX and TX run
|
||||
concurrently on one shared SDR handle:
|
||||
|
||||
1. ``_param_lock`` actually serializes concurrent RX + TX setter calls — the
|
||||
spec's §A6 acceptance criterion is *"``_param_lock`` instrumented for
|
||||
contention"*. We drive parallel ``set_{rx,tx}_sample_rate`` calls through
|
||||
the lock and assert it's hit often enough to prove both paths fight for it.
|
||||
2. Under a sustained full-duplex session (RX capturing + TX transmitting on
|
||||
one ``(device, identifier)``), no setter write is dropped and no exception
|
||||
escapes the executor — i.e., the shared-handle assumption holds. Runs
|
||||
against ``MockSDR`` per the spec; the real Pluto driver now takes the
|
||||
same lock on its TX setters so the production code path is isomorphic.
|
||||
|
||||
The stress window is 2 seconds by default — the handoff mentions 30 s but
|
||||
that's impractical in CI. Set ``RIA_LOCK_STRESS_S`` to override.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
from ria_toolkit_oss.agent.streamer import Streamer
|
||||
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||
|
||||
|
||||
_STRESS_S = float(os.environ.get("RIA_LOCK_STRESS_S", "2.0"))
|
||||
|
||||
|
||||
class InstrumentedMockSDR(MockSDR):
|
||||
"""MockSDR that counts lock acquisitions and exposes a real ``_param_lock``.
|
||||
|
||||
``_param_lock`` is inherited from ``SDR`` as a reentrant lock; we wrap it
|
||||
with a counter that records every time RX or TX setters grab it, so the
|
||||
test can assert real contention rather than just "the code compiles".
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int):
|
||||
super().__init__(buffer_size=buffer_size)
|
||||
self.rx_lock_hits = 0
|
||||
self.tx_lock_hits = 0
|
||||
self.param_lock_hits = 0
|
||||
# Shadow lock that increments a counter each time __enter__ fires.
|
||||
real_lock = self._param_lock
|
||||
|
||||
test = self
|
||||
|
||||
class CountingLock:
|
||||
def __enter__(self_inner):
|
||||
test.param_lock_hits += 1
|
||||
real_lock.acquire()
|
||||
return self_inner
|
||||
|
||||
def __exit__(self_inner, *a):
|
||||
real_lock.release()
|
||||
return False
|
||||
|
||||
# ``threading.RLock`` interop for any code that calls acquire/release directly.
|
||||
def acquire(self_inner, *a, **k):
|
||||
test.param_lock_hits += 1
|
||||
return real_lock.acquire(*a, **k)
|
||||
|
||||
def release(self_inner):
|
||||
return real_lock.release()
|
||||
|
||||
self._param_lock = CountingLock()
|
||||
|
||||
# The MockSDR doesn't ship RX setter methods that hit the lock — override
|
||||
# ``sample_rate`` / ``center_freq`` / ``gain`` writes to route through the
|
||||
# same lock the real Pluto driver uses, so this test faithfully models the
|
||||
# production contention path.
|
||||
def set_rx_sample_rate(self, sample_rate):
|
||||
with self._param_lock:
|
||||
self.rx_lock_hits += 1
|
||||
self.rx_sample_rate = float(sample_rate)
|
||||
self.sample_rate = self.rx_sample_rate
|
||||
|
||||
def set_tx_sample_rate(self, sample_rate):
|
||||
with self._param_lock:
|
||||
self.tx_lock_hits += 1
|
||||
self.tx_sample_rate = float(sample_rate)
|
||||
# Mirror Pluto: both RX and TX write the same native attribute.
|
||||
self.sample_rate = self.tx_sample_rate
|
||||
|
||||
|
||||
class FakeWs:
|
||||
def __init__(self):
|
||||
self.json_sent: list[dict] = []
|
||||
self.bytes_sent: list[bytes] = []
|
||||
|
||||
async def send_json(self, p):
|
||||
self.json_sent.append(p)
|
||||
|
||||
async def send_bytes(self, b):
|
||||
self.bytes_sent.append(b)
|
||||
|
||||
|
||||
def _iq_frame(samples: np.ndarray) -> bytes:
|
||||
interleaved = np.empty(samples.size * 2, dtype=np.float32)
|
||||
interleaved[0::2] = samples.real
|
||||
interleaved[1::2] = samples.imag
|
||||
return interleaved.tobytes()
|
||||
|
||||
|
||||
def test_param_lock_contended_under_concurrent_setters():
|
||||
"""Run two threads that hammer RX + TX sample-rate setters and assert both
|
||||
lock paths fire. This proves the lock is doing work — if either setter
|
||||
bypassed ``_param_lock``, one of the counters would stay at zero."""
|
||||
sdr = InstrumentedMockSDR(buffer_size=16)
|
||||
stop = threading.Event()
|
||||
|
||||
def rx_setter():
|
||||
i = 0
|
||||
while not stop.is_set():
|
||||
sdr.set_rx_sample_rate(1_000_000 + (i % 1000))
|
||||
i += 1
|
||||
|
||||
def tx_setter():
|
||||
i = 0
|
||||
while not stop.is_set():
|
||||
sdr.set_tx_sample_rate(2_000_000 + (i % 1000))
|
||||
i += 1
|
||||
|
||||
t1 = threading.Thread(target=rx_setter)
|
||||
t2 = threading.Thread(target=tx_setter)
|
||||
t1.start()
|
||||
t2.start()
|
||||
time.sleep(min(_STRESS_S, 2.0))
|
||||
stop.set()
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
assert sdr.rx_lock_hits > 100, f"RX setter barely ran: {sdr.rx_lock_hits}"
|
||||
assert sdr.tx_lock_hits > 100, f"TX setter barely ran: {sdr.tx_lock_hits}"
|
||||
# Every setter call should have passed through _param_lock exactly once.
|
||||
assert sdr.param_lock_hits >= sdr.rx_lock_hits + sdr.tx_lock_hits
|
||||
|
||||
|
||||
def test_full_duplex_stays_healthy_over_stress_window():
|
||||
"""Start RX + TX on one shared SDR and drive both paths for ``_STRESS_S``
|
||||
seconds, pushing binary frames and emitting ``tx_configure`` mid-stream.
|
||||
The session must survive, deliver buffers in both directions, and leave
|
||||
the registry clean on shutdown."""
|
||||
BUF = 32
|
||||
sdr = InstrumentedMockSDR(buffer_size=BUF)
|
||||
|
||||
async def scenario():
|
||||
ws = FakeWs()
|
||||
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
||||
|
||||
await s.on_message(
|
||||
{"type": "start", "app_id": "app-1",
|
||||
"radio_config": {"device": "mock", "buffer_size": BUF}}
|
||||
)
|
||||
await s.on_message(
|
||||
{"type": "tx_start", "app_id": "app-1",
|
||||
"radio_config": {
|
||||
"device": "mock", "buffer_size": BUF,
|
||||
"tx_sample_rate": 1_000_000,
|
||||
"tx_center_frequency": 2.45e9,
|
||||
"tx_gain": -20,
|
||||
"underrun_policy": "zero",
|
||||
}}
|
||||
)
|
||||
|
||||
marker = np.arange(BUF, dtype=np.complex64) + 1
|
||||
deadline = time.monotonic() + _STRESS_S
|
||||
i = 0
|
||||
while time.monotonic() < deadline:
|
||||
await s.on_binary(_iq_frame(marker))
|
||||
if i % 8 == 0:
|
||||
# Mid-stream parameter reconfiguration touches _apply_sdr_config,
|
||||
# which routes through the same setters the stress test above
|
||||
# verifies.
|
||||
await s.on_message(
|
||||
{"type": "tx_configure", "app_id": "app-1",
|
||||
"radio_config": {"tx_sample_rate": 1_000_000 + i}}
|
||||
)
|
||||
await s.on_message(
|
||||
{"type": "configure", "app_id": "app-1",
|
||||
"radio_config": {"sample_rate": 2_000_000 + i}}
|
||||
)
|
||||
i += 1
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
await s.on_message({"type": "tx_stop", "app_id": "app-1"})
|
||||
await s.on_message({"type": "stop", "app_id": "app-1"})
|
||||
return ws, s
|
||||
|
||||
ws, s = asyncio.run(scenario())
|
||||
|
||||
# No error frame leaked out.
|
||||
errors = [m for m in ws.json_sent
|
||||
if m.get("type") in ("error", "tx_status") and m.get("state") == "error"]
|
||||
assert errors == [], f"Unexpected error frames: {errors}"
|
||||
# RX produced IQ frames and TX's callback ran — heartbeat-level contention
|
||||
# check: both setter paths were hit at least once during configure dispatch.
|
||||
assert ws.bytes_sent, "RX produced no IQ frames"
|
||||
assert sdr.param_lock_hits > 0
|
||||
# Sessions cleaned up; registry drained.
|
||||
assert s._tx is None
|
||||
assert s._rx is None
|
||||
assert s._registry.refcount(("mock", None)) == 0
|
||||
|
|
@ -139,6 +139,21 @@ def test_unknown_message_type_is_ignored():
|
|||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_tx_data_available_is_a_silent_noop():
|
||||
# Hub sends this as a keepalive; we should accept and ignore without
|
||||
# emitting a WARNING or treating it as an error.
|
||||
async def scenario():
|
||||
ws = FakeWs()
|
||||
s = Streamer(ws=ws, sdr_factory=_factory)
|
||||
await s.on_message({"type": "tx_data_available", "app_id": "x"})
|
||||
return ws
|
||||
|
||||
ws = asyncio.run(scenario())
|
||||
# No outbound frames emitted.
|
||||
assert ws.json_sent == []
|
||||
assert ws.bytes_sent == []
|
||||
|
||||
|
||||
def test_registry_shares_sdr_across_start_stop_cycles():
|
||||
# Two sequential start/stop cycles with the same (device, identifier)
|
||||
# should hit the registry's cache path rather than constructing a new SDR.
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
"""Reconnect + heartbeat timing against a real local websockets server."""
|
||||
"""Reconnect + heartbeat + malformed-control-frame behavior.
|
||||
|
||||
Binary-frame delivery lives in ``test_ws_client_binary.py`` to match the
|
||||
test matrix spelled out in ``Agent TX Streaming Handoff.md`` §A7.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import websockets
|
||||
|
||||
from ria_toolkit_oss.agent.ws_client import WsClient
|
||||
|
|
@ -113,109 +116,6 @@ def test_reconnects_after_server_drop():
|
|||
assert n >= 2
|
||||
|
||||
|
||||
def test_binary_frame_forwarded_to_handler():
|
||||
payload = bytes(range(128))
|
||||
|
||||
async def scenario():
|
||||
received: list[bytes] = []
|
||||
done = asyncio.Event()
|
||||
|
||||
async def handler(ws):
|
||||
await ws.send(payload)
|
||||
done.set()
|
||||
try:
|
||||
await ws.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
server, port = await _open_server(handler)
|
||||
try:
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=10.0,
|
||||
reconnect_pause=0.05,
|
||||
)
|
||||
|
||||
async def on_bin(data):
|
||||
received.append(data)
|
||||
|
||||
task = asyncio.create_task(
|
||||
client.run(
|
||||
on_message=lambda _m: asyncio.sleep(0),
|
||||
heartbeat=lambda: {"type": "heartbeat"},
|
||||
on_binary=on_bin,
|
||||
)
|
||||
)
|
||||
for _ in range(50):
|
||||
if received:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
client.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
return received
|
||||
|
||||
received = asyncio.run(scenario())
|
||||
assert received == [payload]
|
||||
|
||||
|
||||
def test_binary_frame_dropped_when_no_handler():
|
||||
# Regression guard: existing behavior (drop server-sent binary) preserved when
|
||||
# on_binary is not supplied.
|
||||
async def scenario():
|
||||
crashes: list[Exception] = []
|
||||
|
||||
async def handler(ws):
|
||||
await ws.send(b"\x00\x01\x02\x03")
|
||||
await ws.send(json.dumps({"type": "ping"}))
|
||||
try:
|
||||
await ws.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
messages: list[dict] = []
|
||||
server, port = await _open_server(handler)
|
||||
try:
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=10.0,
|
||||
reconnect_pause=0.05,
|
||||
)
|
||||
|
||||
async def on_msg(m):
|
||||
messages.append(m)
|
||||
|
||||
task = asyncio.create_task(
|
||||
client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})
|
||||
)
|
||||
for _ in range(50):
|
||||
if messages:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
client.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception) as exc:
|
||||
crashes.append(exc)
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
return messages, crashes
|
||||
|
||||
messages, crashes = asyncio.run(scenario())
|
||||
# JSON still delivered; binary silently dropped; no uncaught crash.
|
||||
assert messages and messages[0] == {"type": "ping"}
|
||||
|
||||
|
||||
def test_malformed_control_frame_does_not_crash():
|
||||
async def scenario():
|
||||
handled: list[dict] = []
|
||||
|
|
|
|||
186
tests/agent/test_ws_client_binary.py
Normal file
186
tests/agent/test_ws_client_binary.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
"""Binary-frame delivery on the hub → agent WebSocket.
|
||||
|
||||
Named to match the test matrix in ``Agent TX Streaming Handoff.md`` §A7.
|
||||
Exercises:
|
||||
|
||||
- Binary frames are forwarded to an ``on_binary`` coroutine when supplied.
|
||||
- Binary frames are silently dropped (no crash) when ``on_binary`` is omitted,
|
||||
preserving the pre-TX behavior for RX-only deployments.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import websockets
|
||||
|
||||
from ria_toolkit_oss.agent.ws_client import WsClient
|
||||
|
||||
|
||||
async def _open_server(handler):
|
||||
server = await websockets.serve(handler, "127.0.0.1", 0)
|
||||
port = server.sockets[0].getsockname()[1]
|
||||
return server, port
|
||||
|
||||
|
||||
def test_binary_frame_forwarded_to_handler():
|
||||
payload = bytes(range(128))
|
||||
|
||||
async def scenario():
|
||||
received: list[bytes] = []
|
||||
done = asyncio.Event()
|
||||
|
||||
async def handler(ws):
|
||||
await ws.send(payload)
|
||||
done.set()
|
||||
try:
|
||||
await ws.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
server, port = await _open_server(handler)
|
||||
try:
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=10.0,
|
||||
reconnect_pause=0.05,
|
||||
)
|
||||
|
||||
async def on_bin(data):
|
||||
received.append(data)
|
||||
|
||||
task = asyncio.create_task(
|
||||
client.run(
|
||||
on_message=lambda _m: asyncio.sleep(0),
|
||||
heartbeat=lambda: {"type": "heartbeat"},
|
||||
on_binary=on_bin,
|
||||
)
|
||||
)
|
||||
for _ in range(50):
|
||||
if received:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
client.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
return received
|
||||
|
||||
received = asyncio.run(scenario())
|
||||
assert received == [payload]
|
||||
|
||||
|
||||
def test_binary_frame_dropped_when_no_handler():
|
||||
async def scenario():
|
||||
crashes: list[Exception] = []
|
||||
|
||||
async def handler(ws):
|
||||
await ws.send(b"\x00\x01\x02\x03")
|
||||
await ws.send(json.dumps({"type": "ping"}))
|
||||
try:
|
||||
await ws.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
messages: list[dict] = []
|
||||
server, port = await _open_server(handler)
|
||||
try:
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=10.0,
|
||||
reconnect_pause=0.05,
|
||||
)
|
||||
|
||||
async def on_msg(m):
|
||||
messages.append(m)
|
||||
|
||||
task = asyncio.create_task(
|
||||
client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})
|
||||
)
|
||||
for _ in range(50):
|
||||
if messages:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
client.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception) as exc:
|
||||
crashes.append(exc)
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
return messages, crashes
|
||||
|
||||
messages, _ = asyncio.run(scenario())
|
||||
assert messages and messages[0] == {"type": "ping"}
|
||||
|
||||
|
||||
def test_on_binary_exception_does_not_kill_connection():
|
||||
"""A buggy ``on_binary`` raises mid-stream; the WS loop keeps accepting frames."""
|
||||
|
||||
async def scenario():
|
||||
delivered_binary = 0
|
||||
delivered_control: list[dict] = []
|
||||
|
||||
async def handler(ws):
|
||||
await ws.send(b"\x10\x20\x30")
|
||||
await ws.send(b"\x40\x50\x60")
|
||||
await ws.send(json.dumps({"type": "ping"}))
|
||||
try:
|
||||
await ws.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
server, port = await _open_server(handler)
|
||||
try:
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=10.0,
|
||||
reconnect_pause=0.05,
|
||||
)
|
||||
|
||||
async def on_bin(data):
|
||||
nonlocal delivered_binary
|
||||
delivered_binary += 1
|
||||
raise RuntimeError("handler broke")
|
||||
|
||||
async def on_msg(m):
|
||||
delivered_control.append(m)
|
||||
|
||||
task = asyncio.create_task(
|
||||
client.run(
|
||||
on_message=on_msg,
|
||||
heartbeat=lambda: {"type": "heartbeat"},
|
||||
on_binary=on_bin,
|
||||
)
|
||||
)
|
||||
for _ in range(60):
|
||||
if delivered_control:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
client.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
return delivered_binary, delivered_control
|
||||
|
||||
bins, ctrls = asyncio.run(scenario())
|
||||
# Both binary frames were delivered to the (crashing) handler.
|
||||
assert bins == 2
|
||||
# The subsequent JSON frame still arrived — loop didn't die on the exceptions.
|
||||
assert ctrls and ctrls[0] == {"type": "ping"}
|
||||
Loading…
Reference in New Issue
Block a user