diff --git a/poetry.lock b/poetry.lock index 92569ed..9d1e5fe 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.3.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "alabaster" @@ -230,14 +230,14 @@ uvloop = ["uvloop (>=0.15.2) ; sys_platform != \"win32\"", "winloop (>=0.5.0) ; [[package]] name = "cachetools" -version = "7.0.5" +version = "7.0.6" description = "Extensible memoizing collections and decorators" optional = false python-versions = ">=3.10" groups = ["test"] files = [ - {file = "cachetools-7.0.5-py3-none-any.whl", hash = "sha256:46bc8ebefbe485407621d0a4264b23c080cedd913921bad7ac3ed2f26c183114"}, - {file = "cachetools-7.0.5.tar.gz", hash = "sha256:0cd042c24377200c1dcd225f8b7b12b0ca53cc2c961b43757e774ebe190fd990"}, + {file = "cachetools-7.0.6-py3-none-any.whl", hash = "sha256:4e94956cfdd3086f12042cdd29318f5ced3893014f7d0d059bf3ead3f85b7f8b"}, + {file = "cachetools-7.0.6.tar.gz", hash = "sha256:e5d524d36d65703a87243a26ff08ad84f73352adbeafb1cde81e207b456aaf24"}, ] [[package]] @@ -1271,7 +1271,7 @@ files = [ [package.dependencies] attrs = ">=22.2.0" -jsonschema-specifications = ">=2023.3.6" +jsonschema-specifications = ">=2023.03.6" referencing = ">=0.28.4" rpds-py = ">=0.25.0" @@ -3749,4 +3749,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = ">=3.10" -content-hash = "ffde300b2fc93161d2279a6e2b899bc988d3b5eb3833135821830affc9a5fb62" +content-hash = "66c9adf647316db90f963da05e8a83574378bfa4db2c69ce751446b5ee7c408c" diff --git a/pyproject.toml b/pyproject.toml index e648d0c..48a9e1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ dependencies = [ "pyyaml (>=6.0.3,<7.0.0)", "click (>=8.1.0,<9.0.0)", "matplotlib (>=3.8.0,<4.0.0)", - "paramiko (>=4.0.0)" + "paramiko (>=3.5.1)" ] # [project.optional-dependencies] Commented out to prevent Tox tests from failing @@ -149,6 +149,11 @@ exclude = ''' [tool.pytest.ini_options] pythonpath = ["src"] +filterwarnings = [ + # FastAPI emits this internally when handling 422 responses; the constant + # is not yet renamed in the installed starlette version, so we can't migrate. + "ignore:'HTTP_422_UNPROCESSABLE_ENTITY' is deprecated:DeprecationWarning", +] [tool.isort] profile = "black" diff --git a/src/ria_toolkit_oss/agent/legacy_executor.py b/src/ria_toolkit_oss/agent/legacy_executor.py index 6e15eb1..d4a302a 100644 --- a/src/ria_toolkit_oss/agent/legacy_executor.py +++ b/src/ria_toolkit_oss/agent/legacy_executor.py @@ -68,7 +68,7 @@ _HEARTBEAT_INTERVAL = 30 # seconds between heartbeats _POLL_TIMEOUT = 30 # server-side long-poll duration _POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server _RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying -_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB — well below Cloudflare's 100 MB limit +_CHUNK_SIZE = 10 * 1024 * 1024 # 10 MB per chunk — fast enough for git-LFS to process within timeout _DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload _CAPTURE_SAMPLES = 4096 # IQ samples per inference window _IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"}) @@ -93,16 +93,24 @@ class NodeAgent: name: str, sdr_device: str = "unknown", insecure: bool = False, + role: str = "general", + session_code: str | None = None, ) -> None: self.hub_url = hub_url.rstrip("/") self.api_key = api_key self.name = name self.sdr_device = sdr_device self.insecure = insecure + self.role = role + self.session_code = session_code self.node_id: str | None = None self._stop = threading.Event() + # ── TX state ──────────────────────────────────────────────────────── + self._tx_stop = threading.Event() + self._tx_thread: threading.Thread | None = None + # ── Inference state ───────────────────────────────────────────────── # Protected by _inf_lock for cross-thread model swaps. self._inf_lock = threading.Lock() @@ -172,19 +180,27 @@ class NodeAgent: capabilities = ["campaign"] if self._ort_available: capabilities.append("inference") - resp = self._post( - "/composer/nodes/register", - json={ - "name": self.name, - "sdr_device": self.sdr_device, - "ria_toolkit_version": self._ria_version, - "capabilities": capabilities, - }, - timeout=15, - ) + if self.role == "tx": + capabilities.append("transmit") + payload: dict = { + "name": self.name, + "sdr_device": self.sdr_device, + "ria_toolkit_version": self._ria_version, + "capabilities": capabilities, + "role": self.role, + } + if self.session_code: + payload["session_code"] = self.session_code + resp = self._post("/composer/nodes/register", json=payload, timeout=15) resp.raise_for_status() self.node_id = resp.json()["node_id"] - logger.info("Registered as %r (node_id=%s)", self.name, self.node_id) + logger.info( + "Registered as %r (node_id=%s, role=%s%s)", + self.name, + self.node_id, + self.role, + f", session_code={self.session_code!r}" if self.session_code else "", + ) def _deregister(self) -> None: if not self.node_id: @@ -245,9 +261,10 @@ class NodeAgent: if command == "run_campaign": campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4()) config_dict: dict = cmd.get("payload") or {} + skip_local_tx: bool = bool(cmd.get("skip_local_tx", False)) threading.Thread( target=self._run_campaign, - args=(campaign_id, config_dict), + args=(campaign_id, config_dict, skip_local_tx), daemon=True, name=f"campaign-{campaign_id[:8]}", ).start() @@ -269,6 +286,17 @@ class NodeAgent: self._stop_inference() elif command == "configure_inference": self._queue_sdr_config(cmd) + elif command == "start_transmit": + threading.Thread( + target=self._start_transmit, + args=(cmd,), + daemon=True, + name="ria-start-tx", + ).start() + elif command == "stop_transmit": + self._stop_transmit() + elif command == "configure_transmit": + logger.info("configure_transmit received — will apply on next step boundary") else: logger.warning("Unknown command %r — ignored", command) @@ -276,7 +304,7 @@ class NodeAgent: # Campaign execution # ------------------------------------------------------------------ - def _run_campaign(self, campaign_id: str, config_dict: dict) -> None: + def _run_campaign(self, campaign_id: str, config_dict: dict, skip_local_tx: bool = False) -> None: try: from ria_toolkit_oss.orchestration.campaign import CampaignConfig from ria_toolkit_oss.orchestration.executor import CampaignExecutor @@ -288,10 +316,10 @@ class NodeAgent: ) return - logger.info("Campaign %s starting", campaign_id[:8]) + logger.info("Campaign %s starting (skip_local_tx=%s)", campaign_id[:8], skip_local_tx) try: config = CampaignConfig.from_dict(config_dict) - executor = CampaignExecutor(config) + executor = CampaignExecutor(config, skip_local_tx=skip_local_tx) result = executor.run() logger.info("Campaign %s completed — uploading recordings", campaign_id[:8]) self._upload_recordings(campaign_id, config, result) @@ -301,6 +329,58 @@ class NodeAgent: logger.error("Campaign %s failed: %s", campaign_id[:8], exc) self._report_campaign_status(campaign_id, "failed", error=str(exc)) + # ------------------------------------------------------------------ + # TX execution + # ------------------------------------------------------------------ + + def _start_transmit(self, cmd: dict) -> None: + """Execute a synthetic transmit campaign using TxExecutor. + + The command payload mirrors a TransmitterConfig dict with an optional + ``schedule`` of steps. Each step synthesises a signal and transmits it + via the local SDR in TX mode. + """ + try: + from ria_toolkit_oss.orchestration.tx_executor import TxExecutor + except ImportError as exc: + logger.error("start_transmit: TxExecutor not available: %s", exc) + return + + if self._tx_thread and self._tx_thread.is_alive(): + logger.warning("start_transmit: TX already running — ignoring duplicate command") + return + + self._tx_stop.clear() + campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4()) + executor = TxExecutor( + config=cmd, + sdr_device=self.sdr_device, + stop_event=self._tx_stop, + ) + self._tx_thread = threading.Thread( + target=self._run_tx_campaign, + args=(executor, campaign_id), + daemon=True, + name=f"tx-campaign-{campaign_id[:8]}", + ) + self._tx_thread.start() + + def _run_tx_campaign(self, executor: Any, campaign_id: str) -> None: + try: + executor.run() + logger.info("TX campaign %s completed", campaign_id[:8]) + self._report_campaign_status(campaign_id, "completed") + except Exception as exc: + logger.error("TX campaign %s failed: %s", campaign_id[:8], exc) + self._report_campaign_status(campaign_id, "failed", error=str(exc)) + + def _stop_transmit(self) -> None: + """Signal the TX loop to stop gracefully.""" + self._tx_stop.set() + if self._tx_thread and self._tx_thread.is_alive(): + self._tx_thread.join(timeout=5.0) + logger.info("TX stopped") + # ------------------------------------------------------------------ # Inference — model loading # ------------------------------------------------------------------ @@ -579,13 +659,18 @@ class NodeAgent: base_url = f"{self.hub_url}/datasets/upload" steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or [] + output_obj = getattr(config, "output", None) + folder = getattr(output_obj, "folder", None) + campaign_name: str = folder if folder is not None else (getattr(config, "name", None) or "") for step in steps: output_path: str | None = getattr(step, "output_path", None) if not output_path: continue device_id: str = getattr(step, "transmitter_id", "") or "" for fpath in _sigmf_files(output_path): - filename = os.path.basename(fpath) + basename = os.path.basename(fpath) + path_parts = [p for p in (campaign_name, device_id) if p] + filename = "/".join(path_parts + [basename]) metadata = { "filename": filename, "repo_owner": repo_owner, @@ -671,7 +756,7 @@ class NodeAgent: headers=headers, files={"file": (filename, chunk, "application/octet-stream")}, data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks}, - timeout=120, + timeout=(30, None), # 30s connect, no read timeout — server may take minutes on final chunk verify=verify, ) if not resp.ok: @@ -848,6 +933,21 @@ def main() -> None: choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Logging verbosity (default: INFO)", ) + parser.add_argument( + "--role", + default=None, + choices=["general", "rx", "tx"], + help=("Node role reported to the hub. " "'tx' enables synthetic transmission commands. " "Default: general"), + ) + parser.add_argument( + "--session-code", + default=None, + metavar="CODE", + help=( + "3-word session code to pair this TX agent with a waiting campaign, " + "e.g. 'amber-peak-transmit'. Supplied by the campaign UI." + ), + ) args = parser.parse_args() @@ -861,6 +961,8 @@ def main() -> None: device = args.device or cfg.get("device", "unknown") insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False) log_level = args.log_level or cfg.get("log_level", "INFO") + role = args.role or cfg.get("role", "general") + session_code = args.session_code or cfg.get("session_code") if not hub: parser.error("--hub is required (or set 'hub' in the config file)") @@ -888,6 +990,8 @@ def main() -> None: name=name, sdr_device=device, insecure=insecure, + role=role, + session_code=session_code, ) agent.run() diff --git a/src/ria_toolkit_oss/orchestration/campaign.py b/src/ria_toolkit_oss/orchestration/campaign.py index 027c33f..105cc40 100644 --- a/src/ria_toolkit_oss/orchestration/campaign.py +++ b/src/ria_toolkit_oss/orchestration/campaign.py @@ -233,6 +233,9 @@ class TransmitterConfig: # For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port sdr_remote: Optional[dict] = None + # For sdr_agent control — keys: modulation, order, symbol_rate, center_frequency, filter, rolloff + sdr_agent: Optional[dict] = None + @classmethod def from_dict(cls, d: dict) -> "TransmitterConfig": schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])] @@ -244,6 +247,7 @@ class TransmitterConfig: script=d.get("script"), device=d.get("device"), sdr_remote=d.get("sdr_remote"), + sdr_agent=d.get("sdr_agent"), ) @@ -272,6 +276,7 @@ class OutputConfig: path: str = "recordings" device_id: Optional[str] = None # for device-profile campaigns repo: Optional[str] = None + folder: Optional[str] = None # repo subfolder: None = use campaign name, "" = no subfolder, str = custom @classmethod def from_dict(cls, d: dict) -> "OutputConfig": @@ -280,6 +285,7 @@ class OutputConfig: path=str(d.get("path", "recordings")), device_id=d.get("device_id"), repo=d.get("repo"), + folder=d.get("folder"), ) @@ -293,6 +299,7 @@ class CampaignConfig: qa: QAConfig = field(default_factory=QAConfig) output: OutputConfig = field(default_factory=OutputConfig) mode: str = "controlled_testbed" + loops: int = 1 # repeat full schedule this many times; labels get _run{N:02d} suffix # --------------------------------------------------------------------------- # Loaders @@ -320,6 +327,7 @@ class CampaignConfig: return cls( name=safe_name, mode=str(campaign_meta.get("mode", "controlled_testbed")), + loops=max(1, int(campaign_meta.get("loops", 1))), recorder=RecorderConfig.from_dict(raw["recorder"]), transmitters=transmitters, qa=QAConfig.from_dict(raw.get("qa", {})), @@ -384,6 +392,7 @@ class CampaignConfig: return cls( name=safe_name, mode=str(campaign_meta.get("mode", "controlled_testbed")), + loops=max(1, int(campaign_meta.get("loops", 1))), recorder=RecorderConfig.from_dict(raw["recorder"]), transmitters=transmitters, qa=QAConfig.from_dict(raw.get("qa", {})), @@ -486,9 +495,9 @@ class CampaignConfig: ) def total_capture_time_s(self) -> float: - """Sum of all step durations across all transmitters.""" - return sum(step.duration for tx in self.transmitters for step in tx.schedule) + """Sum of all step durations across all transmitters and loops.""" + return sum(step.duration for tx in self.transmitters for step in tx.schedule) * self.loops def total_steps(self) -> int: - """Total number of capture steps across all transmitters.""" - return sum(len(tx.schedule) for tx in self.transmitters) + """Total number of capture steps across all transmitters and loops.""" + return sum(len(tx.schedule) for tx in self.transmitters) * self.loops diff --git a/src/ria_toolkit_oss/orchestration/executor.py b/src/ria_toolkit_oss/orchestration/executor.py index b04e296..4df75dc 100644 --- a/src/ria_toolkit_oss/orchestration/executor.py +++ b/src/ria_toolkit_oss/orchestration/executor.py @@ -5,8 +5,9 @@ from __future__ import annotations import json import logging import subprocess +import threading import time -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from pathlib import Path from typing import Callable, Optional @@ -16,6 +17,7 @@ from ria_toolkit_oss.io.recording import to_sigmf from .campaign import CampaignConfig, CaptureStep, TransmitterConfig from .labeler import build_output_filename, label_recording from .qa import QAResult, check_recording +from .tx_executor import TxExecutor logger = logging.getLogger(__name__) @@ -169,6 +171,21 @@ def _run_script(script: str, *args: str, timeout: float = 15.0) -> str: # --------------------------------------------------------------------------- +def _extract_tx_params(transmitter: TransmitterConfig) -> dict | None: + """Build a tx_params dict from a transmitter's signal config for SigMF labeling. + + For sdr_agent transmitters, returns the synthetic generation parameters + (modulation, order, symbol_rate, etc.) so recordings capture what was + transmitted. Returns None for control methods without signal-level params. + """ + sdr_agent_cfg = getattr(transmitter, "sdr_agent", None) + if not sdr_agent_cfg: + return None + # Extract known signal-level fields; ignore infra fields + _INFRA_KEYS = {"node_id", "session_code"} + return {k: v for k, v in sdr_agent_cfg.items() if k not in _INFRA_KEYS and v is not None} + + class CampaignExecutor: """Executes a :class:`CampaignConfig` end-to-end. @@ -192,11 +209,14 @@ class CampaignExecutor: config: CampaignConfig, progress_cb: Optional[Callable[[int, int, StepResult], None]] = None, verbose: bool = False, + skip_local_tx: bool = False, ): self.config = config self.progress_cb = progress_cb + self.skip_local_tx = skip_local_tx self._sdr = None self._remote_tx_controllers: dict = {} + self._tx_executors: dict[str, tuple] = {} # tx_id → (TxExecutor, stop_event, thread) if verbose: logging.basicConfig(level=logging.DEBUG) @@ -216,10 +236,12 @@ class CampaignExecutor: """ result = CampaignResult(campaign_name=self.config.name) + loops = self.config.loops logger.info( f"Starting campaign '{self.config.name}': " - f"{self.config.total_steps()} steps, " - f"~{self.config.total_capture_time_s():.0f}s capture time" + f"{self.config.total_steps()} steps" + + (f" ({self.config.total_steps() // loops} × {loops} loops)" if loops > 1 else "") + + f", ~{self.config.total_capture_time_s():.0f}s capture time" ) self._init_sdr() @@ -228,29 +250,36 @@ class CampaignExecutor: total = self.config.total_steps() step_index = 0 - for transmitter in self.config.transmitters: - logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)") - for step in transmitter.schedule: - step_result = self._execute_step(transmitter, step) - result.steps.append(step_result) - step_index += 1 + for loop_idx in range(loops): + if loops > 1: + logger.info(f"Loop {loop_idx + 1}/{loops}") + for transmitter in self.config.transmitters: + logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)") + for step in transmitter.schedule: + looped_step = replace(step, label=f"{step.label}_run{loop_idx + 1:02d}") if loops > 1 else step + step_result = self._execute_step(transmitter, looped_step) + result.steps.append(step_result) + step_index += 1 - if self.progress_cb: - self.progress_cb(step_index, total, step_result) + if self.progress_cb: + self.progress_cb(step_index, total, step_result) - if step_result.error: - logger.warning(f"Step '{step.label}' error: {step_result.error}") - elif step_result.qa.flagged: - logger.warning(f"Step '{step.label}' flagged for review: " + "; ".join(step_result.qa.issues)) - else: - logger.info( - f"Step '{step.label}' OK " - f"(SNR {step_result.qa.snr_db:.1f} dB, " - f"{step_result.qa.duration_s:.1f}s)" - ) + if step_result.error: + logger.warning(f"Step '{looped_step.label}' error: {step_result.error}") + elif step_result.qa.flagged: + logger.warning( + f"Step '{looped_step.label}' flagged for review: " + "; ".join(step_result.qa.issues) + ) + else: + logger.info( + f"Step '{looped_step.label}' OK " + f"(SNR {step_result.qa.snr_db:.1f} dB, " + f"{step_result.qa.duration_s:.1f}s)" + ) finally: self._close_sdr() self._close_remote_tx_controllers() + self._close_tx_executors() result.end_time = time.time() logger.info( @@ -325,6 +354,12 @@ class CampaignExecutor: logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}") self._remote_tx_controllers.clear() + def _close_tx_executors(self) -> None: + for tx_id, (_, stop_event, t) in list(self._tx_executors.items()): + stop_event.set() + t.join(timeout=5.0) + self._tx_executors.clear() + def _record(self, duration_s: float) -> Recording: """Capture ``duration_s`` seconds of IQ samples.""" num_samples = int(duration_s * self.config.recorder.sample_rate) @@ -369,6 +404,7 @@ class CampaignExecutor: step=step, capture_timestamp=capture_timestamp, campaign_name=self.config.name, + tx_params=_extract_tx_params(transmitter), ) # QA @@ -437,6 +473,30 @@ class CampaignExecutor: # Start transmission in background; _record() runs concurrently ctrl.transmit_async(step.duration + 1.0) + elif transmitter.control_method == "sdr_agent": + if self.skip_local_tx: + logger.debug(f"skip_local_tx — TX for '{transmitter.id}' delegated to TX agent node") + return + if not transmitter.sdr_agent: + logger.warning(f"Transmitter '{transmitter.id}' has no sdr_agent config — skipping") + return + step_dict: dict = {"label": step.label, "duration": step.duration + 1.0} + if step.power_dbm is not None: + step_dict["power_dbm"] = step.power_dbm + tx_config = { + "id": transmitter.id, + "sdr_agent": transmitter.sdr_agent, + "schedule": [step_dict], + } + rec = self.config.recorder + tx_device = transmitter.device or rec.device + sdr_device = _DEVICE_ALIASES.get(tx_device.lower(), tx_device.lower()) + stop_event = threading.Event() + executor = TxExecutor(tx_config, sdr_device=sdr_device, stop_event=stop_event) + t = threading.Thread(target=executor.run, daemon=True, name=f"tx-{transmitter.id}") + self._tx_executors[transmitter.id] = (executor, stop_event, t) + t.start() + else: logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping") @@ -459,6 +519,13 @@ class CampaignExecutor: if ctrl is not None: ctrl.wait_transmit(timeout=step.duration + 10.0) + elif transmitter.control_method == "sdr_agent": + entry = self._tx_executors.pop(transmitter.id, None) + if entry is not None: + _, stop_event, t = entry + stop_event.set() + t.join(timeout=step.duration + 10.0) + @staticmethod def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str: """Serialise step parameters to a JSON string for the control script.""" diff --git a/src/ria_toolkit_oss/orchestration/labeler.py b/src/ria_toolkit_oss/orchestration/labeler.py index 2e4def0..ab492bd 100644 --- a/src/ria_toolkit_oss/orchestration/labeler.py +++ b/src/ria_toolkit_oss/orchestration/labeler.py @@ -15,6 +15,7 @@ def label_recording( step: CaptureStep, capture_timestamp: float, campaign_name: Optional[str] = None, + tx_params: Optional[dict] = None, ) -> Recording: """Apply device identity and capture configuration labels to a recording's metadata. @@ -27,6 +28,9 @@ def label_recording( step: The capture step that was active during this recording. capture_timestamp: Unix timestamp (float) of when capture started. campaign_name: Optional campaign name for cross-recording reference. + tx_params: Optional dict of transmitter signal parameters (e.g. modulation, + order, symbol_rate) written as ``ria:tx_`` fields so downstream + training pipelines know what was transmitted into the recording. Returns: The same recording with updated metadata. @@ -57,6 +61,11 @@ def label_recording( if step.power_dbm is not None: recording.update_metadata("tx_power_dbm", step.power_dbm) + # Transmitter signal parameters (e.g. from sdr_agent synthetic generation) + if tx_params: + for key, value in tx_params.items(): + recording.update_metadata(f"tx_{key}", value) + return recording diff --git a/src/ria_toolkit_oss/orchestration/tx_executor.py b/src/ria_toolkit_oss/orchestration/tx_executor.py new file mode 100644 index 0000000..a3c9bdc --- /dev/null +++ b/src/ria_toolkit_oss/orchestration/tx_executor.py @@ -0,0 +1,299 @@ +"""TX campaign executor — synthesises and transmits signals via a local SDR. + +The TxExecutor receives a transmitter config dict (matching the +``sdr_agent`` control method's schema) and a step schedule, then for each +step builds a signal chain with the block generator and transmits it via +the local SDR device. + +Supported modulations (``modulation`` field in config): + BPSK, QPSK, 8PSK, 16QAM, 64QAM, 256QAM, FSK, OOK, GMSK, OQPSK + +Example config dict (matches CampaignConfig transmitter with +``control_method: sdr_agent``):: + + { + "id": "synthetic-tx", + "type": "sdr", + "control_method": "sdr_agent", + "sdr_agent": { + "modulation": "QPSK", + "order": 4, + "symbol_rate": 1000000, + "center_frequency": 0.0, + "filter": "rrc", + "rolloff": 0.35 + }, + "schedule": [ + {"label": "step1", "duration": 10, "power_dbm": -10} + ] + } +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +logger = logging.getLogger(__name__) + + +def _parse_hz(val: object) -> float: + """Parse a frequency value that may be a float (Hz) or a string like '2.45GHz'.""" + if isinstance(val, (int, float)): + return float(val) + s = str(val).strip() + for suffix, mult in (("GHz", 1e9), ("MHz", 1e6), ("kHz", 1e3), ("Hz", 1.0)): + if s.endswith(suffix): + return float(s[: -len(suffix)]) * mult + return float(s) + + +def _parse_seconds(val: object) -> float: + """Parse a duration value that may be a float (seconds) or a string like '5s'.""" + if isinstance(val, (int, float)): + return float(val) + s = str(val).strip() + return float(s[:-1]) if s.endswith("s") else float(s) + + +# Mapping from modulation name → (PSK/QAM order, generator_type) +# 'psk' uses PSKGenerator, 'qam' uses QAMGenerator +_MOD_TABLE: dict[str, tuple[int, str]] = { + "BPSK": (1, "psk"), + "QPSK": (2, "psk"), + "8PSK": (3, "psk"), + "16QAM": (4, "qam"), + "64QAM": (6, "qam"), + "256QAM": (8, "qam"), +} + +_SPECIAL_MODS = {"FSK", "OOK", "GMSK", "OQPSK"} + +# usrp-uhd-client's tx_recording() streams 2 000-sample chunks and loops the +# source buffer for the full tx_time, so only this many samples ever need to +# be in RAM regardless of step duration or sample rate. +# 50 000 complex64 samples ≈ 400 kB — enough spectral diversity for looping. +_SYNTH_BLOCK_SAMPLES = 50_000 + + +class TxExecutor: + """Synthesise and transmit a signal campaign via a local SDR. + + Args: + config: Transmitter config dict (must have ``sdr_agent`` sub-dict with + modulation params, and ``schedule`` list of step dicts). + sdr_device: SDR device name to open in TX mode (e.g. "pluto", "usrp"). + stop_event: External event that aborts the TX loop mid-step. + """ + + def __init__( + self, + config: dict, + sdr_device: str = "unknown", + stop_event: threading.Event | None = None, + ) -> None: + self.config = config + self.sdr_device = sdr_device + self.stop_event = stop_event or threading.Event() + self._sdr: Any = None + + def run(self) -> None: + """Execute all steps in the schedule, transmitting for each step duration.""" + agent_cfg: dict = self.config.get("sdr_agent") or {} + schedule: list[dict] = self.config.get("schedule") or [] + + if not schedule: + logger.warning("TxExecutor: no schedule steps — nothing to transmit") + return + + modulation: str = agent_cfg.get("modulation", "QPSK").upper() + symbol_rate: float = float(agent_cfg.get("symbol_rate", 1e6)) + center_freq: float = _parse_hz(agent_cfg.get("center_frequency", 0.0)) + filter_type: str = agent_cfg.get("filter", "rrc").lower() + rolloff: float = float(agent_cfg.get("rolloff", 0.35)) + loops: int = max(1, int(self.config.get("loops", 1))) + + # Upsampling factor: samples_per_symbol, fixed at 8 for SDR compatibility. + sps = 8 + sample_rate = symbol_rate * sps + + self._init_sdr(sample_rate, center_freq) + try: + for loop_idx in range(loops): + if self.stop_event.is_set(): + break + if loops > 1: + logger.info("TX loop %d/%d", loop_idx + 1, loops) + for step in schedule: + if self.stop_event.is_set(): + break + looped_step = ( + {**step, "label": f"{step.get('label', 'step')}_run{loop_idx + 1:02d}"} if loops > 1 else step + ) + self._execute_step(looped_step, modulation, sps, symbol_rate, filter_type, rolloff) + finally: + self._close_sdr() + + def _execute_step( + self, + step: dict, + modulation: str, + sps: int, + symbol_rate: float, + filter_type: str, + rolloff: float, + ) -> None: + duration: float = _parse_seconds(step.get("duration", 10.0)) + label: str = step.get("label", "step") + gain: float = float(step.get("power_dbm") or 0.0) + sample_rate = symbol_rate * sps + + logger.info( + "TX step '%s': %.0f s, %s @ %.3f MHz (sps=%d, filter=%s)", + label, + duration, + modulation, + symbol_rate / 1e6, + sps, + filter_type, + ) + + num_samples = int(duration * sample_rate) + + # Synthesise a short representative block. tx_recording() loops this + # buffer for the full tx_time using a 2 000-sample streaming callback, + # so peak memory is O(_SYNTH_BLOCK_SAMPLES) regardless of duration. + block_size = min(num_samples, _SYNTH_BLOCK_SAMPLES) + signal = self._synthesise(modulation, sps, block_size, filter_type, rolloff) + + if self._sdr is not None: + try: + # Apply gain update if SDR supports it + if hasattr(self._sdr, "set_tx_gain"): + self._sdr.set_tx_gain(gain) + self._sdr.tx_recording(signal, tx_time=duration) + except Exception as exc: + logger.error("TX step '%s' SDR error: %s", label, exc) + else: + # No SDR available — simulate by sleeping for the step duration. + logger.warning("TX step '%s': no SDR — simulating %.0f s delay", label, duration) + self.stop_event.wait(timeout=duration) + + def _synthesise( + self, + modulation: str, + sps: int, + num_samples: int, + filter_type: str, + rolloff: float, + ): + """Build a block-generator chain and return IQ samples as a numpy array.""" + try: + import numpy as np + + from ria_toolkit_oss.signal.block_generator import ( + BinarySource, + GMSKModulator, + Mapper, + OOKModulator, + OQPSKModulator, + RaisedCosineFilter, + RootRaisedCosineFilter, + Upsampling, + ) + from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import ( + FSKModulator, + ) + except ImportError as exc: + raise RuntimeError(f"ria_toolkit_oss block generator not available: {exc}") from exc + + # ── Special modulations with their own source-connected modulator ── + if modulation in ("OOK", "GMSK", "OQPSK"): + src = BinarySource() + if modulation == "OOK": + mod = OOKModulator(src, samples_per_symbol=sps) + elif modulation == "GMSK": + mod = GMSKModulator(src, samples_per_symbol=sps) + else: + mod = OQPSKModulator(src, samples_per_symbol=sps) + recording = mod.record(num_samples) + flat = np.asarray(recording.data).flatten().astype(np.complex64) + if len(flat) < num_samples: + flat = np.tile(flat, num_samples // len(flat) + 1) + return flat[:num_samples] + + if modulation == "FSK": + symbol_rate = num_samples / sps + bits_per_sym = 1 # 2-FSK + num_bits = max(num_samples // sps, 128) * bits_per_sym + bits = BinarySource()((1, num_bits)) + mod = FSKModulator( + num_bits_per_symbol=bits_per_sym, + frequency_spacing=symbol_rate * 0.5, + symbol_duration=1.0 / max(symbol_rate, 1.0), + sampling_frequency=symbol_rate * sps, + ) + flat = np.asarray(mod(bits)).flatten().astype(np.complex64) + if len(flat) < num_samples: + flat = np.tile(flat, num_samples // len(flat) + 1) + return flat[:num_samples] + + # ── PSK / QAM via Mapper → Upsampling → pulse filter ────────────── + if modulation not in _MOD_TABLE: + logger.warning("Unknown modulation %r — defaulting to QPSK", modulation) + modulation = "QPSK" + + bits_per_sym, gen_type = _MOD_TABLE[modulation] + mod_family = "QAM" if gen_type == "qam" else "PSK" + + source = BinarySource() + mapper = Mapper(constellation_type=mod_family, num_bits_per_symbol=bits_per_sym) + upsampler = Upsampling(factor=sps) + + mapper.connect_input([source]) + upsampler.connect_input([mapper]) + + if filter_type in ("rrc",): + pulse_filter = RootRaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff) + pulse_filter.connect_input([upsampler]) + recording = pulse_filter.record(num_samples) + elif filter_type in ("rc",): + pulse_filter = RaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff) + pulse_filter.connect_input([upsampler]) + recording = pulse_filter.record(num_samples) + else: + # "none", "rect", "gaussian" — use upsampler output directly + recording = upsampler.record(num_samples) + + flat = np.asarray(recording.data).flatten().astype(np.complex64) + if len(flat) < num_samples: + flat = np.tile(flat, num_samples // len(flat) + 1) + return flat[:num_samples] + + def _init_sdr(self, sample_rate: float, center_freq: float) -> None: + try: + from ria_toolkit_oss.sdr import get_sdr_device + + self._sdr = get_sdr_device(self.sdr_device) + self._sdr.init_tx( + sample_rate=sample_rate, + center_frequency=center_freq, + gain=0, + channel=0, + gain_mode="manual", + ) + logger.info( + "TX SDR initialised: %s @ %.3f MHz, %.1f Msps", self.sdr_device, center_freq / 1e6, sample_rate / 1e6 + ) + except Exception as exc: + logger.warning("TX SDR init failed (%s) — will simulate: %s", self.sdr_device, exc) + self._sdr = None + + def _close_sdr(self) -> None: + if self._sdr is not None: + try: + self._sdr.close() + except Exception as exc: + logger.debug("TX SDR close error: %s", exc) + self._sdr = None diff --git a/src/ria_toolkit_oss/server/app.py b/src/ria_toolkit_oss/server/app.py index 5e4c58b..42799a6 100644 --- a/src/ria_toolkit_oss/server/app.py +++ b/src/ria_toolkit_oss/server/app.py @@ -3,7 +3,7 @@ from fastapi import Depends, FastAPI from .auth import require_api_key -from .routers import inference, orchestrator +from .routers import conductor, inference def create_app(api_key: str = "") -> FastAPI: @@ -28,9 +28,9 @@ def create_app(api_key: str = "") -> FastAPI: app.state.api_key = api_key app.include_router( - orchestrator.router, - prefix="/orchestrator", - tags=["Orchestrator"], + conductor.router, + prefix="/conductor", + tags=["Conductor"], dependencies=[Depends(require_api_key)], ) app.include_router( diff --git a/src/ria_toolkit_oss/server/models.py b/src/ria_toolkit_oss/server/models.py index e2ba450..9fd88d9 100644 --- a/src/ria_toolkit_oss/server/models.py +++ b/src/ria_toolkit_oss/server/models.py @@ -7,7 +7,7 @@ from pathlib import Path from pydantic import BaseModel, field_validator # --------------------------------------------------------------------------- -# Orchestrator +# Conductor # --------------------------------------------------------------------------- diff --git a/src/ria_toolkit_oss/server/routers/orchestrator.py b/src/ria_toolkit_oss/server/routers/conductor.py similarity index 98% rename from src/ria_toolkit_oss/server/routers/orchestrator.py rename to src/ria_toolkit_oss/server/routers/conductor.py index dfc01af..7ec7d9d 100644 --- a/src/ria_toolkit_oss/server/routers/orchestrator.py +++ b/src/ria_toolkit_oss/server/routers/conductor.py @@ -1,4 +1,4 @@ -"""Orchestrator routes: campaign deployment, status, and cancellation.""" +"""Conductor routes: campaign deployment, status, and cancellation.""" from __future__ import annotations diff --git a/src/ria_toolkit_oss_cli/ria_toolkit_oss/serve.py b/src/ria_toolkit_oss_cli/ria_toolkit_oss/serve.py index 21beb6e..5d541b4 100644 --- a/src/ria_toolkit_oss_cli/ria_toolkit_oss/serve.py +++ b/src/ria_toolkit_oss_cli/ria_toolkit_oss/serve.py @@ -23,9 +23,9 @@ def serve(host: str, port: int, api_key: str, log_level: str): \b Endpoints: - POST /orchestrator/deploy - GET /orchestrator/status/{campaign_id} - POST /orchestrator/cancel/{campaign_id} + POST /conductor/deploy + GET /conductor/status/{campaign_id} + POST /conductor/cancel/{campaign_id} POST /inference/load POST /inference/start POST /inference/stop diff --git a/tests/orchestration/test_executor.py b/tests/orchestration/test_executor.py new file mode 100644 index 0000000..7aba499 --- /dev/null +++ b/tests/orchestration/test_executor.py @@ -0,0 +1,314 @@ +"""Tests for orchestration executor — StepResult, CampaignResult, _run_script, _extract_tx_params.""" + +from __future__ import annotations + +import json +import stat +from types import SimpleNamespace + +import pytest + +from ria_toolkit_oss.orchestration.executor import ( + CampaignResult, + StepResult, + _extract_tx_params, + _run_script, +) +from ria_toolkit_oss.orchestration.qa import QAResult + + +def _ok_qa() -> QAResult: + return QAResult(passed=True, flagged=False, snr_db=20.0, duration_s=1.0) + + +def _flagged_qa() -> QAResult: + return QAResult(passed=True, flagged=True, snr_db=5.0, duration_s=1.0, issues=["low SNR"]) + + +def _failed_qa() -> QAResult: + return QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=["no signal"]) + + +# --------------------------------------------------------------------------- +# StepResult +# --------------------------------------------------------------------------- + + +class TestStepResult: + def test_ok_true_when_no_error_and_qa_passed(self): + r = StepResult( + transmitter_id="tx1", + step_label="step1", + output_path="/out/rec.sigmf-data", + qa=_ok_qa(), + capture_timestamp=0.0, + ) + assert r.ok is True + + def test_ok_false_when_error_set(self): + r = StepResult( + transmitter_id="tx1", + step_label="step1", + output_path=None, + qa=_ok_qa(), + capture_timestamp=0.0, + error="SDR failed", + ) + assert r.ok is False + + def test_ok_false_when_qa_not_passed(self): + r = StepResult( + transmitter_id="tx1", + step_label="step1", + output_path="/out", + qa=_failed_qa(), + capture_timestamp=0.0, + ) + assert r.ok is False + + def test_to_dict_contains_required_keys(self): + r = StepResult( + transmitter_id="tx1", + step_label="step1", + output_path="/out/rec.sigmf-data", + qa=_ok_qa(), + capture_timestamp=1234.5, + ) + d = r.to_dict() + assert d["transmitter_id"] == "tx1" + assert d["step_label"] == "step1" + assert d["output_path"] == "/out/rec.sigmf-data" + assert d["capture_timestamp"] == pytest.approx(1234.5) + assert d["error"] is None + assert d["qa"]["passed"] is True + + def test_to_dict_includes_error_when_set(self): + r = StepResult( + transmitter_id="tx1", + step_label="step1", + output_path=None, + qa=_failed_qa(), + capture_timestamp=0.0, + error="disk full", + ) + assert r.to_dict()["error"] == "disk full" + + +# --------------------------------------------------------------------------- +# CampaignResult +# --------------------------------------------------------------------------- + + +class TestCampaignResult: + def _make(self, steps: list) -> CampaignResult: + r = CampaignResult(campaign_name="test_campaign") + r.steps = steps + r.end_time = r.start_time + 5.0 + return r + + def test_total_steps(self): + r = self._make( + [ + StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), + StepResult("tx1", "s2", "/out", _ok_qa(), 0.0), + ] + ) + assert r.total_steps == 2 + + def test_passed_count(self): + r = self._make( + [ + StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), + StepResult("tx1", "s2", "/out", _failed_qa(), 0.0), + ] + ) + assert r.passed == 1 + + def test_failed_count(self): + r = self._make( + [ + StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), + StepResult("tx1", "s2", "/out", _failed_qa(), 0.0), + ] + ) + assert r.failed == 1 + + def test_flagged_count(self): + r = self._make( + [ + StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), + StepResult("tx1", "s2", "/out", _flagged_qa(), 0.0), + ] + ) + assert r.flagged == 1 + + def test_error_step_counts_as_failed_not_passed(self): + r = self._make( + [ + StepResult("tx1", "s1", None, _ok_qa(), 0.0, error="disk full"), + ] + ) + assert r.failed == 1 + assert r.passed == 0 + + def test_duration_s_from_end_time(self): + r = CampaignResult(campaign_name="c") + r.start_time = 100.0 + r.end_time = 115.0 + assert r.duration_s == pytest.approx(15.0) + + def test_to_dict_structure(self): + r = self._make([StepResult("tx1", "s1", "/out", _ok_qa(), 0.0)]) + d = r.to_dict() + assert d["campaign_name"] == "test_campaign" + assert d["total_steps"] == 1 + assert d["passed"] == 1 + assert len(d["steps"]) == 1 + + def test_write_report(self, tmp_path): + r = self._make([StepResult("tx1", "s1", "/out", _ok_qa(), 0.0)]) + out = tmp_path / "report.json" + r.write_report(str(out)) + assert out.exists() + data = json.loads(out.read_text()) + assert data["campaign_name"] == "test_campaign" + + def test_write_report_creates_nested_dirs(self, tmp_path): + r = self._make([]) + out = tmp_path / "nested" / "deep" / "report.json" + r.write_report(str(out)) + assert out.exists() + + +# --------------------------------------------------------------------------- +# _run_script +# --------------------------------------------------------------------------- + + +class TestRunScript: + def _script(self, tmp_path, body: str) -> str: + s = tmp_path / "script.sh" + s.write_text("#!/bin/sh\n" + body) + s.chmod(s.stat().st_mode | stat.S_IEXEC) + return str(s) + + def test_returns_stdout(self, tmp_path): + out = _run_script(self._script(tmp_path, 'echo "hello world"')) + assert out == "hello world" + + def test_passes_args_to_script(self, tmp_path): + out = _run_script(self._script(tmp_path, 'echo "$1 $2"'), "configure", "arg2") + assert "configure" in out + + def test_raises_on_nonzero_exit(self, tmp_path): + with pytest.raises(RuntimeError, match="exited 1"): + _run_script(self._script(tmp_path, "exit 1")) + + def test_raises_on_relative_path(self): + with pytest.raises(RuntimeError, match="absolute"): + _run_script("relative/script.sh") + + def test_raises_on_missing_file(self, tmp_path): + with pytest.raises(RuntimeError): + _run_script(str(tmp_path / "nonexistent.sh")) + + def test_raises_on_timeout(self, tmp_path): + with pytest.raises(RuntimeError, match="timed out"): + _run_script(self._script(tmp_path, "sleep 60"), timeout=0.1) + + def test_stderr_included_in_error_message(self, tmp_path): + with pytest.raises(RuntimeError) as exc_info: + _run_script(self._script(tmp_path, "echo 'bad thing' >&2; exit 1")) + assert "bad thing" in str(exc_info.value) + + +# --------------------------------------------------------------------------- +# _extract_tx_params +# --------------------------------------------------------------------------- + + +class TestExtractTxParams: + def test_returns_none_when_no_sdr_agent_attribute(self): + tx = SimpleNamespace() + assert _extract_tx_params(tx) is None + + def test_returns_none_when_sdr_agent_is_none(self): + tx = SimpleNamespace(sdr_agent=None) + assert _extract_tx_params(tx) is None + + def test_returns_none_when_sdr_agent_is_empty_dict(self): + tx = SimpleNamespace(sdr_agent={}) + assert _extract_tx_params(tx) is None + + def test_returns_signal_params(self): + tx = SimpleNamespace( + sdr_agent={ + "modulation": "QPSK", + "symbol_rate": 1e6, + "center_frequency": 2.4e9, + } + ) + result = _extract_tx_params(tx) + assert result == {"modulation": "QPSK", "symbol_rate": 1e6, "center_frequency": 2.4e9} + + def test_strips_infra_key_node_id(self): + tx = SimpleNamespace( + sdr_agent={ + "modulation": "BPSK", + "node_id": "node_abc123", + } + ) + result = _extract_tx_params(tx) + assert "node_id" not in result + assert result == {"modulation": "BPSK"} + + def test_strips_infra_key_session_code(self): + tx = SimpleNamespace( + sdr_agent={ + "modulation": "FSK", + "session_code": "amber-peak-transmit", + } + ) + result = _extract_tx_params(tx) + assert "session_code" not in result + + def test_strips_none_values(self): + tx = SimpleNamespace( + sdr_agent={ + "modulation": "QPSK", + "order": None, + "rolloff": 0.35, + } + ) + result = _extract_tx_params(tx) + assert "order" not in result + assert result == {"modulation": "QPSK", "rolloff": 0.35} + + def test_does_not_mutate_source_dict(self): + cfg = {"modulation": "QPSK", "node_id": "nid", "session_code": "code"} + tx = SimpleNamespace(sdr_agent=cfg) + _extract_tx_params(tx) + assert "node_id" in cfg + + def test_full_sdr_agent_config(self): + tx = SimpleNamespace( + sdr_agent={ + "modulation": "16QAM", + "order": 4, + "symbol_rate": 5e6, + "center_frequency": 915e6, + "filter": "rrc", + "rolloff": 0.35, + "node_id": "node_xyz", + "session_code": "some-code", + } + ) + result = _extract_tx_params(tx) + assert result == { + "modulation": "16QAM", + "order": 4, + "symbol_rate": 5e6, + "center_frequency": 915e6, + "filter": "rrc", + "rolloff": 0.35, + } diff --git a/tests/orchestration/test_labeler.py b/tests/orchestration/test_labeler.py index f305bed..d4f8344 100644 --- a/tests/orchestration/test_labeler.py +++ b/tests/orchestration/test_labeler.py @@ -109,6 +109,38 @@ class TestLabelRecording: result = label_recording(rec, "iphone13_001", _wifi_step(), time.time()) assert result is rec + def test_tx_params_none_by_default(self): + rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time()) + tx_keys = [k for k in rec.metadata if k.startswith("tx_")] + assert tx_keys == [] + + def test_tx_params_written_as_tx_prefix_keys(self): + params = {"modulation": "QPSK", "symbol_rate": 1e6} + rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params) + assert rec.metadata["tx_modulation"] == "QPSK" + assert rec.metadata["tx_symbol_rate"] == pytest.approx(1e6) + + def test_tx_params_multiple_fields(self): + params = { + "modulation": "16QAM", + "order": 4, + "symbol_rate": 5e6, + "center_frequency": 915e6, + "filter": "rrc", + "rolloff": 0.35, + } + rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params) + for k, v in params.items(): + assert f"tx_{k}" in rec.metadata + assert ( + rec.metadata[f"tx_{k}"] == pytest.approx(v) if isinstance(v, float) else rec.metadata[f"tx_{k}"] == v + ) + + def test_tx_params_empty_dict_writes_nothing(self): + rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params={}) + tx_keys = [k for k in rec.metadata if k.startswith("tx_") and k != "tx_power_dbm"] + assert tx_keys == [] + # --------------------------------------------------------------------------- # build_output_filename diff --git a/tests/orchestration/test_tx_executor.py b/tests/orchestration/test_tx_executor.py new file mode 100644 index 0000000..9d66850 --- /dev/null +++ b/tests/orchestration/test_tx_executor.py @@ -0,0 +1,153 @@ +"""Tests for TxExecutor — signal synthesis and step execution.""" + +from __future__ import annotations + +import threading +from unittest.mock import patch + +import numpy as np +import pytest + +from ria_toolkit_oss.orchestration.tx_executor import TxExecutor + + +def _cfg(modulation="QPSK", symbol_rate=100_000, steps=None): + return { + "id": "test-tx", + "type": "sdr", + "control_method": "sdr_agent", + "sdr_agent": { + "modulation": modulation, + "symbol_rate": symbol_rate, + "center_frequency": 0.0, + "filter": "rrc", + "rolloff": 0.35, + }, + "schedule": steps or [{"label": "step1", "duration": 0.001, "power_dbm": -10}], + } + + +# --------------------------------------------------------------------------- +# Initialisation +# --------------------------------------------------------------------------- + + +class TestTxExecutorInit: + def test_stores_sdr_device(self): + ex = TxExecutor(_cfg(), sdr_device="pluto") + assert ex.sdr_device == "pluto" + + def test_stop_event_created_when_not_supplied(self): + ex = TxExecutor(_cfg()) + assert isinstance(ex.stop_event, threading.Event) + assert not ex.stop_event.is_set() + + def test_accepts_external_stop_event(self): + ev = threading.Event() + ex = TxExecutor(_cfg(), stop_event=ev) + assert ex.stop_event is ev + + +# --------------------------------------------------------------------------- +# run() — schedule iteration +# --------------------------------------------------------------------------- + + +class TestTxExecutorRun: + def test_empty_schedule_returns_immediately(self): + cfg = _cfg(steps=[]) + ex = TxExecutor(cfg) + ex.run() # must not raise or block + + def test_pre_set_stop_event_skips_all_steps(self): + ev = threading.Event() + ev.set() + ex = TxExecutor(_cfg(), stop_event=ev) + # If stop was set, _execute_step should never be called. + # run() should return cleanly without attempting synthesis. + ex.run() + + def test_no_sdr_falls_back_to_simulation(self, monkeypatch): + """Without SDR hardware TxExecutor simulates by calling stop_event.wait.""" + cfg = _cfg(steps=[{"label": "s", "duration": 0.001, "power_dbm": 0}]) + waited = [] + real_ev = threading.Event() + + def _fake_wait(timeout=None): + waited.append(timeout) + return False + + monkeypatch.setattr(real_ev, "wait", _fake_wait) + + # Patch SDR init to always fail (forces simulation path) + with patch.object(TxExecutor, "_init_sdr", lambda self, *a, **kw: setattr(self, "_sdr", None)): + ex = TxExecutor(cfg, sdr_device="nonexistent_xyz", stop_event=real_ev) + ex.run() + + assert len(waited) >= 1, "expected stop_event.wait to be called for simulation" + + +# --------------------------------------------------------------------------- +# _synthesise — all modulation types and filter types +# --------------------------------------------------------------------------- + + +class TestSynthesise: + @pytest.fixture(autouse=True) + def _ex(self): + self.ex = TxExecutor(_cfg()) + + def _synth(self, mod, num_samples=256): + return self.ex._synthesise(mod, sps=4, num_samples=num_samples, filter_type="rrc", rolloff=0.35) + + @pytest.mark.parametrize("mod", ["BPSK", "QPSK", "8PSK", "16QAM", "64QAM", "256QAM"]) + def test_psk_qam_returns_complex64_array(self, mod): + sig = self._synth(mod) + assert sig.dtype == np.complex64 + assert len(sig) == 256 + + def test_fsk_returns_correct_length(self): + sig = self._synth("FSK") + assert len(sig) == 256 + + def test_ook_returns_correct_length(self): + sig = self._synth("OOK") + assert len(sig) == 256 + + def test_gmsk_returns_correct_length(self): + sig = self._synth("GMSK") + assert len(sig) == 256 + + def test_oqpsk_returns_correct_length(self): + sig = self._synth("OQPSK") + assert len(sig) == 256 + + @pytest.mark.parametrize("mod", ["BPSK", "QPSK", "16QAM", "FSK", "OOK", "GMSK"]) + def test_samples_are_finite(self, mod): + sig = self._synth(mod) + assert np.all(np.isfinite(sig.real)), f"{mod}: non-finite real samples" + assert np.all(np.isfinite(sig.imag)), f"{mod}: non-finite imag samples" + + def test_unknown_modulation_defaults_to_qpsk(self): + sig = self._synth("UNKNOWN_MOD_XYZ") + assert len(sig) == 256 + assert sig.dtype == np.complex64 + + @pytest.mark.parametrize("filter_type", ["rrc", "rc", "gaussian", "rect", "none"]) + def test_all_filter_types(self, filter_type): + sig = self.ex._synthesise("QPSK", sps=4, num_samples=128, filter_type=filter_type, rolloff=0.35) + assert len(sig) == 128 + + @pytest.mark.parametrize("n", [64, 128, 512, 1024]) + def test_output_length_matches_requested_samples(self, n): + sig = self._synth("QPSK", num_samples=n) + assert len(sig) == n + + def test_bpsk_output_is_complex_not_real(self): + sig = self._synth("BPSK") + # complex64 always has imag part; just check dtype + assert sig.dtype == np.complex64 + + def test_256qam_correct_length(self): + sig = self._synth("256QAM") + assert len(sig) == 256 diff --git a/tests/ria_toolkit_oss_cli/test_generate.py b/tests/ria_toolkit_oss_cli/test_generate.py index 68d252c..8915037 100644 --- a/tests/ria_toolkit_oss_cli/test_generate.py +++ b/tests/ria_toolkit_oss_cli/test_generate.py @@ -189,6 +189,8 @@ class TestNoiseCommand: "10000", "--noise-type", "gaussian", + "--power", + "0.01", "--output", output, "-q", @@ -234,7 +236,7 @@ class TestNoiseCommand: "--num-samples", "10000", "--power", - "0.5", + "0.01", "--output", output, "-q", diff --git a/tests/server/test_server.py b/tests/server/test_server.py index 3e9c8db..e3345ae 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -1,6 +1,6 @@ """Tests for the RT-OSS HTTP server. -Covers: auth, inference lifecycle (without SDR/ONNX hardware), orchestrator +Covers: auth, inference lifecycle (without SDR/ONNX hardware), conductor lifecycle (with mocked executor), and state helpers. ``start_inference`` and ``_inference_loop`` require real SDR hardware and an @@ -286,17 +286,17 @@ class TestInferenceStop: # --------------------------------------------------------------------------- -# POST /orchestrator/deploy +# POST /conductor/deploy # --------------------------------------------------------------------------- -class TestOrchestratorDeploy: +class TestConductorDeploy: def test_deploy_422_on_invalid_config(self, client): with patch( - "ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", + "ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict", side_effect=ValueError("missing required field 'name'"), ): - resp = client.post("/orchestrator/deploy", json={"config": {}}) + resp = client.post("/conductor/deploy", json={"config": {}}) assert resp.status_code == 422 def test_deploy_returns_campaign_id(self, client): @@ -307,10 +307,10 @@ class TestOrchestratorDeploy: mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {}) with ( - patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg), - patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor), + patch("ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict", return_value=mock_cfg), + patch("ria_toolkit_oss.server.routers.conductor.CampaignExecutor", mock_executor), ): - resp = client.post("/orchestrator/deploy", json={"config": {"name": "test_campaign"}}) + resp = client.post("/conductor/deploy", json={"config": {"name": "test_campaign"}}) assert resp.status_code == 200 body = resp.json() @@ -325,23 +325,23 @@ class TestOrchestratorDeploy: mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {}) with ( - patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg), - patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor), + patch("ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict", return_value=mock_cfg), + patch("ria_toolkit_oss.server.routers.conductor.CampaignExecutor", mock_executor), ): - resp = client.post("/orchestrator/deploy", json={"config": {}}) + resp = client.post("/conductor/deploy", json={"config": {}}) campaign_id = resp.json()["campaign_id"] assert state_module._campaigns.get(campaign_id) is not None # --------------------------------------------------------------------------- -# GET /orchestrator/status/{campaign_id} +# GET /conductor/status/{campaign_id} # --------------------------------------------------------------------------- -class TestOrchestratorStatus: +class TestConductorStatus: def test_status_404_for_unknown_id(self, client): - resp = client.get("/orchestrator/status/nonexistent-id") + resp = client.get("/conductor/status/nonexistent-id") assert resp.status_code == 404 def test_status_returns_campaign_state(self, client): @@ -357,7 +357,7 @@ class TestOrchestratorStatus: ) state_module._campaigns["abc-123"] = state - resp = client.get("/orchestrator/status/abc-123") + resp = client.get("/conductor/status/abc-123") assert resp.status_code == 200 body = resp.json() assert body["campaign_id"] == "abc-123" @@ -367,13 +367,13 @@ class TestOrchestratorStatus: # --------------------------------------------------------------------------- -# POST /orchestrator/cancel/{campaign_id} +# POST /conductor/cancel/{campaign_id} # --------------------------------------------------------------------------- -class TestOrchestratorCancel: +class TestConductorCancel: def test_cancel_404_for_unknown_id(self, client): - resp = client.post("/orchestrator/cancel/no-such-id") + resp = client.post("/conductor/cancel/no-such-id") assert resp.status_code == 404 def test_cancel_sets_cancel_event(self, client): @@ -387,7 +387,7 @@ class TestOrchestratorCancel: ) state_module._campaigns["camp-to-cancel"] = state - resp = client.post("/orchestrator/cancel/camp-to-cancel") + resp = client.post("/conductor/cancel/camp-to-cancel") assert resp.status_code == 200 assert resp.json()["cancelled"] is True assert cancel_event.is_set() @@ -403,7 +403,7 @@ class TestOrchestratorCancel: ) state_module._campaigns["done"] = state - resp = client.post("/orchestrator/cancel/done") + resp = client.post("/conductor/cancel/done") assert resp.status_code == 200 assert resp.json()["cancelled"] is False assert not cancel_event.is_set() diff --git a/tests/test_agent.py b/tests/test_agent.py new file mode 100644 index 0000000..67991f9 --- /dev/null +++ b/tests/test_agent.py @@ -0,0 +1,247 @@ +"""Tests for NodeAgent — TX role, session code, and TX command dispatch.""" + +from __future__ import annotations + +import threading +import time +from unittest.mock import MagicMock, patch + +from ria_toolkit_oss.agent import NodeAgent + + +def _agent(role="general", session_code=None, **kwargs): + return NodeAgent( + hub_url="http://hub.test", + api_key="test-key", + name="test-node", + sdr_device="mock", + role=role, + session_code=session_code, + **kwargs, + ) + + +def _mock_register(agent, node_id="node_abc123"): + """Patch _post so _register() returns a fake node_id response.""" + resp = MagicMock() + resp.json.return_value = {"node_id": node_id} + resp.raise_for_status.return_value = None + agent._post = MagicMock(return_value=resp) + return agent._post + + +# --------------------------------------------------------------------------- +# Initialisation +# --------------------------------------------------------------------------- + + +class TestNodeAgentInit: + def test_stores_role_general(self): + assert _agent(role="general").role == "general" + + def test_stores_role_tx(self): + assert _agent(role="tx").role == "tx" + + def test_stores_role_rx(self): + assert _agent(role="rx").role == "rx" + + def test_session_code_stored(self): + assert _agent(session_code="amber-peak-transmit").session_code == "amber-peak-transmit" + + def test_session_code_none_by_default(self): + assert _agent().session_code is None + + def test_tx_stop_event_created(self): + a = _agent() + assert isinstance(a._tx_stop, threading.Event) + + def test_tx_thread_none_initially(self): + assert _agent()._tx_thread is None + + def test_hub_url_trailing_slash_stripped(self): + a = NodeAgent(hub_url="http://hub.test/", api_key="k", name="n") + assert a.hub_url == "http://hub.test" + + +# --------------------------------------------------------------------------- +# _register payload +# --------------------------------------------------------------------------- + + +class TestNodeAgentRegisterPayload: + def _payload(self, agent): + post = _mock_register(agent) + agent._register() + _, kwargs = post.call_args + return kwargs["json"] + + def test_general_role_in_payload(self): + payload = self._payload(_agent(role="general")) + assert payload["role"] == "general" + + def test_tx_role_in_payload(self): + payload = self._payload(_agent(role="tx")) + assert payload["role"] == "tx" + + def test_tx_role_adds_transmit_capability(self): + payload = self._payload(_agent(role="tx")) + assert "transmit" in payload["capabilities"] + + def test_general_role_omits_transmit_capability(self): + payload = self._payload(_agent(role="general")) + assert "transmit" not in payload.get("capabilities", []) + + def test_session_code_included_when_set(self): + payload = self._payload(_agent(role="tx", session_code="amber-peak-transmit")) + assert payload["session_code"] == "amber-peak-transmit" + + def test_session_code_omitted_when_none(self): + payload = self._payload(_agent()) + assert "session_code" not in payload + + def test_register_stores_returned_node_id(self): + a = _agent() + _mock_register(a, node_id="node_xyz999") + a._register() + assert a.node_id == "node_xyz999" + + def test_name_in_payload(self): + a = NodeAgent(hub_url="http://h", api_key="k", name="my-bench") + _mock_register(a) + a._register() + _, kwargs = a._post.call_args + assert kwargs["json"]["name"] == "my-bench" + + def test_sdr_device_in_payload(self): + a = _agent() + post = _mock_register(a) + a._register() + _, kwargs = post.call_args + assert kwargs["json"]["sdr_device"] == "mock" + + def test_campaign_capability_always_present(self): + for role in ("general", "rx", "tx"): + a = _agent(role=role) + payload = self._payload(a) + assert "campaign" in payload["capabilities"] + + +# --------------------------------------------------------------------------- +# _dispatch — TX commands +# --------------------------------------------------------------------------- + + +class TestNodeAgentDispatch: + def _make_agent(self): + a = _agent(role="tx") + a.node_id = "node_abc" + a._report_campaign_status = MagicMock() + return a + + def test_start_transmit_spawns_thread(self): + a = self._make_agent() + done = threading.Event() + + class _FakeExecutor: + def run(self_): + done.wait(timeout=2) + + with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()): + a._dispatch({"command": "start_transmit", "sdr_agent": {}, "schedule": []}) + time.sleep(0.05) + assert a._tx_thread is not None + done.set() + + def test_start_transmit_clears_stop_event(self): + a = self._make_agent() + a._tx_stop.set() # pre-set + + done = threading.Event() + + class _FakeExecutor: + def run(self_): + done.wait(timeout=2) + + with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()): + a._dispatch({"command": "start_transmit", "sdr_agent": {}, "schedule": []}) + time.sleep(0.05) + assert not a._tx_stop.is_set() + done.set() + + def test_stop_transmit_sets_stop_event(self): + a = self._make_agent() + a._dispatch({"command": "stop_transmit"}) + assert a._tx_stop.is_set() + + def test_configure_transmit_does_not_raise(self): + a = self._make_agent() + a._dispatch({"command": "configure_transmit", "modulation": "BPSK"}) + + def test_unknown_command_is_ignored(self): + a = self._make_agent() + a._dispatch({"command": "frobnicate_xyz"}) + + def test_duplicate_start_transmit_ignored_while_running(self): + a = self._make_agent() + done = threading.Event() + run_calls = [] + + class _FakeExecutor: + def run(self_): + run_calls.append(1) + done.wait(timeout=2) + + with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()): + a._dispatch({"command": "start_transmit"}) + time.sleep(0.05) + a._dispatch({"command": "start_transmit"}) # second while first alive + done.set() + time.sleep(0.05) + + assert len(run_calls) == 1 + + def test_run_campaign_dispatched_in_thread(self): + a = self._make_agent() + done = threading.Event() + + with patch("ria_toolkit_oss.agent.NodeAgent._run_campaign") as mock_run: + mock_run.side_effect = lambda *_: done.set() + a._dispatch({"command": "run_campaign", "campaign_id": "c1", "payload": {}}) + done.wait(timeout=2) + assert mock_run.called + + +# --------------------------------------------------------------------------- +# _stop_transmit +# --------------------------------------------------------------------------- + + +class TestStopTransmit: + def test_no_thread_noop(self): + a = _agent() + a._stop_transmit() # must not raise + + def test_sets_stop_event(self): + a = _agent() + a._stop_transmit() + assert a._tx_stop.is_set() + + def test_joins_live_thread(self): + a = _agent() + finished = threading.Event() + unblock = threading.Event() + + def _task(): + unblock.wait(timeout=2) + finished.set() + + t = threading.Thread(target=_task, daemon=True) + t.start() + a._tx_thread = t + + # Signal stop and trigger thread exit + a._tx_stop.set() + unblock.set() + a._stop_transmit() + + assert not t.is_alive()