"""Reconnect + heartbeat timing against a real local websockets server.""" from __future__ import annotations import asyncio import json import pytest import websockets from ria_toolkit_oss.agent.ws_client import WsClient async def _recv_json(ws) -> dict: raw = await ws.recv() return json.loads(raw) async def _open_server(handler): # websockets 13 ignores extra positional args; bind to localhost:0 for an # ephemeral port and return both the server and the port. server = await websockets.serve(handler, "127.0.0.1", 0) port = server.sockets[0].getsockname()[1] return server, port def test_heartbeat_sent_on_connect(): async def scenario(): received: list[dict] = [] connected = asyncio.Event() async def handler(ws): connected.set() msg = await _recv_json(ws) received.append(msg) server, port = await _open_server(handler) try: client = WsClient( f"ws://127.0.0.1:{port}", token="", heartbeat_interval=0.05, reconnect_pause=0.05, ) task = asyncio.create_task( client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat", "n": 1}) ) await asyncio.wait_for(connected.wait(), timeout=2.0) for _ in range(50): if received: break await asyncio.sleep(0.02) client.stop() task.cancel() try: await task except (asyncio.CancelledError, Exception): pass finally: server.close() await server.wait_closed() return received received = asyncio.run(scenario()) assert received and received[0]["type"] == "heartbeat" def test_reconnects_after_server_drop(): async def scenario(): connections = 0 first_dropped = asyncio.Event() async def handler(ws): nonlocal connections connections += 1 if connections == 1: await ws.close() first_dropped.set() else: try: await ws.recv() except Exception: pass server, port = await _open_server(handler) try: client = WsClient( f"ws://127.0.0.1:{port}", token="", heartbeat_interval=10.0, reconnect_pause=0.05, ) task = asyncio.create_task( client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat"}) ) await asyncio.wait_for(first_dropped.wait(), timeout=2.0) for _ in range(100): if connections >= 2: break await asyncio.sleep(0.02) client.stop() task.cancel() try: await task except (asyncio.CancelledError, Exception): pass finally: server.close() await server.wait_closed() return connections n = asyncio.run(scenario()) assert n >= 2 def test_binary_frame_forwarded_to_handler(): payload = bytes(range(128)) async def scenario(): received: list[bytes] = [] done = asyncio.Event() async def handler(ws): await ws.send(payload) done.set() try: await ws.wait_closed() except Exception: pass server, port = await _open_server(handler) try: client = WsClient( f"ws://127.0.0.1:{port}", token="", heartbeat_interval=10.0, reconnect_pause=0.05, ) async def on_bin(data): received.append(data) task = asyncio.create_task( client.run( on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat"}, on_binary=on_bin, ) ) for _ in range(50): if received: break await asyncio.sleep(0.02) client.stop() task.cancel() try: await task except (asyncio.CancelledError, Exception): pass finally: server.close() await server.wait_closed() return received received = asyncio.run(scenario()) assert received == [payload] def test_binary_frame_dropped_when_no_handler(): # Regression guard: existing behavior (drop server-sent binary) preserved when # on_binary is not supplied. async def scenario(): crashes: list[Exception] = [] async def handler(ws): await ws.send(b"\x00\x01\x02\x03") await ws.send(json.dumps({"type": "ping"})) try: await ws.wait_closed() except Exception: pass messages: list[dict] = [] server, port = await _open_server(handler) try: client = WsClient( f"ws://127.0.0.1:{port}", token="", heartbeat_interval=10.0, reconnect_pause=0.05, ) async def on_msg(m): messages.append(m) task = asyncio.create_task( client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}) ) for _ in range(50): if messages: break await asyncio.sleep(0.02) client.stop() task.cancel() try: await task except (asyncio.CancelledError, Exception) as exc: crashes.append(exc) finally: server.close() await server.wait_closed() return messages, crashes messages, crashes = asyncio.run(scenario()) # JSON still delivered; binary silently dropped; no uncaught crash. assert messages and messages[0] == {"type": "ping"} def test_malformed_control_frame_does_not_crash(): async def scenario(): handled: list[dict] = [] done = asyncio.Event() async def handler(ws): await ws.send("not json") await ws.send(json.dumps({"type": "ping"})) done.set() try: await ws.wait_closed() except Exception: pass server, port = await _open_server(handler) try: client = WsClient( f"ws://127.0.0.1:{port}", token="", heartbeat_interval=10.0, reconnect_pause=0.05, ) async def on_msg(m): handled.append(m) task = asyncio.create_task( client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}) ) for _ in range(50): if handled: break await asyncio.sleep(0.02) client.stop() task.cancel() try: await task except (asyncio.CancelledError, Exception): pass finally: server.close() await server.wait_closed() return handled handled = asyncio.run(scenario()) assert handled and handled[0] == {"type": "ping"}