J
2026-04-16 15:12:56 -04:00
|
|
|
"""Reconnect + heartbeat + malformed-control-frame behavior.
|
|
|
|
|
|
|
|
|
|
Binary-frame delivery lives in ``test_ws_client_binary.py`` to match the
|
|
|
|
|
test matrix spelled out in ``Agent TX Streaming Handoff.md`` §A7.
|
|
|
|
|
"""
|
J
2026-04-13 11:48:15 -04:00
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
import websockets
|
|
|
|
|
|
|
|
|
|
from ria_toolkit_oss.agent.ws_client import WsClient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _recv_json(ws) -> dict:
|
|
|
|
|
raw = await ws.recv()
|
|
|
|
|
return json.loads(raw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _open_server(handler):
|
|
|
|
|
# websockets 13 ignores extra positional args; bind to localhost:0 for an
|
|
|
|
|
# ephemeral port and return both the server and the port.
|
|
|
|
|
server = await websockets.serve(handler, "127.0.0.1", 0)
|
|
|
|
|
port = server.sockets[0].getsockname()[1]
|
|
|
|
|
return server, port
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_heartbeat_sent_on_connect():
|
|
|
|
|
async def scenario():
|
|
|
|
|
received: list[dict] = []
|
|
|
|
|
connected = asyncio.Event()
|
|
|
|
|
|
|
|
|
|
async def handler(ws):
|
|
|
|
|
connected.set()
|
|
|
|
|
msg = await _recv_json(ws)
|
|
|
|
|
received.append(msg)
|
|
|
|
|
|
|
|
|
|
server, port = await _open_server(handler)
|
|
|
|
|
try:
|
|
|
|
|
client = WsClient(
|
|
|
|
|
f"ws://127.0.0.1:{port}",
|
|
|
|
|
token="",
|
|
|
|
|
heartbeat_interval=0.05,
|
|
|
|
|
reconnect_pause=0.05,
|
|
|
|
|
)
|
|
|
|
|
task = asyncio.create_task(
|
|
|
|
|
client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat", "n": 1})
|
|
|
|
|
)
|
|
|
|
|
await asyncio.wait_for(connected.wait(), timeout=2.0)
|
|
|
|
|
for _ in range(50):
|
|
|
|
|
if received:
|
|
|
|
|
break
|
|
|
|
|
await asyncio.sleep(0.02)
|
|
|
|
|
client.stop()
|
|
|
|
|
task.cancel()
|
|
|
|
|
try:
|
|
|
|
|
await task
|
|
|
|
|
except (asyncio.CancelledError, Exception):
|
|
|
|
|
pass
|
|
|
|
|
finally:
|
|
|
|
|
server.close()
|
|
|
|
|
await server.wait_closed()
|
|
|
|
|
return received
|
|
|
|
|
|
|
|
|
|
received = asyncio.run(scenario())
|
|
|
|
|
assert received and received[0]["type"] == "heartbeat"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_reconnects_after_server_drop():
|
|
|
|
|
async def scenario():
|
|
|
|
|
connections = 0
|
|
|
|
|
first_dropped = asyncio.Event()
|
|
|
|
|
|
|
|
|
|
async def handler(ws):
|
|
|
|
|
nonlocal connections
|
|
|
|
|
connections += 1
|
|
|
|
|
if connections == 1:
|
|
|
|
|
await ws.close()
|
|
|
|
|
first_dropped.set()
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
await ws.recv()
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
server, port = await _open_server(handler)
|
|
|
|
|
try:
|
|
|
|
|
client = WsClient(
|
|
|
|
|
f"ws://127.0.0.1:{port}",
|
|
|
|
|
token="",
|
|
|
|
|
heartbeat_interval=10.0,
|
|
|
|
|
reconnect_pause=0.05,
|
|
|
|
|
)
|
|
|
|
|
task = asyncio.create_task(
|
|
|
|
|
client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat"})
|
|
|
|
|
)
|
|
|
|
|
await asyncio.wait_for(first_dropped.wait(), timeout=2.0)
|
|
|
|
|
for _ in range(100):
|
|
|
|
|
if connections >= 2:
|
|
|
|
|
break
|
|
|
|
|
await asyncio.sleep(0.02)
|
|
|
|
|
client.stop()
|
|
|
|
|
task.cancel()
|
|
|
|
|
try:
|
|
|
|
|
await task
|
|
|
|
|
except (asyncio.CancelledError, Exception):
|
|
|
|
|
pass
|
|
|
|
|
finally:
|
|
|
|
|
server.close()
|
|
|
|
|
await server.wait_closed()
|
|
|
|
|
return connections
|
|
|
|
|
|
|
|
|
|
n = asyncio.run(scenario())
|
|
|
|
|
assert n >= 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_malformed_control_frame_does_not_crash():
|
|
|
|
|
async def scenario():
|
|
|
|
|
handled: list[dict] = []
|
|
|
|
|
done = asyncio.Event()
|
|
|
|
|
|
|
|
|
|
async def handler(ws):
|
|
|
|
|
await ws.send("not json")
|
|
|
|
|
await ws.send(json.dumps({"type": "ping"}))
|
|
|
|
|
done.set()
|
|
|
|
|
try:
|
|
|
|
|
await ws.wait_closed()
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
server, port = await _open_server(handler)
|
|
|
|
|
try:
|
|
|
|
|
client = WsClient(
|
|
|
|
|
f"ws://127.0.0.1:{port}",
|
|
|
|
|
token="",
|
|
|
|
|
heartbeat_interval=10.0,
|
|
|
|
|
reconnect_pause=0.05,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def on_msg(m):
|
|
|
|
|
handled.append(m)
|
|
|
|
|
|
2026-04-20 13:51:15 -04:00
|
|
|
task = asyncio.create_task(client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}))
|
J
2026-04-13 11:48:15 -04:00
|
|
|
for _ in range(50):
|
|
|
|
|
if handled:
|
|
|
|
|
break
|
|
|
|
|
await asyncio.sleep(0.02)
|
|
|
|
|
client.stop()
|
|
|
|
|
task.cancel()
|
|
|
|
|
try:
|
|
|
|
|
await task
|
|
|
|
|
except (asyncio.CancelledError, Exception):
|
|
|
|
|
pass
|
|
|
|
|
finally:
|
|
|
|
|
server.close()
|
|
|
|
|
await server.wait_closed()
|
|
|
|
|
return handled
|
|
|
|
|
|
|
|
|
|
handled = asyncio.run(scenario())
|
|
|
|
|
assert handled and handled[0] == {"type": "ping"}
|