From c27a5944c7ee852794bbe6823610f356e6ca4535 Mon Sep 17 00:00:00 2001 From: ben Date: Mon, 20 Apr 2026 16:49:52 -0400 Subject: [PATCH] formats --- src/ria_toolkit_oss/agent/legacy_executor.py | 6 +- .../orchestration/tx_executor.py | 28 +++-- tests/orchestration/test_executor.py | 115 ++++++++++-------- tests/orchestration/test_labeler.py | 16 +-- tests/orchestration/test_tx_executor.py | 4 +- tests/test_agent.py | 2 - 6 files changed, 92 insertions(+), 79 deletions(-) diff --git a/src/ria_toolkit_oss/agent/legacy_executor.py b/src/ria_toolkit_oss/agent/legacy_executor.py index 776238a..d8a56d6 100644 --- a/src/ria_toolkit_oss/agent/legacy_executor.py +++ b/src/ria_toolkit_oss/agent/legacy_executor.py @@ -931,11 +931,7 @@ def main() -> None: "--role", default=None, choices=["general", "rx", "tx"], - help=( - "Node role reported to the hub. " - "'tx' enables synthetic transmission commands. " - "Default: general" - ), + help=("Node role reported to the hub. " "'tx' enables synthetic transmission commands. " "Default: general"), ) parser.add_argument( "--session-code", diff --git a/src/ria_toolkit_oss/orchestration/tx_executor.py b/src/ria_toolkit_oss/orchestration/tx_executor.py index e666a1a..6ae32b1 100644 --- a/src/ria_toolkit_oss/orchestration/tx_executor.py +++ b/src/ria_toolkit_oss/orchestration/tx_executor.py @@ -33,7 +33,6 @@ from __future__ import annotations import logging import threading -import time from typing import Any logger = logging.getLogger(__name__) @@ -41,11 +40,11 @@ logger = logging.getLogger(__name__) # Mapping from modulation name → (PSK/QAM order, generator_type) # 'psk' uses PSKGenerator, 'qam' uses QAMGenerator _MOD_TABLE: dict[str, tuple[int, str]] = { - "BPSK": (1, "psk"), - "QPSK": (2, "psk"), - "8PSK": (3, "psk"), - "16QAM": (4, "qam"), - "64QAM": (6, "qam"), + "BPSK": (1, "psk"), + "QPSK": (2, "psk"), + "8PSK": (3, "psk"), + "16QAM": (4, "qam"), + "64QAM": (6, "qam"), "256QAM": (8, "qam"), } @@ -117,7 +116,12 @@ class TxExecutor: logger.info( "TX step '%s': %.0f s, %s @ %.3f MHz (sps=%d, filter=%s)", - label, duration, modulation, symbol_rate / 1e6, sps, filter_type, + label, + duration, + modulation, + symbol_rate / 1e6, + sps, + filter_type, ) num_samples = int(duration * sample_rate) @@ -133,9 +137,7 @@ class TxExecutor: logger.error("TX step '%s' SDR error: %s", label, exc) else: # No SDR available — simulate by sleeping for the step duration. - logger.warning( - "TX step '%s': no SDR — simulating %.0f s delay", label, duration - ) + logger.warning("TX step '%s': no SDR — simulating %.0f s delay", label, duration) self.stop_event.wait(timeout=duration) def _synthesise( @@ -149,6 +151,7 @@ class TxExecutor: """Build a block-generator chain and return IQ samples as a numpy array.""" try: import numpy as np + from ria_toolkit_oss.signal.block_generator import ( BinarySource, GMSKModulator, @@ -231,6 +234,7 @@ class TxExecutor: def _init_sdr(self, sample_rate: float, center_freq: float) -> None: try: from ria_toolkit_oss.sdr import get_sdr_device + self._sdr = get_sdr_device(self.sdr_device) self._sdr.init_tx( sample_rate=sample_rate, @@ -239,7 +243,9 @@ class TxExecutor: channel=0, gain_mode="manual", ) - logger.info("TX SDR initialised: %s @ %.3f MHz, %.1f Msps", self.sdr_device, center_freq / 1e6, sample_rate / 1e6) + logger.info( + "TX SDR initialised: %s @ %.3f MHz, %.1f Msps", self.sdr_device, center_freq / 1e6, sample_rate / 1e6 + ) except Exception as exc: logger.warning("TX SDR init failed (%s) — will simulate: %s", self.sdr_device, exc) self._sdr = None diff --git a/tests/orchestration/test_executor.py b/tests/orchestration/test_executor.py index 260c883..7aba499 100644 --- a/tests/orchestration/test_executor.py +++ b/tests/orchestration/test_executor.py @@ -4,7 +4,6 @@ from __future__ import annotations import json import stat -import threading from types import SimpleNamespace import pytest @@ -108,37 +107,47 @@ class TestCampaignResult: return r def test_total_steps(self): - r = self._make([ - StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), - StepResult("tx1", "s2", "/out", _ok_qa(), 0.0), - ]) + r = self._make( + [ + StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), + StepResult("tx1", "s2", "/out", _ok_qa(), 0.0), + ] + ) assert r.total_steps == 2 def test_passed_count(self): - r = self._make([ - StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), - StepResult("tx1", "s2", "/out", _failed_qa(), 0.0), - ]) + r = self._make( + [ + StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), + StepResult("tx1", "s2", "/out", _failed_qa(), 0.0), + ] + ) assert r.passed == 1 def test_failed_count(self): - r = self._make([ - StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), - StepResult("tx1", "s2", "/out", _failed_qa(), 0.0), - ]) + r = self._make( + [ + StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), + StepResult("tx1", "s2", "/out", _failed_qa(), 0.0), + ] + ) assert r.failed == 1 def test_flagged_count(self): - r = self._make([ - StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), - StepResult("tx1", "s2", "/out", _flagged_qa(), 0.0), - ]) + r = self._make( + [ + StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), + StepResult("tx1", "s2", "/out", _flagged_qa(), 0.0), + ] + ) assert r.flagged == 1 def test_error_step_counts_as_failed_not_passed(self): - r = self._make([ - StepResult("tx1", "s1", None, _ok_qa(), 0.0, error="disk full"), - ]) + r = self._make( + [ + StepResult("tx1", "s1", None, _ok_qa(), 0.0, error="disk full"), + ] + ) assert r.failed == 1 assert r.passed == 0 @@ -232,37 +241,45 @@ class TestExtractTxParams: assert _extract_tx_params(tx) is None def test_returns_signal_params(self): - tx = SimpleNamespace(sdr_agent={ - "modulation": "QPSK", - "symbol_rate": 1e6, - "center_frequency": 2.4e9, - }) + tx = SimpleNamespace( + sdr_agent={ + "modulation": "QPSK", + "symbol_rate": 1e6, + "center_frequency": 2.4e9, + } + ) result = _extract_tx_params(tx) assert result == {"modulation": "QPSK", "symbol_rate": 1e6, "center_frequency": 2.4e9} def test_strips_infra_key_node_id(self): - tx = SimpleNamespace(sdr_agent={ - "modulation": "BPSK", - "node_id": "node_abc123", - }) + tx = SimpleNamespace( + sdr_agent={ + "modulation": "BPSK", + "node_id": "node_abc123", + } + ) result = _extract_tx_params(tx) assert "node_id" not in result assert result == {"modulation": "BPSK"} def test_strips_infra_key_session_code(self): - tx = SimpleNamespace(sdr_agent={ - "modulation": "FSK", - "session_code": "amber-peak-transmit", - }) + tx = SimpleNamespace( + sdr_agent={ + "modulation": "FSK", + "session_code": "amber-peak-transmit", + } + ) result = _extract_tx_params(tx) assert "session_code" not in result def test_strips_none_values(self): - tx = SimpleNamespace(sdr_agent={ - "modulation": "QPSK", - "order": None, - "rolloff": 0.35, - }) + tx = SimpleNamespace( + sdr_agent={ + "modulation": "QPSK", + "order": None, + "rolloff": 0.35, + } + ) result = _extract_tx_params(tx) assert "order" not in result assert result == {"modulation": "QPSK", "rolloff": 0.35} @@ -274,16 +291,18 @@ class TestExtractTxParams: assert "node_id" in cfg def test_full_sdr_agent_config(self): - tx = SimpleNamespace(sdr_agent={ - "modulation": "16QAM", - "order": 4, - "symbol_rate": 5e6, - "center_frequency": 915e6, - "filter": "rrc", - "rolloff": 0.35, - "node_id": "node_xyz", - "session_code": "some-code", - }) + tx = SimpleNamespace( + sdr_agent={ + "modulation": "16QAM", + "order": 4, + "symbol_rate": 5e6, + "center_frequency": 915e6, + "filter": "rrc", + "rolloff": 0.35, + "node_id": "node_xyz", + "session_code": "some-code", + } + ) result = _extract_tx_params(tx) assert result == { "modulation": "16QAM", diff --git a/tests/orchestration/test_labeler.py b/tests/orchestration/test_labeler.py index 670a9dc..2e47739 100644 --- a/tests/orchestration/test_labeler.py +++ b/tests/orchestration/test_labeler.py @@ -116,9 +116,7 @@ class TestLabelRecording: def test_tx_params_written_as_tx_prefix_keys(self): params = {"modulation": "QPSK", "symbol_rate": 1e6} - rec = label_recording( - _simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params - ) + rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params) assert rec.metadata["tx_modulation"] == "QPSK" assert rec.metadata["tx_symbol_rate"] == pytest.approx(1e6) @@ -131,17 +129,15 @@ class TestLabelRecording: "filter": "rrc", "rolloff": 0.35, } - rec = label_recording( - _simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params - ) + rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params) for k, v in params.items(): assert f"tx_{k}" in rec.metadata - assert rec.metadata[f"tx_{k}"] == pytest.approx(v) if isinstance(v, float) else rec.metadata[f"tx_{k}"] == v + assert ( + rec.metadata[f"tx_{k}"] == pytest.approx(v) if isinstance(v, float) else rec.metadata[f"tx_{k}"] == v + ) def test_tx_params_empty_dict_writes_nothing(self): - rec = label_recording( - _simple_recording(), "dev", _wifi_step(), time.time(), tx_params={} - ) + rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params={}) tx_keys = [k for k in rec.metadata if k.startswith("tx_") and k != "tx_power_dbm"] assert tx_keys == [] diff --git a/tests/orchestration/test_tx_executor.py b/tests/orchestration/test_tx_executor.py index 9a03e6b..9d66850 100644 --- a/tests/orchestration/test_tx_executor.py +++ b/tests/orchestration/test_tx_executor.py @@ -3,7 +3,7 @@ from __future__ import annotations import threading -from unittest.mock import MagicMock, patch +from unittest.mock import patch import numpy as np import pytest @@ -73,8 +73,6 @@ class TestTxExecutorRun: waited = [] real_ev = threading.Event() - orig_wait = real_ev.wait - def _fake_wait(timeout=None): waited.append(timeout) return False diff --git a/tests/test_agent.py b/tests/test_agent.py index ce7ac9c..67991f9 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -6,8 +6,6 @@ import threading import time from unittest.mock import MagicMock, patch -import pytest - from ria_toolkit_oss.agent import NodeAgent