ria-toolkit-oss/tests/agent/test_tx_underrun.py

131 lines
4.4 KiB
Python
Raw Normal View History

J
2026-04-16 11:13:43 -04:00
"""Underrun policies: pause, zero, repeat."""
from __future__ import annotations
import asyncio
import time
import numpy as np
from ria_toolkit_oss.agent.config import AgentConfig
from ria_toolkit_oss.agent.streamer import Streamer
from ria_toolkit_oss.sdr.mock import MockSDR
class RecordingMockSDR(MockSDR):
def __init__(self, buffer_size: int):
super().__init__(buffer_size=buffer_size)
self.tx_produced: list[np.ndarray] = []
def _stream_tx(self, callback):
self._enable_tx = True
self._tx_initialized = True
while self._enable_tx:
result = callback(self.rx_buffer_size)
self.tx_produced.append(np.asarray(result).copy())
time.sleep(0.005)
class FakeWs:
def __init__(self):
self.json_sent = []
self.bytes_sent = []
async def send_json(self, p):
self.json_sent.append(p)
async def send_bytes(self, b):
self.bytes_sent.append(b)
def _iq_frame(samples: np.ndarray) -> bytes:
interleaved = np.empty(samples.size * 2, dtype=np.float32)
interleaved[0::2] = samples.real
interleaved[1::2] = samples.imag
return interleaved.tobytes()
def _start_cfg(policy: str, buf: int = 8) -> dict:
return {
"type": "tx_start",
"app_id": "a",
"radio_config": {
"device": "mock",
"buffer_size": buf,
J
2026-04-16 15:38:35 -04:00
"tx_sample_rate": 1_000_000,
J
2026-04-16 11:13:43 -04:00
"tx_gain": -20,
"tx_center_frequency": 2.45e9,
"underrun_policy": policy,
},
}
def test_underrun_pause_stops_session_and_emits_status():
sdr = RecordingMockSDR(buffer_size=8)
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
await s.on_message(_start_cfg("pause"))
# Do not push any buffers. The callback underruns on first tick and
# the watchdog should emit "underrun" and tear down.
for _ in range(100):
2026-04-20 13:51:15 -04:00
if any(m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent):
J
2026-04-16 11:13:43 -04:00
break
await asyncio.sleep(0.01)
for _ in range(50):
if s._tx is None:
break
await asyncio.sleep(0.01)
return ws, s
ws, s = asyncio.run(scenario())
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
assert "underrun" in states
assert s._tx is None
def test_underrun_zero_keeps_session_alive():
sdr = RecordingMockSDR(buffer_size=8)
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
await s.on_message(_start_cfg("zero"))
# Let it produce several underrun-filled buffers.
await asyncio.sleep(0.08)
still_alive = s._tx is not None
await s.on_message({"type": "tx_stop", "app_id": "a"})
return ws, still_alive
ws, still_alive = asyncio.run(scenario())
# No underrun status emitted (policy absorbs it silently).
2026-04-20 13:51:15 -04:00
assert not any(m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent)
J
2026-04-16 11:13:43 -04:00
assert still_alive
# All produced buffers are zero (no real data was pushed).
assert sdr.tx_produced, "expected at least one TX callback invocation"
assert all(not np.any(b != 0) for b in sdr.tx_produced)
def test_underrun_repeat_replays_last_buffer():
BUF = 8
sdr = RecordingMockSDR(buffer_size=BUF)
marker = np.arange(BUF, dtype=np.complex64) + 1 # distinct non-zero buffer
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
await s.on_message(_start_cfg("repeat", buf=BUF))
await s.on_binary(_iq_frame(marker))
# Give the executor time to consume the real frame + several repeats.
await asyncio.sleep(0.08)
await s.on_message({"type": "tx_stop", "app_id": "a"})
return ws, sdr
ws, sdr = asyncio.run(scenario())
# No underrun status emitted.
2026-04-20 13:51:15 -04:00
assert not any(m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent)
J
2026-04-16 11:13:43 -04:00
# At least two buffers equal to the marker — the real one and ≥1 repeat.
matching = [b for b in sdr.tx_produced if np.array_equal(b, marker)]
assert len(matching) >= 2, f"expected ≥2 buffers matching marker, got {len(matching)}"