ria-toolkit-oss/tests/agent/test_ws_client.py
2026-04-16 11:13:43 -04:00

265 lines
7.5 KiB
Python

"""Reconnect + heartbeat timing against a real local websockets server."""
from __future__ import annotations
import asyncio
import json
import pytest
import websockets
from ria_toolkit_oss.agent.ws_client import WsClient
async def _recv_json(ws) -> dict:
raw = await ws.recv()
return json.loads(raw)
async def _open_server(handler):
# websockets 13 ignores extra positional args; bind to localhost:0 for an
# ephemeral port and return both the server and the port.
server = await websockets.serve(handler, "127.0.0.1", 0)
port = server.sockets[0].getsockname()[1]
return server, port
def test_heartbeat_sent_on_connect():
async def scenario():
received: list[dict] = []
connected = asyncio.Event()
async def handler(ws):
connected.set()
msg = await _recv_json(ws)
received.append(msg)
server, port = await _open_server(handler)
try:
client = WsClient(
f"ws://127.0.0.1:{port}",
token="",
heartbeat_interval=0.05,
reconnect_pause=0.05,
)
task = asyncio.create_task(
client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat", "n": 1})
)
await asyncio.wait_for(connected.wait(), timeout=2.0)
for _ in range(50):
if received:
break
await asyncio.sleep(0.02)
client.stop()
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
finally:
server.close()
await server.wait_closed()
return received
received = asyncio.run(scenario())
assert received and received[0]["type"] == "heartbeat"
def test_reconnects_after_server_drop():
async def scenario():
connections = 0
first_dropped = asyncio.Event()
async def handler(ws):
nonlocal connections
connections += 1
if connections == 1:
await ws.close()
first_dropped.set()
else:
try:
await ws.recv()
except Exception:
pass
server, port = await _open_server(handler)
try:
client = WsClient(
f"ws://127.0.0.1:{port}",
token="",
heartbeat_interval=10.0,
reconnect_pause=0.05,
)
task = asyncio.create_task(
client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat"})
)
await asyncio.wait_for(first_dropped.wait(), timeout=2.0)
for _ in range(100):
if connections >= 2:
break
await asyncio.sleep(0.02)
client.stop()
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
finally:
server.close()
await server.wait_closed()
return connections
n = asyncio.run(scenario())
assert n >= 2
def test_binary_frame_forwarded_to_handler():
payload = bytes(range(128))
async def scenario():
received: list[bytes] = []
done = asyncio.Event()
async def handler(ws):
await ws.send(payload)
done.set()
try:
await ws.wait_closed()
except Exception:
pass
server, port = await _open_server(handler)
try:
client = WsClient(
f"ws://127.0.0.1:{port}",
token="",
heartbeat_interval=10.0,
reconnect_pause=0.05,
)
async def on_bin(data):
received.append(data)
task = asyncio.create_task(
client.run(
on_message=lambda _m: asyncio.sleep(0),
heartbeat=lambda: {"type": "heartbeat"},
on_binary=on_bin,
)
)
for _ in range(50):
if received:
break
await asyncio.sleep(0.02)
client.stop()
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
finally:
server.close()
await server.wait_closed()
return received
received = asyncio.run(scenario())
assert received == [payload]
def test_binary_frame_dropped_when_no_handler():
# Regression guard: existing behavior (drop server-sent binary) preserved when
# on_binary is not supplied.
async def scenario():
crashes: list[Exception] = []
async def handler(ws):
await ws.send(b"\x00\x01\x02\x03")
await ws.send(json.dumps({"type": "ping"}))
try:
await ws.wait_closed()
except Exception:
pass
messages: list[dict] = []
server, port = await _open_server(handler)
try:
client = WsClient(
f"ws://127.0.0.1:{port}",
token="",
heartbeat_interval=10.0,
reconnect_pause=0.05,
)
async def on_msg(m):
messages.append(m)
task = asyncio.create_task(
client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})
)
for _ in range(50):
if messages:
break
await asyncio.sleep(0.02)
client.stop()
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception) as exc:
crashes.append(exc)
finally:
server.close()
await server.wait_closed()
return messages, crashes
messages, crashes = asyncio.run(scenario())
# JSON still delivered; binary silently dropped; no uncaught crash.
assert messages and messages[0] == {"type": "ping"}
def test_malformed_control_frame_does_not_crash():
async def scenario():
handled: list[dict] = []
done = asyncio.Event()
async def handler(ws):
await ws.send("not json")
await ws.send(json.dumps({"type": "ping"}))
done.set()
try:
await ws.wait_closed()
except Exception:
pass
server, port = await _open_server(handler)
try:
client = WsClient(
f"ws://127.0.0.1:{port}",
token="",
heartbeat_interval=10.0,
reconnect_pause=0.05,
)
async def on_msg(m):
handled.append(m)
task = asyncio.create_task(
client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})
)
for _ in range(50):
if handled:
break
await asyncio.sleep(0.02)
client.stop()
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
finally:
server.close()
await server.wait_closed()
return handled
handled = asyncio.run(scenario())
assert handled and handled[0] == {"type": "ping"}