Compare commits
16 Commits
a502dd97a9
...
2881aaf06e
| Author | SHA1 | Date | |
|---|---|---|---|
| 2881aaf06e | |||
| 50d04161b7 | |||
| 07c72294f5 | |||
| c9b19949ad | |||
| 53e8e5adb6 | |||
| 34b67c0c17 | |||
| 39d5d74d6a | |||
| 4d3aaf6ec8 | |||
| 4aea2841be | |||
| 4c2c9c0288 | |||
| c27a5944c7 | |||
| 062a0e766f | |||
| cdcc03327b | |||
| 912fc54f25 | |||
| b884397f1f | |||
| dae9510981 |
12
poetry.lock
generated
12
poetry.lock
generated
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 2.3.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "alabaster"
|
||||
|
|
@ -230,14 +230,14 @@ uvloop = ["uvloop (>=0.15.2) ; sys_platform != \"win32\"", "winloop (>=0.5.0) ;
|
|||
|
||||
[[package]]
|
||||
name = "cachetools"
|
||||
version = "7.0.5"
|
||||
version = "7.0.6"
|
||||
description = "Extensible memoizing collections and decorators"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["test"]
|
||||
files = [
|
||||
{file = "cachetools-7.0.5-py3-none-any.whl", hash = "sha256:46bc8ebefbe485407621d0a4264b23c080cedd913921bad7ac3ed2f26c183114"},
|
||||
{file = "cachetools-7.0.5.tar.gz", hash = "sha256:0cd042c24377200c1dcd225f8b7b12b0ca53cc2c961b43757e774ebe190fd990"},
|
||||
{file = "cachetools-7.0.6-py3-none-any.whl", hash = "sha256:4e94956cfdd3086f12042cdd29318f5ced3893014f7d0d059bf3ead3f85b7f8b"},
|
||||
{file = "cachetools-7.0.6.tar.gz", hash = "sha256:e5d524d36d65703a87243a26ff08ad84f73352adbeafb1cde81e207b456aaf24"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1271,7 +1271,7 @@ files = [
|
|||
|
||||
[package.dependencies]
|
||||
attrs = ">=22.2.0"
|
||||
jsonschema-specifications = ">=2023.3.6"
|
||||
jsonschema-specifications = ">=2023.03.6"
|
||||
referencing = ">=0.28.4"
|
||||
rpds-py = ">=0.25.0"
|
||||
|
||||
|
|
@ -3749,4 +3749,4 @@ files = [
|
|||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10"
|
||||
content-hash = "ffde300b2fc93161d2279a6e2b899bc988d3b5eb3833135821830affc9a5fb62"
|
||||
content-hash = "66c9adf647316db90f963da05e8a83574378bfa4db2c69ce751446b5ee7c408c"
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ dependencies = [
|
|||
"pyyaml (>=6.0.3,<7.0.0)",
|
||||
"click (>=8.1.0,<9.0.0)",
|
||||
"matplotlib (>=3.8.0,<4.0.0)",
|
||||
"paramiko (>=4.0.0)"
|
||||
"paramiko (>=3.5.1)"
|
||||
]
|
||||
|
||||
# [project.optional-dependencies] Commented out to prevent Tox tests from failing
|
||||
|
|
@ -149,6 +149,11 @@ exclude = '''
|
|||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["src"]
|
||||
filterwarnings = [
|
||||
# FastAPI emits this internally when handling 422 responses; the constant
|
||||
# is not yet renamed in the installed starlette version, so we can't migrate.
|
||||
"ignore:'HTTP_422_UNPROCESSABLE_ENTITY' is deprecated:DeprecationWarning",
|
||||
]
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ _HEARTBEAT_INTERVAL = 30 # seconds between heartbeats
|
|||
_POLL_TIMEOUT = 30 # server-side long-poll duration
|
||||
_POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server
|
||||
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
|
||||
_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB — well below Cloudflare's 100 MB limit
|
||||
_CHUNK_SIZE = 10 * 1024 * 1024 # 10 MB per chunk — fast enough for git-LFS to process within timeout
|
||||
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
|
||||
_CAPTURE_SAMPLES = 4096 # IQ samples per inference window
|
||||
_IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"})
|
||||
|
|
@ -93,16 +93,24 @@ class NodeAgent:
|
|||
name: str,
|
||||
sdr_device: str = "unknown",
|
||||
insecure: bool = False,
|
||||
role: str = "general",
|
||||
session_code: str | None = None,
|
||||
) -> None:
|
||||
self.hub_url = hub_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.name = name
|
||||
self.sdr_device = sdr_device
|
||||
self.insecure = insecure
|
||||
self.role = role
|
||||
self.session_code = session_code
|
||||
|
||||
self.node_id: str | None = None
|
||||
self._stop = threading.Event()
|
||||
|
||||
# ── TX state ────────────────────────────────────────────────────────
|
||||
self._tx_stop = threading.Event()
|
||||
self._tx_thread: threading.Thread | None = None
|
||||
|
||||
# ── Inference state ─────────────────────────────────────────────────
|
||||
# Protected by _inf_lock for cross-thread model swaps.
|
||||
self._inf_lock = threading.Lock()
|
||||
|
|
@ -172,19 +180,27 @@ class NodeAgent:
|
|||
capabilities = ["campaign"]
|
||||
if self._ort_available:
|
||||
capabilities.append("inference")
|
||||
resp = self._post(
|
||||
"/composer/nodes/register",
|
||||
json={
|
||||
if self.role == "tx":
|
||||
capabilities.append("transmit")
|
||||
payload: dict = {
|
||||
"name": self.name,
|
||||
"sdr_device": self.sdr_device,
|
||||
"ria_toolkit_version": self._ria_version,
|
||||
"capabilities": capabilities,
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
"role": self.role,
|
||||
}
|
||||
if self.session_code:
|
||||
payload["session_code"] = self.session_code
|
||||
resp = self._post("/composer/nodes/register", json=payload, timeout=15)
|
||||
resp.raise_for_status()
|
||||
self.node_id = resp.json()["node_id"]
|
||||
logger.info("Registered as %r (node_id=%s)", self.name, self.node_id)
|
||||
logger.info(
|
||||
"Registered as %r (node_id=%s, role=%s%s)",
|
||||
self.name,
|
||||
self.node_id,
|
||||
self.role,
|
||||
f", session_code={self.session_code!r}" if self.session_code else "",
|
||||
)
|
||||
|
||||
def _deregister(self) -> None:
|
||||
if not self.node_id:
|
||||
|
|
@ -245,9 +261,10 @@ class NodeAgent:
|
|||
if command == "run_campaign":
|
||||
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
|
||||
config_dict: dict = cmd.get("payload") or {}
|
||||
skip_local_tx: bool = bool(cmd.get("skip_local_tx", False))
|
||||
threading.Thread(
|
||||
target=self._run_campaign,
|
||||
args=(campaign_id, config_dict),
|
||||
args=(campaign_id, config_dict, skip_local_tx),
|
||||
daemon=True,
|
||||
name=f"campaign-{campaign_id[:8]}",
|
||||
).start()
|
||||
|
|
@ -269,6 +286,17 @@ class NodeAgent:
|
|||
self._stop_inference()
|
||||
elif command == "configure_inference":
|
||||
self._queue_sdr_config(cmd)
|
||||
elif command == "start_transmit":
|
||||
threading.Thread(
|
||||
target=self._start_transmit,
|
||||
args=(cmd,),
|
||||
daemon=True,
|
||||
name="ria-start-tx",
|
||||
).start()
|
||||
elif command == "stop_transmit":
|
||||
self._stop_transmit()
|
||||
elif command == "configure_transmit":
|
||||
logger.info("configure_transmit received — will apply on next step boundary")
|
||||
else:
|
||||
logger.warning("Unknown command %r — ignored", command)
|
||||
|
||||
|
|
@ -276,7 +304,7 @@ class NodeAgent:
|
|||
# Campaign execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_campaign(self, campaign_id: str, config_dict: dict) -> None:
|
||||
def _run_campaign(self, campaign_id: str, config_dict: dict, skip_local_tx: bool = False) -> None:
|
||||
try:
|
||||
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
||||
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
|
||||
|
|
@ -288,10 +316,10 @@ class NodeAgent:
|
|||
)
|
||||
return
|
||||
|
||||
logger.info("Campaign %s starting", campaign_id[:8])
|
||||
logger.info("Campaign %s starting (skip_local_tx=%s)", campaign_id[:8], skip_local_tx)
|
||||
try:
|
||||
config = CampaignConfig.from_dict(config_dict)
|
||||
executor = CampaignExecutor(config)
|
||||
executor = CampaignExecutor(config, skip_local_tx=skip_local_tx)
|
||||
result = executor.run()
|
||||
logger.info("Campaign %s completed — uploading recordings", campaign_id[:8])
|
||||
self._upload_recordings(campaign_id, config, result)
|
||||
|
|
@ -301,6 +329,58 @@ class NodeAgent:
|
|||
logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
|
||||
self._report_campaign_status(campaign_id, "failed", error=str(exc))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TX execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _start_transmit(self, cmd: dict) -> None:
|
||||
"""Execute a synthetic transmit campaign using TxExecutor.
|
||||
|
||||
The command payload mirrors a TransmitterConfig dict with an optional
|
||||
``schedule`` of steps. Each step synthesises a signal and transmits it
|
||||
via the local SDR in TX mode.
|
||||
"""
|
||||
try:
|
||||
from ria_toolkit_oss.orchestration.tx_executor import TxExecutor
|
||||
except ImportError as exc:
|
||||
logger.error("start_transmit: TxExecutor not available: %s", exc)
|
||||
return
|
||||
|
||||
if self._tx_thread and self._tx_thread.is_alive():
|
||||
logger.warning("start_transmit: TX already running — ignoring duplicate command")
|
||||
return
|
||||
|
||||
self._tx_stop.clear()
|
||||
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
|
||||
executor = TxExecutor(
|
||||
config=cmd,
|
||||
sdr_device=self.sdr_device,
|
||||
stop_event=self._tx_stop,
|
||||
)
|
||||
self._tx_thread = threading.Thread(
|
||||
target=self._run_tx_campaign,
|
||||
args=(executor, campaign_id),
|
||||
daemon=True,
|
||||
name=f"tx-campaign-{campaign_id[:8]}",
|
||||
)
|
||||
self._tx_thread.start()
|
||||
|
||||
def _run_tx_campaign(self, executor: Any, campaign_id: str) -> None:
|
||||
try:
|
||||
executor.run()
|
||||
logger.info("TX campaign %s completed", campaign_id[:8])
|
||||
self._report_campaign_status(campaign_id, "completed")
|
||||
except Exception as exc:
|
||||
logger.error("TX campaign %s failed: %s", campaign_id[:8], exc)
|
||||
self._report_campaign_status(campaign_id, "failed", error=str(exc))
|
||||
|
||||
def _stop_transmit(self) -> None:
|
||||
"""Signal the TX loop to stop gracefully."""
|
||||
self._tx_stop.set()
|
||||
if self._tx_thread and self._tx_thread.is_alive():
|
||||
self._tx_thread.join(timeout=5.0)
|
||||
logger.info("TX stopped")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inference — model loading
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -579,13 +659,18 @@ class NodeAgent:
|
|||
base_url = f"{self.hub_url}/datasets/upload"
|
||||
steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or []
|
||||
|
||||
output_obj = getattr(config, "output", None)
|
||||
folder = getattr(output_obj, "folder", None)
|
||||
campaign_name: str = folder if folder is not None else (getattr(config, "name", None) or "")
|
||||
for step in steps:
|
||||
output_path: str | None = getattr(step, "output_path", None)
|
||||
if not output_path:
|
||||
continue
|
||||
device_id: str = getattr(step, "transmitter_id", "") or ""
|
||||
for fpath in _sigmf_files(output_path):
|
||||
filename = os.path.basename(fpath)
|
||||
basename = os.path.basename(fpath)
|
||||
path_parts = [p for p in (campaign_name, device_id) if p]
|
||||
filename = "/".join(path_parts + [basename])
|
||||
metadata = {
|
||||
"filename": filename,
|
||||
"repo_owner": repo_owner,
|
||||
|
|
@ -671,7 +756,7 @@ class NodeAgent:
|
|||
headers=headers,
|
||||
files={"file": (filename, chunk, "application/octet-stream")},
|
||||
data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks},
|
||||
timeout=120,
|
||||
timeout=(30, None), # 30s connect, no read timeout — server may take minutes on final chunk
|
||||
verify=verify,
|
||||
)
|
||||
if not resp.ok:
|
||||
|
|
@ -848,6 +933,21 @@ def main() -> None:
|
|||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
help="Logging verbosity (default: INFO)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--role",
|
||||
default=None,
|
||||
choices=["general", "rx", "tx"],
|
||||
help=("Node role reported to the hub. " "'tx' enables synthetic transmission commands. " "Default: general"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--session-code",
|
||||
default=None,
|
||||
metavar="CODE",
|
||||
help=(
|
||||
"3-word session code to pair this TX agent with a waiting campaign, "
|
||||
"e.g. 'amber-peak-transmit'. Supplied by the campaign UI."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -861,6 +961,8 @@ def main() -> None:
|
|||
device = args.device or cfg.get("device", "unknown")
|
||||
insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False)
|
||||
log_level = args.log_level or cfg.get("log_level", "INFO")
|
||||
role = args.role or cfg.get("role", "general")
|
||||
session_code = args.session_code or cfg.get("session_code")
|
||||
|
||||
if not hub:
|
||||
parser.error("--hub is required (or set 'hub' in the config file)")
|
||||
|
|
@ -888,6 +990,8 @@ def main() -> None:
|
|||
name=name,
|
||||
sdr_device=device,
|
||||
insecure=insecure,
|
||||
role=role,
|
||||
session_code=session_code,
|
||||
)
|
||||
agent.run()
|
||||
|
||||
|
|
|
|||
|
|
@ -233,6 +233,9 @@ class TransmitterConfig:
|
|||
# For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port
|
||||
sdr_remote: Optional[dict] = None
|
||||
|
||||
# For sdr_agent control — keys: modulation, order, symbol_rate, center_frequency, filter, rolloff
|
||||
sdr_agent: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "TransmitterConfig":
|
||||
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
|
||||
|
|
@ -244,6 +247,7 @@ class TransmitterConfig:
|
|||
script=d.get("script"),
|
||||
device=d.get("device"),
|
||||
sdr_remote=d.get("sdr_remote"),
|
||||
sdr_agent=d.get("sdr_agent"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -272,6 +276,7 @@ class OutputConfig:
|
|||
path: str = "recordings"
|
||||
device_id: Optional[str] = None # for device-profile campaigns
|
||||
repo: Optional[str] = None
|
||||
folder: Optional[str] = None # repo subfolder: None = use campaign name, "" = no subfolder, str = custom
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "OutputConfig":
|
||||
|
|
@ -280,6 +285,7 @@ class OutputConfig:
|
|||
path=str(d.get("path", "recordings")),
|
||||
device_id=d.get("device_id"),
|
||||
repo=d.get("repo"),
|
||||
folder=d.get("folder"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -293,6 +299,7 @@ class CampaignConfig:
|
|||
qa: QAConfig = field(default_factory=QAConfig)
|
||||
output: OutputConfig = field(default_factory=OutputConfig)
|
||||
mode: str = "controlled_testbed"
|
||||
loops: int = 1 # repeat full schedule this many times; labels get _run{N:02d} suffix
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Loaders
|
||||
|
|
@ -320,6 +327,7 @@ class CampaignConfig:
|
|||
return cls(
|
||||
name=safe_name,
|
||||
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||
loops=max(1, int(campaign_meta.get("loops", 1))),
|
||||
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||
transmitters=transmitters,
|
||||
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||
|
|
@ -384,6 +392,7 @@ class CampaignConfig:
|
|||
return cls(
|
||||
name=safe_name,
|
||||
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||
loops=max(1, int(campaign_meta.get("loops", 1))),
|
||||
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||
transmitters=transmitters,
|
||||
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||
|
|
@ -486,9 +495,9 @@ class CampaignConfig:
|
|||
)
|
||||
|
||||
def total_capture_time_s(self) -> float:
|
||||
"""Sum of all step durations across all transmitters."""
|
||||
return sum(step.duration for tx in self.transmitters for step in tx.schedule)
|
||||
"""Sum of all step durations across all transmitters and loops."""
|
||||
return sum(step.duration for tx in self.transmitters for step in tx.schedule) * self.loops
|
||||
|
||||
def total_steps(self) -> int:
|
||||
"""Total number of capture steps across all transmitters."""
|
||||
return sum(len(tx.schedule) for tx in self.transmitters)
|
||||
"""Total number of capture steps across all transmitters and loops."""
|
||||
return sum(len(tx.schedule) for tx in self.transmitters) * self.loops
|
||||
|
|
|
|||
|
|
@ -5,8 +5,9 @@ from __future__ import annotations
|
|||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, replace
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
|
@ -16,6 +17,7 @@ from ria_toolkit_oss.io.recording import to_sigmf
|
|||
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
|
||||
from .labeler import build_output_filename, label_recording
|
||||
from .qa import QAResult, check_recording
|
||||
from .tx_executor import TxExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -169,6 +171,21 @@ def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_tx_params(transmitter: TransmitterConfig) -> dict | None:
|
||||
"""Build a tx_params dict from a transmitter's signal config for SigMF labeling.
|
||||
|
||||
For sdr_agent transmitters, returns the synthetic generation parameters
|
||||
(modulation, order, symbol_rate, etc.) so recordings capture what was
|
||||
transmitted. Returns None for control methods without signal-level params.
|
||||
"""
|
||||
sdr_agent_cfg = getattr(transmitter, "sdr_agent", None)
|
||||
if not sdr_agent_cfg:
|
||||
return None
|
||||
# Extract known signal-level fields; ignore infra fields
|
||||
_INFRA_KEYS = {"node_id", "session_code"}
|
||||
return {k: v for k, v in sdr_agent_cfg.items() if k not in _INFRA_KEYS and v is not None}
|
||||
|
||||
|
||||
class CampaignExecutor:
|
||||
"""Executes a :class:`CampaignConfig` end-to-end.
|
||||
|
||||
|
|
@ -192,11 +209,14 @@ class CampaignExecutor:
|
|||
config: CampaignConfig,
|
||||
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
|
||||
verbose: bool = False,
|
||||
skip_local_tx: bool = False,
|
||||
):
|
||||
self.config = config
|
||||
self.progress_cb = progress_cb
|
||||
self.skip_local_tx = skip_local_tx
|
||||
self._sdr = None
|
||||
self._remote_tx_controllers: dict = {}
|
||||
self._tx_executors: dict[str, tuple] = {} # tx_id → (TxExecutor, stop_event, thread)
|
||||
|
||||
if verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
|
@ -216,10 +236,12 @@ class CampaignExecutor:
|
|||
"""
|
||||
result = CampaignResult(campaign_name=self.config.name)
|
||||
|
||||
loops = self.config.loops
|
||||
logger.info(
|
||||
f"Starting campaign '{self.config.name}': "
|
||||
f"{self.config.total_steps()} steps, "
|
||||
f"~{self.config.total_capture_time_s():.0f}s capture time"
|
||||
f"{self.config.total_steps()} steps"
|
||||
+ (f" ({self.config.total_steps() // loops} × {loops} loops)" if loops > 1 else "")
|
||||
+ f", ~{self.config.total_capture_time_s():.0f}s capture time"
|
||||
)
|
||||
|
||||
self._init_sdr()
|
||||
|
|
@ -228,10 +250,14 @@ class CampaignExecutor:
|
|||
total = self.config.total_steps()
|
||||
step_index = 0
|
||||
|
||||
for loop_idx in range(loops):
|
||||
if loops > 1:
|
||||
logger.info(f"Loop {loop_idx + 1}/{loops}")
|
||||
for transmitter in self.config.transmitters:
|
||||
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
|
||||
for step in transmitter.schedule:
|
||||
step_result = self._execute_step(transmitter, step)
|
||||
looped_step = replace(step, label=f"{step.label}_run{loop_idx + 1:02d}") if loops > 1 else step
|
||||
step_result = self._execute_step(transmitter, looped_step)
|
||||
result.steps.append(step_result)
|
||||
step_index += 1
|
||||
|
||||
|
|
@ -239,18 +265,21 @@ class CampaignExecutor:
|
|||
self.progress_cb(step_index, total, step_result)
|
||||
|
||||
if step_result.error:
|
||||
logger.warning(f"Step '{step.label}' error: {step_result.error}")
|
||||
logger.warning(f"Step '{looped_step.label}' error: {step_result.error}")
|
||||
elif step_result.qa.flagged:
|
||||
logger.warning(f"Step '{step.label}' flagged for review: " + "; ".join(step_result.qa.issues))
|
||||
logger.warning(
|
||||
f"Step '{looped_step.label}' flagged for review: " + "; ".join(step_result.qa.issues)
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Step '{step.label}' OK "
|
||||
f"Step '{looped_step.label}' OK "
|
||||
f"(SNR {step_result.qa.snr_db:.1f} dB, "
|
||||
f"{step_result.qa.duration_s:.1f}s)"
|
||||
)
|
||||
finally:
|
||||
self._close_sdr()
|
||||
self._close_remote_tx_controllers()
|
||||
self._close_tx_executors()
|
||||
|
||||
result.end_time = time.time()
|
||||
logger.info(
|
||||
|
|
@ -325,6 +354,12 @@ class CampaignExecutor:
|
|||
logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}")
|
||||
self._remote_tx_controllers.clear()
|
||||
|
||||
def _close_tx_executors(self) -> None:
|
||||
for tx_id, (_, stop_event, t) in list(self._tx_executors.items()):
|
||||
stop_event.set()
|
||||
t.join(timeout=5.0)
|
||||
self._tx_executors.clear()
|
||||
|
||||
def _record(self, duration_s: float) -> Recording:
|
||||
"""Capture ``duration_s`` seconds of IQ samples."""
|
||||
num_samples = int(duration_s * self.config.recorder.sample_rate)
|
||||
|
|
@ -369,6 +404,7 @@ class CampaignExecutor:
|
|||
step=step,
|
||||
capture_timestamp=capture_timestamp,
|
||||
campaign_name=self.config.name,
|
||||
tx_params=_extract_tx_params(transmitter),
|
||||
)
|
||||
|
||||
# QA
|
||||
|
|
@ -437,6 +473,30 @@ class CampaignExecutor:
|
|||
# Start transmission in background; _record() runs concurrently
|
||||
ctrl.transmit_async(step.duration + 1.0)
|
||||
|
||||
elif transmitter.control_method == "sdr_agent":
|
||||
if self.skip_local_tx:
|
||||
logger.debug(f"skip_local_tx — TX for '{transmitter.id}' delegated to TX agent node")
|
||||
return
|
||||
if not transmitter.sdr_agent:
|
||||
logger.warning(f"Transmitter '{transmitter.id}' has no sdr_agent config — skipping")
|
||||
return
|
||||
step_dict: dict = {"label": step.label, "duration": step.duration + 1.0}
|
||||
if step.power_dbm is not None:
|
||||
step_dict["power_dbm"] = step.power_dbm
|
||||
tx_config = {
|
||||
"id": transmitter.id,
|
||||
"sdr_agent": transmitter.sdr_agent,
|
||||
"schedule": [step_dict],
|
||||
}
|
||||
rec = self.config.recorder
|
||||
tx_device = transmitter.device or rec.device
|
||||
sdr_device = _DEVICE_ALIASES.get(tx_device.lower(), tx_device.lower())
|
||||
stop_event = threading.Event()
|
||||
executor = TxExecutor(tx_config, sdr_device=sdr_device, stop_event=stop_event)
|
||||
t = threading.Thread(target=executor.run, daemon=True, name=f"tx-{transmitter.id}")
|
||||
self._tx_executors[transmitter.id] = (executor, stop_event, t)
|
||||
t.start()
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
|
||||
|
||||
|
|
@ -459,6 +519,13 @@ class CampaignExecutor:
|
|||
if ctrl is not None:
|
||||
ctrl.wait_transmit(timeout=step.duration + 10.0)
|
||||
|
||||
elif transmitter.control_method == "sdr_agent":
|
||||
entry = self._tx_executors.pop(transmitter.id, None)
|
||||
if entry is not None:
|
||||
_, stop_event, t = entry
|
||||
stop_event.set()
|
||||
t.join(timeout=step.duration + 10.0)
|
||||
|
||||
@staticmethod
|
||||
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
|
||||
"""Serialise step parameters to a JSON string for the control script."""
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ def label_recording(
|
|||
step: CaptureStep,
|
||||
capture_timestamp: float,
|
||||
campaign_name: Optional[str] = None,
|
||||
tx_params: Optional[dict] = None,
|
||||
) -> Recording:
|
||||
"""Apply device identity and capture configuration labels to a recording's metadata.
|
||||
|
||||
|
|
@ -27,6 +28,9 @@ def label_recording(
|
|||
step: The capture step that was active during this recording.
|
||||
capture_timestamp: Unix timestamp (float) of when capture started.
|
||||
campaign_name: Optional campaign name for cross-recording reference.
|
||||
tx_params: Optional dict of transmitter signal parameters (e.g. modulation,
|
||||
order, symbol_rate) written as ``ria:tx_<key>`` fields so downstream
|
||||
training pipelines know what was transmitted into the recording.
|
||||
|
||||
Returns:
|
||||
The same recording with updated metadata.
|
||||
|
|
@ -57,6 +61,11 @@ def label_recording(
|
|||
if step.power_dbm is not None:
|
||||
recording.update_metadata("tx_power_dbm", step.power_dbm)
|
||||
|
||||
# Transmitter signal parameters (e.g. from sdr_agent synthetic generation)
|
||||
if tx_params:
|
||||
for key, value in tx_params.items():
|
||||
recording.update_metadata(f"tx_{key}", value)
|
||||
|
||||
return recording
|
||||
|
||||
|
||||
|
|
|
|||
299
src/ria_toolkit_oss/orchestration/tx_executor.py
Normal file
299
src/ria_toolkit_oss/orchestration/tx_executor.py
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
"""TX campaign executor — synthesises and transmits signals via a local SDR.
|
||||
|
||||
The TxExecutor receives a transmitter config dict (matching the
|
||||
``sdr_agent`` control method's schema) and a step schedule, then for each
|
||||
step builds a signal chain with the block generator and transmits it via
|
||||
the local SDR device.
|
||||
|
||||
Supported modulations (``modulation`` field in config):
|
||||
BPSK, QPSK, 8PSK, 16QAM, 64QAM, 256QAM, FSK, OOK, GMSK, OQPSK
|
||||
|
||||
Example config dict (matches CampaignConfig transmitter with
|
||||
``control_method: sdr_agent``)::
|
||||
|
||||
{
|
||||
"id": "synthetic-tx",
|
||||
"type": "sdr",
|
||||
"control_method": "sdr_agent",
|
||||
"sdr_agent": {
|
||||
"modulation": "QPSK",
|
||||
"order": 4,
|
||||
"symbol_rate": 1000000,
|
||||
"center_frequency": 0.0,
|
||||
"filter": "rrc",
|
||||
"rolloff": 0.35
|
||||
},
|
||||
"schedule": [
|
||||
{"label": "step1", "duration": 10, "power_dbm": -10}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_hz(val: object) -> float:
|
||||
"""Parse a frequency value that may be a float (Hz) or a string like '2.45GHz'."""
|
||||
if isinstance(val, (int, float)):
|
||||
return float(val)
|
||||
s = str(val).strip()
|
||||
for suffix, mult in (("GHz", 1e9), ("MHz", 1e6), ("kHz", 1e3), ("Hz", 1.0)):
|
||||
if s.endswith(suffix):
|
||||
return float(s[: -len(suffix)]) * mult
|
||||
return float(s)
|
||||
|
||||
|
||||
def _parse_seconds(val: object) -> float:
|
||||
"""Parse a duration value that may be a float (seconds) or a string like '5s'."""
|
||||
if isinstance(val, (int, float)):
|
||||
return float(val)
|
||||
s = str(val).strip()
|
||||
return float(s[:-1]) if s.endswith("s") else float(s)
|
||||
|
||||
|
||||
# Mapping from modulation name → (PSK/QAM order, generator_type)
|
||||
# 'psk' uses PSKGenerator, 'qam' uses QAMGenerator
|
||||
_MOD_TABLE: dict[str, tuple[int, str]] = {
|
||||
"BPSK": (1, "psk"),
|
||||
"QPSK": (2, "psk"),
|
||||
"8PSK": (3, "psk"),
|
||||
"16QAM": (4, "qam"),
|
||||
"64QAM": (6, "qam"),
|
||||
"256QAM": (8, "qam"),
|
||||
}
|
||||
|
||||
_SPECIAL_MODS = {"FSK", "OOK", "GMSK", "OQPSK"}
|
||||
|
||||
# usrp-uhd-client's tx_recording() streams 2 000-sample chunks and loops the
|
||||
# source buffer for the full tx_time, so only this many samples ever need to
|
||||
# be in RAM regardless of step duration or sample rate.
|
||||
# 50 000 complex64 samples ≈ 400 kB — enough spectral diversity for looping.
|
||||
_SYNTH_BLOCK_SAMPLES = 50_000
|
||||
|
||||
|
||||
class TxExecutor:
|
||||
"""Synthesise and transmit a signal campaign via a local SDR.
|
||||
|
||||
Args:
|
||||
config: Transmitter config dict (must have ``sdr_agent`` sub-dict with
|
||||
modulation params, and ``schedule`` list of step dicts).
|
||||
sdr_device: SDR device name to open in TX mode (e.g. "pluto", "usrp").
|
||||
stop_event: External event that aborts the TX loop mid-step.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
sdr_device: str = "unknown",
|
||||
stop_event: threading.Event | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.sdr_device = sdr_device
|
||||
self.stop_event = stop_event or threading.Event()
|
||||
self._sdr: Any = None
|
||||
|
||||
def run(self) -> None:
|
||||
"""Execute all steps in the schedule, transmitting for each step duration."""
|
||||
agent_cfg: dict = self.config.get("sdr_agent") or {}
|
||||
schedule: list[dict] = self.config.get("schedule") or []
|
||||
|
||||
if not schedule:
|
||||
logger.warning("TxExecutor: no schedule steps — nothing to transmit")
|
||||
return
|
||||
|
||||
modulation: str = agent_cfg.get("modulation", "QPSK").upper()
|
||||
symbol_rate: float = float(agent_cfg.get("symbol_rate", 1e6))
|
||||
center_freq: float = _parse_hz(agent_cfg.get("center_frequency", 0.0))
|
||||
filter_type: str = agent_cfg.get("filter", "rrc").lower()
|
||||
rolloff: float = float(agent_cfg.get("rolloff", 0.35))
|
||||
loops: int = max(1, int(self.config.get("loops", 1)))
|
||||
|
||||
# Upsampling factor: samples_per_symbol, fixed at 8 for SDR compatibility.
|
||||
sps = 8
|
||||
sample_rate = symbol_rate * sps
|
||||
|
||||
self._init_sdr(sample_rate, center_freq)
|
||||
try:
|
||||
for loop_idx in range(loops):
|
||||
if self.stop_event.is_set():
|
||||
break
|
||||
if loops > 1:
|
||||
logger.info("TX loop %d/%d", loop_idx + 1, loops)
|
||||
for step in schedule:
|
||||
if self.stop_event.is_set():
|
||||
break
|
||||
looped_step = (
|
||||
{**step, "label": f"{step.get('label', 'step')}_run{loop_idx + 1:02d}"} if loops > 1 else step
|
||||
)
|
||||
self._execute_step(looped_step, modulation, sps, symbol_rate, filter_type, rolloff)
|
||||
finally:
|
||||
self._close_sdr()
|
||||
|
||||
def _execute_step(
|
||||
self,
|
||||
step: dict,
|
||||
modulation: str,
|
||||
sps: int,
|
||||
symbol_rate: float,
|
||||
filter_type: str,
|
||||
rolloff: float,
|
||||
) -> None:
|
||||
duration: float = _parse_seconds(step.get("duration", 10.0))
|
||||
label: str = step.get("label", "step")
|
||||
gain: float = float(step.get("power_dbm") or 0.0)
|
||||
sample_rate = symbol_rate * sps
|
||||
|
||||
logger.info(
|
||||
"TX step '%s': %.0f s, %s @ %.3f MHz (sps=%d, filter=%s)",
|
||||
label,
|
||||
duration,
|
||||
modulation,
|
||||
symbol_rate / 1e6,
|
||||
sps,
|
||||
filter_type,
|
||||
)
|
||||
|
||||
num_samples = int(duration * sample_rate)
|
||||
|
||||
# Synthesise a short representative block. tx_recording() loops this
|
||||
# buffer for the full tx_time using a 2 000-sample streaming callback,
|
||||
# so peak memory is O(_SYNTH_BLOCK_SAMPLES) regardless of duration.
|
||||
block_size = min(num_samples, _SYNTH_BLOCK_SAMPLES)
|
||||
signal = self._synthesise(modulation, sps, block_size, filter_type, rolloff)
|
||||
|
||||
if self._sdr is not None:
|
||||
try:
|
||||
# Apply gain update if SDR supports it
|
||||
if hasattr(self._sdr, "set_tx_gain"):
|
||||
self._sdr.set_tx_gain(gain)
|
||||
self._sdr.tx_recording(signal, tx_time=duration)
|
||||
except Exception as exc:
|
||||
logger.error("TX step '%s' SDR error: %s", label, exc)
|
||||
else:
|
||||
# No SDR available — simulate by sleeping for the step duration.
|
||||
logger.warning("TX step '%s': no SDR — simulating %.0f s delay", label, duration)
|
||||
self.stop_event.wait(timeout=duration)
|
||||
|
||||
def _synthesise(
|
||||
self,
|
||||
modulation: str,
|
||||
sps: int,
|
||||
num_samples: int,
|
||||
filter_type: str,
|
||||
rolloff: float,
|
||||
):
|
||||
"""Build a block-generator chain and return IQ samples as a numpy array."""
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.signal.block_generator import (
|
||||
BinarySource,
|
||||
GMSKModulator,
|
||||
Mapper,
|
||||
OOKModulator,
|
||||
OQPSKModulator,
|
||||
RaisedCosineFilter,
|
||||
RootRaisedCosineFilter,
|
||||
Upsampling,
|
||||
)
|
||||
from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import (
|
||||
FSKModulator,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(f"ria_toolkit_oss block generator not available: {exc}") from exc
|
||||
|
||||
# ── Special modulations with their own source-connected modulator ──
|
||||
if modulation in ("OOK", "GMSK", "OQPSK"):
|
||||
src = BinarySource()
|
||||
if modulation == "OOK":
|
||||
mod = OOKModulator(src, samples_per_symbol=sps)
|
||||
elif modulation == "GMSK":
|
||||
mod = GMSKModulator(src, samples_per_symbol=sps)
|
||||
else:
|
||||
mod = OQPSKModulator(src, samples_per_symbol=sps)
|
||||
recording = mod.record(num_samples)
|
||||
flat = np.asarray(recording.data).flatten().astype(np.complex64)
|
||||
if len(flat) < num_samples:
|
||||
flat = np.tile(flat, num_samples // len(flat) + 1)
|
||||
return flat[:num_samples]
|
||||
|
||||
if modulation == "FSK":
|
||||
symbol_rate = num_samples / sps
|
||||
bits_per_sym = 1 # 2-FSK
|
||||
num_bits = max(num_samples // sps, 128) * bits_per_sym
|
||||
bits = BinarySource()((1, num_bits))
|
||||
mod = FSKModulator(
|
||||
num_bits_per_symbol=bits_per_sym,
|
||||
frequency_spacing=symbol_rate * 0.5,
|
||||
symbol_duration=1.0 / max(symbol_rate, 1.0),
|
||||
sampling_frequency=symbol_rate * sps,
|
||||
)
|
||||
flat = np.asarray(mod(bits)).flatten().astype(np.complex64)
|
||||
if len(flat) < num_samples:
|
||||
flat = np.tile(flat, num_samples // len(flat) + 1)
|
||||
return flat[:num_samples]
|
||||
|
||||
# ── PSK / QAM via Mapper → Upsampling → pulse filter ──────────────
|
||||
if modulation not in _MOD_TABLE:
|
||||
logger.warning("Unknown modulation %r — defaulting to QPSK", modulation)
|
||||
modulation = "QPSK"
|
||||
|
||||
bits_per_sym, gen_type = _MOD_TABLE[modulation]
|
||||
mod_family = "QAM" if gen_type == "qam" else "PSK"
|
||||
|
||||
source = BinarySource()
|
||||
mapper = Mapper(constellation_type=mod_family, num_bits_per_symbol=bits_per_sym)
|
||||
upsampler = Upsampling(factor=sps)
|
||||
|
||||
mapper.connect_input([source])
|
||||
upsampler.connect_input([mapper])
|
||||
|
||||
if filter_type in ("rrc",):
|
||||
pulse_filter = RootRaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
|
||||
pulse_filter.connect_input([upsampler])
|
||||
recording = pulse_filter.record(num_samples)
|
||||
elif filter_type in ("rc",):
|
||||
pulse_filter = RaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
|
||||
pulse_filter.connect_input([upsampler])
|
||||
recording = pulse_filter.record(num_samples)
|
||||
else:
|
||||
# "none", "rect", "gaussian" — use upsampler output directly
|
||||
recording = upsampler.record(num_samples)
|
||||
|
||||
flat = np.asarray(recording.data).flatten().astype(np.complex64)
|
||||
if len(flat) < num_samples:
|
||||
flat = np.tile(flat, num_samples // len(flat) + 1)
|
||||
return flat[:num_samples]
|
||||
|
||||
def _init_sdr(self, sample_rate: float, center_freq: float) -> None:
|
||||
try:
|
||||
from ria_toolkit_oss.sdr import get_sdr_device
|
||||
|
||||
self._sdr = get_sdr_device(self.sdr_device)
|
||||
self._sdr.init_tx(
|
||||
sample_rate=sample_rate,
|
||||
center_frequency=center_freq,
|
||||
gain=0,
|
||||
channel=0,
|
||||
gain_mode="manual",
|
||||
)
|
||||
logger.info(
|
||||
"TX SDR initialised: %s @ %.3f MHz, %.1f Msps", self.sdr_device, center_freq / 1e6, sample_rate / 1e6
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("TX SDR init failed (%s) — will simulate: %s", self.sdr_device, exc)
|
||||
self._sdr = None
|
||||
|
||||
def _close_sdr(self) -> None:
|
||||
if self._sdr is not None:
|
||||
try:
|
||||
self._sdr.close()
|
||||
except Exception as exc:
|
||||
logger.debug("TX SDR close error: %s", exc)
|
||||
self._sdr = None
|
||||
|
|
@ -3,7 +3,7 @@
|
|||
from fastapi import Depends, FastAPI
|
||||
|
||||
from .auth import require_api_key
|
||||
from .routers import inference, orchestrator
|
||||
from .routers import conductor, inference
|
||||
|
||||
|
||||
def create_app(api_key: str = "") -> FastAPI:
|
||||
|
|
@ -28,9 +28,9 @@ def create_app(api_key: str = "") -> FastAPI:
|
|||
app.state.api_key = api_key
|
||||
|
||||
app.include_router(
|
||||
orchestrator.router,
|
||||
prefix="/orchestrator",
|
||||
tags=["Orchestrator"],
|
||||
conductor.router,
|
||||
prefix="/conductor",
|
||||
tags=["Conductor"],
|
||||
dependencies=[Depends(require_api_key)],
|
||||
)
|
||||
app.include_router(
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from pathlib import Path
|
|||
from pydantic import BaseModel, field_validator
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Orchestrator
|
||||
# Conductor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Orchestrator routes: campaign deployment, status, and cancellation."""
|
||||
"""Conductor routes: campaign deployment, status, and cancellation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -23,9 +23,9 @@ def serve(host: str, port: int, api_key: str, log_level: str):
|
|||
|
||||
\b
|
||||
Endpoints:
|
||||
POST /orchestrator/deploy
|
||||
GET /orchestrator/status/{campaign_id}
|
||||
POST /orchestrator/cancel/{campaign_id}
|
||||
POST /conductor/deploy
|
||||
GET /conductor/status/{campaign_id}
|
||||
POST /conductor/cancel/{campaign_id}
|
||||
POST /inference/load
|
||||
POST /inference/start
|
||||
POST /inference/stop
|
||||
|
|
|
|||
314
tests/orchestration/test_executor.py
Normal file
314
tests/orchestration/test_executor.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
"""Tests for orchestration executor — StepResult, CampaignResult, _run_script, _extract_tx_params."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import stat
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from ria_toolkit_oss.orchestration.executor import (
|
||||
CampaignResult,
|
||||
StepResult,
|
||||
_extract_tx_params,
|
||||
_run_script,
|
||||
)
|
||||
from ria_toolkit_oss.orchestration.qa import QAResult
|
||||
|
||||
|
||||
def _ok_qa() -> QAResult:
|
||||
return QAResult(passed=True, flagged=False, snr_db=20.0, duration_s=1.0)
|
||||
|
||||
|
||||
def _flagged_qa() -> QAResult:
|
||||
return QAResult(passed=True, flagged=True, snr_db=5.0, duration_s=1.0, issues=["low SNR"])
|
||||
|
||||
|
||||
def _failed_qa() -> QAResult:
|
||||
return QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=["no signal"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StepResult
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStepResult:
|
||||
def test_ok_true_when_no_error_and_qa_passed(self):
|
||||
r = StepResult(
|
||||
transmitter_id="tx1",
|
||||
step_label="step1",
|
||||
output_path="/out/rec.sigmf-data",
|
||||
qa=_ok_qa(),
|
||||
capture_timestamp=0.0,
|
||||
)
|
||||
assert r.ok is True
|
||||
|
||||
def test_ok_false_when_error_set(self):
|
||||
r = StepResult(
|
||||
transmitter_id="tx1",
|
||||
step_label="step1",
|
||||
output_path=None,
|
||||
qa=_ok_qa(),
|
||||
capture_timestamp=0.0,
|
||||
error="SDR failed",
|
||||
)
|
||||
assert r.ok is False
|
||||
|
||||
def test_ok_false_when_qa_not_passed(self):
|
||||
r = StepResult(
|
||||
transmitter_id="tx1",
|
||||
step_label="step1",
|
||||
output_path="/out",
|
||||
qa=_failed_qa(),
|
||||
capture_timestamp=0.0,
|
||||
)
|
||||
assert r.ok is False
|
||||
|
||||
def test_to_dict_contains_required_keys(self):
|
||||
r = StepResult(
|
||||
transmitter_id="tx1",
|
||||
step_label="step1",
|
||||
output_path="/out/rec.sigmf-data",
|
||||
qa=_ok_qa(),
|
||||
capture_timestamp=1234.5,
|
||||
)
|
||||
d = r.to_dict()
|
||||
assert d["transmitter_id"] == "tx1"
|
||||
assert d["step_label"] == "step1"
|
||||
assert d["output_path"] == "/out/rec.sigmf-data"
|
||||
assert d["capture_timestamp"] == pytest.approx(1234.5)
|
||||
assert d["error"] is None
|
||||
assert d["qa"]["passed"] is True
|
||||
|
||||
def test_to_dict_includes_error_when_set(self):
|
||||
r = StepResult(
|
||||
transmitter_id="tx1",
|
||||
step_label="step1",
|
||||
output_path=None,
|
||||
qa=_failed_qa(),
|
||||
capture_timestamp=0.0,
|
||||
error="disk full",
|
||||
)
|
||||
assert r.to_dict()["error"] == "disk full"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CampaignResult
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCampaignResult:
|
||||
def _make(self, steps: list) -> CampaignResult:
|
||||
r = CampaignResult(campaign_name="test_campaign")
|
||||
r.steps = steps
|
||||
r.end_time = r.start_time + 5.0
|
||||
return r
|
||||
|
||||
def test_total_steps(self):
|
||||
r = self._make(
|
||||
[
|
||||
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
|
||||
StepResult("tx1", "s2", "/out", _ok_qa(), 0.0),
|
||||
]
|
||||
)
|
||||
assert r.total_steps == 2
|
||||
|
||||
def test_passed_count(self):
|
||||
r = self._make(
|
||||
[
|
||||
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
|
||||
StepResult("tx1", "s2", "/out", _failed_qa(), 0.0),
|
||||
]
|
||||
)
|
||||
assert r.passed == 1
|
||||
|
||||
def test_failed_count(self):
|
||||
r = self._make(
|
||||
[
|
||||
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
|
||||
StepResult("tx1", "s2", "/out", _failed_qa(), 0.0),
|
||||
]
|
||||
)
|
||||
assert r.failed == 1
|
||||
|
||||
def test_flagged_count(self):
|
||||
r = self._make(
|
||||
[
|
||||
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
|
||||
StepResult("tx1", "s2", "/out", _flagged_qa(), 0.0),
|
||||
]
|
||||
)
|
||||
assert r.flagged == 1
|
||||
|
||||
def test_error_step_counts_as_failed_not_passed(self):
|
||||
r = self._make(
|
||||
[
|
||||
StepResult("tx1", "s1", None, _ok_qa(), 0.0, error="disk full"),
|
||||
]
|
||||
)
|
||||
assert r.failed == 1
|
||||
assert r.passed == 0
|
||||
|
||||
def test_duration_s_from_end_time(self):
|
||||
r = CampaignResult(campaign_name="c")
|
||||
r.start_time = 100.0
|
||||
r.end_time = 115.0
|
||||
assert r.duration_s == pytest.approx(15.0)
|
||||
|
||||
def test_to_dict_structure(self):
|
||||
r = self._make([StepResult("tx1", "s1", "/out", _ok_qa(), 0.0)])
|
||||
d = r.to_dict()
|
||||
assert d["campaign_name"] == "test_campaign"
|
||||
assert d["total_steps"] == 1
|
||||
assert d["passed"] == 1
|
||||
assert len(d["steps"]) == 1
|
||||
|
||||
def test_write_report(self, tmp_path):
|
||||
r = self._make([StepResult("tx1", "s1", "/out", _ok_qa(), 0.0)])
|
||||
out = tmp_path / "report.json"
|
||||
r.write_report(str(out))
|
||||
assert out.exists()
|
||||
data = json.loads(out.read_text())
|
||||
assert data["campaign_name"] == "test_campaign"
|
||||
|
||||
def test_write_report_creates_nested_dirs(self, tmp_path):
|
||||
r = self._make([])
|
||||
out = tmp_path / "nested" / "deep" / "report.json"
|
||||
r.write_report(str(out))
|
||||
assert out.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_script
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunScript:
|
||||
def _script(self, tmp_path, body: str) -> str:
|
||||
s = tmp_path / "script.sh"
|
||||
s.write_text("#!/bin/sh\n" + body)
|
||||
s.chmod(s.stat().st_mode | stat.S_IEXEC)
|
||||
return str(s)
|
||||
|
||||
def test_returns_stdout(self, tmp_path):
|
||||
out = _run_script(self._script(tmp_path, 'echo "hello world"'))
|
||||
assert out == "hello world"
|
||||
|
||||
def test_passes_args_to_script(self, tmp_path):
|
||||
out = _run_script(self._script(tmp_path, 'echo "$1 $2"'), "configure", "arg2")
|
||||
assert "configure" in out
|
||||
|
||||
def test_raises_on_nonzero_exit(self, tmp_path):
|
||||
with pytest.raises(RuntimeError, match="exited 1"):
|
||||
_run_script(self._script(tmp_path, "exit 1"))
|
||||
|
||||
def test_raises_on_relative_path(self):
|
||||
with pytest.raises(RuntimeError, match="absolute"):
|
||||
_run_script("relative/script.sh")
|
||||
|
||||
def test_raises_on_missing_file(self, tmp_path):
|
||||
with pytest.raises(RuntimeError):
|
||||
_run_script(str(tmp_path / "nonexistent.sh"))
|
||||
|
||||
def test_raises_on_timeout(self, tmp_path):
|
||||
with pytest.raises(RuntimeError, match="timed out"):
|
||||
_run_script(self._script(tmp_path, "sleep 60"), timeout=0.1)
|
||||
|
||||
def test_stderr_included_in_error_message(self, tmp_path):
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
_run_script(self._script(tmp_path, "echo 'bad thing' >&2; exit 1"))
|
||||
assert "bad thing" in str(exc_info.value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_tx_params
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractTxParams:
|
||||
def test_returns_none_when_no_sdr_agent_attribute(self):
|
||||
tx = SimpleNamespace()
|
||||
assert _extract_tx_params(tx) is None
|
||||
|
||||
def test_returns_none_when_sdr_agent_is_none(self):
|
||||
tx = SimpleNamespace(sdr_agent=None)
|
||||
assert _extract_tx_params(tx) is None
|
||||
|
||||
def test_returns_none_when_sdr_agent_is_empty_dict(self):
|
||||
tx = SimpleNamespace(sdr_agent={})
|
||||
assert _extract_tx_params(tx) is None
|
||||
|
||||
def test_returns_signal_params(self):
|
||||
tx = SimpleNamespace(
|
||||
sdr_agent={
|
||||
"modulation": "QPSK",
|
||||
"symbol_rate": 1e6,
|
||||
"center_frequency": 2.4e9,
|
||||
}
|
||||
)
|
||||
result = _extract_tx_params(tx)
|
||||
assert result == {"modulation": "QPSK", "symbol_rate": 1e6, "center_frequency": 2.4e9}
|
||||
|
||||
def test_strips_infra_key_node_id(self):
|
||||
tx = SimpleNamespace(
|
||||
sdr_agent={
|
||||
"modulation": "BPSK",
|
||||
"node_id": "node_abc123",
|
||||
}
|
||||
)
|
||||
result = _extract_tx_params(tx)
|
||||
assert "node_id" not in result
|
||||
assert result == {"modulation": "BPSK"}
|
||||
|
||||
def test_strips_infra_key_session_code(self):
|
||||
tx = SimpleNamespace(
|
||||
sdr_agent={
|
||||
"modulation": "FSK",
|
||||
"session_code": "amber-peak-transmit",
|
||||
}
|
||||
)
|
||||
result = _extract_tx_params(tx)
|
||||
assert "session_code" not in result
|
||||
|
||||
def test_strips_none_values(self):
|
||||
tx = SimpleNamespace(
|
||||
sdr_agent={
|
||||
"modulation": "QPSK",
|
||||
"order": None,
|
||||
"rolloff": 0.35,
|
||||
}
|
||||
)
|
||||
result = _extract_tx_params(tx)
|
||||
assert "order" not in result
|
||||
assert result == {"modulation": "QPSK", "rolloff": 0.35}
|
||||
|
||||
def test_does_not_mutate_source_dict(self):
|
||||
cfg = {"modulation": "QPSK", "node_id": "nid", "session_code": "code"}
|
||||
tx = SimpleNamespace(sdr_agent=cfg)
|
||||
_extract_tx_params(tx)
|
||||
assert "node_id" in cfg
|
||||
|
||||
def test_full_sdr_agent_config(self):
|
||||
tx = SimpleNamespace(
|
||||
sdr_agent={
|
||||
"modulation": "16QAM",
|
||||
"order": 4,
|
||||
"symbol_rate": 5e6,
|
||||
"center_frequency": 915e6,
|
||||
"filter": "rrc",
|
||||
"rolloff": 0.35,
|
||||
"node_id": "node_xyz",
|
||||
"session_code": "some-code",
|
||||
}
|
||||
)
|
||||
result = _extract_tx_params(tx)
|
||||
assert result == {
|
||||
"modulation": "16QAM",
|
||||
"order": 4,
|
||||
"symbol_rate": 5e6,
|
||||
"center_frequency": 915e6,
|
||||
"filter": "rrc",
|
||||
"rolloff": 0.35,
|
||||
}
|
||||
|
|
@ -109,6 +109,38 @@ class TestLabelRecording:
|
|||
result = label_recording(rec, "iphone13_001", _wifi_step(), time.time())
|
||||
assert result is rec
|
||||
|
||||
def test_tx_params_none_by_default(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
tx_keys = [k for k in rec.metadata if k.startswith("tx_")]
|
||||
assert tx_keys == []
|
||||
|
||||
def test_tx_params_written_as_tx_prefix_keys(self):
|
||||
params = {"modulation": "QPSK", "symbol_rate": 1e6}
|
||||
rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params)
|
||||
assert rec.metadata["tx_modulation"] == "QPSK"
|
||||
assert rec.metadata["tx_symbol_rate"] == pytest.approx(1e6)
|
||||
|
||||
def test_tx_params_multiple_fields(self):
|
||||
params = {
|
||||
"modulation": "16QAM",
|
||||
"order": 4,
|
||||
"symbol_rate": 5e6,
|
||||
"center_frequency": 915e6,
|
||||
"filter": "rrc",
|
||||
"rolloff": 0.35,
|
||||
}
|
||||
rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params)
|
||||
for k, v in params.items():
|
||||
assert f"tx_{k}" in rec.metadata
|
||||
assert (
|
||||
rec.metadata[f"tx_{k}"] == pytest.approx(v) if isinstance(v, float) else rec.metadata[f"tx_{k}"] == v
|
||||
)
|
||||
|
||||
def test_tx_params_empty_dict_writes_nothing(self):
|
||||
rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params={})
|
||||
tx_keys = [k for k in rec.metadata if k.startswith("tx_") and k != "tx_power_dbm"]
|
||||
assert tx_keys == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_output_filename
|
||||
|
|
|
|||
153
tests/orchestration/test_tx_executor.py
Normal file
153
tests/orchestration/test_tx_executor.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
"""Tests for TxExecutor — signal synthesis and step execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from ria_toolkit_oss.orchestration.tx_executor import TxExecutor
|
||||
|
||||
|
||||
def _cfg(modulation="QPSK", symbol_rate=100_000, steps=None):
|
||||
return {
|
||||
"id": "test-tx",
|
||||
"type": "sdr",
|
||||
"control_method": "sdr_agent",
|
||||
"sdr_agent": {
|
||||
"modulation": modulation,
|
||||
"symbol_rate": symbol_rate,
|
||||
"center_frequency": 0.0,
|
||||
"filter": "rrc",
|
||||
"rolloff": 0.35,
|
||||
},
|
||||
"schedule": steps or [{"label": "step1", "duration": 0.001, "power_dbm": -10}],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Initialisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTxExecutorInit:
|
||||
def test_stores_sdr_device(self):
|
||||
ex = TxExecutor(_cfg(), sdr_device="pluto")
|
||||
assert ex.sdr_device == "pluto"
|
||||
|
||||
def test_stop_event_created_when_not_supplied(self):
|
||||
ex = TxExecutor(_cfg())
|
||||
assert isinstance(ex.stop_event, threading.Event)
|
||||
assert not ex.stop_event.is_set()
|
||||
|
||||
def test_accepts_external_stop_event(self):
|
||||
ev = threading.Event()
|
||||
ex = TxExecutor(_cfg(), stop_event=ev)
|
||||
assert ex.stop_event is ev
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run() — schedule iteration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTxExecutorRun:
|
||||
def test_empty_schedule_returns_immediately(self):
|
||||
cfg = _cfg(steps=[])
|
||||
ex = TxExecutor(cfg)
|
||||
ex.run() # must not raise or block
|
||||
|
||||
def test_pre_set_stop_event_skips_all_steps(self):
|
||||
ev = threading.Event()
|
||||
ev.set()
|
||||
ex = TxExecutor(_cfg(), stop_event=ev)
|
||||
# If stop was set, _execute_step should never be called.
|
||||
# run() should return cleanly without attempting synthesis.
|
||||
ex.run()
|
||||
|
||||
def test_no_sdr_falls_back_to_simulation(self, monkeypatch):
|
||||
"""Without SDR hardware TxExecutor simulates by calling stop_event.wait."""
|
||||
cfg = _cfg(steps=[{"label": "s", "duration": 0.001, "power_dbm": 0}])
|
||||
waited = []
|
||||
real_ev = threading.Event()
|
||||
|
||||
def _fake_wait(timeout=None):
|
||||
waited.append(timeout)
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(real_ev, "wait", _fake_wait)
|
||||
|
||||
# Patch SDR init to always fail (forces simulation path)
|
||||
with patch.object(TxExecutor, "_init_sdr", lambda self, *a, **kw: setattr(self, "_sdr", None)):
|
||||
ex = TxExecutor(cfg, sdr_device="nonexistent_xyz", stop_event=real_ev)
|
||||
ex.run()
|
||||
|
||||
assert len(waited) >= 1, "expected stop_event.wait to be called for simulation"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _synthesise — all modulation types and filter types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSynthesise:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _ex(self):
|
||||
self.ex = TxExecutor(_cfg())
|
||||
|
||||
def _synth(self, mod, num_samples=256):
|
||||
return self.ex._synthesise(mod, sps=4, num_samples=num_samples, filter_type="rrc", rolloff=0.35)
|
||||
|
||||
@pytest.mark.parametrize("mod", ["BPSK", "QPSK", "8PSK", "16QAM", "64QAM", "256QAM"])
|
||||
def test_psk_qam_returns_complex64_array(self, mod):
|
||||
sig = self._synth(mod)
|
||||
assert sig.dtype == np.complex64
|
||||
assert len(sig) == 256
|
||||
|
||||
def test_fsk_returns_correct_length(self):
|
||||
sig = self._synth("FSK")
|
||||
assert len(sig) == 256
|
||||
|
||||
def test_ook_returns_correct_length(self):
|
||||
sig = self._synth("OOK")
|
||||
assert len(sig) == 256
|
||||
|
||||
def test_gmsk_returns_correct_length(self):
|
||||
sig = self._synth("GMSK")
|
||||
assert len(sig) == 256
|
||||
|
||||
def test_oqpsk_returns_correct_length(self):
|
||||
sig = self._synth("OQPSK")
|
||||
assert len(sig) == 256
|
||||
|
||||
@pytest.mark.parametrize("mod", ["BPSK", "QPSK", "16QAM", "FSK", "OOK", "GMSK"])
|
||||
def test_samples_are_finite(self, mod):
|
||||
sig = self._synth(mod)
|
||||
assert np.all(np.isfinite(sig.real)), f"{mod}: non-finite real samples"
|
||||
assert np.all(np.isfinite(sig.imag)), f"{mod}: non-finite imag samples"
|
||||
|
||||
def test_unknown_modulation_defaults_to_qpsk(self):
|
||||
sig = self._synth("UNKNOWN_MOD_XYZ")
|
||||
assert len(sig) == 256
|
||||
assert sig.dtype == np.complex64
|
||||
|
||||
@pytest.mark.parametrize("filter_type", ["rrc", "rc", "gaussian", "rect", "none"])
|
||||
def test_all_filter_types(self, filter_type):
|
||||
sig = self.ex._synthesise("QPSK", sps=4, num_samples=128, filter_type=filter_type, rolloff=0.35)
|
||||
assert len(sig) == 128
|
||||
|
||||
@pytest.mark.parametrize("n", [64, 128, 512, 1024])
|
||||
def test_output_length_matches_requested_samples(self, n):
|
||||
sig = self._synth("QPSK", num_samples=n)
|
||||
assert len(sig) == n
|
||||
|
||||
def test_bpsk_output_is_complex_not_real(self):
|
||||
sig = self._synth("BPSK")
|
||||
# complex64 always has imag part; just check dtype
|
||||
assert sig.dtype == np.complex64
|
||||
|
||||
def test_256qam_correct_length(self):
|
||||
sig = self._synth("256QAM")
|
||||
assert len(sig) == 256
|
||||
|
|
@ -189,6 +189,8 @@ class TestNoiseCommand:
|
|||
"10000",
|
||||
"--noise-type",
|
||||
"gaussian",
|
||||
"--power",
|
||||
"0.01",
|
||||
"--output",
|
||||
output,
|
||||
"-q",
|
||||
|
|
@ -234,7 +236,7 @@ class TestNoiseCommand:
|
|||
"--num-samples",
|
||||
"10000",
|
||||
"--power",
|
||||
"0.5",
|
||||
"0.01",
|
||||
"--output",
|
||||
output,
|
||||
"-q",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Tests for the RT-OSS HTTP server.
|
||||
|
||||
Covers: auth, inference lifecycle (without SDR/ONNX hardware), orchestrator
|
||||
Covers: auth, inference lifecycle (without SDR/ONNX hardware), conductor
|
||||
lifecycle (with mocked executor), and state helpers.
|
||||
|
||||
``start_inference`` and ``_inference_loop`` require real SDR hardware and an
|
||||
|
|
@ -286,17 +286,17 @@ class TestInferenceStop:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /orchestrator/deploy
|
||||
# POST /conductor/deploy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOrchestratorDeploy:
|
||||
class TestConductorDeploy:
|
||||
def test_deploy_422_on_invalid_config(self, client):
|
||||
with patch(
|
||||
"ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict",
|
||||
"ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict",
|
||||
side_effect=ValueError("missing required field 'name'"),
|
||||
):
|
||||
resp = client.post("/orchestrator/deploy", json={"config": {}})
|
||||
resp = client.post("/conductor/deploy", json={"config": {}})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_deploy_returns_campaign_id(self, client):
|
||||
|
|
@ -307,10 +307,10 @@ class TestOrchestratorDeploy:
|
|||
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
|
||||
|
||||
with (
|
||||
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg),
|
||||
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor),
|
||||
patch("ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict", return_value=mock_cfg),
|
||||
patch("ria_toolkit_oss.server.routers.conductor.CampaignExecutor", mock_executor),
|
||||
):
|
||||
resp = client.post("/orchestrator/deploy", json={"config": {"name": "test_campaign"}})
|
||||
resp = client.post("/conductor/deploy", json={"config": {"name": "test_campaign"}})
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
|
|
@ -325,23 +325,23 @@ class TestOrchestratorDeploy:
|
|||
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
|
||||
|
||||
with (
|
||||
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg),
|
||||
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor),
|
||||
patch("ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict", return_value=mock_cfg),
|
||||
patch("ria_toolkit_oss.server.routers.conductor.CampaignExecutor", mock_executor),
|
||||
):
|
||||
resp = client.post("/orchestrator/deploy", json={"config": {}})
|
||||
resp = client.post("/conductor/deploy", json={"config": {}})
|
||||
|
||||
campaign_id = resp.json()["campaign_id"]
|
||||
assert state_module._campaigns.get(campaign_id) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /orchestrator/status/{campaign_id}
|
||||
# GET /conductor/status/{campaign_id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOrchestratorStatus:
|
||||
class TestConductorStatus:
|
||||
def test_status_404_for_unknown_id(self, client):
|
||||
resp = client.get("/orchestrator/status/nonexistent-id")
|
||||
resp = client.get("/conductor/status/nonexistent-id")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_status_returns_campaign_state(self, client):
|
||||
|
|
@ -357,7 +357,7 @@ class TestOrchestratorStatus:
|
|||
)
|
||||
state_module._campaigns["abc-123"] = state
|
||||
|
||||
resp = client.get("/orchestrator/status/abc-123")
|
||||
resp = client.get("/conductor/status/abc-123")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["campaign_id"] == "abc-123"
|
||||
|
|
@ -367,13 +367,13 @@ class TestOrchestratorStatus:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /orchestrator/cancel/{campaign_id}
|
||||
# POST /conductor/cancel/{campaign_id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOrchestratorCancel:
|
||||
class TestConductorCancel:
|
||||
def test_cancel_404_for_unknown_id(self, client):
|
||||
resp = client.post("/orchestrator/cancel/no-such-id")
|
||||
resp = client.post("/conductor/cancel/no-such-id")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_cancel_sets_cancel_event(self, client):
|
||||
|
|
@ -387,7 +387,7 @@ class TestOrchestratorCancel:
|
|||
)
|
||||
state_module._campaigns["camp-to-cancel"] = state
|
||||
|
||||
resp = client.post("/orchestrator/cancel/camp-to-cancel")
|
||||
resp = client.post("/conductor/cancel/camp-to-cancel")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["cancelled"] is True
|
||||
assert cancel_event.is_set()
|
||||
|
|
@ -403,7 +403,7 @@ class TestOrchestratorCancel:
|
|||
)
|
||||
state_module._campaigns["done"] = state
|
||||
|
||||
resp = client.post("/orchestrator/cancel/done")
|
||||
resp = client.post("/conductor/cancel/done")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["cancelled"] is False
|
||||
assert not cancel_event.is_set()
|
||||
|
|
|
|||
247
tests/test_agent.py
Normal file
247
tests/test_agent.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
"""Tests for NodeAgent — TX role, session code, and TX command dispatch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from ria_toolkit_oss.agent import NodeAgent
|
||||
|
||||
|
||||
def _agent(role="general", session_code=None, **kwargs):
|
||||
return NodeAgent(
|
||||
hub_url="http://hub.test",
|
||||
api_key="test-key",
|
||||
name="test-node",
|
||||
sdr_device="mock",
|
||||
role=role,
|
||||
session_code=session_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _mock_register(agent, node_id="node_abc123"):
|
||||
"""Patch _post so _register() returns a fake node_id response."""
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"node_id": node_id}
|
||||
resp.raise_for_status.return_value = None
|
||||
agent._post = MagicMock(return_value=resp)
|
||||
return agent._post
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Initialisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNodeAgentInit:
|
||||
def test_stores_role_general(self):
|
||||
assert _agent(role="general").role == "general"
|
||||
|
||||
def test_stores_role_tx(self):
|
||||
assert _agent(role="tx").role == "tx"
|
||||
|
||||
def test_stores_role_rx(self):
|
||||
assert _agent(role="rx").role == "rx"
|
||||
|
||||
def test_session_code_stored(self):
|
||||
assert _agent(session_code="amber-peak-transmit").session_code == "amber-peak-transmit"
|
||||
|
||||
def test_session_code_none_by_default(self):
|
||||
assert _agent().session_code is None
|
||||
|
||||
def test_tx_stop_event_created(self):
|
||||
a = _agent()
|
||||
assert isinstance(a._tx_stop, threading.Event)
|
||||
|
||||
def test_tx_thread_none_initially(self):
|
||||
assert _agent()._tx_thread is None
|
||||
|
||||
def test_hub_url_trailing_slash_stripped(self):
|
||||
a = NodeAgent(hub_url="http://hub.test/", api_key="k", name="n")
|
||||
assert a.hub_url == "http://hub.test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _register payload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNodeAgentRegisterPayload:
|
||||
def _payload(self, agent):
|
||||
post = _mock_register(agent)
|
||||
agent._register()
|
||||
_, kwargs = post.call_args
|
||||
return kwargs["json"]
|
||||
|
||||
def test_general_role_in_payload(self):
|
||||
payload = self._payload(_agent(role="general"))
|
||||
assert payload["role"] == "general"
|
||||
|
||||
def test_tx_role_in_payload(self):
|
||||
payload = self._payload(_agent(role="tx"))
|
||||
assert payload["role"] == "tx"
|
||||
|
||||
def test_tx_role_adds_transmit_capability(self):
|
||||
payload = self._payload(_agent(role="tx"))
|
||||
assert "transmit" in payload["capabilities"]
|
||||
|
||||
def test_general_role_omits_transmit_capability(self):
|
||||
payload = self._payload(_agent(role="general"))
|
||||
assert "transmit" not in payload.get("capabilities", [])
|
||||
|
||||
def test_session_code_included_when_set(self):
|
||||
payload = self._payload(_agent(role="tx", session_code="amber-peak-transmit"))
|
||||
assert payload["session_code"] == "amber-peak-transmit"
|
||||
|
||||
def test_session_code_omitted_when_none(self):
|
||||
payload = self._payload(_agent())
|
||||
assert "session_code" not in payload
|
||||
|
||||
def test_register_stores_returned_node_id(self):
|
||||
a = _agent()
|
||||
_mock_register(a, node_id="node_xyz999")
|
||||
a._register()
|
||||
assert a.node_id == "node_xyz999"
|
||||
|
||||
def test_name_in_payload(self):
|
||||
a = NodeAgent(hub_url="http://h", api_key="k", name="my-bench")
|
||||
_mock_register(a)
|
||||
a._register()
|
||||
_, kwargs = a._post.call_args
|
||||
assert kwargs["json"]["name"] == "my-bench"
|
||||
|
||||
def test_sdr_device_in_payload(self):
|
||||
a = _agent()
|
||||
post = _mock_register(a)
|
||||
a._register()
|
||||
_, kwargs = post.call_args
|
||||
assert kwargs["json"]["sdr_device"] == "mock"
|
||||
|
||||
def test_campaign_capability_always_present(self):
|
||||
for role in ("general", "rx", "tx"):
|
||||
a = _agent(role=role)
|
||||
payload = self._payload(a)
|
||||
assert "campaign" in payload["capabilities"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _dispatch — TX commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNodeAgentDispatch:
|
||||
def _make_agent(self):
|
||||
a = _agent(role="tx")
|
||||
a.node_id = "node_abc"
|
||||
a._report_campaign_status = MagicMock()
|
||||
return a
|
||||
|
||||
def test_start_transmit_spawns_thread(self):
|
||||
a = self._make_agent()
|
||||
done = threading.Event()
|
||||
|
||||
class _FakeExecutor:
|
||||
def run(self_):
|
||||
done.wait(timeout=2)
|
||||
|
||||
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
|
||||
a._dispatch({"command": "start_transmit", "sdr_agent": {}, "schedule": []})
|
||||
time.sleep(0.05)
|
||||
assert a._tx_thread is not None
|
||||
done.set()
|
||||
|
||||
def test_start_transmit_clears_stop_event(self):
|
||||
a = self._make_agent()
|
||||
a._tx_stop.set() # pre-set
|
||||
|
||||
done = threading.Event()
|
||||
|
||||
class _FakeExecutor:
|
||||
def run(self_):
|
||||
done.wait(timeout=2)
|
||||
|
||||
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
|
||||
a._dispatch({"command": "start_transmit", "sdr_agent": {}, "schedule": []})
|
||||
time.sleep(0.05)
|
||||
assert not a._tx_stop.is_set()
|
||||
done.set()
|
||||
|
||||
def test_stop_transmit_sets_stop_event(self):
|
||||
a = self._make_agent()
|
||||
a._dispatch({"command": "stop_transmit"})
|
||||
assert a._tx_stop.is_set()
|
||||
|
||||
def test_configure_transmit_does_not_raise(self):
|
||||
a = self._make_agent()
|
||||
a._dispatch({"command": "configure_transmit", "modulation": "BPSK"})
|
||||
|
||||
def test_unknown_command_is_ignored(self):
|
||||
a = self._make_agent()
|
||||
a._dispatch({"command": "frobnicate_xyz"})
|
||||
|
||||
def test_duplicate_start_transmit_ignored_while_running(self):
|
||||
a = self._make_agent()
|
||||
done = threading.Event()
|
||||
run_calls = []
|
||||
|
||||
class _FakeExecutor:
|
||||
def run(self_):
|
||||
run_calls.append(1)
|
||||
done.wait(timeout=2)
|
||||
|
||||
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
|
||||
a._dispatch({"command": "start_transmit"})
|
||||
time.sleep(0.05)
|
||||
a._dispatch({"command": "start_transmit"}) # second while first alive
|
||||
done.set()
|
||||
time.sleep(0.05)
|
||||
|
||||
assert len(run_calls) == 1
|
||||
|
||||
def test_run_campaign_dispatched_in_thread(self):
|
||||
a = self._make_agent()
|
||||
done = threading.Event()
|
||||
|
||||
with patch("ria_toolkit_oss.agent.NodeAgent._run_campaign") as mock_run:
|
||||
mock_run.side_effect = lambda *_: done.set()
|
||||
a._dispatch({"command": "run_campaign", "campaign_id": "c1", "payload": {}})
|
||||
done.wait(timeout=2)
|
||||
assert mock_run.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _stop_transmit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStopTransmit:
|
||||
def test_no_thread_noop(self):
|
||||
a = _agent()
|
||||
a._stop_transmit() # must not raise
|
||||
|
||||
def test_sets_stop_event(self):
|
||||
a = _agent()
|
||||
a._stop_transmit()
|
||||
assert a._tx_stop.is_set()
|
||||
|
||||
def test_joins_live_thread(self):
|
||||
a = _agent()
|
||||
finished = threading.Event()
|
||||
unblock = threading.Event()
|
||||
|
||||
def _task():
|
||||
unblock.wait(timeout=2)
|
||||
finished.set()
|
||||
|
||||
t = threading.Thread(target=_task, daemon=True)
|
||||
t.start()
|
||||
a._tx_thread = t
|
||||
|
||||
# Signal stop and trigger thread exit
|
||||
a._tx_stop.set()
|
||||
unblock.set()
|
||||
a._stop_transmit()
|
||||
|
||||
assert not t.is_alive()
|
||||
Loading…
Reference in New Issue
Block a user