From a268f2ab25b81dbbc5411af578ba680c3163a1a0 Mon Sep 17 00:00:00 2001 From: jonny Date: Mon, 13 Apr 2026 11:48:15 -0400 Subject: [PATCH] Add changes for screens agent connections. --- poetry.lock | 159 +++++++------ pyproject.toml | 3 +- src/ria_toolkit_oss/agent/__init__.py | 26 +++ src/ria_toolkit_oss/agent/cli.py | 131 +++++++++++ src/ria_toolkit_oss/agent/config.py | 63 +++++ src/ria_toolkit_oss/agent/hardware.py | 22 ++ .../{agent.py => agent/legacy_executor.py} | 0 src/ria_toolkit_oss/agent/streamer.py | 221 ++++++++++++++++++ src/ria_toolkit_oss/agent/ws_client.py | 117 ++++++++++ src/ria_toolkit_oss/sdr/__init__.py | 42 +++- src/ria_toolkit_oss/sdr/pluto.py | 21 +- src/ria_toolkit_oss/sdr/sdr.py | 48 ++++ tests/agent/__init__.py | 0 tests/agent/test_config.py | 33 +++ tests/agent/test_disconnect.py | 81 +++++++ tests/agent/test_hardware.py | 29 +++ tests/agent/test_integration.py | 100 ++++++++ tests/agent/test_legacy.py | 19 ++ tests/agent/test_streamer.py | 124 ++++++++++ tests/agent/test_ws_client.py | 161 +++++++++++++ 20 files changed, 1329 insertions(+), 71 deletions(-) create mode 100644 src/ria_toolkit_oss/agent/__init__.py create mode 100644 src/ria_toolkit_oss/agent/cli.py create mode 100644 src/ria_toolkit_oss/agent/config.py create mode 100644 src/ria_toolkit_oss/agent/hardware.py rename src/ria_toolkit_oss/{agent.py => agent/legacy_executor.py} (100%) create mode 100644 src/ria_toolkit_oss/agent/streamer.py create mode 100644 src/ria_toolkit_oss/agent/ws_client.py create mode 100644 tests/agent/__init__.py create mode 100644 tests/agent/test_config.py create mode 100644 tests/agent/test_disconnect.py create mode 100644 tests/agent/test_hardware.py create mode 100644 tests/agent/test_integration.py create mode 100644 tests/agent/test_legacy.py create mode 100644 tests/agent/test_streamer.py create mode 100644 tests/agent/test_ws_client.py diff --git a/poetry.lock b/poetry.lock index d2ddd55..cb7a9f0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.4 and should not be changed by hand. [[package]] name = "alabaster" @@ -1096,7 +1096,7 @@ files = [ [package.dependencies] attrs = ">=22.2.0" -jsonschema-specifications = ">=2023.03.6" +jsonschema-specifications = ">=2023.3.6" referencing = ">=0.28.4" rpds-py = ">=0.25.0" @@ -3451,76 +3451,101 @@ anyio = ">=3.0.0" [[package]] name = "websockets" -version = "16.0" +version = "13.1" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" optional = false -python-versions = ">=3.10" -groups = ["docs", "server", "test"] +python-versions = ">=3.8" +groups = ["agent", "docs", "server", "test"] files = [ - {file = "websockets-16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04cdd5d2d1dacbad0a7bf36ccbcd3ccd5a30ee188f2560b7a62a30d14107b31a"}, - {file = "websockets-16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8ff32bb86522a9e5e31439a58addbb0166f0204d64066fb955265c4e214160f0"}, - {file = "websockets-16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:583b7c42688636f930688d712885cf1531326ee05effd982028212ccc13e5957"}, - {file = "websockets-16.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7d837379b647c0c4c2355c2499723f82f1635fd2c26510e1f587d89bc2199e72"}, - {file = "websockets-16.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df57afc692e517a85e65b72e165356ed1df12386ecb879ad5693be08fac65dde"}, - {file = "websockets-16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2b9f1e0d69bc60a4a87349d50c09a037a2607918746f07de04df9e43252c77a3"}, - {file = "websockets-16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:335c23addf3d5e6a8633f9f8eda77efad001671e80b95c491dd0924587ece0b3"}, - {file = "websockets-16.0-cp310-cp310-win32.whl", hash = "sha256:37b31c1623c6605e4c00d466c9d633f9b812ea430c11c8a278774a1fde1acfa9"}, - {file = "websockets-16.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e1dab317b6e77424356e11e99a432b7cb2f3ec8c5ab4dabbcee6add48f72b35"}, - {file = "websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8"}, - {file = "websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad"}, - {file = "websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d"}, - {file = "websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe"}, - {file = "websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b"}, - {file = "websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5"}, - {file = "websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64"}, - {file = "websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6"}, - {file = "websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac"}, - {file = "websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00"}, - {file = "websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79"}, - {file = "websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39"}, - {file = "websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c"}, - {file = "websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f"}, - {file = "websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1"}, - {file = "websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2"}, - {file = "websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89"}, - {file = "websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea"}, - {file = "websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9"}, - {file = "websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230"}, - {file = "websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c"}, - {file = "websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5"}, - {file = "websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82"}, - {file = "websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8"}, - {file = "websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f"}, - {file = "websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a"}, - {file = "websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156"}, - {file = "websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0"}, - {file = "websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904"}, - {file = "websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4"}, - {file = "websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e"}, - {file = "websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4"}, - {file = "websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1"}, - {file = "websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3"}, - {file = "websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8"}, - {file = "websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d"}, - {file = "websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244"}, - {file = "websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e"}, - {file = "websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641"}, - {file = "websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8"}, - {file = "websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e"}, - {file = "websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944"}, - {file = "websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206"}, - {file = "websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6"}, - {file = "websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd"}, - {file = "websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d"}, - {file = "websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03"}, - {file = "websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da"}, - {file = "websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c"}, - {file = "websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767"}, - {file = "websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec"}, - {file = "websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5"}, + {file = "websockets-13.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f48c749857f8fb598fb890a75f540e3221d0976ed0bf879cf3c7eef34151acee"}, + {file = "websockets-13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c7e72ce6bda6fb9409cc1e8164dd41d7c91466fb599eb047cfda72fe758a34a7"}, + {file = "websockets-13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f779498eeec470295a2b1a5d97aa1bc9814ecd25e1eb637bd9d1c73a327387f6"}, + {file = "websockets-13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4676df3fe46956fbb0437d8800cd5f2b6d41143b6e7e842e60554398432cf29b"}, + {file = "websockets-13.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7affedeb43a70351bb811dadf49493c9cfd1ed94c9c70095fd177e9cc1541fa"}, + {file = "websockets-13.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1971e62d2caa443e57588e1d82d15f663b29ff9dfe7446d9964a4b6f12c1e700"}, + {file = "websockets-13.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5f2e75431f8dc4a47f31565a6e1355fb4f2ecaa99d6b89737527ea917066e26c"}, + {file = "websockets-13.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:58cf7e75dbf7e566088b07e36ea2e3e2bd5676e22216e4cad108d4df4a7402a0"}, + {file = "websockets-13.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c90d6dec6be2c7d03378a574de87af9b1efea77d0c52a8301dd831ece938452f"}, + {file = "websockets-13.1-cp310-cp310-win32.whl", hash = "sha256:730f42125ccb14602f455155084f978bd9e8e57e89b569b4d7f0f0c17a448ffe"}, + {file = "websockets-13.1-cp310-cp310-win_amd64.whl", hash = "sha256:5993260f483d05a9737073be197371940c01b257cc45ae3f1d5d7adb371b266a"}, + {file = "websockets-13.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:61fc0dfcda609cda0fc9fe7977694c0c59cf9d749fbb17f4e9483929e3c48a19"}, + {file = "websockets-13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ceec59f59d092c5007e815def4ebb80c2de330e9588e101cf8bd94c143ec78a5"}, + {file = "websockets-13.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c1dca61c6db1166c48b95198c0b7d9c990b30c756fc2923cc66f68d17dc558fd"}, + {file = "websockets-13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:308e20f22c2c77f3f39caca508e765f8725020b84aa963474e18c59accbf4c02"}, + {file = "websockets-13.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62d516c325e6540e8a57b94abefc3459d7dab8ce52ac75c96cad5549e187e3a7"}, + {file = "websockets-13.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87c6e35319b46b99e168eb98472d6c7d8634ee37750d7693656dc766395df096"}, + {file = "websockets-13.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5f9fee94ebafbc3117c30be1844ed01a3b177bb6e39088bc6b2fa1dc15572084"}, + {file = "websockets-13.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7c1e90228c2f5cdde263253fa5db63e6653f1c00e7ec64108065a0b9713fa1b3"}, + {file = "websockets-13.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6548f29b0e401eea2b967b2fdc1c7c7b5ebb3eeb470ed23a54cd45ef078a0db9"}, + {file = "websockets-13.1-cp311-cp311-win32.whl", hash = "sha256:c11d4d16e133f6df8916cc5b7e3e96ee4c44c936717d684a94f48f82edb7c92f"}, + {file = "websockets-13.1-cp311-cp311-win_amd64.whl", hash = "sha256:d04f13a1d75cb2b8382bdc16ae6fa58c97337253826dfe136195b7f89f661557"}, + {file = "websockets-13.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9d75baf00138f80b48f1eac72ad1535aac0b6461265a0bcad391fc5aba875cfc"}, + {file = "websockets-13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9b6f347deb3dcfbfde1c20baa21c2ac0751afaa73e64e5b693bb2b848efeaa49"}, + {file = "websockets-13.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:de58647e3f9c42f13f90ac7e5f58900c80a39019848c5547bc691693098ae1bd"}, + {file = "websockets-13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1b54689e38d1279a51d11e3467dd2f3a50f5f2e879012ce8f2d6943f00e83f0"}, + {file = "websockets-13.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf1781ef73c073e6b0f90af841aaf98501f975d306bbf6221683dd594ccc52b6"}, + {file = "websockets-13.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d23b88b9388ed85c6faf0e74d8dec4f4d3baf3ecf20a65a47b836d56260d4b9"}, + {file = "websockets-13.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3c78383585f47ccb0fcf186dcb8a43f5438bd7d8f47d69e0b56f71bf431a0a68"}, + {file = "websockets-13.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d6d300f8ec35c24025ceb9b9019ae9040c1ab2f01cddc2bcc0b518af31c75c14"}, + {file = "websockets-13.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a9dcaf8b0cc72a392760bb8755922c03e17a5a54e08cca58e8b74f6902b433cf"}, + {file = "websockets-13.1-cp312-cp312-win32.whl", hash = "sha256:2f85cf4f2a1ba8f602298a853cec8526c2ca42a9a4b947ec236eaedb8f2dc80c"}, + {file = "websockets-13.1-cp312-cp312-win_amd64.whl", hash = "sha256:38377f8b0cdeee97c552d20cf1865695fcd56aba155ad1b4ca8779a5b6ef4ac3"}, + {file = "websockets-13.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a9ab1e71d3d2e54a0aa646ab6d4eebfaa5f416fe78dfe4da2839525dc5d765c6"}, + {file = "websockets-13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b9d7439d7fab4dce00570bb906875734df13d9faa4b48e261c440a5fec6d9708"}, + {file = "websockets-13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:327b74e915cf13c5931334c61e1a41040e365d380f812513a255aa804b183418"}, + {file = "websockets-13.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:325b1ccdbf5e5725fdcb1b0e9ad4d2545056479d0eee392c291c1bf76206435a"}, + {file = "websockets-13.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:346bee67a65f189e0e33f520f253d5147ab76ae42493804319b5716e46dddf0f"}, + {file = "websockets-13.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91a0fa841646320ec0d3accdff5b757b06e2e5c86ba32af2e0815c96c7a603c5"}, + {file = "websockets-13.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:18503d2c5f3943e93819238bf20df71982d193f73dcecd26c94514f417f6b135"}, + {file = "websockets-13.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a9cd1af7e18e5221d2878378fbc287a14cd527fdd5939ed56a18df8a31136bb2"}, + {file = "websockets-13.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:70c5be9f416aa72aab7a2a76c90ae0a4fe2755c1816c153c1a2bcc3333ce4ce6"}, + {file = "websockets-13.1-cp313-cp313-win32.whl", hash = "sha256:624459daabeb310d3815b276c1adef475b3e6804abaf2d9d2c061c319f7f187d"}, + {file = "websockets-13.1-cp313-cp313-win_amd64.whl", hash = "sha256:c518e84bb59c2baae725accd355c8dc517b4a3ed8db88b4bc93c78dae2974bf2"}, + {file = "websockets-13.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c7934fd0e920e70468e676fe7f1b7261c1efa0d6c037c6722278ca0228ad9d0d"}, + {file = "websockets-13.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:149e622dc48c10ccc3d2760e5f36753db9cacf3ad7bc7bbbfd7d9c819e286f23"}, + {file = "websockets-13.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a569eb1b05d72f9bce2ebd28a1ce2054311b66677fcd46cf36204ad23acead8c"}, + {file = "websockets-13.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95df24ca1e1bd93bbca51d94dd049a984609687cb2fb08a7f2c56ac84e9816ea"}, + {file = "websockets-13.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8dbb1bf0c0a4ae8b40bdc9be7f644e2f3fb4e8a9aca7145bfa510d4a374eeb7"}, + {file = "websockets-13.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:035233b7531fb92a76beefcbf479504db8c72eb3bff41da55aecce3a0f729e54"}, + {file = "websockets-13.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:e4450fc83a3df53dec45922b576e91e94f5578d06436871dce3a6be38e40f5db"}, + {file = "websockets-13.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:463e1c6ec853202dd3657f156123d6b4dad0c546ea2e2e38be2b3f7c5b8e7295"}, + {file = "websockets-13.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6d6855bbe70119872c05107e38fbc7f96b1d8cb047d95c2c50869a46c65a8e96"}, + {file = "websockets-13.1-cp38-cp38-win32.whl", hash = "sha256:204e5107f43095012b00f1451374693267adbb832d29966a01ecc4ce1db26faf"}, + {file = "websockets-13.1-cp38-cp38-win_amd64.whl", hash = "sha256:485307243237328c022bc908b90e4457d0daa8b5cf4b3723fd3c4a8012fce4c6"}, + {file = "websockets-13.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9b37c184f8b976f0c0a231a5f3d6efe10807d41ccbe4488df8c74174805eea7d"}, + {file = "websockets-13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:163e7277e1a0bd9fb3c8842a71661ad19c6aa7bb3d6678dc7f89b17fbcc4aeb7"}, + {file = "websockets-13.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4b889dbd1342820cc210ba44307cf75ae5f2f96226c0038094455a96e64fb07a"}, + {file = "websockets-13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:586a356928692c1fed0eca68b4d1c2cbbd1ca2acf2ac7e7ebd3b9052582deefa"}, + {file = "websockets-13.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bd6abf1e070a6b72bfeb71049d6ad286852e285f146682bf30d0296f5fbadfa"}, + {file = "websockets-13.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2aad13a200e5934f5a6767492fb07151e1de1d6079c003ab31e1823733ae79"}, + {file = "websockets-13.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:df01aea34b6e9e33572c35cd16bae5a47785e7d5c8cb2b54b2acdb9678315a17"}, + {file = "websockets-13.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e54affdeb21026329fb0744ad187cf812f7d3c2aa702a5edb562b325191fcab6"}, + {file = "websockets-13.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9ef8aa8bdbac47f4968a5d66462a2a0935d044bf35c0e5a8af152d58516dbeb5"}, + {file = "websockets-13.1-cp39-cp39-win32.whl", hash = "sha256:deeb929efe52bed518f6eb2ddc00cc496366a14c726005726ad62c2dd9017a3c"}, + {file = "websockets-13.1-cp39-cp39-win_amd64.whl", hash = "sha256:7c65ffa900e7cc958cd088b9a9157a8141c991f8c53d11087e6fb7277a03f81d"}, + {file = "websockets-13.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5dd6da9bec02735931fccec99d97c29f47cc61f644264eb995ad6c0c27667238"}, + {file = "websockets-13.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:2510c09d8e8df777177ee3d40cd35450dc169a81e747455cc4197e63f7e7bfe5"}, + {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1c3cf67185543730888b20682fb186fc8d0fa6f07ccc3ef4390831ab4b388d9"}, + {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bcc03c8b72267e97b49149e4863d57c2d77f13fae12066622dc78fe322490fe6"}, + {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:004280a140f220c812e65f36944a9ca92d766b6cc4560be652a0a3883a79ed8a"}, + {file = "websockets-13.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e2620453c075abeb0daa949a292e19f56de518988e079c36478bacf9546ced23"}, + {file = "websockets-13.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9156c45750b37337f7b0b00e6248991a047be4aa44554c9886fe6bdd605aab3b"}, + {file = "websockets-13.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:80c421e07973a89fbdd93e6f2003c17d20b69010458d3a8e37fb47874bd67d51"}, + {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82d0ba76371769d6a4e56f7e83bb8e81846d17a6190971e38b5de108bde9b0d7"}, + {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9875a0143f07d74dc5e1ded1c4581f0d9f7ab86c78994e2ed9e95050073c94d"}, + {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a11e38ad8922c7961447f35c7b17bffa15de4d17c70abd07bfbe12d6faa3e027"}, + {file = "websockets-13.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4059f790b6ae8768471cddb65d3c4fe4792b0ab48e154c9f0a04cefaabcd5978"}, + {file = "websockets-13.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:25c35bf84bf7c7369d247f0b8cfa157f989862c49104c5cf85cb5436a641d93e"}, + {file = "websockets-13.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:83f91d8a9bb404b8c2c41a707ac7f7f75b9442a0a876df295de27251a856ad09"}, + {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a43cfdcddd07f4ca2b1afb459824dd3c6d53a51410636a2c7fc97b9a8cf4842"}, + {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48a2ef1381632a2f0cb4efeff34efa97901c9fbc118e01951ad7cfc10601a9bb"}, + {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:459bf774c754c35dbb487360b12c5727adab887f1622b8aed5755880a21c4a20"}, + {file = "websockets-13.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:95858ca14a9f6fa8413d29e0a585b31b278388aa775b8a81fa24830123874678"}, + {file = "websockets-13.1-py3-none-any.whl", hash = "sha256:a9a396a6ad26130cdae92ae10c36af09d9bfe6cafe69670fd3b6da9b07b4044f"}, + {file = "websockets-13.1.tar.gz", hash = "sha256:a3b3366087c1bc0a2795111edcadddb8b3b59509d5db5d7ea3fdd69f954a8878"}, ] [metadata] lock-version = "2.1" python-versions = ">=3.10" -content-hash = "b1e5ddd7284aecf49624e51740b7a4c31bc8d0e703c255126ba5d9b2a4a0e519" +content-hash = "7ddbf7d85e9ae7bd3a1b99ae481df20aaf6fd185d5f628b0fdf9b7bd278730ed" diff --git a/pyproject.toml b/pyproject.toml index 8db3469..a0bd664 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ optional = true [tool.poetry.group.agent.dependencies] requests = ">=2.28,<3.0" +websockets = ">=12.0,<14.0" [tool.poetry.group.dev.dependencies] flake8 = "^7.1.0" @@ -116,7 +117,7 @@ pylint = "^3.2.6" # For pyreverse, to automate the creation of UML diagrams ria = "ria_toolkit_oss_cli.cli:cli" ria-tools = "ria_toolkit_oss_cli.cli:cli" ria-server = "ria_toolkit_oss.server.cli:serve" -ria-agent = "ria_toolkit_oss.agent:main" +ria-agent = "ria_toolkit_oss.agent.cli:main" [tool.poetry.group.server.dependencies] fastapi = ">=0.111,<1.0" diff --git a/src/ria_toolkit_oss/agent/__init__.py b/src/ria_toolkit_oss/agent/__init__.py new file mode 100644 index 0000000..11647ef --- /dev/null +++ b/src/ria_toolkit_oss/agent/__init__.py @@ -0,0 +1,26 @@ +"""RIA Toolkit agent package. + +Provides two execution modes: + +- **Legacy long-poll executor** (`NodeAgent` in :mod:`legacy_executor`) — an + HTTP long-polling agent that runs ONNX inference locally on the host. +- **Streamer** (:mod:`streamer`) — a thin WebSocket client that opens an SDR + and streams raw IQ to the RIA Hub server, which performs all inference. + +Back-compat: ``from ria_toolkit_oss.agent import NodeAgent`` and the ``main`` +entry point continue to work. +""" + +from __future__ import annotations + +from .legacy_executor import NodeAgent +from .legacy_executor import main as _legacy_main + +__all__ = ["NodeAgent", "main"] + + +def main() -> None: + """Unified CLI entry point. Dispatches to streamer/legacy subcommands.""" + from .cli import main as _cli_main + + _cli_main() diff --git a/src/ria_toolkit_oss/agent/cli.py b/src/ria_toolkit_oss/agent/cli.py new file mode 100644 index 0000000..2873293 --- /dev/null +++ b/src/ria_toolkit_oss/agent/cli.py @@ -0,0 +1,131 @@ +"""Unified ``ria-agent`` CLI. + +Subcommands: + +- ``ria-agent run [legacy args]`` — legacy long-poll NodeAgent (unchanged). +- ``ria-agent stream`` — new WebSocket-based IQ streamer. +- ``ria-agent detect`` — print SDR drivers whose modules import cleanly. +- ``ria-agent register --url URL --token TOKEN`` — save credentials to + ``~/.ria/agent.json``. + +Invoking ``ria-agent`` with no subcommand falls through to the legacy +long-poll behavior for back-compatibility with existing deployments. +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import sys + +from . import config as _config +from .hardware import available_devices +from .legacy_executor import main as _legacy_main + +_LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"} + + +def _cmd_detect(_args: argparse.Namespace) -> int: + devices = available_devices() + if not devices: + print("No SDR drivers available (install ria-toolkit-oss[all-sdr] or per-driver extras).") + return 0 + for name in devices: + print(name) + return 0 + + +def _cmd_register(args: argparse.Namespace) -> int: + cfg = _config.load() + cfg.hub_url = args.url + cfg.token = args.token + if args.name: + cfg.name = args.name + if args.agent_id: + cfg.agent_id = args.agent_id + cfg.insecure = bool(args.insecure) + path = _config.save(cfg) + print(f"Saved agent credentials to {path}") + return 0 + + +def _cmd_stream(args: argparse.Namespace) -> int: + from .streamer import run_streamer + + cfg = _config.load() + url = args.url or _derive_ws_url(cfg.hub_url, cfg.agent_id) + token = args.token or cfg.token + if not url: + print("error: --url is required (or run `ria-agent register` first)", file=sys.stderr) + return 2 + try: + asyncio.run(run_streamer(url, token)) + except KeyboardInterrupt: + pass + return 0 + + +def _derive_ws_url(hub_url: str, agent_id: str) -> str: + if not hub_url: + return "" + base = hub_url.rstrip("/") + if base.startswith("https://"): + base = "wss://" + base[len("https://"):] + elif base.startswith("http://"): + base = "ws://" + base[len("http://"):] + suffix = f"/api/agent/ws/{agent_id}" if agent_id else "/api/agent/ws" + return base + suffix + + +def main() -> None: + # Back-compat: if the first non-flag token matches a known legacy flag, + # or there is no subcommand at all, dispatch to the legacy CLI. + argv = sys.argv[1:] + if not argv or (argv[0].startswith("--") and argv[0] in _LEGACY_ALIASES): + _legacy_main() + return + + parser = argparse.ArgumentParser(prog="ria-agent") + sub = parser.add_subparsers(dest="command", required=True) + + sub.add_parser("run", help="Legacy long-poll agent (NodeAgent)") + sub.add_parser("detect", help="List available SDR drivers") + + p_reg = sub.add_parser("register", help="Save agent credentials to ~/.ria/agent.json") + p_reg.add_argument("--url", required=True, help="RIA Hub base URL") + p_reg.add_argument("--token", required=True, help="Agent registration token") + p_reg.add_argument("--name", default=None) + p_reg.add_argument("--agent-id", dest="agent_id", default=None) + p_reg.add_argument("--insecure", action="store_true") + + p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer") + p_stream.add_argument("--url", default=None, help="Override WebSocket URL") + p_stream.add_argument("--token", default=None, help="Override bearer token") + p_stream.add_argument("--log-level", default="INFO") + + # Unknown extras are forwarded to the legacy CLI when command == "run". + args, extras = parser.parse_known_args(argv) + + logging.basicConfig( + level=getattr(logging, getattr(args, "log_level", "INFO"), logging.INFO), + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + ) + + if args.command == "run": + sys.argv = [sys.argv[0], *extras] + _legacy_main() + return + if args.command == "detect": + sys.exit(_cmd_detect(args)) + if args.command == "register": + sys.exit(_cmd_register(args)) + if args.command == "stream": + sys.exit(_cmd_stream(args)) + + parser.error(f"unknown command: {args.command}") + + +if __name__ == "__main__": + main() diff --git a/src/ria_toolkit_oss/agent/config.py b/src/ria_toolkit_oss/agent/config.py new file mode 100644 index 0000000..01f99ba --- /dev/null +++ b/src/ria_toolkit_oss/agent/config.py @@ -0,0 +1,63 @@ +"""Agent configuration stored at ``~/.ria/agent.json``. + +Schema:: + + { + "hub_url": "https://riahub.example.com", + "agent_id": "agent-abc123", + "token": "rha_xxxx", + "name": "lab-bench-1", + "insecure": false + } +""" + +from __future__ import annotations + +import json +import os +from dataclasses import asdict, dataclass, field +from pathlib import Path + +_DEFAULT_PATH = Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json"))) + + +@dataclass +class AgentConfig: + hub_url: str = "" + agent_id: str = "" + token: str = "" + name: str = "" + insecure: bool = False + extra: dict = field(default_factory=dict) + + +def default_path() -> Path: + return _DEFAULT_PATH + + +def load(path: Path | None = None) -> AgentConfig: + p = path or _DEFAULT_PATH + if not p.exists(): + return AgentConfig() + data = json.loads(p.read_text()) + known = {f for f in AgentConfig.__dataclass_fields__ if f != "extra"} + extra = {k: v for k, v in data.items() if k not in known} + return AgentConfig( + hub_url=data.get("hub_url", ""), + agent_id=data.get("agent_id", ""), + token=data.get("token", ""), + name=data.get("name", ""), + insecure=bool(data.get("insecure", False)), + extra=extra, + ) + + +def save(cfg: AgentConfig, path: Path | None = None) -> Path: + p = path or _DEFAULT_PATH + p.parent.mkdir(parents=True, exist_ok=True) + data = asdict(cfg) + extra = data.pop("extra", {}) or {} + data.update(extra) + p.write_text(json.dumps(data, indent=2)) + os.chmod(p, 0o600) + return p diff --git a/src/ria_toolkit_oss/agent/hardware.py b/src/ria_toolkit_oss/agent/hardware.py new file mode 100644 index 0000000..417bf1c --- /dev/null +++ b/src/ria_toolkit_oss/agent/hardware.py @@ -0,0 +1,22 @@ +"""Hardware detection and heartbeat payload construction for the streamer.""" + +from __future__ import annotations + +from ria_toolkit_oss.sdr import detect_available + + +def available_devices() -> list[str]: + """Return a sorted list of device names whose driver modules import cleanly.""" + return sorted(detect_available().keys()) + + +def heartbeat_payload(status: str = "idle", app_id: str | None = None) -> dict: + """Build the JSON body of a periodic heartbeat frame.""" + payload: dict = { + "type": "heartbeat", + "hardware": available_devices(), + "status": status, + } + if app_id: + payload["app_id"] = app_id + return payload diff --git a/src/ria_toolkit_oss/agent.py b/src/ria_toolkit_oss/agent/legacy_executor.py similarity index 100% rename from src/ria_toolkit_oss/agent.py rename to src/ria_toolkit_oss/agent/legacy_executor.py diff --git a/src/ria_toolkit_oss/agent/streamer.py b/src/ria_toolkit_oss/agent/streamer.py new file mode 100644 index 0000000..4d89743 --- /dev/null +++ b/src/ria_toolkit_oss/agent/streamer.py @@ -0,0 +1,221 @@ +"""Thin IQ-streaming agent. + +Listens for control messages from the RIA Hub over a persistent WebSocket. +When the server sends ``start``, opens the SDR described in ``radio_config``, +loops over ``sdr.rx(buffer_size)``, and sends each buffer as raw +interleaved float32 bytes. ``stop`` closes the SDR; ``configure`` applies +parameter updates at the next capture boundary. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +import numpy as np + +from .hardware import heartbeat_payload +from .ws_client import WsClient + +logger = logging.getLogger("ria_agent.streamer") + +_DEFAULT_BUFFER_SIZE = 1024 + + +class Streamer: + """Main streamer loop. + + Parameters + ---------- + ws: + Connected :class:`WsClient`. + sdr_factory: + Callable ``(device, identifier) -> SDR``. Defaults to + :func:`ria_toolkit_oss.sdr.get_sdr_device`. Injectable for tests. + """ + + def __init__(self, ws: WsClient, sdr_factory=None) -> None: + self.ws = ws + self._sdr_factory = sdr_factory + self._app_id: str | None = None + self._sdr: Any = None + self._pending_config: dict = {} + self._capture_task: asyncio.Task | None = None + self._status = "idle" + + # ------------------------------------------------------------------ + # WsClient wiring + + def build_heartbeat(self) -> dict: + return heartbeat_payload(status=self._status, app_id=self._app_id) + + async def on_message(self, msg: dict) -> None: + t = msg.get("type") + if t == "start": + await self._handle_start(msg) + elif t == "stop": + await self._handle_stop(msg) + elif t == "configure": + self._pending_config.update(msg.get("radio_config") or {}) + logger.debug("Queued configure: %s", self._pending_config) + else: + logger.warning("Unknown server message type: %r", t) + + # ------------------------------------------------------------------ + async def _handle_start(self, msg: dict) -> None: + if self._capture_task is not None and not self._capture_task.done(): + logger.warning("start received while already streaming — ignoring") + return + + self._app_id = msg.get("app_id") + radio_config = dict(msg.get("radio_config") or {}) + device = radio_config.pop("device", None) + identifier = radio_config.pop("identifier", None) + buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE)) + if not device: + await self._send_error("start missing radio_config.device") + return + + try: + factory = self._sdr_factory or _default_sdr_factory + self._sdr = factory(device, identifier) + _apply_sdr_config(self._sdr, radio_config) + except Exception as exc: + logger.exception("Failed to open SDR %r", device) + await self._send_error(f"SDR init failed: {exc}") + return + + self._status = "streaming" + await self._send_status("streaming") + self._capture_task = asyncio.create_task( + self._capture_loop(buffer_size), name="ria-streamer-capture" + ) + + async def _handle_stop(self, msg: dict) -> None: + if self._capture_task is not None: + self._capture_task.cancel() + try: + await self._capture_task + except (asyncio.CancelledError, Exception): + pass + self._capture_task = None + self._close_sdr() + self._app_id = None + self._status = "idle" + await self._send_status("idle") + + async def _capture_loop(self, buffer_size: int) -> None: + loop = asyncio.get_running_loop() + try: + while True: + if self._pending_config: + cfg = self._pending_config + self._pending_config = {} + try: + _apply_sdr_config(self._sdr, cfg) + except Exception as exc: + logger.warning("Applying configure failed: %s", exc) + + try: + samples = await loop.run_in_executor(None, self._sdr.rx, buffer_size) + except Exception as exc: + from ria_toolkit_oss.sdr import SdrDisconnectedError + + if isinstance(exc, SdrDisconnectedError): + logger.warning("SDR disconnected: %s", exc) + await self._send_error(f"SDR disconnected: {exc}") + else: + logger.exception("SDR rx error") + await self._send_error(f"SDR capture failed: {exc}") + break + + payload = _samples_to_interleaved_float32(samples) + try: + await self.ws.send_bytes(payload) + except Exception as exc: + logger.warning("Send failed: %s — ending capture", exc) + break + except asyncio.CancelledError: + raise + finally: + self._close_sdr() + + def _close_sdr(self) -> None: + if self._sdr is None: + return + try: + self._sdr.close() + except Exception: + pass + self._sdr = None + + async def _send_status(self, status: str) -> None: + try: + await self.ws.send_json({"type": "status", "status": status, "app_id": self._app_id}) + except Exception as exc: + logger.debug("Status send failed: %s", exc) + + async def _send_error(self, message: str) -> None: + try: + await self.ws.send_json({"type": "error", "app_id": self._app_id, "message": message}) + except Exception as exc: + logger.debug("Error-frame send failed: %s", exc) + + +# --------------------------------------------------------------------------- +# Helpers + +_CONFIG_ATTR_MAP = { + "sample_rate": ("sample_rate", "rx_sample_rate"), + "center_frequency": ("center_freq", "rx_center_frequency"), + "center_freq": ("center_freq", "rx_center_frequency"), + "gain": ("gain", "rx_gain"), + "bandwidth": ("bandwidth", "rx_bandwidth"), +} + + +def _apply_sdr_config(sdr: Any, cfg: dict) -> None: + """Apply a radio_config dict to an SDR, trying multiple attribute aliases.""" + for key, value in cfg.items(): + if value is None: + continue + attrs = _CONFIG_ATTR_MAP.get(key, (key,)) + applied = False + for attr in attrs: + if hasattr(sdr, attr): + try: + setattr(sdr, attr, value) + applied = True + break + except Exception as exc: + logger.debug("setattr %s=%r failed: %s", attr, value, exc) + if not applied: + logger.debug("radio_config key %r ignored (no matching attr)", key) + + +def _samples_to_interleaved_float32(samples: Any) -> bytes: + """Convert complex IQ samples (any numeric dtype) to interleaved float32 bytes.""" + arr = np.asarray(samples) + if np.iscomplexobj(arr): + interleaved = np.empty(arr.size * 2, dtype=np.float32) + interleaved[0::2] = arr.real.astype(np.float32, copy=False).ravel() + interleaved[1::2] = arr.imag.astype(np.float32, copy=False).ravel() + return interleaved.tobytes() + return arr.astype(np.float32, copy=False).tobytes() + + +def _default_sdr_factory(device: str, identifier: str | None): + from ria_toolkit_oss.sdr import get_sdr_device + + return get_sdr_device(device, ident=identifier) + + +# --------------------------------------------------------------------------- +# Top-level entry + +async def run_streamer(ws_url: str, token: str) -> None: + """Connect to *ws_url* and run the streamer loop until cancelled.""" + ws = WsClient(ws_url, token) + streamer = Streamer(ws) + await ws.run(streamer.on_message, streamer.build_heartbeat) diff --git a/src/ria_toolkit_oss/agent/ws_client.py b/src/ria_toolkit_oss/agent/ws_client.py new file mode 100644 index 0000000..1bc66f6 --- /dev/null +++ b/src/ria_toolkit_oss/agent/ws_client.py @@ -0,0 +1,117 @@ +"""Persistent WebSocket client for the streamer agent. + +Handles connection lifecycle: connect, heartbeat, auto-reconnect on drop. +The caller drives the I/O loop via ``run()`` with a message handler callback. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Awaitable, Callable + +logger = logging.getLogger("ria_agent.ws") + +MessageHandler = Callable[[dict], Awaitable[None]] +HeartbeatBuilder = Callable[[], dict] + + +class WsClient: + """Persistent WebSocket connection with heartbeat and auto-reconnect. + + ``url`` should be a full ``wss://host/path`` (or ``ws://``) URL. ``token`` + is sent as a bearer in the ``Authorization`` header on connect. + """ + + def __init__( + self, + url: str, + token: str, + *, + heartbeat_interval: float = 30.0, + reconnect_pause: float = 5.0, + ) -> None: + self.url = url + self.token = token + self.heartbeat_interval = heartbeat_interval + self.reconnect_pause = reconnect_pause + self._ws = None + self._stop = asyncio.Event() + + # ------------------------------------------------------------------ + async def _connect(self): + import websockets + + headers = [("Authorization", f"Bearer {self.token}")] if self.token else None + # websockets >= 12 accepts additional_headers; fall back to extra_headers for older versions. + try: + return await websockets.connect(self.url, additional_headers=headers) + except TypeError: + return await websockets.connect(self.url, extra_headers=headers) + + # ------------------------------------------------------------------ + async def send_json(self, payload: dict) -> None: + if self._ws is None: + raise ConnectionError("WebSocket is not connected") + await self._ws.send(json.dumps(payload)) + + async def send_bytes(self, data: bytes) -> None: + if self._ws is None: + raise ConnectionError("WebSocket is not connected") + await self._ws.send(data) + + def stop(self) -> None: + self._stop.set() + + # ------------------------------------------------------------------ + async def run(self, on_message: MessageHandler, heartbeat: HeartbeatBuilder) -> None: + """Main loop: connect, heartbeat, dispatch messages, reconnect on drop.""" + while not self._stop.is_set(): + try: + self._ws = await self._connect() + logger.info("Connected to %s", self.url) + hb_task = asyncio.create_task(self._heartbeat_loop(heartbeat)) + try: + async for raw in self._ws: + if isinstance(raw, bytes): + # Server shouldn't send binary to the agent; log and drop. + logger.debug("Discarding unexpected %d-byte binary frame", len(raw)) + continue + try: + msg = json.loads(raw) + except json.JSONDecodeError: + logger.warning("Malformed control frame: %r", raw[:200]) + continue + await on_message(msg) + finally: + hb_task.cancel() + try: + await hb_task + except (asyncio.CancelledError, Exception): + pass + except asyncio.CancelledError: + raise + except Exception as exc: + if self._stop.is_set(): + break + logger.warning("WS error: %s — reconnecting in %.1fs", exc, self.reconnect_pause) + finally: + try: + if self._ws is not None: + await self._ws.close() + except Exception: + pass + self._ws = None + if self._stop.is_set(): + break + await asyncio.sleep(self.reconnect_pause) + + async def _heartbeat_loop(self, heartbeat: HeartbeatBuilder) -> None: + while True: + try: + await self.send_json(heartbeat()) + except Exception as exc: + logger.debug("Heartbeat send failed: %s", exc) + return + await asyncio.sleep(self.heartbeat_interval) diff --git a/src/ria_toolkit_oss/sdr/__init__.py b/src/ria_toolkit_oss/sdr/__init__.py index 78a13a9..4b327a2 100644 --- a/src/ria_toolkit_oss/sdr/__init__.py +++ b/src/ria_toolkit_oss/sdr/__init__.py @@ -4,10 +4,48 @@ It streamlines tasks involving signal reception and transmission, as well as com operations such as detecting and configuring available devices. """ -__all__ = ["SDR", "SDRError", "SDRParameterError", "MockSDR", "get_sdr_device"] +__all__ = [ + "SDR", + "SDRError", + "SDRParameterError", + "SdrDisconnectedError", + "MockSDR", + "get_sdr_device", + "detect_available", +] from .mock import MockSDR -from .sdr import SDR, SDRError, SDRParameterError +from .sdr import SDR, SDRError, SdrDisconnectedError, SDRParameterError, translate_disconnect # noqa: F401 + + +_DRIVER_CANDIDATES: tuple[tuple[str, str, str], ...] = ( + ("mock", "ria_toolkit_oss.sdr.mock", "MockSDR"), + ("pluto", "ria_toolkit_oss.sdr.pluto", "Pluto"), + ("hackrf", "ria_toolkit_oss.sdr.hackrf", "HackRF"), + ("rtlsdr", "ria_toolkit_oss.sdr.rtlsdr", "RTLSDR"), + ("usrp", "ria_toolkit_oss.sdr.usrp", "USRP"), + ("blade", "ria_toolkit_oss.sdr.blade", "Blade"), + ("thinkrf", "ria_toolkit_oss.sdr.thinkrf", "ThinkRF"), +) + + +def detect_available() -> dict[str, type]: + """Return ``{device_name: driver_class}`` for every driver whose module imports cleanly. + + Importability is a proxy for "the user has installed this driver's optional dependency". + It does not probe for physical hardware presence — that requires actually instantiating + the driver, which can be slow and side-effectful. + """ + import importlib + + out: dict[str, type] = {} + for name, module_path, cls_name in _DRIVER_CANDIDATES: + try: + mod = importlib.import_module(module_path) + out[name] = getattr(mod, cls_name) + except Exception: + continue + return out def get_sdr_device(device_type: str, ident: str | None = None, tx: bool = False) -> SDR: diff --git a/src/ria_toolkit_oss/sdr/pluto.py b/src/ria_toolkit_oss/sdr/pluto.py index 68b3973..7ed3be0 100644 --- a/src/ria_toolkit_oss/sdr/pluto.py +++ b/src/ria_toolkit_oss/sdr/pluto.py @@ -8,7 +8,7 @@ import adi import numpy as np from ria_toolkit_oss.datatypes.recording import Recording -from ria_toolkit_oss.sdr.sdr import SDR, SDRError, SDRParameterError +from ria_toolkit_oss.sdr.sdr import SDR, SDRError, SDRParameterError, translate_disconnect class Pluto(SDR): @@ -164,6 +164,25 @@ class Pluto(SDR): # send callback complex signal callback(buffer=signal, metadata=None) + def rx(self, num_samples: Optional[int] = None) -> np.ndarray: + """PlutoSDR-style single-buffer capture returning a complex64 array. + + Sets the radio buffer size to *num_samples* (if given) and returns one + buffer directly from ``self.radio.rx()``. Raises + :class:`SdrDisconnectedError` on USB/device drop so callers (e.g. the + streamer) can report the failure and stop cleanly instead of crashing. + """ + if num_samples is not None: + try: + self.set_rx_buffer_size(buffer_size=int(num_samples)) + except Exception as exc: + raise translate_disconnect(exc) from exc + try: + samples = self.radio.rx() + except Exception as exc: + raise translate_disconnect(exc) from exc + return np.asarray(samples) + def _record_fast(self, num_samples): """Optimized single-buffer capture for ≤16M samples.""" diff --git a/src/ria_toolkit_oss/sdr/sdr.py b/src/ria_toolkit_oss/sdr/sdr.py index 36e26f7..abab125 100644 --- a/src/ria_toolkit_oss/sdr/sdr.py +++ b/src/ria_toolkit_oss/sdr/sdr.py @@ -528,3 +528,51 @@ class SDROverflowError(SDRError): """Buffer overflow detected.""" pass + + +class SdrDisconnectedError(SDRError): + """Raised when the SDR device disappears mid-operation (USB unplug, network drop).""" + + pass + + +# Substrings that strongly indicate a device has disappeared rather than a +# transient / recoverable error. Checked case-insensitively against str(exc). +_DISCONNECT_MARKERS = ( + "no such device", + "device not found", + "not found", + "broken pipe", + "disconnected", + "no device", + "device unplugged", + "usb", + "i/o error", + "input/output error", + "errno 19", # ENODEV + "errno 5", # EIO +) + + +def translate_disconnect(exc: BaseException) -> BaseException: + """Return ``SdrDisconnectedError`` if *exc* looks like a USB/device drop, else *exc*. + + Drivers wrap their native-API calls with:: + + try: + return self.radio.rx() + except Exception as exc: + raise translate_disconnect(exc) from exc + + The caller (e.g. the streamer) can then catch ``SdrDisconnectedError`` + specifically and report it to the hub rather than crashing the loop. + """ + if isinstance(exc, SdrDisconnectedError): + return exc + msg = str(exc).lower() + if any(marker in msg for marker in _DISCONNECT_MARKERS): + return SdrDisconnectedError(str(exc)) + # OSError subclass with ENODEV / EIO errno is also a disconnect signal. + if isinstance(exc, OSError) and getattr(exc, "errno", None) in (5, 19): + return SdrDisconnectedError(str(exc)) + return exc diff --git a/tests/agent/__init__.py b/tests/agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/agent/test_config.py b/tests/agent/test_config.py new file mode 100644 index 0000000..2532abd --- /dev/null +++ b/tests/agent/test_config.py @@ -0,0 +1,33 @@ +from ria_toolkit_oss.agent import config as agent_config + + +def test_round_trip(tmp_path): + p = tmp_path / "agent.json" + cfg = agent_config.AgentConfig( + hub_url="https://hub.example.com", + agent_id="agent-1", + token="t", + name="bench", + insecure=True, + ) + agent_config.save(cfg, path=p) + loaded = agent_config.load(path=p) + assert loaded == cfg + + +def test_load_missing_returns_empty(tmp_path): + loaded = agent_config.load(path=tmp_path / "none.json") + assert loaded == agent_config.AgentConfig() + + +def test_extra_keys_preserved(tmp_path): + p = tmp_path / "agent.json" + p.write_text('{"hub_url": "x", "custom": 42}') + cfg = agent_config.load(path=p) + assert cfg.hub_url == "x" + assert cfg.extra == {"custom": 42} + agent_config.save(cfg, path=p) + import json + + data = json.loads(p.read_text()) + assert data["custom"] == 42 diff --git a/tests/agent/test_disconnect.py b/tests/agent/test_disconnect.py new file mode 100644 index 0000000..f063e3a --- /dev/null +++ b/tests/agent/test_disconnect.py @@ -0,0 +1,81 @@ +"""SdrDisconnectedError translation + streamer handling.""" + +from __future__ import annotations + +import asyncio + +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.sdr import SdrDisconnectedError +from ria_toolkit_oss.sdr.sdr import translate_disconnect + + +def test_translate_disconnect_usb_message(): + exc = RuntimeError("libiio: Input/output error (errno 5)") + out = translate_disconnect(exc) + assert isinstance(out, SdrDisconnectedError) + + +def test_translate_disconnect_enodev_oserror(): + exc = OSError(19, "No such device") + assert isinstance(translate_disconnect(exc), SdrDisconnectedError) + + +def test_translate_disconnect_passes_through_unrelated(): + exc = ValueError("bad sample rate") + assert translate_disconnect(exc) is exc + + +def test_translate_disconnect_preserves_sdr_disconnected(): + original = SdrDisconnectedError("already typed") + assert translate_disconnect(original) is original + + +class _FlakySdr: + """SDR that raises SdrDisconnectedError on the first rx() call.""" + + def __init__(self) -> None: + self.closed = False + + def rx(self, n): # noqa: D401 - trivial + raise SdrDisconnectedError("usb gone") + + def close(self): + self.closed = True + + +class _Ws: + def __init__(self): + self.json_sent = [] + self.bytes_sent = [] + + async def send_json(self, p): + self.json_sent.append(p) + + async def send_bytes(self, b): + self.bytes_sent.append(b) + + +def test_streamer_reports_disconnected_and_ends_capture(): + async def scenario(): + ws = _Ws() + sdr = _FlakySdr() + streamer = Streamer(ws=ws, sdr_factory=lambda d, i: sdr) + await streamer.on_message( + { + "type": "start", + "app_id": "a", + "radio_config": {"device": "fake", "buffer_size": 8}, + } + ) + # Wait for the capture task to fail out. + for _ in range(50): + if streamer._capture_task and streamer._capture_task.done(): + break + await asyncio.sleep(0.01) + return ws, sdr, streamer + + ws, sdr, streamer = asyncio.run(scenario()) + assert sdr.closed + errors = [m for m in ws.json_sent if m.get("type") == "error"] + assert errors, "expected an error frame" + assert "disconnected" in errors[-1]["message"].lower() diff --git a/tests/agent/test_hardware.py b/tests/agent/test_hardware.py new file mode 100644 index 0000000..ab9fcdf --- /dev/null +++ b/tests/agent/test_hardware.py @@ -0,0 +1,29 @@ +from ria_toolkit_oss.agent import hardware +from ria_toolkit_oss.sdr import detect_available + + +def test_detect_available_includes_mock(): + drivers = detect_available() + assert "mock" in drivers + from ria_toolkit_oss.sdr.mock import MockSDR + + assert drivers["mock"] is MockSDR + + +def test_available_devices_sorted_list(): + devices = hardware.available_devices() + assert isinstance(devices, list) + assert devices == sorted(devices) + assert "mock" in devices + + +def test_heartbeat_payload_shape(): + p = hardware.heartbeat_payload() + assert p["type"] == "heartbeat" + assert p["status"] == "idle" + assert "mock" in p["hardware"] + assert "app_id" not in p + + p2 = hardware.heartbeat_payload(status="streaming", app_id="abc") + assert p2["status"] == "streaming" + assert p2["app_id"] == "abc" diff --git a/tests/agent/test_integration.py b/tests/agent/test_integration.py new file mode 100644 index 0000000..168e7a6 --- /dev/null +++ b/tests/agent/test_integration.py @@ -0,0 +1,100 @@ +"""End-to-end: local websockets server drives a Streamer with a MockSDR.""" + +from __future__ import annotations + +import asyncio +import json + +import websockets + +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.agent.ws_client import WsClient +from ria_toolkit_oss.sdr.mock import MockSDR + + +def test_server_start_stream_stop_cycle_over_real_ws(): + async def scenario(): + control_frames: list[dict] = [] + binary_frames: list[bytes] = [] + ready = asyncio.Event() + stopped = asyncio.Event() + + async def server_handler(ws): + # Agent will open the connection; wait for heartbeat first. + try: + first = await asyncio.wait_for(ws.recv(), timeout=2.0) + control_frames.append(json.loads(first)) + await ws.send( + json.dumps( + { + "type": "start", + "app_id": "app-1", + "radio_config": { + "device": "mock", + "buffer_size": 32, + "sample_rate": 1_000_000, + "center_frequency": 2.45e9, + }, + } + ) + ) + while len(binary_frames) < 3: + msg = await asyncio.wait_for(ws.recv(), timeout=2.0) + if isinstance(msg, bytes): + binary_frames.append(msg) + else: + control_frames.append(json.loads(msg)) + ready.set() + await ws.send(json.dumps({"type": "stop", "app_id": "app-1"})) + # Drain final status frame. + try: + while True: + msg = await asyncio.wait_for(ws.recv(), timeout=0.5) + if isinstance(msg, bytes): + binary_frames.append(msg) + else: + control_frames.append(json.loads(msg)) + except (asyncio.TimeoutError, Exception): + pass + stopped.set() + except Exception: + stopped.set() + + server = await websockets.serve(server_handler, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + streamer = Streamer(ws=client, sdr_factory=lambda d, i: MockSDR(buffer_size=32, seed=0)) + task = asyncio.create_task( + client.run(on_message=streamer.on_message, heartbeat=streamer.build_heartbeat) + ) + await asyncio.wait_for(ready.wait(), timeout=3.0) + await asyncio.wait_for(stopped.wait(), timeout=3.0) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return control_frames, binary_frames + + controls, binaries = asyncio.run(scenario()) + + # Heartbeat reached the server. + assert any(f.get("type") == "heartbeat" for f in controls) + # Status transitioned idle -> streaming -> idle. + statuses = [f["status"] for f in controls if f.get("type") == "status"] + assert "streaming" in statuses + assert statuses[-1] == "idle" + # Three binary IQ frames of 32 samples × 2 floats × 4 bytes. + assert len(binaries) >= 3 + for b in binaries[:3]: + assert len(b) == 32 * 2 * 4 diff --git a/tests/agent/test_legacy.py b/tests/agent/test_legacy.py new file mode 100644 index 0000000..36e4ea0 --- /dev/null +++ b/tests/agent/test_legacy.py @@ -0,0 +1,19 @@ +"""Regression: legacy NodeAgent still importable after the package move.""" + + +def test_import_node_agent_from_package(): + from ria_toolkit_oss.agent import NodeAgent + + assert NodeAgent.__name__ == "NodeAgent" + + +def test_main_entry_point_exists(): + from ria_toolkit_oss.agent import main + + assert callable(main) + + +def test_legacy_module_still_direct_importable(): + from ria_toolkit_oss.agent.legacy_executor import NodeAgent as LegacyNodeAgent + + assert LegacyNodeAgent.__name__ == "NodeAgent" diff --git a/tests/agent/test_streamer.py b/tests/agent/test_streamer.py new file mode 100644 index 0000000..1bb2081 --- /dev/null +++ b/tests/agent/test_streamer.py @@ -0,0 +1,124 @@ +"""Unit tests for the streamer: drive it with a fake WsClient + MockSDR.""" + +from __future__ import annotations + +import asyncio + +import numpy as np + +from ria_toolkit_oss.agent.streamer import ( + Streamer, + _apply_sdr_config, + _samples_to_interleaved_float32, +) +from ria_toolkit_oss.sdr.mock import MockSDR + + +class FakeWs: + def __init__(self): + self.json_sent: list[dict] = [] + self.bytes_sent: list[bytes] = [] + + async def send_json(self, payload: dict) -> None: + self.json_sent.append(payload) + + async def send_bytes(self, data: bytes) -> None: + self.bytes_sent.append(data) + + +def _factory(device: str, identifier): + return MockSDR(buffer_size=32, seed=0) + + +def test_samples_to_interleaved_float32_roundtrip(): + c = np.array([1 + 2j, 3 + 4j], dtype=np.complex64) + raw = _samples_to_interleaved_float32(c) + arr = np.frombuffer(raw, dtype=np.float32) + assert arr.tolist() == [1.0, 2.0, 3.0, 4.0] + + +def test_apply_sdr_config_sets_attributes(): + sdr = MockSDR(buffer_size=16) + _apply_sdr_config(sdr, {"sample_rate": 2e6, "center_frequency": 9.15e8, "gain": 30}) + assert sdr.sample_rate == 2e6 + assert sdr.center_freq == 9.15e8 + assert sdr.gain == 30 + + +def test_heartbeat_reflects_status_and_app(): + s = Streamer(ws=FakeWs(), sdr_factory=_factory) + hb = s.build_heartbeat() + assert hb["type"] == "heartbeat" + assert hb["status"] == "idle" + s._status = "streaming" + s._app_id = "app-42" + hb2 = s.build_heartbeat() + assert hb2["status"] == "streaming" + assert hb2["app_id"] == "app-42" + + +def test_full_start_stream_stop_cycle(): + async def scenario(): + ws = FakeWs() + streamer = Streamer(ws=ws, sdr_factory=_factory) + + await streamer.on_message( + { + "type": "start", + "app_id": "abc", + "radio_config": { + "device": "mock", + "sample_rate": 1_000_000, + "center_frequency": 2.45e9, + "gain": 40, + "buffer_size": 64, + }, + } + ) + for _ in range(30): + if len(ws.bytes_sent) >= 2: + break + await asyncio.sleep(0.02) + await streamer.on_message({"type": "stop", "app_id": "abc"}) + return ws, streamer + + ws, streamer = asyncio.run(scenario()) + assert len(ws.bytes_sent) >= 1 + for frame in ws.bytes_sent: + assert len(frame) == 64 * 2 * 4 # 64 samples × (I,Q) × float32 + statuses = [m for m in ws.json_sent if m.get("type") == "status"] + assert statuses[0]["status"] == "streaming" + assert statuses[-1]["status"] == "idle" + assert streamer._sdr is None + + +def test_start_without_device_emits_error(): + async def scenario(): + ws = FakeWs() + streamer = Streamer(ws=ws, sdr_factory=_factory) + await streamer.on_message({"type": "start", "app_id": "x", "radio_config": {}}) + return ws + + ws = asyncio.run(scenario()) + errors = [m for m in ws.json_sent if m.get("type") == "error"] + assert errors and "device" in errors[0]["message"] + + +def test_configure_queues_update(): + async def scenario(): + streamer = Streamer(ws=FakeWs(), sdr_factory=_factory) + await streamer.on_message( + {"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}} + ) + return streamer._pending_config + + pending = asyncio.run(scenario()) + assert pending == {"center_frequency": 915e6} + + +def test_unknown_message_type_is_ignored(): + async def scenario(): + s = Streamer(ws=FakeWs(), sdr_factory=_factory) + await s.on_message({"type": "nope"}) + + asyncio.run(scenario()) diff --git a/tests/agent/test_ws_client.py b/tests/agent/test_ws_client.py new file mode 100644 index 0000000..0994a5b --- /dev/null +++ b/tests/agent/test_ws_client.py @@ -0,0 +1,161 @@ +"""Reconnect + heartbeat timing against a real local websockets server.""" + +from __future__ import annotations + +import asyncio +import json + +import pytest +import websockets + +from ria_toolkit_oss.agent.ws_client import WsClient + + +async def _recv_json(ws) -> dict: + raw = await ws.recv() + return json.loads(raw) + + +async def _open_server(handler): + # websockets 13 ignores extra positional args; bind to localhost:0 for an + # ephemeral port and return both the server and the port. + server = await websockets.serve(handler, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + return server, port + + +def test_heartbeat_sent_on_connect(): + async def scenario(): + received: list[dict] = [] + connected = asyncio.Event() + + async def handler(ws): + connected.set() + msg = await _recv_json(ws) + received.append(msg) + + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=0.05, + reconnect_pause=0.05, + ) + task = asyncio.create_task( + client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat", "n": 1}) + ) + await asyncio.wait_for(connected.wait(), timeout=2.0) + for _ in range(50): + if received: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return received + + received = asyncio.run(scenario()) + assert received and received[0]["type"] == "heartbeat" + + +def test_reconnects_after_server_drop(): + async def scenario(): + connections = 0 + first_dropped = asyncio.Event() + + async def handler(ws): + nonlocal connections + connections += 1 + if connections == 1: + await ws.close() + first_dropped.set() + else: + try: + await ws.recv() + except Exception: + pass + + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + task = asyncio.create_task( + client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat"}) + ) + await asyncio.wait_for(first_dropped.wait(), timeout=2.0) + for _ in range(100): + if connections >= 2: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return connections + + n = asyncio.run(scenario()) + assert n >= 2 + + +def test_malformed_control_frame_does_not_crash(): + async def scenario(): + handled: list[dict] = [] + done = asyncio.Event() + + async def handler(ws): + await ws.send("not json") + await ws.send(json.dumps({"type": "ping"})) + done.set() + try: + await ws.wait_closed() + except Exception: + pass + + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + + async def on_msg(m): + handled.append(m) + + task = asyncio.create_task( + client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}) + ) + for _ in range(50): + if handled: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return handled + + handled = asyncio.run(scenario()) + assert handled and handled[0] == {"type": "ping"}