129 lines
4.7 KiB
Python
129 lines
4.7 KiB
Python
"""Persistent WebSocket client for the streamer agent.
|
|
|
|
Handles connection lifecycle: connect, heartbeat, auto-reconnect on drop.
|
|
The caller drives the I/O loop via ``run()`` with a message handler callback.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Awaitable, Callable
|
|
|
|
logger = logging.getLogger("ria_agent.ws")
|
|
|
|
MessageHandler = Callable[[dict], Awaitable[None]]
|
|
HeartbeatBuilder = Callable[[], dict]
|
|
BinaryHandler = Callable[[bytes], Awaitable[None]]
|
|
|
|
|
|
class WsClient:
|
|
"""Persistent WebSocket connection with heartbeat and auto-reconnect.
|
|
|
|
``url`` should be a full ``wss://host/path`` (or ``ws://``) URL. ``token``
|
|
is sent as a bearer in the ``Authorization`` header on connect.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
url: str,
|
|
token: str,
|
|
*,
|
|
heartbeat_interval: float = 30.0,
|
|
reconnect_pause: float = 5.0,
|
|
) -> None:
|
|
self.url = url
|
|
self.token = token
|
|
self.heartbeat_interval = heartbeat_interval
|
|
self.reconnect_pause = reconnect_pause
|
|
self._ws = None
|
|
self._stop = asyncio.Event()
|
|
|
|
# ------------------------------------------------------------------
|
|
async def _connect(self):
|
|
import websockets
|
|
|
|
headers = [("Authorization", f"Bearer {self.token}")] if self.token else None
|
|
# websockets >= 12 accepts additional_headers; fall back to extra_headers for older versions.
|
|
try:
|
|
return await websockets.connect(self.url, additional_headers=headers)
|
|
except TypeError:
|
|
return await websockets.connect(self.url, extra_headers=headers)
|
|
|
|
# ------------------------------------------------------------------
|
|
async def send_json(self, payload: dict) -> None:
|
|
if self._ws is None:
|
|
raise ConnectionError("WebSocket is not connected")
|
|
await self._ws.send(json.dumps(payload))
|
|
|
|
async def send_bytes(self, data: bytes) -> None:
|
|
if self._ws is None:
|
|
raise ConnectionError("WebSocket is not connected")
|
|
await self._ws.send(data)
|
|
|
|
def stop(self) -> None:
|
|
self._stop.set()
|
|
|
|
# ------------------------------------------------------------------
|
|
async def run(
|
|
self,
|
|
on_message: MessageHandler,
|
|
heartbeat: HeartbeatBuilder,
|
|
on_binary: BinaryHandler | None = None,
|
|
) -> None:
|
|
"""Main loop: connect, heartbeat, dispatch messages, reconnect on drop."""
|
|
while not self._stop.is_set():
|
|
try:
|
|
self._ws = await self._connect()
|
|
logger.info("Connected to %s", self.url)
|
|
hb_task = asyncio.create_task(self._heartbeat_loop(heartbeat))
|
|
try:
|
|
async for raw in self._ws:
|
|
if isinstance(raw, bytes):
|
|
if on_binary is None:
|
|
logger.debug("Discarding unexpected %d-byte binary frame", len(raw))
|
|
continue
|
|
try:
|
|
await on_binary(raw)
|
|
except Exception:
|
|
logger.exception("on_binary handler raised; dropping frame")
|
|
continue
|
|
try:
|
|
msg = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
logger.warning("Malformed control frame: %r", raw[:200])
|
|
continue
|
|
await on_message(msg)
|
|
finally:
|
|
hb_task.cancel()
|
|
try:
|
|
await hb_task
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as exc:
|
|
if self._stop.is_set():
|
|
break
|
|
logger.warning("WS error: %s — reconnecting in %.1fs", exc, self.reconnect_pause)
|
|
finally:
|
|
try:
|
|
if self._ws is not None:
|
|
await self._ws.close()
|
|
except Exception:
|
|
pass
|
|
self._ws = None
|
|
if self._stop.is_set():
|
|
break
|
|
await asyncio.sleep(self.reconnect_pause)
|
|
|
|
async def _heartbeat_loop(self, heartbeat: HeartbeatBuilder) -> None:
|
|
while True:
|
|
try:
|
|
await self.send_json(heartbeat())
|
|
except Exception as exc:
|
|
logger.debug("Heartbeat send failed: %s", exc)
|
|
return
|
|
await asyncio.sleep(self.heartbeat_interval)
|