Compare commits

..

No commits in common. "main" and "screens-connection" have entirely different histories.

132 changed files with 1213 additions and 12271 deletions

1
.gitignore vendored
View File

@ -52,7 +52,6 @@ tests/sdr/
# Sphinx documentation # Sphinx documentation
docs/build/ docs/build/
docs/_build/
# Jupyter Notebook # Jupyter Notebook
.ipynb_checkpoints .ipynb_checkpoints

View File

@ -1,21 +1,5 @@
# Changelog # Changelog
## [0.1.0] - 2026-02-20
### Added
- **Dual-Threshold Detection:** Logic to capture the start and end of signals, not just the peak.
- **Signal Smoothing & Noise Filters:** Prevents detections from breaking into fragments and ignores short interference spikes.
- **Auto-Frequency Calculation:** Automatically adjusts bounding boxes to fit signal frequency ranges tightly.
### Changed
- **Signal Power Detection:** Switched from raw signal strength to power for improved accuracy.
- **CLI Workflow:** `Clear` and `Remove` commands now modify files directly (in-place) to avoid redundant copies.
- **Metadata Logic:** Updated labels to show detection percentages and overhauled internal metadata cleaning.
- **Viewer UI:** Moved legend outside the plot, added a black background, and adjusted transparency for better spectrogram visibility.
### Fixed
- Prevented redundant `_annotated` suffixes in file naming patterns.
- Simplified internal math to increase processing speed and precision.
All notable changes to this project will be documented in this file. All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) and [Semantic Versioning](https://semver.org/spec/v2.0.0.html). The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) and [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

View File

@ -159,7 +159,7 @@ Finally, RIA Toolkit OSS can be installed directly from the source code. This ap
Once the project is installed, you can import modules, functions, and classes from the Toolkit for use in your Python code. For example, you can use the following import statement to access the `Recording` object: Once the project is installed, you can import modules, functions, and classes from the Toolkit for use in your Python code. For example, you can use the following import statement to access the `Recording` object:
```python ```python
from ria_toolkit_oss.data import Recording from ria_toolkit_oss.datatypes import Recording
``` ```
Additional usage information is provided in the project documentation: [RIA Toolkit OSS Documentation](https://ria-toolkit-oss.readthedocs.io/). Additional usage information is provided in the project documentation: [RIA Toolkit OSS Documentation](https://ria-toolkit-oss.readthedocs.io/).

File diff suppressed because it is too large Load Diff

View File

@ -1,29 +0,0 @@
/* Change the hex values below to customize heading colours */
.rst-content h1 { color: #2c3e50; }
.rst-content h2,
.rst-content h2 a { color: #ffffff !important; font-size: 22px !important; }
.rst-content h3,
.rst-content h3 a { color: #ffffff !important; font-size: 16px !important; }
.rst-content h3 code { font-size: inherit !important; }
.rst-content .admonition.warning {
background: #1a1a2e !important;
border-left: 4px solid #c0392b !important;
}
.rst-content .admonition.warning .admonition-title {
background: #c0392b !important;
color: #ffffff !important;
}
.rst-content .admonition.warning p {
color: #ffffff !important;
}
.rst-content h4 { color: #404040; }
.highlight * { color: #ffffff !important; }
.ria-cmd { color: #2980b9 !important; }

View File

@ -1,8 +0,0 @@
document.addEventListener('DOMContentLoaded', function () {
document.querySelectorAll('.highlight pre').forEach(function (pre) {
pre.innerHTML = pre.innerHTML.replace(
/((?:^|\n|>))(ria)(?=[ \t]|<)/g,
'$1<span class="ria-cmd">$2</span>'
);
});
});

View File

@ -12,9 +12,9 @@ sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = 'ria-toolkit-oss' project = 'ria-toolkit-oss'
copyright = '2026, Qoherent Inc' copyright = '2025, Qoherent Inc'
author = 'Qoherent Inc.' author = 'Qoherent Inc.'
release = '0.1.5' release = '0.1.4'
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
@ -73,6 +73,3 @@ def setup(app):
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme' html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']
html_css_files = ['custom.css']
html_js_files = ['custom.js']

View File

@ -1,4 +1,4 @@
.. _sdr_examples: .. _examples:
############ ############
SDR Examples SDR Examples

View File

@ -25,7 +25,7 @@ In this example, we initialize the `Blade` SDR, configure it to record a signal
import time import time
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.sdr.blade import Blade from ria_toolkit_oss.sdr.blade import Blade
my_radio = Blade() my_radio = Blade()

View File

@ -21,7 +21,7 @@ Code
import numpy as np import numpy as np
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.sdr.blade import Blade from ria_toolkit_oss.sdr.blade import Blade
# Parameters # Parameters

File diff suppressed because it is too large Load Diff

View File

@ -11,15 +11,15 @@ The Radio Dataset Framework provides a software interface to access and manipula
the need for users to interface with the source files directly. Instead, users initialize and interact with a Python the need for users to interface with the source files directly. Instead, users initialize and interact with a Python
object, while the complexities of efficient data retrieval and source file manipulation are managed behind the scenes. object, while the complexities of efficient data retrieval and source file manipulation are managed behind the scenes.
Ria Toolkit OSS includes an abstract class called :py:obj:`ria_toolkit_oss.data.datasets.RadioDataset`, which defines common properties and Utils includes an abstract class called :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset`, which defines common properties and
behaviors for all radio datasets. :py:obj:`ria_toolkit_oss.data.datasets.RadioDataset` can be considered a blueprint for all behaviors for all radio datasets. :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset` can be considered a blueprint for all
other radio dataset classes. This class is then subclassed to define more specific blueprints for different types other radio dataset classes. This class is then subclassed to define more specific blueprints for different types
of radio datasets. For example, :py:obj:`ria_toolkit_oss.data.datasets.IQDataset`, which is tailored for machine learning tasks of radio datasets. For example, :py:obj:`ria_toolkit_oss.datatypes.datasets.IQDataset`, which is tailored for machine learning tasks
involving the processing of signals represented as IQ (In-phase and Quadrature) samples. involving the processing of signals represented as IQ (In-phase and Quadrature) samples.
Then, in the various project backends, there are concrete dataset classes, which inherit from both Ria Toolkit OSS and the base Then, in the various project backends, there are concrete dataset classes, which inherit from both Utils and the base
dataset class from the respective backend. For example, the :py:obj:`TorchIQDataset` class extends both dataset class from the respective backend. For example, the :py:obj:`TorchIQDataset` class extends both
:py:obj:`ria_toolkit_oss.data.datasets.IQDataset` from Ria Toolkit OSS and :py:obj:`torch.ria_toolkit_oss.data.IterableDataset` from :py:obj:`ria_toolkit_oss.datatypes.datasets.IQDataset` from Utils and :py:obj:`torch.ria_toolkit_oss.datatypes.IterableDataset` from
PyTorch, providing a concrete dataset class tailored for IQ datasets and optimized for the PyTorch backend. PyTorch, providing a concrete dataset class tailored for IQ datasets and optimized for the PyTorch backend.
Dataset initialization Dataset initialization
@ -130,7 +130,7 @@ Dataset processing and manipulation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
All radio datasets support methods tailored specifically for radio processing. These methods are backend-independent, All radio datasets support methods tailored specifically for radio processing. These methods are backend-independent,
inherited from the blueprints in Ria Toolkit OSS like :py:obj:`ria_toolkit_oss.data.datasets.RadioDataset`. inherited from the blueprints in Utils like :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset`.
For example, we can trim down the length of the examples from 1,024 to 512 samples, and then augment the dataset: For example, we can trim down the length of the examples from 1,024 to 512 samples, and then augment the dataset:

View File

@ -1,7 +1,7 @@
Dataset License SubModule Dataset License SubModule
========================= =========================
.. automodule:: ria_toolkit_oss.data.datasets.license .. automodule:: ria_toolkit_oss.datatypes.datasets.license
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:

View File

@ -1,11 +1,11 @@
Datatypes Package (ria_toolkit_oss.data) Datatypes Package (ria_toolkit_oss.datatypes)
============================================= =============================================
.. |br| raw:: html .. |br| raw:: html
<br /> <br />
.. automodule:: ria_toolkit_oss.data .. automodule:: ria_toolkit_oss.datatypes
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
@ -13,7 +13,7 @@ Datatypes Package (ria_toolkit_oss.data)
Radio Dataset SubPackage Radio Dataset SubPackage
------------------------ ------------------------
.. automodule:: ria_toolkit_oss.data.datasets .. automodule:: ria_toolkit_oss.datatypes.datasets
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
@ -21,5 +21,5 @@ Radio Dataset SubPackage
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
Dataset License SubModule <ria_toolkit_oss.data.datasets.license> Dataset License SubModule <ria_toolkit_oss.datatypes.datasets.license>
Radio Datasets <radio_datasets> Radio Datasets <radio_datasets>

View File

@ -11,7 +11,7 @@ class and function signatures, and doctest examples where available.
:maxdepth: 2 :maxdepth: 2
:caption: Contents: :caption: Contents:
Data Package <data/ria_toolkit_oss.data> Datatypes Package <datatypes/ria_toolkit_oss.datatypes>
SDR Package <ria_toolkit_oss.sdr> SDR Package <ria_toolkit_oss.sdr>
IO Package <ria_toolkit_oss.io> IO Package <ria_toolkit_oss.io>
Transforms Package <ria_toolkit_oss.transforms> Transforms Package <ria_toolkit_oss.transforms>

View File

@ -40,36 +40,26 @@ Limitations
- USB 3.0 connectivity is required for optimal performance; using USB 2.0 will significantly limit data - USB 3.0 connectivity is required for optimal performance; using USB 2.0 will significantly limit data
transfer rates. transfer rates.
Set up instructions (Linux) Set up instructions (Linux, Radioconda)
--------------------------- ---------------------------------------
No additional Python packages are required for BladeRF beyond the base RIA Toolkit OSS installation. 1. Activate your Radioconda environment.
1. Install the system library:
.. code-block:: bash .. code-block:: bash
sudo apt install libbladerf-dev conda activate <your-env-name>
For a more complete installation including CLI tools and FPGA images, use the Nuand PPA: 2. Install the base dependencies and drivers (*Easy method*):
.. code-block:: bash .. code-block:: bash
sudo add-apt-repository ppa:nuandllc/bladerf sudo add-apt-repository ppa:nuandllc/bladerf
sudo apt-get update sudo apt-get update
sudo apt-get install bladerf libbladerf-dev sudo apt-get install bladerf
sudo apt-get install bladerf-fpga-hostedxa4 # Necessary for BladeRF 2.0 Micro xA4 sudo apt-get install libbladerf-dev
sudo apt-get install bladerf-fpga-hostedxa4 # Necessary for installation of bladeRF 2.0 Micro A4.
2. Install udev rules: 3. Install a ``udev`` rule by creating a link into your Radioconda installation:
For most users:
.. code-block:: bash
sudo udevadm control --reload
sudo udevadm trigger
For **Radioconda** users, create symlinks from your conda environment instead:
.. code-block:: bash .. code-block:: bash

View File

@ -39,28 +39,23 @@ Limitations
- Bandwidth is limited to 20 MHz. - Bandwidth is limited to 20 MHz.
- USB 2.0 connectivity might limit data transfer rates compared to USB 3.0 or Ethernet-based SDRs. - USB 2.0 connectivity might limit data transfer rates compared to USB 3.0 or Ethernet-based SDRs.
Set up instructions (Linux) Set up instructions (Linux, Radioconda)
--------------------------- ---------------------------------------
HackRF is supported out of the box after installing RIA Toolkit OSS. 1. Activate your Radioconda environment:
1. Ensure ``libhackrf`` is installed at the system level. On most Ubuntu installations this is already
present. If not:
.. code-block:: bash .. code-block:: bash
sudo apt install libhackrf-dev conda activate <your-env-name>
2. Install udev rules to allow non-root device access: 2. Install the System Package (Ubuntu / Debian):
For most users:
.. code-block:: bash .. code-block:: bash
sudo udevadm control --reload sudo apt-get update
sudo udevadm trigger sudo apt-get install hackrf
For **Radioconda** users, create a symlink from your conda environment instead: 3. Install a ``udev`` rule by creating a link into your Radioconda installation:
.. code-block:: bash .. code-block:: bash
@ -68,7 +63,7 @@ HackRF is supported out of the box after installing RIA Toolkit OSS.
sudo udevadm control --reload sudo udevadm control --reload
sudo udevadm trigger sudo udevadm trigger
Make sure your user account belongs to the ``plugdev`` group in order to access your device: Make sure your user account belongs to the plugdev group in order to access your device:
.. code-block:: bash .. code-block:: bash
@ -76,7 +71,7 @@ HackRF is supported out of the box after installing RIA Toolkit OSS.
.. note:: .. note::
You may have to restart your system for group membership changes to take effect. You may have to restart your system for changes to take effect.
Further information Further information
------------------- -------------------

View File

@ -43,34 +43,34 @@ Limitations
affect stability. affect stability.
- USB 2.0 connectivity might limit data transfer rates compared to USB 3.0 or Ethernet-based SDRs. - USB 2.0 connectivity might limit data transfer rates compared to USB 3.0 or Ethernet-based SDRs.
Set up instructions (Linux) Set up instructions (Linux, Radioconda)
--------------------------- ---------------------------------------
The PlutoSDR is supported out of the box after installing RIA Toolkit OSS. The required Python package 1. Activate your Radioconda environment:
(``pyadi-iio``) is included in the toolkit's dependencies.
1. Ensure ``libiio`` is installed at the system level. On most Ubuntu installations this is already present.
If not:
.. code-block:: bash .. code-block:: bash
sudo apt install libiio-dev libiio-utils libiio0 conda activate <your-env-name>
.. note:: 2. Install system dependencies:
PlutoSDR devices are discoverable over both USB and network (mDNS). Network discovery uses Avahi — if
``avahi-daemon`` is not running, network discovery will be skipped but USB discovery still works.
2. Install a ``udev`` rule to allow non-root device access:
For most users:
.. code-block:: bash .. code-block:: bash
sudo udevadm control --reload sudo apt-get update
sudo udevadm trigger sudo apt-get install -y \
build-essential \
git \
libxml2-dev \
bison \
flex \
libcdk5-dev \
cmake \
libusb-1.0-0-dev \
libavahi-client-dev \
libavahi-common-dev \
libaio-dev
For **Radioconda** users, create a symlink from your conda environment instead: 3. Install a ``udev`` rule by creating a link into your Radioconda installation:
.. code-block:: bash .. code-block:: bash
@ -78,18 +78,11 @@ The PlutoSDR is supported out of the box after installing RIA Toolkit OSS. The r
sudo udevadm control --reload sudo udevadm control --reload
sudo udevadm trigger sudo udevadm trigger
Once you can communicate with the hardware, you may want to perform the post-install steps detailed on Once you can talk to the hardware, you may want to perform the post-install steps detailed on the `PlutoSDR Documentation <https://wiki.analog.com/university/tools/pluto>`_.
the `PlutoSDR Documentation <https://wiki.analog.com/university/tools/pluto>`_.
3. (Optional) Building ``libiio`` or ``libad9361-iio`` from source: 4. (Optional) Building ``libiio`` or ``libad9361-iio`` from source:
This step is only required if you need a version not available via ``apt``. First install build This step is only required if you want the latest version of these libraries not provided in Radioconda.
dependencies:
.. code-block:: bash
sudo apt-get install -y build-essential git libxml2-dev bison flex libcdk5-dev cmake \
libusb-1.0-0-dev libavahi-client-dev libavahi-common-dev libaio-dev
.. code-block:: bash .. code-block:: bash

View File

@ -30,10 +30,18 @@ Limitations
- Sensitivity and performance can vary depending on the specific model and components. - Sensitivity and performance can vary depending on the specific model and components.
- Requires external software for signal processing and analysis. - Requires external software for signal processing and analysis.
Set up instructions (Linux) Set up instructions (Linux, Radioconda)
--------------------------- ---------------------------------------
1. If you previously had RTL-SDR drivers installed, purge them first: 1. Activate your Radioconda environment:
.. code-block:: bash
conda activate <your-env-name>
2. Purge drivers:
If you already have other drivers installed, purge them from your system.
.. code-block:: bash .. code-block:: bash
@ -45,95 +53,47 @@ Set up instructions (Linux)
sudo rm -rvf /usr/local/include/rtl_* sudo rm -rvf /usr/local/include/rtl_*
sudo rm -rvf /usr/local/bin/rtl_* sudo rm -rvf /usr/local/bin/rtl_*
2. Install build dependencies: 3. Install RTL-SDR Blog drivers:
.. code-block:: bash .. code-block:: bash
sudo apt install libusb-1.0-0-dev git cmake pkg-config build-essential sudo apt-get install libusb-1.0-0-dev git cmake pkg-config build-essential
git clone https://github.com/osmocom/rtl-sdr
3. Build ``librtlsdr`` from source: cd rtl-sdr
mkdir build
The standard ``librtlsdr`` package available via ``apt`` is missing symbols required by the Python cd build
bindings. Build from the **rtl-sdr-blog fork**: cmake ../ -DINSTALL_UDEV_RULES=ON
.. code-block:: bash
git clone https://github.com/rtlsdrblog/rtl-sdr-blog.git
cd rtl-sdr-blog
mkdir build && cd build
cmake .. -DINSTALL_UDEV_RULES=ON
make make
sudo make install sudo make install
sudo cp ../rtl-sdr.rules /etc/udev/rules.d/ sudo cp ../rtl-sdr.rules /etc/udev/rules.d/
sudo ldconfig sudo ldconfig
.. important:: 4. Blacklist the DVB-T modules that would otherwise claim the device:
Do not use the osmocom ``rtl-sdr`` repository or the Ubuntu ``librtlsdr-dev`` apt package. Neither
provides the ``rtlsdr_set_dithering`` symbol that the Python bindings require.
4. Blacklist the kernel DVB driver:
The kernel DVB-T driver (``dvb_usb_rtl28xxu``) claims the RTL-SDR device and prevents ``librtlsdr``
from accessing it.
For most users:
.. code-block:: bash .. code-block:: bash
echo 'blacklist dvb_usb_rtl28xxu' | sudo tee /etc/modprobe.d/blacklist-rtlsdr.conf
sudo modprobe -r dvb_usb_rtl28xxu
For **Radioconda** users, a blacklist configuration is already provided in your conda environment:
.. code-block:: bash
sudo ln -s $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf /etc/modprobe.d/radioconda-rtl-sdr-blacklist.conf sudo ln -s $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf /etc/modprobe.d/radioconda-rtl-sdr-blacklist.conf
sudo modprobe -r $(cat $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf | sed -n -e 's/^blacklist //p') sudo modprobe -r $(cat $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf | sed -n -e 's/^blacklist //p')
If ``modprobe -r`` fails with "Module is in use", unplug the RTL-SDR dongle, run the command again, .. note::
then plug it back in. Alternatively, reboot — the blacklist takes effect on next boot.
.. note:: In addition to the Radioconda blacklist file, some systems also require
manually blacklisting the following DVB-T modules to prevent them from
Some systems also require blacklisting additional DVB-T modules. Add these entries to your claiming the device:
blacklist configuration if needed:
- ``dvb_usb_rtl28xxu``
- ``rtl2832`` - ``rtl2832``
- ``rtl2830`` - ``rtl2830``
5. Reload udev rules: Add these entries to ``rtlsdr.conf`` (or create the file at
``/etc/modprobe.d/rtlsdr.conf``) if they are not already present.
For most users (rules are installed by the build step above): 5. Install a udev rule by creating a link into your radioconda installation:
.. code-block:: bash .. code-block:: bash
sudo udevadm control --reload
sudo udevadm trigger
For **Radioconda** users, create a symlink from your conda environment instead:
.. code-block:: bash
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/rtl-sdr.rules /etc/udev/rules.d/radioconda-rtl-sdr.rules sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/rtl-sdr.rules /etc/udev/rules.d/radioconda-rtl-sdr.rules
sudo udevadm control --reload sudo udevadm control --reload
sudo udevadm trigger sudo udevadm trigger
6. Install Python packages:
.. code-block:: bash
pip install pyrtlsdr==0.3.0
pip install setuptools==69.5.1
.. note::
``pyrtlsdr`` 0.4.0 references a ``rtlsdr_set_dithering`` symbol not present in standard
``librtlsdr`` builds. Version 0.3.0 works correctly.
``pyrtlsdr`` 0.3.0 depends on ``pkg_resources``, which was removed in ``setuptools`` >= 82.
Pinning to 69.5.1 ensures ``pkg_resources`` is available.
Further Information Further Information
------------------- -------------------
- `RTL-SDR Official Website <https://www.rtl-sdr.com/>`_ - `RTL-SDR Official Website <https://www.rtl-sdr.com/>`_

View File

@ -39,48 +39,18 @@ Limitations
Set up instructions (Linux) Set up instructions (Linux)
--------------------------------- ---------------------------------
ThinkRF devices require the ``pyrf`` package, which is written in Python 2 syntax and must be patched Install PyRF
after installation to work with Python 3.
.. note::
``lib2to3`` was fully removed in Python 3.13. ThinkRF support is currently limited to
**Python 3.12 and below**.
1. Install ``lib2to3``:
On some distributions (including Ubuntu 24.04+), ``lib2to3`` is not included by default:
.. code-block:: bash .. code-block:: bash
sudo apt install python3-lib2to3 pip install 'pyrf>=2.8.0'
2. Install ``pyrf``: Convert PyRF scripts to Python 3
.. code-block:: bash .. code-block:: bash
pip install pyrf cd ../scripts
./convert_pyrf_to_python3.sh
3. Patch ``pyrf`` for Python 3:
The ``pyrf`` package contains Python 2 syntax throughout (e.g., ``dict.iteritems()``, ``print``
statements). Run the following to automatically convert the entire package to Python 3:
.. code-block:: bash
python -c "
from lib2to3.refactor import RefactoringTool, get_fixers_from_package
import pyrf, os
pyrf_path = os.path.dirname(pyrf.__file__)
fixers = get_fixers_from_package('lib2to3.fixes')
tool = RefactoringTool(fixers)
tool.refactor_dir(pyrf_path, write=True)
print('Done')
"
.. note::
This patches the entire ``pyrf`` package in place, which is required for the driver to fully load.
Further Information Further Information
------------------- -------------------

View File

@ -41,97 +41,34 @@ Limitations
- Compatibility with certain software tools may vary depending on the version of the UHD. - Compatibility with certain software tools may vary depending on the version of the UHD.
- Price range can be a consideration, especially for high-end models. - Price range can be a consideration, especially for high-end models.
Set up instructions (Linux) Set up instructions (Linux, Radioconda)
--------------------------- ---------------------------------------
USRP devices require the UHD (USRP Hardware Driver) library with Python bindings. There is no pip-installable 1. Activate your Radioconda environment:
UHD package — it must either be installed via conda or built from source.
**Option A: Install via conda (recommended for conda environments)** .. code-block:: bash
conda activate <your-env-name>
2. Install UHD and Python bindings:
.. code-block:: bash .. code-block:: bash
conda install conda-forge::uhd conda install conda-forge::uhd
**Option B: Build from source (required for pip/venv environments)** 3. Download UHD images:
The Python bindings must target the same Python version used in your virtual environment.
1. Install build dependencies:
.. code-block:: bash
sudo apt install cmake build-essential libboost-all-dev libusb-1.0-0-dev \
python3-dev python3-numpy libncurses-dev
2. Install the Mako template library into your virtual environment (used by UHD's build system):
.. code-block:: bash
pip install mako
3. Clone and build UHD with your virtual environment activated:
.. code-block:: bash
git clone https://github.com/EttusResearch/uhd.git
cd uhd
git checkout v4.7.0.0
cd host
mkdir build && cd build
cmake -DENABLE_PYTHON_API=ON -DPYTHON_EXECUTABLE=$(which python3) ..
make -j$(nproc)
sudo make install
sudo ldconfig
.. important::
Run the ``cmake`` command with your virtual environment activated so ``$(which python3)`` points
to the correct interpreter. Before running ``make``, verify the cmake output includes::
-- * LibUHD - Python API → must say "Enabling"
-- Python interpreter: .../your-venv/bin/python3
If "LibUHD - Python API" is not listed under enabled components, the Python bindings will not be
built. The build typically takes 1030 minutes.
4. Copy the Python bindings into your virtual environment if ``import uhd`` fails after installation:
.. code-block:: bash
cp -r ~/uhd/host/build/python/uhd ~/.venv/lib/python3.XX/site-packages/
Replace ``python3.XX`` with your Python version (e.g., ``python3.12``).
.. note::
If you have a pre-existing UHD installation built against a different Python version, you will see
a circular import error. The bindings must match the Python version in your virtual environment exactly.
**After either installation method:**
1. Download UHD FPGA/firmware images:
.. code-block:: bash .. code-block:: bash
uhd_images_downloader uhd_images_downloader
2. Verify device access: 4. Verify access to your device:
.. code-block:: bash .. code-block:: bash
uhd_find_devices uhd_find_devices
For USB devices (e.g. B-series), install a ``udev`` rule. For USB devices only (e.g. B series), install a ``udev`` rule by creating a link into your Radioconda installation.
For most users:
.. code-block:: bash
sudo udevadm control --reload
sudo udevadm trigger
For **Radioconda** users, create a symlink from your conda environment instead:
.. code-block:: bash .. code-block:: bash
@ -139,7 +76,7 @@ UHD package — it must either be installed via conda or built from source.
sudo udevadm control --reload sudo udevadm control --reload
sudo udevadm trigger sudo udevadm trigger
3. (Optional) Update firmware/FPGA images: 5. (Optional) Update firmware/FPGA images:
.. code-block:: bash .. code-block:: bash

911
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ria-toolkit-oss" name = "ria-toolkit-oss"
version = "0.1.5" version = "0.1.4"
description = "An open-source version of the RIA Toolkit, including the fundamental tools to get started developing, testing, and deploying radio intelligence applications" description = "An open-source version of the RIA Toolkit, including the fundamental tools to get started developing, testing, and deploying radio intelligence applications"
license = { text = "AGPL-3.0-only" } license = { text = "AGPL-3.0-only" }
readme = "README.md" readme = "README.md"
@ -49,8 +49,7 @@ dependencies = [
"pyzmq (>=27.1.0,<28.0.0)", "pyzmq (>=27.1.0,<28.0.0)",
"pyyaml (>=6.0.3,<7.0.0)", "pyyaml (>=6.0.3,<7.0.0)",
"click (>=8.1.0,<9.0.0)", "click (>=8.1.0,<9.0.0)",
"matplotlib (>=3.8.0,<4.0.0)", "matplotlib (>=3.8.0,<4.0.0)"
"paramiko (>=3.5.1)"
] ]
# [project.optional-dependencies] Commented out to prevent Tox tests from failing # [project.optional-dependencies] Commented out to prevent Tox tests from failing
@ -88,7 +87,7 @@ pytest = "^8.0.0"
tox = "^4.19.0" tox = "^4.19.0"
fastapi = ">=0.111,<1.0" fastapi = ">=0.111,<1.0"
uvicorn = {version = ">=0.29,<1.0", extras = ["standard"]} uvicorn = {version = ">=0.29,<1.0", extras = ["standard"]}
onnxruntime = {version = ">=1.17,<2.0", python = ">=3.11"} onnxruntime = ">=1.17,<2.0"
httpx = ">=0.27,<1.0" httpx = ">=0.27,<1.0"
[tool.poetry.group.docs.dependencies] [tool.poetry.group.docs.dependencies]
@ -119,12 +118,11 @@ ria = "ria_toolkit_oss_cli.cli:cli"
ria-tools = "ria_toolkit_oss_cli.cli:cli" ria-tools = "ria_toolkit_oss_cli.cli:cli"
ria-server = "ria_toolkit_oss.server.cli:serve" ria-server = "ria_toolkit_oss.server.cli:serve"
ria-agent = "ria_toolkit_oss.agent.cli:main" ria-agent = "ria_toolkit_oss.agent.cli:main"
ria-app = "ria_toolkit_oss.app.cli:main"
[tool.poetry.group.server.dependencies] [tool.poetry.group.server.dependencies]
fastapi = ">=0.111,<1.0" fastapi = ">=0.111,<1.0"
uvicorn = {version = ">=0.29,<1.0", extras = ["standard"]} uvicorn = {version = ">=0.29,<1.0", extras = ["standard"]}
onnxruntime = {version = ">=1.17,<2.0", python = ">=3.11"} onnxruntime = ">=1.17,<2.0"
[tool.black] [tool.black]
line-length = 119 line-length = 119
@ -149,11 +147,6 @@ exclude = '''
[tool.pytest.ini_options] [tool.pytest.ini_options]
pythonpath = ["src"] pythonpath = ["src"]
filterwarnings = [
# FastAPI emits this internally when handling 422 responses; the constant
# is not yet renamed in the installed starlette version, so we can't migrate.
"ignore:'HTTP_422_UNPROCESSABLE_ENTITY' is deprecated:DeprecationWarning",
]
[tool.isort] [tool.isort]
profile = "black" profile = "black"

View File

@ -1,225 +0,0 @@
#!/usr/bin/env python3
"""Transmit a continuous tone through the agent's TX pipeline on a real Pluto.
End-to-end smoke test for the Pluto + Streamer TX path. Drives the same
``Streamer`` the hub talks to, but in-process with a logging ``FakeWs`` so
the script is self-contained no hub required.
Default: 100 kHz baseband tone × 2 450 MHz LO carrier at 2 450.1 MHz,
continuous until you Ctrl-C (or the ``--duration`` timer fires). A spectrum
analyzer tuned to 2 450.1 MHz should show a clean CW spike as long as
``tx_status: transmitting`` prints.
Usage::
python3 scripts/pluto_tx_smoke.py # auto-discover Pluto
python3 scripts/pluto_tx_smoke.py --identifier 192.168.3.1
python3 scripts/pluto_tx_smoke.py --frequency 2.4e9 --gain -20 --duration 60
Flags map 1:1 onto the agent's ``radio_config``:
--identifier Pluto IP or hostname (omitted ip:pluto.local).
--frequency TX LO in Hz. Default 2 450 MHz.
--gain Pluto TX gain in dB. Pluto range is ``[-89, 0]``; more negative
= more attenuation = less power. Default -30.
--sample-rate Baseband sample rate. Default 1 MHz.
--tone Baseband tone offset in Hz. Default 100 kHz; set 0 for DC
(unmodulated carrier at exactly --frequency, but Pluto's
LO leakage will dominate).
--buffer-size Complex samples per WS frame. Default 4096.
--duration Stop after this many seconds (0 = run until Ctrl-C).
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import signal
import sys
import numpy as np
from ria_toolkit_oss.agent.config import AgentConfig
from ria_toolkit_oss.agent.streamer import Streamer
class LoggingFakeWs:
"""In-process stand-in for the hub's WebSocket.
Prints every ``tx_status`` + ``error`` frame the Streamer emits so the
operator can watch the lifecycle (armed transmitting done) on stdout.
"""
async def send_json(self, payload: dict) -> None:
t = payload.get("type")
if t == "tx_status":
state = payload.get("state")
msg = payload.get("message")
tail = f"{msg}" if msg else ""
print(f"[tx_status] {state}{tail}")
elif t == "error":
print(f"[error] {payload.get('message')}")
async def send_bytes(self, data: bytes) -> None:
# Agent side won't send RX bytes in this script (no RX session).
pass
def _make_iq_frame(
buffer_size: int, tone_hz: float, sample_rate: float, phase_offset: float = 0.0
) -> tuple[bytes, float]:
"""Return ``(interleaved_float32_bytes, next_phase)`` for a sine tone.
Emitting one continuous phase-coherent tone requires threading the phase
across frames; the returned ``next_phase`` should be fed back as
``phase_offset`` on the next call so the sinusoid doesn't glitch at frame
boundaries. Amplitude is 0.7 to leave some headroom below the [-1, 1] cap
that ``_verify_sample_format`` polices elsewhere in the toolkit.
"""
n = np.arange(buffer_size, dtype=np.float64)
phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset
amp = 0.7
iq = amp * (np.cos(phase) + 1j * np.sin(phase))
iq = iq.astype(np.complex64)
interleaved = np.empty(buffer_size * 2, dtype=np.float32)
interleaved[0::2] = iq.real
interleaved[1::2] = iq.imag
next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi)
return interleaved.tobytes(), next_phase
def _make_pluto_factory(identifier: str | None):
def factory(device: str, _ident: str | None):
if device != "pluto":
raise ValueError(f"this script only drives pluto; got device={device!r}")
from ria_toolkit_oss.sdr.pluto import Pluto
return Pluto(identifier=identifier)
return factory
async def _run(args: argparse.Namespace) -> int:
ws = LoggingFakeWs()
cfg = AgentConfig(
tx_enabled=True,
# Pluto's TX gain range is [-89, 0]. Cap at 0 so a fat-fingered
# --gain=+5 still gets rejected at the agent boundary rather than
# turned into mystery attenuation by Pluto's setter.
tx_max_gain_db=0.0,
tx_max_duration_s=float(args.duration) if args.duration > 0 else None,
)
streamer = Streamer(ws=ws, sdr_factory=_make_pluto_factory(args.identifier), cfg=cfg)
await streamer.on_message(
{
"type": "tx_start",
"app_id": "smoke",
"radio_config": {
"device": "pluto",
"identifier": args.identifier,
"tx_sample_rate": int(args.sample_rate),
"tx_center_frequency": int(args.frequency),
"tx_gain": int(args.gain),
"buffer_size": int(args.buffer_size),
# "repeat" keeps the last buffer on the air if we ever stall,
# so a continuous carrier stays up even when Python GC or
# asyncio scheduling briefly pauses the producer.
"underrun_policy": "repeat",
},
}
)
# Abort if tx_start was rejected by an interlock (no session → nothing to do).
if streamer._tx is None:
print("tx_start rejected — see [tx_status] line above for the reason.", file=sys.stderr)
return 2
print(
f"Transmitting at {args.frequency/1e6:.3f} MHz with "
f"{args.tone/1e3:.1f} kHz baseband tone at gain {args.gain} dB. "
f"{'Running for ' + str(args.duration) + 's' if args.duration > 0 else 'Run until Ctrl-C'}."
)
# Arrange a clean shutdown on Ctrl-C.
stop = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(sig, stop.set)
except NotImplementedError:
# add_signal_handler is not available on Windows event loops.
pass
# Produce buffers at the nominal sample-rate pace. We deliberately stay
# slightly ahead of the radio — queue is bounded at 8, so backpressure
# flows naturally.
phase = 0.0
buffer_dt = args.buffer_size / args.sample_rate
# Aim for one buffer every ``buffer_dt * 0.5`` seconds so the queue stays
# topped up. The queue's own backpressure keeps us from spinning.
produce_interval = buffer_dt * 0.5
try:
async def producer():
nonlocal phase
while not stop.is_set():
frame, phase = _make_iq_frame(args.buffer_size, args.tone, args.sample_rate, phase)
await streamer.on_binary(frame)
await asyncio.sleep(produce_interval)
producer_task = asyncio.create_task(producer())
if args.duration > 0:
try:
await asyncio.wait_for(stop.wait(), timeout=args.duration)
except asyncio.TimeoutError:
pass
else:
await stop.wait()
stop.set()
producer_task.cancel()
try:
await producer_task
except (asyncio.CancelledError, Exception):
pass
finally:
await streamer.on_message({"type": "tx_stop", "app_id": "smoke"})
print("TX session closed.")
return 0
def main() -> int:
p = argparse.ArgumentParser(
description="End-to-end TX smoke test: agent → Pluto continuous tone.",
)
p.add_argument("--identifier", default=None, help="Pluto IP/hostname (default: auto-discover pluto.local)")
p.add_argument("--frequency", type=float, default=3_410_000_000.0, help="TX LO in Hz (default 2.45 GHz)")
p.add_argument("--gain", type=float, default=-0.0, help="TX gain in dB; Pluto range [-89, 0] (default -30)")
p.add_argument("--sample-rate", type=float, default=1_000_000.0, help="Baseband sample rate (default 1 Msps)")
p.add_argument(
"--tone", type=float, default=100_000.0, help="Baseband tone offset in Hz; 0 = DC (default 100 kHz)"
)
p.add_argument("--buffer-size", type=int, default=4096, help="Complex samples per frame (default 4096)")
p.add_argument(
"--duration", type=float, default=60.0, help="Seconds to transmit; 0 = run until Ctrl-C (default 30)"
)
p.add_argument("--log-level", default="INFO")
args = p.parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper(), logging.INFO),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
try:
return asyncio.run(_run(args))
except KeyboardInterrupt:
return 130
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,230 +0,0 @@
#!/usr/bin/env python3
"""Full-stack TX smoke test: localhost mock-hub → WS → agent → real Pluto.
Same radio output as ``pluto_tx_smoke.py`` (continuous tone at 2 450.1 MHz),
but drives the agent through the *real* WebSocket path instead of calling
handlers in-process. Proves that the hub-driven path behaves identically:
mock hub ws:// WsClient.run() Streamer.on_message
Streamer.on_binary
real Pluto
This is the most rigorous check short of pointing the real ``ria-agent stream``
at a live ria-hub. If a tone appears on the spectrum analyzer here but *not*
when ria-hub drives it, the fault is above the WS decoder (registration,
capability gate, TX operator, hub's binary-frame publisher); everything
downstream of ``ws.recv()`` is this script's code path.
Usage::
python3 scripts/pluto_tx_ws_smoke.py # default 30s tone
python3 scripts/pluto_tx_ws_smoke.py --identifier 192.168.3.1
python3 scripts/pluto_tx_ws_smoke.py --duration 0 # until Ctrl-C
"""
from __future__ import annotations
import argparse
import asyncio
import json
import logging
import signal
import sys
import numpy as np
import websockets
from ria_toolkit_oss.agent.config import AgentConfig
from ria_toolkit_oss.agent.streamer import Streamer
from ria_toolkit_oss.agent.ws_client import WsClient
def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float, phase_offset: float) -> tuple[bytes, float]:
n = np.arange(buffer_size, dtype=np.float64)
phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset
amp = 0.7
iq = (amp * (np.cos(phase) + 1j * np.sin(phase))).astype(np.complex64)
interleaved = np.empty(buffer_size * 2, dtype=np.float32)
interleaved[0::2] = iq.real
interleaved[1::2] = iq.imag
next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi)
return interleaved.tobytes(), next_phase
def _make_pluto_factory(identifier: str | None):
def factory(device: str, _ident: str | None):
if device != "pluto":
raise ValueError(f"this script only drives pluto; got device={device!r}")
from ria_toolkit_oss.sdr.pluto import Pluto
return Pluto(identifier=identifier)
return factory
async def _mock_hub_handler(ws, args, stop: asyncio.Event):
"""Server side of the WS. Sends tx_start, streams IQ, then tx_stop."""
# Drain the first heartbeat so the log is clean; we don't need to gate on
# it for a localhost smoke test.
try:
first = await asyncio.wait_for(ws.recv(), timeout=2.0)
if isinstance(first, str):
payload = json.loads(first)
if payload.get("type") == "heartbeat":
caps = payload.get("capabilities")
print(f"[mock-hub] agent heartbeat: capabilities={caps} " f"tx_enabled={payload.get('tx_enabled')}")
except asyncio.TimeoutError:
print("[mock-hub] warning: no heartbeat received in first 2s")
# Arm the agent's TX path.
await ws.send(
json.dumps(
{
"type": "tx_start",
"app_id": "ws-smoke",
"radio_config": {
"device": "pluto",
"identifier": args.identifier,
"tx_sample_rate": int(args.sample_rate),
"tx_center_frequency": int(args.frequency),
"tx_gain": int(args.gain),
"buffer_size": int(args.buffer_size),
"underrun_policy": "repeat",
},
}
)
)
print(f"[mock-hub] sent tx_start at {args.frequency/1e6:.3f} MHz, " f"gain={args.gain} dB")
# Producer: push IQ frames at a steady clip. Use a concurrent receiver so
# tx_status frames show up in real time rather than being queued behind
# the sends.
phase = 0.0
buffer_dt = args.buffer_size / args.sample_rate
async def receiver():
try:
while True:
msg = await ws.recv()
if isinstance(msg, str):
print(f"[mock-hub] ← {msg}")
except (websockets.ConnectionClosed, asyncio.CancelledError):
pass
recv_task = asyncio.create_task(receiver())
try:
deadline = None if args.duration <= 0 else (asyncio.get_event_loop().time() + args.duration)
while not stop.is_set():
if deadline is not None and asyncio.get_event_loop().time() >= deadline:
break
frame, phase = _make_iq_frame(args.buffer_size, args.tone, args.sample_rate, phase)
try:
await ws.send(frame)
except websockets.ConnectionClosed:
break
# Slightly ahead of real-time; WS backpressure handles the rest.
await asyncio.sleep(buffer_dt * 0.5)
finally:
try:
await ws.send(json.dumps({"type": "tx_stop", "app_id": "ws-smoke"}))
print("[mock-hub] sent tx_stop")
except websockets.ConnectionClosed:
pass
# Give the agent a moment to emit `tx_status: done` before we tear down.
await asyncio.sleep(0.3)
recv_task.cancel()
try:
await recv_task
except (asyncio.CancelledError, Exception):
pass
async def _run(args: argparse.Namespace) -> int:
stop = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(sig, stop.set)
except NotImplementedError:
pass
# Start the mock hub on a local port.
async def handler(ws):
try:
await _mock_hub_handler(ws, args, stop)
finally:
stop.set()
server = await websockets.serve(handler, "127.0.0.1", 0)
port = server.sockets[0].getsockname()[1]
print(f"[mock-hub] listening on ws://127.0.0.1:{port}")
# Run the agent — exactly as ``ria-agent stream`` would, just with a
# different URL and an in-memory AgentConfig instead of one loaded from
# ``~/.ria/agent.json``.
client = WsClient(
f"ws://127.0.0.1:{port}",
token="",
heartbeat_interval=5.0,
reconnect_pause=0.5,
)
streamer = Streamer(
ws=client,
sdr_factory=_make_pluto_factory(args.identifier),
cfg=AgentConfig(tx_enabled=True, tx_max_gain_db=0.0),
)
client_task = asyncio.create_task(
client.run(
on_message=streamer.on_message,
heartbeat=streamer.build_heartbeat,
on_binary=streamer.on_binary,
)
)
try:
await stop.wait()
finally:
client.stop()
client_task.cancel()
try:
await client_task
except (asyncio.CancelledError, Exception):
pass
server.close()
await server.wait_closed()
print("Done.")
return 0
def main() -> int:
p = argparse.ArgumentParser(
description="Full-stack TX smoke: localhost mock-hub → WS → agent → Pluto.",
)
p.add_argument("--identifier", default=None, help="Pluto IP/hostname (default: auto-discover pluto.local)")
p.add_argument("--frequency", type=float, default=2_450_000_000.0, help="TX LO in Hz (default 2.45 GHz)")
p.add_argument("--gain", type=float, default=0.0, help="TX gain in dB; Pluto range [-89, 0] (default 0)")
p.add_argument("--sample-rate", type=float, default=1_000_000.0, help="Baseband sample rate (default 1 Msps)")
p.add_argument("--tone", type=float, default=100_000.0, help="Baseband tone offset in Hz (default 100 kHz)")
p.add_argument("--buffer-size", type=int, default=4096, help="Complex samples per frame (default 4096)")
p.add_argument(
"--duration", type=float, default=30.0, help="Seconds to transmit; 0 = run until Ctrl-C (default 30)"
)
p.add_argument("--log-level", default="INFO")
args = p.parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper(), logging.INFO),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
try:
return asyncio.run(_run(args))
except KeyboardInterrupt:
return 130
if __name__ == "__main__":
sys.exit(main())

View File

@ -5,8 +5,8 @@ Subcommands:
- ``ria-agent run [legacy args]`` legacy long-poll NodeAgent (unchanged). - ``ria-agent run [legacy args]`` legacy long-poll NodeAgent (unchanged).
- ``ria-agent stream`` new WebSocket-based IQ streamer. - ``ria-agent stream`` new WebSocket-based IQ streamer.
- ``ria-agent detect`` print SDR drivers whose modules import cleanly. - ``ria-agent detect`` print SDR drivers whose modules import cleanly.
- ``ria-agent register --hub URL --api-key KEY`` register with the hub and - ``ria-agent register --url URL --token TOKEN`` save credentials to
save credentials (and optional TX interlocks) to ``~/.ria/agent.json``. ``~/.ria/agent.json``.
Invoking ``ria-agent`` with no subcommand falls through to the legacy Invoking ``ria-agent`` with no subcommand falls through to the legacy
long-poll behavior for back-compatibility with existing deployments. long-poll behavior for back-compatibility with existing deployments.
@ -23,7 +23,6 @@ import sys
from . import config as _config from . import config as _config
from .hardware import available_devices from .hardware import available_devices
from .legacy_executor import main as _legacy_main from .legacy_executor import main as _legacy_main
from .namegen import generate_agent_name
_LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"} _LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"}
@ -43,8 +42,7 @@ def _cmd_register(args: argparse.Namespace) -> int:
hub_url = args.hub.rstrip("/") hub_url = args.hub.rstrip("/")
url = f"{hub_url}/screens/agents/register" url = f"{hub_url}/screens/agents/register"
name = args.name or generate_agent_name() body = json.dumps({"name": args.name or ""}).encode()
body = json.dumps({"name": name}).encode()
req = urllib.request.Request( req = urllib.request.Request(
url, url,
data=body, data=body,
@ -68,29 +66,12 @@ def _cmd_register(args: argparse.Namespace) -> int:
cfg.agent_id = agent_id cfg.agent_id = agent_id
cfg.token = token cfg.token = token
cfg.api_key = args.api_key cfg.api_key = args.api_key
cfg.name = name if args.name:
cfg.name = args.name
cfg.insecure = bool(args.insecure) cfg.insecure = bool(args.insecure)
cfg.tx_enabled = bool(getattr(args, "allow_tx", False))
if (v := getattr(args, "tx_max_gain_db", None)) is not None:
cfg.tx_max_gain_db = float(v)
if (v := getattr(args, "tx_max_duration_s", None)) is not None:
cfg.tx_max_duration_s = float(v)
freq_ranges = getattr(args, "tx_freq_range", None) or []
if freq_ranges:
cfg.tx_allowed_freq_ranges = [[float(lo), float(hi)] for lo, hi in freq_ranges]
path = _config.save(cfg) path = _config.save(cfg)
print(f"Registered agent: {agent_id}") print(f"Registered agent: {agent_id}")
if cfg.tx_enabled:
caps: list[str] = []
if cfg.tx_max_gain_db is not None:
caps.append(f"gain<={cfg.tx_max_gain_db} dB")
if cfg.tx_max_duration_s is not None:
caps.append(f"duration<={cfg.tx_max_duration_s} s")
if cfg.tx_allowed_freq_ranges:
caps.append(f"freq in {cfg.tx_allowed_freq_ranges}")
tail = f" ({', '.join(caps)})" if caps else ""
print(f"TX enabled{tail}")
print(f"Credentials saved to {path}") print(f"Credentials saved to {path}")
return 0 return 0
@ -104,10 +85,8 @@ def _cmd_stream(args: argparse.Namespace) -> int:
if not url: if not url:
print("error: --url is required (or run `ria-agent register` first)", file=sys.stderr) print("error: --url is required (or run `ria-agent register` first)", file=sys.stderr)
return 2 return 2
if getattr(args, "allow_tx", False):
cfg.tx_enabled = True
try: try:
asyncio.run(run_streamer(url, token, cfg=cfg)) asyncio.run(run_streamer(url, token))
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
return 0 return 0
@ -118,9 +97,9 @@ def _derive_ws_url(hub_url: str, agent_id: str) -> str:
return "" return ""
base = hub_url.rstrip("/") base = hub_url.rstrip("/")
if base.startswith("https://"): if base.startswith("https://"):
base = "wss://" + base[len("https://") :] base = "wss://" + base[len("https://"):]
elif base.startswith("http://"): elif base.startswith("http://"):
base = "ws://" + base[len("http://") :] base = "ws://" + base[len("http://"):]
suffix = f"/screens/agent/ws?agent_id={agent_id}" if agent_id else "/screens/agent/ws" suffix = f"/screens/agent/ws?agent_id={agent_id}" if agent_id else "/screens/agent/ws"
return base + suffix return base + suffix
@ -144,47 +123,11 @@ def main() -> None:
p_reg.add_argument("--api-key", dest="api_key", required=True, help="Hub API key") p_reg.add_argument("--api-key", dest="api_key", required=True, help="Hub API key")
p_reg.add_argument("--name", default=None, help="Human-friendly agent name") p_reg.add_argument("--name", default=None, help="Human-friendly agent name")
p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification") p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification")
p_reg.add_argument(
"--allow-tx",
dest="allow_tx",
action="store_true",
help="Opt this agent in to TX (required for any transmission from the hub)",
)
p_reg.add_argument(
"--tx-max-gain-db",
dest="tx_max_gain_db",
type=float,
default=None,
help="Reject tx_start frames whose tx_gain exceeds this cap (dB)",
)
p_reg.add_argument(
"--tx-max-duration-s",
dest="tx_max_duration_s",
type=float,
default=None,
help="Auto-stop any TX session after this many seconds",
)
p_reg.add_argument(
"--tx-freq-range",
dest="tx_freq_range",
type=float,
nargs=2,
action="append",
metavar=("LO", "HI"),
default=None,
help="Allowed TX center-frequency range in Hz (repeat for multiple bands)",
)
p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer") p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer")
p_stream.add_argument("--url", default=None, help="Override WebSocket URL") p_stream.add_argument("--url", default=None, help="Override WebSocket URL")
p_stream.add_argument("--token", default=None, help="Override bearer token") p_stream.add_argument("--token", default=None, help="Override bearer token")
p_stream.add_argument("--log-level", default="INFO") p_stream.add_argument("--log-level", default="INFO")
p_stream.add_argument(
"--allow-tx",
dest="allow_tx",
action="store_true",
help="Runtime override: enable TX for this process without writing config",
)
# Unknown extras are forwarded to the legacy CLI when command == "run". # Unknown extras are forwarded to the legacy CLI when command == "run".
args, extras = parser.parse_known_args(argv) args, extras = parser.parse_known_args(argv)

View File

@ -7,11 +7,7 @@ Schema::
"agent_id": "agent-abc123", "agent_id": "agent-abc123",
"token": "rha_xxxx", "token": "rha_xxxx",
"name": "lab-bench-1", "name": "lab-bench-1",
"insecure": false, "insecure": false
"tx_enabled": false,
"tx_max_gain_db": null,
"tx_max_duration_s": null,
"tx_allowed_freq_ranges": null
} }
""" """
@ -22,9 +18,7 @@ import os
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from pathlib import Path from pathlib import Path
_DEFAULT_PATH = Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json")))
def _resolve_default_path() -> Path:
return Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json")))
@dataclass @dataclass
@ -35,29 +29,15 @@ class AgentConfig:
name: str = "" name: str = ""
insecure: bool = False insecure: bool = False
api_key: str = "" api_key: str = ""
tx_enabled: bool = False
tx_max_gain_db: float | None = None
tx_max_duration_s: float | None = None
tx_allowed_freq_ranges: list[list[float]] | None = None
extra: dict = field(default_factory=dict) extra: dict = field(default_factory=dict)
def default_path() -> Path: def default_path() -> Path:
return _resolve_default_path() return _DEFAULT_PATH
def _coerce_ranges(raw) -> list[list[float]] | None:
if raw is None:
return None
out: list[list[float]] = []
for pair in raw:
lo, hi = pair
out.append([float(lo), float(hi)])
return out
def load(path: Path | None = None) -> AgentConfig: def load(path: Path | None = None) -> AgentConfig:
p = path or _resolve_default_path() p = path or _DEFAULT_PATH
if not p.exists(): if not p.exists():
return AgentConfig() return AgentConfig()
data = json.loads(p.read_text()) data = json.loads(p.read_text())
@ -70,16 +50,12 @@ def load(path: Path | None = None) -> AgentConfig:
name=data.get("name", ""), name=data.get("name", ""),
insecure=bool(data.get("insecure", False)), insecure=bool(data.get("insecure", False)),
api_key=data.get("api_key", ""), api_key=data.get("api_key", ""),
tx_enabled=bool(data.get("tx_enabled", False)),
tx_max_gain_db=(float(v) if (v := data.get("tx_max_gain_db")) is not None else None),
tx_max_duration_s=(float(v) if (v := data.get("tx_max_duration_s")) is not None else None),
tx_allowed_freq_ranges=_coerce_ranges(data.get("tx_allowed_freq_ranges")),
extra=extra, extra=extra,
) )
def save(cfg: AgentConfig, path: Path | None = None) -> Path: def save(cfg: AgentConfig, path: Path | None = None) -> Path:
p = path or _resolve_default_path() p = path or _DEFAULT_PATH
p.parent.mkdir(parents=True, exist_ok=True) p.parent.mkdir(parents=True, exist_ok=True)
data = asdict(cfg) data = asdict(cfg)
extra = data.pop("extra", {}) or {} extra = data.pop("extra", {}) or {}

View File

@ -4,51 +4,19 @@ from __future__ import annotations
from ria_toolkit_oss.sdr import detect_available from ria_toolkit_oss.sdr import detect_available
from .config import AgentConfig
def available_devices() -> list[str]: def available_devices() -> list[str]:
"""Return a sorted list of device names whose driver modules import cleanly.""" """Return a sorted list of device names whose driver modules import cleanly."""
return sorted(detect_available().keys()) return sorted(detect_available().keys())
def heartbeat_payload( def heartbeat_payload(status: str = "idle", app_id: str | None = None) -> dict:
status: str = "idle", """Build the JSON body of a periodic heartbeat frame."""
app_id: str | None = None,
*,
cfg: AgentConfig | None = None,
sessions: dict | None = None,
) -> dict:
"""Build the JSON body of a periodic heartbeat frame.
*cfg* drives the ``capabilities`` list and the ``tx_enabled`` flag. If not
supplied, the heartbeat advertises RX-only with ``tx_enabled=False``
matching the pre-TX shape.
"""
c = cfg or AgentConfig()
capabilities = ["rx"]
if c.tx_enabled:
capabilities.append("tx")
payload: dict = { payload: dict = {
"type": "heartbeat", "type": "heartbeat",
"hardware": available_devices(), "hardware": available_devices(),
"status": status, "status": status,
"capabilities": capabilities,
"tx_enabled": bool(c.tx_enabled),
} }
# Surface configured interlock values so the hub can pre-filter UI controls
# before sending a tx_start that would be rejected. Only included when TX
# is opted in AND the operator set a cap.
if c.tx_enabled:
if c.tx_max_gain_db is not None:
payload["tx_max_gain_db"] = float(c.tx_max_gain_db)
if c.tx_max_duration_s is not None:
payload["tx_max_duration_s"] = float(c.tx_max_duration_s)
if c.tx_allowed_freq_ranges:
payload["tx_allowed_freq_ranges"] = [[float(lo), float(hi)] for lo, hi in c.tx_allowed_freq_ranges]
if app_id: if app_id:
payload["app_id"] = app_id payload["app_id"] = app_id
if sessions:
payload["sessions"] = sessions
return payload return payload

View File

@ -20,7 +20,7 @@ Usage::
The agent: The agent:
1. Registers with RIA Hub and receives a ``node_id``. 1. Registers with RIA Hub and receives a ``node_id``.
2. Sends a heartbeat every 30 s so the hub knows it is online. 2. Sends a heartbeat every 30 s so the hub knows it is online.
3. Long-polls ``GET /composer/nodes/{id}/commands`` (30 s timeout). 3. Long-polls ``GET /orchestrator/nodes/{id}/commands`` (30 s timeout).
4. Dispatches received commands: 4. Dispatches received commands:
- ``run_campaign``: executes via CampaignExecutor, uploads recordings. - ``run_campaign``: executes via CampaignExecutor, uploads recordings.
- ``load_model``: loads an ONNX fingerprint or detector model. - ``load_model``: loads an ONNX fingerprint or detector model.
@ -68,7 +68,7 @@ _HEARTBEAT_INTERVAL = 30 # seconds between heartbeats
_POLL_TIMEOUT = 30 # server-side long-poll duration _POLL_TIMEOUT = 30 # server-side long-poll duration
_POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server _POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying _RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
_CHUNK_SIZE = 10 * 1024 * 1024 # 10 MB per chunk — fast enough for git-LFS to process within timeout _CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB — well below Cloudflare's 100 MB limit
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload _DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
_CAPTURE_SAMPLES = 4096 # IQ samples per inference window _CAPTURE_SAMPLES = 4096 # IQ samples per inference window
_IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"}) _IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"})
@ -93,24 +93,16 @@ class NodeAgent:
name: str, name: str,
sdr_device: str = "unknown", sdr_device: str = "unknown",
insecure: bool = False, insecure: bool = False,
role: str = "general",
session_code: str | None = None,
) -> None: ) -> None:
self.hub_url = hub_url.rstrip("/") self.hub_url = hub_url.rstrip("/")
self.api_key = api_key self.api_key = api_key
self.name = name self.name = name
self.sdr_device = sdr_device self.sdr_device = sdr_device
self.insecure = insecure self.insecure = insecure
self.role = role
self.session_code = session_code
self.node_id: str | None = None self.node_id: str | None = None
self._stop = threading.Event() self._stop = threading.Event()
# ── TX state ────────────────────────────────────────────────────────
self._tx_stop = threading.Event()
self._tx_thread: threading.Thread | None = None
# ── Inference state ───────────────────────────────────────────────── # ── Inference state ─────────────────────────────────────────────────
# Protected by _inf_lock for cross-thread model swaps. # Protected by _inf_lock for cross-thread model swaps.
self._inf_lock = threading.Lock() self._inf_lock = threading.Lock()
@ -180,33 +172,25 @@ class NodeAgent:
capabilities = ["campaign"] capabilities = ["campaign"]
if self._ort_available: if self._ort_available:
capabilities.append("inference") capabilities.append("inference")
if self.role == "tx": resp = self._post(
capabilities.append("transmit") "/orchestrator/nodes/register",
payload: dict = { json={
"name": self.name, "name": self.name,
"sdr_device": self.sdr_device, "sdr_device": self.sdr_device,
"ria_toolkit_version": self._ria_version, "ria_toolkit_version": self._ria_version,
"capabilities": capabilities, "capabilities": capabilities,
"role": self.role, },
} timeout=15,
if self.session_code: )
payload["session_code"] = self.session_code
resp = self._post("/composer/nodes/register", json=payload, timeout=15)
resp.raise_for_status() resp.raise_for_status()
self.node_id = resp.json()["node_id"] self.node_id = resp.json()["node_id"]
logger.info( logger.info("Registered as %r (node_id=%s)", self.name, self.node_id)
"Registered as %r (node_id=%s, role=%s%s)",
self.name,
self.node_id,
self.role,
f", session_code={self.session_code!r}" if self.session_code else "",
)
def _deregister(self) -> None: def _deregister(self) -> None:
if not self.node_id: if not self.node_id:
return return
try: try:
self._delete(f"/composer/nodes/{self.node_id}", timeout=10) self._delete(f"/orchestrator/nodes/{self.node_id}", timeout=10)
logger.info("Deregistered %s", self.node_id) logger.info("Deregistered %s", self.node_id)
except Exception as exc: except Exception as exc:
logger.debug("Deregister failed (ignored on shutdown): %s", exc) logger.debug("Deregister failed (ignored on shutdown): %s", exc)
@ -218,7 +202,7 @@ class NodeAgent:
def _heartbeat_loop(self) -> None: def _heartbeat_loop(self) -> None:
while not self._stop.wait(_HEARTBEAT_INTERVAL): while not self._stop.wait(_HEARTBEAT_INTERVAL):
try: try:
resp = self._post(f"/composer/nodes/{self.node_id}/heartbeat", timeout=10) resp = self._post(f"/orchestrator/nodes/{self.node_id}/heartbeat", timeout=10)
if resp.status_code == 404: if resp.status_code == 404:
logger.warning("Heartbeat got 404 — hub lost registration, re-registering") logger.warning("Heartbeat got 404 — hub lost registration, re-registering")
self._register() self._register()
@ -233,7 +217,7 @@ class NodeAgent:
while not self._stop.is_set(): while not self._stop.is_set():
try: try:
resp = self._get( resp = self._get(
f"/composer/nodes/{self.node_id}/commands", f"/orchestrator/nodes/{self.node_id}/commands",
timeout=_POLL_CLIENT_TIMEOUT, timeout=_POLL_CLIENT_TIMEOUT,
) )
if resp.status_code == 204: if resp.status_code == 204:
@ -261,10 +245,9 @@ class NodeAgent:
if command == "run_campaign": if command == "run_campaign":
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4()) campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
config_dict: dict = cmd.get("payload") or {} config_dict: dict = cmd.get("payload") or {}
skip_local_tx: bool = bool(cmd.get("skip_local_tx", False))
threading.Thread( threading.Thread(
target=self._run_campaign, target=self._run_campaign,
args=(campaign_id, config_dict, skip_local_tx), args=(campaign_id, config_dict),
daemon=True, daemon=True,
name=f"campaign-{campaign_id[:8]}", name=f"campaign-{campaign_id[:8]}",
).start() ).start()
@ -286,17 +269,6 @@ class NodeAgent:
self._stop_inference() self._stop_inference()
elif command == "configure_inference": elif command == "configure_inference":
self._queue_sdr_config(cmd) self._queue_sdr_config(cmd)
elif command == "start_transmit":
threading.Thread(
target=self._start_transmit,
args=(cmd,),
daemon=True,
name="ria-start-tx",
).start()
elif command == "stop_transmit":
self._stop_transmit()
elif command == "configure_transmit":
logger.info("configure_transmit received — will apply on next step boundary")
else: else:
logger.warning("Unknown command %r — ignored", command) logger.warning("Unknown command %r — ignored", command)
@ -304,7 +276,7 @@ class NodeAgent:
# Campaign execution # Campaign execution
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _run_campaign(self, campaign_id: str, config_dict: dict, skip_local_tx: bool = False) -> None: def _run_campaign(self, campaign_id: str, config_dict: dict) -> None:
try: try:
from ria_toolkit_oss.orchestration.campaign import CampaignConfig from ria_toolkit_oss.orchestration.campaign import CampaignConfig
from ria_toolkit_oss.orchestration.executor import CampaignExecutor from ria_toolkit_oss.orchestration.executor import CampaignExecutor
@ -316,10 +288,10 @@ class NodeAgent:
) )
return return
logger.info("Campaign %s starting (skip_local_tx=%s)", campaign_id[:8], skip_local_tx) logger.info("Campaign %s starting", campaign_id[:8])
try: try:
config = CampaignConfig.from_dict(config_dict) config = CampaignConfig.from_dict(config_dict)
executor = CampaignExecutor(config, skip_local_tx=skip_local_tx) executor = CampaignExecutor(config)
result = executor.run() result = executor.run()
logger.info("Campaign %s completed — uploading recordings", campaign_id[:8]) logger.info("Campaign %s completed — uploading recordings", campaign_id[:8])
self._upload_recordings(campaign_id, config, result) self._upload_recordings(campaign_id, config, result)
@ -329,58 +301,6 @@ class NodeAgent:
logger.error("Campaign %s failed: %s", campaign_id[:8], exc) logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
self._report_campaign_status(campaign_id, "failed", error=str(exc)) self._report_campaign_status(campaign_id, "failed", error=str(exc))
# ------------------------------------------------------------------
# TX execution
# ------------------------------------------------------------------
def _start_transmit(self, cmd: dict) -> None:
"""Execute a synthetic transmit campaign using TxExecutor.
The command payload mirrors a TransmitterConfig dict with an optional
``schedule`` of steps. Each step synthesises a signal and transmits it
via the local SDR in TX mode.
"""
try:
from ria_toolkit_oss.orchestration.tx_executor import TxExecutor
except ImportError as exc:
logger.error("start_transmit: TxExecutor not available: %s", exc)
return
if self._tx_thread and self._tx_thread.is_alive():
logger.warning("start_transmit: TX already running — ignoring duplicate command")
return
self._tx_stop.clear()
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
executor = TxExecutor(
config=cmd,
sdr_device=self.sdr_device,
stop_event=self._tx_stop,
)
self._tx_thread = threading.Thread(
target=self._run_tx_campaign,
args=(executor, campaign_id),
daemon=True,
name=f"tx-campaign-{campaign_id[:8]}",
)
self._tx_thread.start()
def _run_tx_campaign(self, executor: Any, campaign_id: str) -> None:
try:
executor.run()
logger.info("TX campaign %s completed", campaign_id[:8])
self._report_campaign_status(campaign_id, "completed")
except Exception as exc:
logger.error("TX campaign %s failed: %s", campaign_id[:8], exc)
self._report_campaign_status(campaign_id, "failed", error=str(exc))
def _stop_transmit(self) -> None:
"""Signal the TX loop to stop gracefully."""
self._tx_stop.set()
if self._tx_thread and self._tx_thread.is_alive():
self._tx_thread.join(timeout=5.0)
logger.info("TX stopped")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Inference — model loading # Inference — model loading
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -620,7 +540,7 @@ class NodeAgent:
logger.info("Inference loop exited") logger.info("Inference loop exited")
def _post_event(self, device_id: str | None, confidence: float, snr_db: float) -> None: def _post_event(self, device_id: str | None, confidence: float, snr_db: float) -> None:
"""POST a single detection event to ``POST /composer/nodes/{id}/events``. """POST a single detection event to ``POST /orchestrator/nodes/{id}/events``.
Failures are logged at DEBUG level and silently swallowed so that a Failures are logged at DEBUG level and silently swallowed so that a
transient network blip does not crash the inference loop. transient network blip does not crash the inference loop.
@ -636,7 +556,7 @@ class NodeAgent:
} }
try: try:
resp = self._post( resp = self._post(
f"/composer/nodes/{self.node_id}/events", f"/orchestrator/nodes/{self.node_id}/events",
json=payload, json=payload,
timeout=5, timeout=5,
) )
@ -659,18 +579,13 @@ class NodeAgent:
base_url = f"{self.hub_url}/datasets/upload" base_url = f"{self.hub_url}/datasets/upload"
steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or [] steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or []
output_obj = getattr(config, "output", None)
folder = getattr(output_obj, "folder", None)
campaign_name: str = folder if folder is not None else (getattr(config, "name", None) or "")
for step in steps: for step in steps:
output_path: str | None = getattr(step, "output_path", None) output_path: str | None = getattr(step, "output_path", None)
if not output_path: if not output_path:
continue continue
device_id: str = getattr(step, "transmitter_id", "") or "" device_id: str = getattr(step, "transmitter_id", "") or ""
for fpath in _sigmf_files(output_path): for fpath in _sigmf_files(output_path):
basename = os.path.basename(fpath) filename = os.path.basename(fpath)
path_parts = [p for p in (campaign_name, device_id) if p]
filename = "/".join(path_parts + [basename])
metadata = { metadata = {
"filename": filename, "filename": filename,
"repo_owner": repo_owner, "repo_owner": repo_owner,
@ -704,7 +619,7 @@ class NodeAgent:
payload["error"] = error payload["error"] = error
try: try:
resp = self._post( resp = self._post(
f"/composer/nodes/{self.node_id}/campaign-status", f"/orchestrator/nodes/{self.node_id}/campaign-status",
json=payload, json=payload,
timeout=15, timeout=15,
) )
@ -756,7 +671,7 @@ class NodeAgent:
headers=headers, headers=headers,
files={"file": (filename, chunk, "application/octet-stream")}, files={"file": (filename, chunk, "application/octet-stream")},
data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks}, data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks},
timeout=(30, None), # 30s connect, no read timeout — server may take minutes on final chunk timeout=120,
verify=verify, verify=verify,
) )
if not resp.ok: if not resp.ok:
@ -933,21 +848,6 @@ def main() -> None:
choices=["DEBUG", "INFO", "WARNING", "ERROR"], choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging verbosity (default: INFO)", help="Logging verbosity (default: INFO)",
) )
parser.add_argument(
"--role",
default=None,
choices=["general", "rx", "tx"],
help=("Node role reported to the hub. " "'tx' enables synthetic transmission commands. " "Default: general"),
)
parser.add_argument(
"--session-code",
default=None,
metavar="CODE",
help=(
"3-word session code to pair this TX agent with a waiting campaign, "
"e.g. 'amber-peak-transmit'. Supplied by the campaign UI."
),
)
args = parser.parse_args() args = parser.parse_args()
@ -961,8 +861,6 @@ def main() -> None:
device = args.device or cfg.get("device", "unknown") device = args.device or cfg.get("device", "unknown")
insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False) insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False)
log_level = args.log_level or cfg.get("log_level", "INFO") log_level = args.log_level or cfg.get("log_level", "INFO")
role = args.role or cfg.get("role", "general")
session_code = args.session_code or cfg.get("session_code")
if not hub: if not hub:
parser.error("--hub is required (or set 'hub' in the config file)") parser.error("--hub is required (or set 'hub' in the config file)")
@ -990,8 +888,6 @@ def main() -> None:
name=name, name=name,
sdr_device=device, sdr_device=device,
insecure=insecure, insecure=insecure,
role=role,
session_code=session_code,
) )
agent.run() agent.run()

View File

@ -1,147 +0,0 @@
"""Generate random human-readable agent names.
Produces names in the form ``adjective-colour-animal``, e.g.
``swift-teal-falcon`` or ``brave-coral-otter``. All words are chosen
to be friendly and inoffensive.
"""
from __future__ import annotations
import random
ADJECTIVES: list[str] = [
"brave",
"bright",
"calm",
"clever",
"cool",
"daring",
"eager",
"fair",
"fancy",
"fast",
"fierce",
"gentle",
"grand",
"happy",
"jolly",
"keen",
"kind",
"lively",
"lucky",
"mighty",
"noble",
"plucky",
"proud",
"quick",
"quiet",
"sharp",
"shiny",
"sleek",
"smart",
"steady",
"stellar",
"strong",
"sturdy",
"sunny",
"sure",
"swift",
"tall",
"vivid",
"warm",
"wise",
]
COLOURS: list[str] = [
"amber",
"aqua",
"azure",
"beige",
"blue",
"bronze",
"coral",
"copper",
"crimson",
"cyan",
"denim",
"gold",
"green",
"grey",
"indigo",
"ivory",
"jade",
"lemon",
"lilac",
"lime",
"maroon",
"mint",
"navy",
"olive",
"onyx",
"peach",
"pearl",
"plum",
"red",
"rose",
"ruby",
"rust",
"sage",
"sand",
"scarlet",
"silver",
"slate",
"steel",
"teal",
"violet",
]
ANIMALS: list[str] = [
"badger",
"bear",
"bison",
"crane",
"deer",
"dolphin",
"eagle",
"elk",
"falcon",
"finch",
"fox",
"gecko",
"hawk",
"heron",
"horse",
"ibis",
"jaguar",
"jay",
"kite",
"koala",
"lark",
"lion",
"lynx",
"marten",
"moose",
"newt",
"orca",
"osprey",
"otter",
"owl",
"panda",
"puma",
"raven",
"robin",
"salmon",
"seal",
"shark",
"stork",
"swift",
"wolf",
]
def generate_agent_name() -> str:
"""Return a random ``adjective-colour-animal`` name."""
adj = random.choice(ADJECTIVES)
col = random.choice(COLOURS)
ani = random.choice(ANIMALS)
return f"{adj}-{col}-{ani}"

View File

@ -1,33 +1,20 @@
"""IQ-streaming agent. """Thin IQ-streaming agent.
Listens for control messages from the RIA Hub over a persistent WebSocket. Listens for control messages from the RIA Hub over a persistent WebSocket.
Supports: When the server sends ``start``, opens the SDR described in ``radio_config``,
loops over ``sdr.rx(buffer_size)``, and sends each buffer as raw
- An **RX session** (hub sends ``start``/``stop``/``configure``; agent opens interleaved float32 bytes. ``stop`` closes the SDR; ``configure`` applies
the SDR, loops ``sdr.rx()`` and ships raw interleaved float32 IQ). parameter updates at the next capture boundary.
- A **TX session** (hub sends ``tx_start``/``tx_stop``/``tx_configure`` plus
binary IQ frames; agent feeds them into ``sdr._stream_tx``). Phase 3 wires
up the session plumbing and rejects TX when ``cfg.tx_enabled`` is False;
Phase 4 implements the full TX loop.
Both sessions can run concurrently on the same physical SDR (FDD) a
ref-counted SDR registry shares one driver instance when RX and TX name the
same ``(device, identifier)``.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging import logging
import queue
import threading
import time
from dataclasses import dataclass, field
from typing import Any from typing import Any
import numpy as np import numpy as np
from .config import AgentConfig
from .hardware import heartbeat_payload from .hardware import heartbeat_payload
from .ws_client import WsClient from .ws_client import WsClient
@ -36,98 +23,6 @@ logger = logging.getLogger("ria_agent.streamer")
_DEFAULT_BUFFER_SIZE = 1024 _DEFAULT_BUFFER_SIZE = 1024
# ---------------------------------------------------------------------------
# Session dataclasses
@dataclass
class RxSession:
app_id: str
sdr: Any
device_key: tuple[str, str | None]
buffer_size: int
task: asyncio.Task | None = None
pending_config: dict = field(default_factory=dict)
@dataclass
class TxSession:
app_id: str
sdr: Any
device_key: tuple[str, str | None]
buffer_size: int
task: Any = None # concurrent.futures.Future from run_in_executor
pending_config: dict = field(default_factory=dict)
underrun_policy: str = "pause"
last_buffer: np.ndarray | None = None
stop_event: threading.Event = field(default_factory=threading.Event)
started_at: float = 0.0
max_duration_s: float | None = None
state: str = "armed"
# Thread-safe queue of inbound interleaved-float32 IQ frames. Bounded so
# hub-side over-production triggers WS backpressure rather than memory
# growth in the agent.
in_queue: "queue.Queue[bytes]" = field(default_factory=lambda: queue.Queue(maxsize=8))
# Set by the TX callback when it hits an underrun while policy=="pause";
# asyncio side flips the session state and emits tx_status.
underrun_flag: threading.Event = field(default_factory=threading.Event)
# ---------------------------------------------------------------------------
# SDR registry (ref-counted so one Pluto handle serves RX + TX simultaneously)
class _SdrRegistry:
def __init__(self, factory):
self._factory = factory
self._instances: dict[tuple[str, str | None], tuple[Any, int]] = {}
self._lock = threading.Lock()
def acquire(self, device: str, identifier: str | None) -> tuple[Any, tuple[str, str | None]]:
key = (device, identifier)
with self._lock:
if key in self._instances:
sdr, rc = self._instances[key]
self._instances[key] = (sdr, rc + 1)
return sdr, key
# Build outside the lock: driver init can be slow and we don't want to
# block concurrent releases on unrelated devices.
sdr = self._factory(device, identifier)
with self._lock:
if key in self._instances:
# Raced another acquirer; discard our duplicate and share theirs.
other_sdr, rc = self._instances[key]
try:
sdr.close()
except Exception:
pass
self._instances[key] = (other_sdr, rc + 1)
return other_sdr, key
self._instances[key] = (sdr, 1)
return sdr, key
def release(self, key: tuple[str, str | None]) -> bool:
"""Decrement refcount. Returns True if the caller owns the last reference
and should close the SDR."""
with self._lock:
sdr, rc = self._instances.get(key, (None, 0))
if sdr is None:
return False
if rc <= 1:
del self._instances[key]
return True
self._instances[key] = (sdr, rc - 1)
return False
def refcount(self, key: tuple[str, str | None]) -> int:
with self._lock:
return self._instances.get(key, (None, 0))[1]
# ---------------------------------------------------------------------------
# Streamer
class Streamer: class Streamer:
"""Main streamer loop. """Main streamer loop.
@ -136,188 +31,103 @@ class Streamer:
ws: ws:
Connected :class:`WsClient`. Connected :class:`WsClient`.
sdr_factory: sdr_factory:
Callable ``(device, identifier) -> SDR``. Defaults to the helper in Callable ``(device, identifier) -> SDR``. Defaults to
:mod:`ria_toolkit_oss.sdr`. Injectable for tests. :func:`ria_toolkit_oss.sdr.get_sdr_device`. Injectable for tests.
cfg:
:class:`AgentConfig` for interlocks (``tx_enabled`` and caps) and
heartbeat capabilities. Defaults to an empty ``AgentConfig()`` which
leaves TX disabled.
""" """
def __init__( def __init__(self, ws: WsClient, sdr_factory=None) -> None:
self,
ws,
sdr_factory=None,
cfg: AgentConfig | None = None,
) -> None:
self.ws = ws self.ws = ws
self._cfg = cfg or AgentConfig() self._sdr_factory = sdr_factory
self._registry = _SdrRegistry(sdr_factory or _default_sdr_factory) self._app_id: str | None = None
self._rx: RxSession | None = None self._sdr: Any = None
self._tx: TxSession | None = None self._pending_config: dict = {}
# Pending radio_config accepted via ``configure`` before ``start``. self._capture_task: asyncio.Task | None = None
self._standalone_pending_config: dict = {} self._status = "idle"
# Cached asyncio event loop, set the first time a handler runs. Used
# to schedule async callbacks from the TX executor thread.
self._loop: asyncio.AbstractEventLoop | None = None
# ------------------------------------------------------------------
# Back-compat read-only shims for callers that check ``._sdr`` etc.
# Writes to these attributes are not supported — use the session objects.
@property
def _sdr(self):
return self._rx.sdr if self._rx is not None else None
@property
def _pending_config(self) -> dict:
return self._rx.pending_config if self._rx is not None else self._standalone_pending_config
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# WsClient wiring # WsClient wiring
def build_heartbeat(self) -> dict: def build_heartbeat(self) -> dict:
status = "streaming" if (self._rx is not None or self._tx is not None) else "idle" return heartbeat_payload(status=self._status, app_id=self._app_id)
app_id: str | None = None
if self._rx is not None:
app_id = self._rx.app_id
elif self._tx is not None:
app_id = self._tx.app_id
sessions: dict[str, dict] = {}
if self._rx is not None:
sessions["rx"] = {"app_id": self._rx.app_id, "state": "streaming"}
if self._tx is not None:
sessions["tx"] = {"app_id": self._tx.app_id, "state": self._tx.state}
return heartbeat_payload(
status=status,
app_id=app_id,
cfg=self._cfg,
sessions=sessions or None,
)
# Advisory / keepalive message types we accept and ignore without warning.
_IGNORED_MESSAGE_TYPES = frozenset({"tx_data_available"})
async def on_message(self, msg: dict) -> None: async def on_message(self, msg: dict) -> None:
t = msg.get("type") t = msg.get("type")
if t in self._IGNORED_MESSAGE_TYPES: if t == "start":
logger.debug("Ignoring advisory message: %r", t) await self._handle_start(msg)
return elif t == "stop":
handler = { await self._handle_stop(msg)
"start": self._handle_rx_start, elif t == "configure":
"stop": self._handle_rx_stop, self._pending_config.update(msg.get("radio_config") or {})
"configure": self._handle_rx_configure, logger.debug("Queued configure: %s", self._pending_config)
"tx_start": self._handle_tx_start, else:
"tx_stop": self._handle_tx_stop,
"tx_configure": self._handle_tx_configure,
}.get(t)
if handler is None:
logger.warning("Unknown server message type: %r", t) logger.warning("Unknown server message type: %r", t)
return
await handler(msg)
async def on_binary(self, data: bytes) -> None: # ------------------------------------------------------------------
tx = self._tx async def _handle_start(self, msg: dict) -> None:
if tx is None: if self._capture_task is not None and not self._capture_task.done():
logger.debug("Dropping %d-byte binary frame: no TX session", len(data))
return
# Backpressure: if the TX queue is full, await briefly so the hub's
# ``await ws.send`` throttles naturally via TCP. We don't block
# indefinitely — a 2s stall means something else is wrong.
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(None, lambda: tx.in_queue.put(data, timeout=2.0))
except queue.Full:
logger.warning("TX queue stalled; dropping frame")
# ==================================================================
# RX
async def _handle_rx_start(self, msg: dict) -> None:
if self._rx is not None:
logger.warning("start received while already streaming — ignoring") logger.warning("start received while already streaming — ignoring")
return return
app_id = msg.get("app_id") or "" self._app_id = msg.get("app_id")
radio_config = dict(msg.get("radio_config") or {}) radio_config = dict(msg.get("radio_config") or {})
device = radio_config.pop("device", None) device = radio_config.pop("device", None)
identifier = radio_config.pop("identifier", None) identifier = radio_config.pop("identifier", None)
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE)) buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
if not device: if not device:
await self._send_error(app_id, "start missing radio_config.device") await self._send_error("start missing radio_config.device")
return return
try: try:
sdr, device_key = self._registry.acquire(device, identifier) factory = self._sdr_factory or _default_sdr_factory
_apply_sdr_config(sdr, radio_config) self._sdr = factory(device, identifier)
_apply_sdr_config(self._sdr, radio_config)
except Exception as exc: except Exception as exc:
logger.exception("Failed to open SDR %r", device) logger.exception("Failed to open SDR %r", device)
await self._send_error(app_id, f"SDR init failed: {exc}") await self._send_error(f"SDR init failed: {exc}")
return return
# Inherit any pending config that was queued before start. self._status = "streaming"
pending = dict(self._standalone_pending_config) await self._send_status("streaming")
self._standalone_pending_config = {} self._capture_task = asyncio.create_task(
self._capture_loop(buffer_size), name="ria-streamer-capture"
session = RxSession(
app_id=app_id,
sdr=sdr,
device_key=device_key,
buffer_size=buffer_size,
pending_config=pending,
) )
self._rx = session
await self._send_status("streaming", app_id)
session.task = asyncio.create_task(self._capture_loop(session), name="ria-streamer-capture")
async def _handle_rx_stop(self, msg: dict) -> None: async def _handle_stop(self, msg: dict) -> None:
session = self._rx if self._capture_task is not None:
if session is None: self._capture_task.cancel()
return
if session.task is not None:
session.task.cancel()
try: try:
await session.task await self._capture_task
except (asyncio.CancelledError, Exception): except (asyncio.CancelledError, Exception):
pass pass
self._close_session_sdr(session) self._capture_task = None
app_id = session.app_id self._close_sdr()
self._rx = None self._app_id = None
await self._send_status("idle", app_id) self._status = "idle"
await self._send_status("idle")
async def _handle_rx_configure(self, msg: dict) -> None: async def _capture_loop(self, buffer_size: int) -> None:
cfg = dict(msg.get("radio_config") or {})
if self._rx is not None:
self._rx.pending_config.update(cfg)
else:
self._standalone_pending_config.update(cfg)
logger.debug("Queued configure: %s", cfg)
async def _capture_loop(self, session: RxSession) -> None:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
try: try:
while True: while True:
if session.pending_config: if self._pending_config:
cfg = session.pending_config cfg = self._pending_config
session.pending_config = {} self._pending_config = {}
try: try:
_apply_sdr_config(session.sdr, cfg) _apply_sdr_config(self._sdr, cfg)
except Exception as exc: except Exception as exc:
logger.warning("Applying configure failed: %s", exc) logger.warning("Applying configure failed: %s", exc)
try: try:
samples = await loop.run_in_executor(None, session.sdr.rx, session.buffer_size) samples = await loop.run_in_executor(None, self._sdr.rx, buffer_size)
except Exception as exc: except Exception as exc:
from ria_toolkit_oss.sdr import SdrDisconnectedError from ria_toolkit_oss.sdr import SdrDisconnectedError
if isinstance(exc, SdrDisconnectedError): if isinstance(exc, SdrDisconnectedError):
logger.warning("SDR disconnected: %s", exc) logger.warning("SDR disconnected: %s", exc)
await self._send_error(session.app_id, f"SDR disconnected: {exc}") await self._send_error(f"SDR disconnected: {exc}")
else: else:
logger.exception("SDR rx error") logger.exception("SDR rx error")
await self._send_error(session.app_id, f"SDR capture failed: {exc}") await self._send_error(f"SDR capture failed: {exc}")
break break
payload = _samples_to_interleaved_float32(samples) payload = _samples_to_interleaved_float32(samples)
@ -329,320 +139,29 @@ class Streamer:
except asyncio.CancelledError: except asyncio.CancelledError:
raise raise
finally: finally:
self._close_session_sdr(session) self._close_sdr()
# If the loop died on its own (e.g. SDR disconnect), clear the
# session handle so future ``start`` messages can proceed.
if self._rx is session:
self._rx = None
# ================================================================== def _close_sdr(self) -> None:
# TX if self._sdr is None:
async def _handle_tx_start(self, msg: dict) -> None: # noqa: C901
app_id = msg.get("app_id") or ""
radio_config = dict(msg.get("radio_config") or {})
# --- interlocks (agent-enforced; never trust the hub alone) ---
if not self._cfg.tx_enabled:
await self._send_tx_status(app_id, "error", "tx disabled on this agent")
return return
tx_gain = radio_config.get("tx_gain")
if (
self._cfg.tx_max_gain_db is not None
and tx_gain is not None
and float(tx_gain) > float(self._cfg.tx_max_gain_db)
):
await self._send_tx_status(
app_id,
"error",
f"tx_gain {tx_gain} exceeds cap {self._cfg.tx_max_gain_db}",
)
return
tx_freq = radio_config.get("tx_center_frequency")
if self._cfg.tx_allowed_freq_ranges and tx_freq is not None:
f = float(tx_freq)
if not any(float(lo) <= f <= float(hi) for lo, hi in self._cfg.tx_allowed_freq_ranges):
await self._send_tx_status(
app_id,
"error",
f"tx_center_frequency {tx_freq} outside allowed ranges",
)
return
if self._tx is not None:
await self._send_tx_status(app_id, "error", "tx already active on this agent")
return
# --- device ---
device = radio_config.pop("device", None)
identifier = radio_config.pop("identifier", None)
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
underrun_policy = str(radio_config.pop("underrun_policy", "pause"))
if underrun_policy not in ("pause", "zero", "repeat"):
await self._send_tx_status(app_id, "error", f"invalid underrun_policy {underrun_policy!r}")
return
if not device:
await self._send_tx_status(app_id, "error", "tx_start missing radio_config.device")
return
device_key: tuple[str, str | None] | None = None
sdr: Any = None
try: try:
sdr, device_key = self._registry.acquire(device, identifier) self._sdr.close()
_apply_sdr_config(sdr, radio_config)
# init_tx is mandatory for any driver that exposes it: drivers
# that gate _stream_tx on _tx_initialized (Pluto, HackRF, USRP,
# …) crash with a confusing "TX was not initialized" error 2 s
# later in the executor thread if we skip it. Treat the three
# required keys as a hard contract — a missing one is a hub-side
# manifest bug and we want it surfaced immediately, not papered
# over with stale radio state.
if hasattr(sdr, "init_tx"):
init_args = {k: radio_config.get(f"tx_{k}") for k in ("sample_rate", "center_frequency", "gain")}
missing = [f"tx_{k}" for k, v in init_args.items() if v is None]
if missing:
raise ValueError(f"tx_start missing required radio_config keys: {missing}")
sdr.init_tx(
sample_rate=init_args["sample_rate"],
center_frequency=init_args["center_frequency"],
gain=init_args["gain"],
channel=radio_config.get("tx_channel", 0),
gain_mode=radio_config.get("tx_gain_mode", "manual"),
)
except Exception as exc:
if device_key is not None:
if self._registry.release(device_key):
try:
sdr.close()
except Exception: except Exception:
pass pass
logger.exception("Failed to init TX on %r", device) self._sdr = None
await self._send_tx_status(app_id, "error", f"tx init failed: {exc}")
return
self._loop = asyncio.get_running_loop() async def _send_status(self, status: str) -> None:
session = TxSession(
app_id=app_id,
sdr=sdr,
device_key=device_key,
buffer_size=buffer_size,
underrun_policy=underrun_policy,
started_at=time.monotonic(),
max_duration_s=self._cfg.tx_max_duration_s,
)
self._tx = session
await self._send_tx_status(app_id, "armed")
session.task = self._loop.run_in_executor(None, self._tx_executor_body, session)
# Spawn a small watchdog that transitions armed → transmitting when
# the first buffer has been consumed, and surfaces underrun / max-
# duration terminations back to the hub.
asyncio.create_task(self._tx_watchdog(session))
async def _handle_tx_stop(self, msg: dict) -> None:
session = self._tx
if session is None:
return
app_id = session.app_id
session.stop_event.set()
try: try:
session.sdr.pause_tx() await self.ws.send_json({"type": "status", "status": status, "app_id": self._app_id})
except Exception:
logger.debug("pause_tx raised during stop", exc_info=True)
# Wake the executor thread if it's blocked on ``queue.get``.
self._drain_tx_queue(session)
if session.task is not None:
try:
await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.5)
except asyncio.TimeoutError:
logger.warning("TX executor did not exit within 1.5s after stop")
except Exception:
logger.debug("TX executor raised on shutdown", exc_info=True)
self._close_session_sdr(session)
self._tx = None
await self._send_tx_status(app_id, "done")
async def _handle_tx_configure(self, msg: dict) -> None:
if self._tx is None:
return
self._tx.pending_config.update(msg.get("radio_config") or {})
# ------------------------------------------------------------------
# TX executor & watchdog
def _tx_executor_body(self, session: TxSession) -> None:
try:
session.sdr._stream_tx(lambda n: self._tx_callback(session, n))
except Exception as exc:
logger.exception("TX stream crashed")
# Schedule both the error frame and session teardown on the loop
# so ``self._tx`` clears, subsequent binary frames are rejected,
# and the SDR handle is released.
self._schedule(self._tx_crash_teardown(session, str(exc)))
def _tx_callback(self, session: TxSession, num_samples) -> np.ndarray:
n = int(num_samples)
# Honor stop requests: return silence one last time and let the driver
# exit its loop on the next iteration (pause_tx flips _enable_tx).
if session.stop_event.is_set():
return _silence(n)
# Max-duration watchdog.
if session.max_duration_s is not None and (time.monotonic() - session.started_at) >= float(
session.max_duration_s
):
session.stop_event.set()
try:
session.sdr.pause_tx()
except Exception:
pass
self._schedule(self._send_tx_status(session.app_id, "done", "max duration reached"))
return _silence(n)
# Apply queued configure at buffer boundary.
if session.pending_config:
cfg = session.pending_config
session.pending_config = {}
try:
_apply_sdr_config(session.sdr, cfg)
except Exception as exc:
logger.debug("tx_configure apply failed: %s", exc)
try:
raw = session.in_queue.get(timeout=0.1)
except queue.Empty:
return self._underrun_fill(session, n)
arr = np.frombuffer(raw, dtype=np.float32)
if arr.size < 2 or arr.size % 2 != 0:
logger.warning("Malformed TX frame: %d floats (must be non-zero even count)", arr.size)
return self._underrun_fill(session, n)
samples = arr[0::2].astype(np.complex64) + 1j * arr[1::2].astype(np.complex64)
if samples.size < n:
out = np.zeros(n, dtype=np.complex64)
out[: samples.size] = samples
session.last_buffer = out
return out
if samples.size > n:
samples = samples[:n]
session.last_buffer = samples
if session.state == "armed":
session.state = "transmitting"
self._schedule(self._send_tx_status(session.app_id, "transmitting"))
return samples
def _underrun_fill(self, session: TxSession, n: int) -> np.ndarray:
policy = session.underrun_policy
if policy == "zero":
return _silence(n)
if policy == "repeat" and session.last_buffer is not None:
buf = session.last_buffer
if buf.size == n:
return buf
if buf.size > n:
return buf[:n].copy()
out = np.zeros(n, dtype=np.complex64)
out[: buf.size] = buf
return out
# "pause" policy (default) or "repeat" before any buffer arrived.
if not session.underrun_flag.is_set():
session.underrun_flag.set()
session.stop_event.set()
try:
session.sdr.pause_tx()
except Exception:
pass
return _silence(n)
async def _tx_watchdog(self, session: TxSession) -> None:
# Poll the underrun flag so we can emit status + tear down cleanly
# when the callback flips the flag from the executor thread. Check
# underrun_flag before stop_event, since the "pause" path sets both.
while session is self._tx:
if session.underrun_flag.is_set():
await self._send_tx_status(session.app_id, "underrun")
await self._teardown_tx_after_underrun(session)
return
if session.stop_event.is_set():
return
await asyncio.sleep(0.05)
async def _tx_crash_teardown(self, session: TxSession, message: str) -> None:
# Called from the executor thread via _schedule when _stream_tx raises.
# Emit the error, mark stopped, drain the queue, release the SDR.
await self._send_tx_status(session.app_id, "error", f"tx stream crashed: {message}")
if self._tx is not session:
return
session.stop_event.set()
self._drain_tx_queue(session)
self._close_session_sdr(session)
if self._tx is session:
self._tx = None
async def _teardown_tx_after_underrun(self, session: TxSession) -> None:
if self._tx is not session:
return
self._drain_tx_queue(session)
if session.task is not None:
try:
await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.0)
except asyncio.TimeoutError:
logger.warning("TX executor did not exit within 1s after underrun")
except Exception:
logger.debug("TX executor raised during underrun teardown", exc_info=True)
self._close_session_sdr(session)
if self._tx is session:
self._tx = None
def _drain_tx_queue(self, session: TxSession) -> None:
try:
while True:
session.in_queue.get_nowait()
except queue.Empty:
pass
def _schedule(self, coro) -> None:
loop = self._loop
if loop is None:
return
try:
asyncio.run_coroutine_threadsafe(coro, loop)
except Exception:
logger.debug("_schedule failed", exc_info=True)
# ==================================================================
# Helpers
def _close_session_sdr(self, session) -> None:
if session.sdr is None:
return
should_close = self._registry.release(session.device_key)
if should_close:
try:
session.sdr.close()
except Exception:
logger.debug("SDR close raised", exc_info=True)
async def _send_status(self, status: str, app_id: str) -> None:
try:
await self.ws.send_json({"type": "status", "status": status, "app_id": app_id})
except Exception as exc: except Exception as exc:
logger.debug("Status send failed: %s", exc) logger.debug("Status send failed: %s", exc)
async def _send_error(self, app_id: str, message: str) -> None: async def _send_error(self, message: str) -> None:
try: try:
await self.ws.send_json({"type": "error", "app_id": app_id, "message": message}) await self.ws.send_json({"type": "error", "app_id": self._app_id, "message": message})
except Exception as exc: except Exception as exc:
logger.debug("Error-frame send failed: %s", exc) logger.debug("Error-frame send failed: %s", exc)
async def _send_tx_status(self, app_id: str, state: str, message: str | None = None) -> None:
payload: dict = {"type": "tx_status", "app_id": app_id, "state": state}
if message is not None:
payload["message"] = message
try:
await self.ws.send_json(payload)
except Exception as exc:
logger.debug("tx_status send failed: %s", exc)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
@ -653,51 +172,16 @@ _CONFIG_ATTR_MAP = {
"center_freq": ("center_freq", "rx_center_frequency"), "center_freq": ("center_freq", "rx_center_frequency"),
"gain": ("gain", "rx_gain"), "gain": ("gain", "rx_gain"),
"bandwidth": ("bandwidth", "rx_bandwidth"), "bandwidth": ("bandwidth", "rx_bandwidth"),
"tx_sample_rate": ("tx_sample_rate",),
"tx_center_frequency": ("tx_center_frequency", "tx_lo"),
"tx_gain": ("tx_gain",),
"tx_bandwidth": ("tx_bandwidth",),
} }
def _is_stub_setter(method: Any) -> bool:
"""True when *method* is an unimplemented base-class stub.
The ``SDR`` abstract base defines ``set_rx_sample_rate`` / ``set_tx_gain``
etc. as zero-argument ``NotImplementedError`` stubs. A driver (Pluto) that
actually transmits overrides them with a real ``(value, ...)`` signature.
Comparing ``__qualname__`` against ``SDR.`` lets us skip the stubs cheaply.
"""
return getattr(method, "__qualname__", "").startswith("SDR.")
def _apply_sdr_config(sdr: Any, cfg: dict) -> None: def _apply_sdr_config(sdr: Any, cfg: dict) -> None:
"""Apply a radio_config dict to an SDR. """Apply a radio_config dict to an SDR, trying multiple attribute aliases."""
Prefers ``sdr.set_<attr>(value)`` when the driver implements it Pluto's
setters take ``_param_lock``, so routing through them keeps concurrent
RX + TX reconfigures from racing on shared native attributes. Falls back
to ``setattr`` for drivers (MockSDR, tests) that don't override the
base-class stubs.
"""
for key, value in cfg.items(): for key, value in cfg.items():
if value is None: if value is None:
continue continue
attrs = _CONFIG_ATTR_MAP.get(key, (key,)) attrs = _CONFIG_ATTR_MAP.get(key, (key,))
applied = False applied = False
for attr in attrs:
setter = getattr(sdr, f"set_{attr}", None)
if callable(setter) and not _is_stub_setter(setter):
try:
setter(value)
applied = True
break
except Exception as exc:
logger.debug("set_%s(%r) failed: %s", attr, value, exc)
# Fall through to setattr; some drivers may partially
# implement setters.
if applied:
continue
for attr in attrs: for attr in attrs:
if hasattr(sdr, attr): if hasattr(sdr, attr):
try: try:
@ -710,11 +194,6 @@ def _apply_sdr_config(sdr: Any, cfg: dict) -> None:
logger.debug("radio_config key %r ignored (no matching attr)", key) logger.debug("radio_config key %r ignored (no matching attr)", key)
def _silence(num_samples: int) -> np.ndarray:
"""Return a ``num_samples``-length zero-filled complex64 buffer."""
return np.zeros(int(num_samples), dtype=np.complex64)
def _samples_to_interleaved_float32(samples: Any) -> bytes: def _samples_to_interleaved_float32(samples: Any) -> bytes:
"""Convert complex IQ samples (any numeric dtype) to interleaved float32 bytes.""" """Convert complex IQ samples (any numeric dtype) to interleaved float32 bytes."""
arr = np.asarray(samples) arr = np.asarray(samples)
@ -735,13 +214,8 @@ def _default_sdr_factory(device: str, identifier: str | None):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Top-level entry # Top-level entry
async def run_streamer(ws_url: str, token: str) -> None:
async def run_streamer(ws_url: str, token: str, *, cfg: AgentConfig | None = None) -> None:
"""Connect to *ws_url* and run the streamer loop until cancelled.""" """Connect to *ws_url* and run the streamer loop until cancelled."""
ws = WsClient(ws_url, token) ws = WsClient(ws_url, token)
streamer = Streamer(ws, cfg=cfg) streamer = Streamer(ws)
await ws.run( await ws.run(streamer.on_message, streamer.build_heartbeat)
streamer.on_message,
streamer.build_heartbeat,
on_binary=streamer.on_binary,
)

View File

@ -15,7 +15,6 @@ logger = logging.getLogger("ria_agent.ws")
MessageHandler = Callable[[dict], Awaitable[None]] MessageHandler = Callable[[dict], Awaitable[None]]
HeartbeatBuilder = Callable[[], dict] HeartbeatBuilder = Callable[[], dict]
BinaryHandler = Callable[[bytes], Awaitable[None]]
class WsClient: class WsClient:
@ -66,12 +65,7 @@ class WsClient:
self._stop.set() self._stop.set()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
async def run( async def run(self, on_message: MessageHandler, heartbeat: HeartbeatBuilder) -> None:
self,
on_message: MessageHandler,
heartbeat: HeartbeatBuilder,
on_binary: BinaryHandler | None = None,
) -> None:
"""Main loop: connect, heartbeat, dispatch messages, reconnect on drop.""" """Main loop: connect, heartbeat, dispatch messages, reconnect on drop."""
while not self._stop.is_set(): while not self._stop.is_set():
try: try:
@ -81,14 +75,9 @@ class WsClient:
try: try:
async for raw in self._ws: async for raw in self._ws:
if isinstance(raw, bytes): if isinstance(raw, bytes):
if on_binary is None: # Server shouldn't send binary to the agent; log and drop.
logger.debug("Discarding unexpected %d-byte binary frame", len(raw)) logger.debug("Discarding unexpected %d-byte binary frame", len(raw))
continue continue
try:
await on_binary(raw)
except Exception:
logger.exception("on_binary handler raised; dropping frame")
continue
try: try:
msg = json.loads(raw) msg = json.loads(raw)
except json.JSONDecodeError: except json.JSONDecodeError:

View File

@ -1,54 +0,0 @@
"""
The annotations package contains tools and utilities for creating, managing, and processing annotations.
Provides automatic annotation generation using various signal detection algorithms:
- Energy-based detection (detect_signals_energy)
- CUSUM-based segmentation (annotate_with_cusum)
- Threshold-based qualification (threshold_qualifier)
- Signal isolation and extraction (isolate_signal)
- Occupied bandwidth analysis (calculate_occupied_bandwidth, calculate_nominal_bandwidth)
All detection functions return Recording objects with added annotations.
"""
__all__ = [
# Energy-based detection
"detect_signals_energy",
"calculate_occupied_bandwidth",
"calculate_nominal_bandwidth",
"calculate_full_detected_bandwidth",
"annotate_with_obw",
# CUSUM detection
"annotate_with_cusum",
# Threshold detection
"threshold_qualifier",
# Parallel signal separation (Phase 2)
"find_spectral_components",
"split_annotation_by_components",
"split_recording_annotations",
# Signal isolation
"isolate_signal",
# Annotation transforms
"remove_contained_boxes",
"is_annotation_contained",
# Dataset creation
"qualify_slice_from_annotations",
]
from .annotation_transforms import is_annotation_contained, remove_contained_boxes
from .cusum_annotator import annotate_with_cusum
from .energy_detector import (
annotate_with_obw,
calculate_full_detected_bandwidth,
calculate_nominal_bandwidth,
calculate_occupied_bandwidth,
detect_signals_energy,
)
from .parallel_signal_separator import (
find_spectral_components,
split_annotation_by_components,
split_recording_annotations,
)
from .qualify_slice import qualify_slice_from_annotations
from .signal_isolation import isolate_signal
from .threshold_qualifier import threshold_qualifier

View File

@ -1,55 +0,0 @@
from ria_toolkit_oss.data.annotation import Annotation
# TODO figure out how to transfer labels in the merge case
def remove_contained_boxes(annotations: list[Annotation]):
"""
Remove all annotations (bounding boxes) that are entirely contained within other boxes in the list.
:param annotations: A list of Annotation objects.
:type annotations: list[Annotation]
:returns: A new list of Annotation objects.
:rtype: list[Annotation]"""
output_boxes = []
for i in range(len(annotations)):
contained = False
for j in range(len(annotations)):
if i != j and is_annotation_contained(annotations[i], annotations[j]):
contained = True
break
if not contained:
output_boxes.append(annotations[i])
return output_boxes
def is_annotation_contained(inner: Annotation, outer: Annotation) -> bool:
"""
Check if an annotation box is entirely contained within another annotation bounding box.
:param inner: The inner box.
:type inner: Annotation.
:param outer: The outer box.
:type outer: Annotation.
:returns: True if inner is within outer, false otherwise.
:rtype: bool
"""
inner_sample_stop = inner.sample_start + inner.sample_count
outer_sample_stop = outer.sample_start + outer.sample_count
if inner.sample_start > outer.sample_start and inner_sample_stop < outer_sample_stop:
if inner.freq_lower_edge > outer.freq_lower_edge and inner.freq_upper_edge < outer.freq_upper_edge:
return True
return False
def merge_annotations(annotations: list[Annotation], overlap_threshold) -> list[Annotation]:
raise NotImplementedError

View File

@ -1,203 +0,0 @@
import json
from typing import Optional
import numpy as np
from ria_toolkit_oss.data import Annotation, Recording
def annotate_with_cusum(
recording: Recording,
label: Optional[str] = "segment",
window_size: Optional[int] = 1,
min_duration: Optional[float] = None,
tolerance: Optional[int] = None,
annotation_type: Optional[str] = "standalone",
):
"""
Add annotations that divide the recording into distinct time segments.
This algorithm computes the cumulative sum of the sample magnitudes and
determines break points in the signal.
This tool can be used to find points where a signal turns on or off, or
changes between a low and high amplitude.
:param recording: A ``Recording`` object to annotate.
:type recording: ``ria_toolkit_oss.data.Recording``
:param label: Label for the detected segments.
:type label: str
:param window_size: The length (in samples) of the moving average window.
:type window_size: int
:param min_duration: The minimum duration (in ms) of a segment.
The algorithm will not produce annotations shorter than this length.
:type min_duration: float
:param tolerance: The minimum length (in samples) of a segment.
:type tolerance: int
:param annotation_type: Annotation type (standalone, parallel, intersection).
:type annotation_type: str
"""
sample_rate = recording.metadata["sample_rate"]
center_frequency = recording.metadata.get("center_frequency", 0)
# Create an object of the time segmenter
time_segmenter = TimeSegmenter(sample_rate, min_duration, window_size, tolerance)
change_points = time_segmenter.apply(recording.data[0])
time_segments_indices = np.append(np.insert(change_points, 0, 0), len(recording.data[0]))
annotations = []
for i in range(len(time_segments_indices) - 1):
# Build comment JSON with type metadata
comment_data = {
"type": annotation_type,
"generator": "cusum_annotator",
"params": {
"window_size": window_size,
"min_duration": min_duration,
"tolerance": tolerance,
},
}
f_min, f_max = detect_frequency(
signal=recording.data[0],
start=time_segments_indices[i],
stop=time_segments_indices[i + 1],
sample_rate=sample_rate,
)
annotations.append(
Annotation(
sample_start=time_segments_indices[i],
sample_count=time_segments_indices[i + 1] - time_segments_indices[i],
freq_lower_edge=center_frequency + f_min,
freq_upper_edge=center_frequency + f_max,
label=label,
comment=json.dumps(comment_data),
detail={"generator": "cusum_annotator"},
)
)
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations)
def _compute_cusum(_signal, sample_rate: int, tolerance: int = None, min_duration: float = -1):
"""
This function efficiently computes the cumulative sum of a give list (_signal), with an optional tolerance.
Args:
- _signal: array of iq samples.
- Tolerance: the least acceptable length of a block, Defaults to None.
Returns:
- cusum (array): Array of the cumulative sum of the given list
- sample_rate (int): __description_
- change_points (array): Array of the indices at which a change in the CUSUM direction happens.
- min_duration (float): The least acceptable time width of each segment (in ms). Defaults to -1.
"""
# efficiently calculate the running sum of the signal
# cusum = list(itertools.accumulate((_signal - np.mean(_signal))))
x = _signal - np.mean(_signal)
cusum = np.cumsum(x)
# 'diff' computes the differences between the consecutive values,
# then 'sign' determines if it is +ve or -ve.
change_indicators = np.sign(np.diff(cusum))
change_points = np.where(np.diff(change_indicators))[0] + 1
# Limit the change_points
# Reject those whose number of samples < minimum accepted #n of samples in (min duration) ms.
if min_duration is not None and min_duration > 0:
min_samples_wide = int(min_duration * sample_rate / 1000)
segments_lengths = np.diff(change_points)
segments_lengths = np.insert(segments_lengths, 0, change_points[0])
change_points = change_points[np.where(segments_lengths > min_samples_wide)[0]]
return cusum, change_points
def detect_frequency(signal, start, stop, sample_rate):
signal_segment = signal[start:stop]
if len(signal_segment) > 0:
fft_data = np.abs(np.fft.fftshift(np.fft.fft(signal_segment)))
fft_freqs = np.fft.fftshift(np.fft.fftfreq(len(signal_segment), 1 / sample_rate))
# Use a spectral threshold to find the 'height' of the orange block
spectral_thresh = np.max(fft_data) * 0.15
sig_indices = np.where(fft_data > spectral_thresh)[0]
if len(sig_indices) > 4:
return fft_freqs[sig_indices[0]], fft_freqs[sig_indices[-1]]
else:
return -sample_rate / 4, sample_rate / 4
else:
return -sample_rate / 4, sample_rate / 4
class TimeSegmenter:
"""Time Segmenter class, it creates a segmenter object with certain\
characteristics to easily split an input signal to segments based on\
the cumulative sum of deviations (of the signal mean)
"""
def __init__(
self, sample_rate: int, min_duration: float = 1, moving_average_window: int = 3, tolerance: int = None
):
"""_summary_
Args:
sample_rate (int): _description_
min_duration (float, optional): _description_. Defaults to 1.
moving_average_window (int, optional): _description_. Defaults to 3.
tolerance (int, optional): _description_. Defaults to None.
"""
self.sample_rate = sample_rate
self.min_duration = min_duration
self.moving_average_window = moving_average_window
self._moving_avg_filter = self._init_filter()
self.tolerance = tolerance
def _init_filter(self):
"""_summary_
Returns:
_type_: _description_
"""
return np.ones(self.moving_average_window) / self.moving_average_window
def _apply_filter(self, iqsignal: np.array):
"""_summary_
Args:
iqsignal (np.array): _description_
Returns:
_type_: _description_
"""
return np.convolve(abs(iqsignal), self._moving_avg_filter, mode="same")
def _create_segments(self, iq_signal: np.array, change_points: np.array):
"""_summary_
Args:
iq_signal (np.array): _description_
change_points (np.array): _description_
Returns:
_type_: _description_
"""
return np.split(iq_signal, change_points)
def apply(self, iq_signal: np.array):
"""_summary_
Args:
iq_signal (np.array): _description_
Returns:
_type_: _description_
"""
smoothed_signal = self._apply_filter(iq_signal)
_, change_points = _compute_cusum(smoothed_signal, self.sample_rate, self.tolerance, self.min_duration)
# segments = self._create_segments(iq_signal, change_points)
return change_points

View File

@ -1,438 +0,0 @@
"""
Energy-based signal detection and bandwidth analysis.
Provides automatic annotation generation using energy-based signal detection
and occupied bandwidth calculation following ITU-R SM.328 standard.
"""
import json
from typing import Tuple
import numpy as np
from scipy.signal import filtfilt
from ria_toolkit_oss.data import Annotation, Recording
def detect_signals_energy(
recording: Recording,
k: int = 10,
threshold_factor: float = 1.2,
window_size: int = 200,
min_distance: int = 5000,
label: str = "signal",
annotation_type: str = "standalone",
freq_method: str = "nbw",
nfft: int = None,
obw_power: float = 0.99,
) -> Recording:
"""
Detect signal bursts using energy-based method with adaptive noise floor estimation.
This algorithm smooths the signal with a moving average filter, estimates the noise
floor from k segments, applies a threshold to detect regions above noise, and merges
nearby detections. Detected time boundaries are then assigned frequency bounds based
on the selected frequency method.
Time Detection Algorithm:
1. Smooth signal using moving average (envelope detection)
2. Divide smoothed signal into k segments
3. Estimate noise floor as median of segment mean powers
4. Detect regions where power exceeds threshold_factor * noise_floor
5. Merge regions closer than min_distance samples
Frequency Bounding (freq_method):
- 'nbw': Nominal bandwidth (OBW + center frequency) - DEFAULT
- 'obw': Occupied bandwidth (99.99% power, includes siedelobes)
- 'full-detected': Lowest to highest spectral component
- 'full-bandwidth': Entire Nyquist span (center_freq ± sample_rate/2)
:param recording: Recording to analyze
:type recording: Recording
:param k: Number of segments for noise floor estimation (default: 10)
:type k: int
:param threshold_factor: Threshold multiplier above noise floor (typical: 1.2-2.0, default: 1.2)
:type threshold_factor: float
:param window_size: Moving average window size in samples (default: 200)
:type window_size: int
:param min_distance: Minimum distance between separate signals in samples (default: 5000)
:type min_distance: int
:param label: Label for detected annotations (default: "signal")
:type label: str
:param annotation_type: Annotation type (standalone, parallel, intersection, default: standalone)
:type annotation_type: str
:param freq_method: How to calculate frequency bounds (default: 'nbw')
:type freq_method: str
:param nfft: FFT size for frequency calculations (default: None)
:type nfft: int
:param obw_power: Power percentage for OBW (0.9999 = 99.99%, default: 0.99)
:type obw_power: float
:returns: New Recording with added annotations
:rtype: Recording
**Example**::
>>> from ria.io import load_recording
>>> from ria_toolkit_oss.annotations import detect_signals_energy
>>> recording = load_recording("capture.sigmf")
>>> # Detect with NBW frequency bounds (default, best for real signals)
>>> annotated = detect_signals_energy(recording, label="burst")
>>> # Detect with OBW (more conservative, includes siedelobes)
>>> annotated = detect_signals_energy(
... recording, label="burst", freq_method="obw"
... )
>>> # Detect with full detected range (captures all spectral components)
>>> annotated = detect_signals_energy(
... recording, label="burst", freq_method="full-detected"
... )
"""
# Extract signal data (use first channel only)
signal = recording.data[0]
# Calculate smoothed signal power
kernel = np.ones(window_size) / window_size
smoothed_power = filtfilt(kernel, [1], np.abs(signal) ** 2)
# Estimate noise floor using segment-based median (robust to signal presence)
segments = np.array_split(smoothed_power, k)
noise_floor = np.median([np.mean(s) for s in segments])
# Detect signal boundaries (regions above threshold)
enter = noise_floor * threshold_factor
exit = enter * 0.8
boundaries = []
start = None
active = False
for i, p in enumerate(smoothed_power):
if not active and p > enter:
start = i
active = True
elif active and p < exit:
boundaries.append((start, i - start))
active = False
if active:
boundaries.append((start, len(smoothed_power) - start))
# Merge boundaries that are closer than min_distance
merged_boundaries = []
if boundaries:
start, length = boundaries[0]
for next_start, next_length in boundaries[1:]:
if next_start - (start + length) < min_distance:
# Merge with current boundary
length = next_start + next_length - start
else:
# Save current and start new boundary
merged_boundaries.append((start, length))
start, length = next_start, next_length
# Add final boundary
merged_boundaries.append((start, length))
# Create annotations from detected boundaries
sample_rate = recording.metadata["sample_rate"]
center_frequency = recording.metadata.get("center_frequency", 0)
# Validate frequency method
valid_freq_methods = ["nbw", "obw", "full-detected", "full-bandwidth"]
if freq_method not in valid_freq_methods:
raise ValueError(f"Invalid freq_method '{freq_method}'. " f"Must be one of: {', '.join(valid_freq_methods)}")
annotations = []
for start_sample, sample_count in merged_boundaries:
# Calculate frequency bounds based on method
freq_lower, freq_upper = calculate_frequency_bounds(
freq_method, center_frequency, sample_rate, nfft, signal, start_sample, sample_count, obw_power
)
# Build comment JSON with type metadata
comment_data = {
"type": annotation_type,
"generator": "energy_detector",
"freq_method": freq_method,
"params": {
"threshold_factor": threshold_factor,
"window_size": window_size,
"noise_floor": float(noise_floor),
"threshold": float(enter),
},
}
anno = Annotation(
sample_start=start_sample,
sample_count=sample_count,
freq_lower_edge=freq_lower,
freq_upper_edge=freq_upper,
label=label,
comment=json.dumps(comment_data),
detail={"generator": "energy_detector", "freq_method": freq_method},
)
annotations.append(anno)
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations)
def calculate_occupied_bandwidth(
signal: np.ndarray,
sampling_rate: float,
nfft: int = None,
power_percentage: float = 0.99,
):
if nfft is None:
nfft = max(65536, 2 ** int(np.floor(np.log2(len(signal)))))
window = np.blackman(len(signal))
spec = np.fft.fftshift(np.fft.fft(signal * window, n=nfft))
psd = np.abs(spec) ** 2
psd = psd / psd.sum() # normalize
freqs = np.fft.fftshift(np.fft.fftfreq(nfft, 1 / sampling_rate))
cdf = np.cumsum(psd)
tail = (1 - power_percentage) / 2
lower_idx = np.searchsorted(cdf, tail)
upper_idx = np.searchsorted(cdf, 1 - tail)
return freqs[upper_idx] - freqs[lower_idx], freqs[lower_idx], freqs[upper_idx]
def calculate_nominal_bandwidth(
signal: np.ndarray,
sampling_rate: float,
nfft: int = None,
power_percentage: float = 0.99,
) -> Tuple[float, float]:
"""
Calculate nominal bandwidth and center frequency.
Nominal bandwidth (NBW) is the occupied bandwidth along with the center
frequency of the signal's spectral occupancy. Useful for characterizing
signals with unknown or drifting center frequencies.
:param signal: Complex IQ signal samples
:type signal: np.ndarray
:param sampling_rate: Sample rate in Hz
:type sampling_rate: float
:param nfft: FFT size
:type nfft: int
:param power_percentage: Fraction of power to contain
:type power_percentage: float
:returns: Tuple of (nominal_bandwidth_hz, center_frequency_hz)
:rtype: Tuple[float, float]
**Example**::
>>> from ria_toolkit_oss.annotations import calculate_nominal_bandwidth
>>> nbw, center = calculate_nominal_bandwidth(signal, sampling_rate=10e6)
>>> print(f"NBW: {nbw/1e6:.3f} MHz, Center: {center/1e6:.3f} MHz")
"""
bw, lower_freq, upper_freq = calculate_occupied_bandwidth(signal, sampling_rate, nfft, power_percentage)
# Center frequency is midpoint of occupied band
center_freq = (lower_freq + upper_freq) / 2
return lower_freq, upper_freq, center_freq
def calculate_full_detected_bandwidth(
signal: np.ndarray,
sampling_rate: float,
nfft: int = None,
start_offset: int = 1000,
) -> Tuple[float, float, float]:
"""
Calculate frequency range from lowest to highest spectral component.
Unlike OBW/NBW which define a power-based bandwidth, this calculates
the absolute frequency span from the lowest non-zero spectral component
to the highest non-zero component.
Useful for:
- Signals with spectral gaps
- Multiple parallel signals (captures all of them)
- Understanding total occupied spectrum vs. actual bandwidth
:param signal: Complex IQ signal samples
:type signal: np.ndarray
:param sampling_rate: Sample rate in Hz
:type sampling_rate: float
:param nfft: FFT size
:type nfft: int
:param start_offset: Skip samples at start
:type start_offset: int
:returns: Tuple of (bandwidth_hz, lower_freq_hz, upper_freq_hz)
:rtype: Tuple[float, float, float]
**Example**::
>>> # Signal with two components at different frequencies
>>> bw, f_low, f_high = calculate_full_detected_bandwidth(
... signal, sampling_rate=10e6, nfft=65536
... )
>>> print(f"Full span: {f_low/1e6:.3f} to {f_high/1e6:.3f} MHz")
"""
# Validate input
if len(signal) < nfft + start_offset:
raise ValueError(
f"Signal too short: need {nfft + start_offset} samples, "
f"got {len(signal)}. Reduce nfft or start_offset."
)
# Extract segment
signal_segment = signal[start_offset : nfft + start_offset]
# Compute FFT and power spectral density
freq_spectrum = np.fft.fft(signal_segment, n=nfft)
psd = np.abs(freq_spectrum) ** 2
# Shift to center DC
psd_shifted = np.fft.fftshift(psd)
freq_bins = np.fft.fftshift(np.fft.fftfreq(nfft, 1 / sampling_rate))
# Find noise floor (mean of lowest 10% of bins) and all bins above noise floor
noise_floor = np.mean(np.sort(psd_shifted)[: int(len(psd_shifted) * 0.1)])
above_noise = np.where(psd_shifted > noise_floor * 1.5)[0]
if len(above_noise) == 0:
# No signal above noise, return zero bandwidth
return 0.0, 0.0, 0.0
# Get frequency range of signal components
lower_idx = above_noise[0]
upper_idx = above_noise[-1]
lower_freq = freq_bins[lower_idx]
upper_freq = freq_bins[upper_idx]
bandwidth = upper_freq - lower_freq
return bandwidth, lower_freq, upper_freq
def annotate_with_obw(
recording: Recording,
label: str = "signal",
annotation_type: str = "standalone",
nfft: int = None,
power_percentage: float = 0.99,
) -> Recording:
"""
Create a single annotation spanning the occupied bandwidth of the entire recording.
Analyzes the full recording to find its occupied bandwidth and creates an annotation
covering that frequency range for the entire time duration.
:param recording: Recording to analyze
:type recording: Recording
:param label: Annotation label
:type label: str
:param annotation_type: Annotation type
:type annotation_type: str
:param nfft: FFT size
:type nfft: int
:param power_percentage: Power percentage for OBW calculation
:type power_percentage: float
:returns: Recording with OBW annotation added
:rtype: Recording
**Example**::
>>> from ria_toolkit_oss.annotations import annotate_with_obw
>>> annotated = annotate_with_obw(recording, label="signal_obw")
"""
signal = recording.data[0]
sample_rate = recording.metadata["sample_rate"]
center_freq = recording.metadata.get("center_frequency", 0)
# Calculate OBW
obw, lower_offset, upper_offset = calculate_occupied_bandwidth(signal, sample_rate, nfft, power_percentage)
# Convert baseband offsets to absolute frequencies
freq_lower = center_freq + lower_offset
freq_upper = center_freq + upper_offset
# Create comment JSON
comment_data = {
"type": annotation_type,
"generator": "obw_annotator",
"obw_hz": float(obw),
"power_percentage": power_percentage,
"params": {"nfft": nfft},
}
# Create annotation spanning entire recording
anno = Annotation(
sample_start=0,
sample_count=len(signal),
freq_lower_edge=freq_lower,
freq_upper_edge=freq_upper,
label=label,
comment=json.dumps(comment_data),
detail={"generator": "obw_annotator", "obw_hz": float(obw)},
)
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + [anno])
def calculate_frequency_bounds(
freq_method, center_frequency, sample_rate, nfft, signal, start_sample, sample_count, obw_power
):
if freq_method == "full-bandwidth":
# Full Nyquist span
freq_lower = center_frequency - (sample_rate / 2)
freq_upper = center_frequency + (sample_rate / 2)
else:
# Extract segment for frequency analysis
segment_start = start_sample
segment_end = min(start_sample + sample_count, len(signal))
segment = signal[segment_start:segment_end]
if nfft is None or len(segment) >= nfft:
if freq_method == "nbw":
# Nominal bandwidth (OBW + center frequency)
try:
lower_freq, upper_freq, _ = calculate_nominal_bandwidth(segment, sample_rate, nfft, obw_power)
freq_lower = center_frequency + lower_freq
freq_upper = center_frequency + upper_freq
except (ValueError, IndexError):
# Fallback if calculation fails
freq_lower = center_frequency - (sample_rate / 2)
freq_upper = center_frequency + (sample_rate / 2)
elif freq_method == "obw":
# Occupied bandwidth
try:
_, f_lower, f_upper = calculate_occupied_bandwidth(segment, sample_rate, nfft, obw_power)
freq_lower = center_frequency + f_lower
freq_upper = center_frequency + f_upper
except (ValueError, IndexError):
# Fallback if calculation fails
freq_lower = center_frequency - (sample_rate / 2)
freq_upper = center_frequency + (sample_rate / 2)
elif freq_method == "full-detected":
# Full detected range (lowest to highest component)
try:
_, f_lower, f_upper = calculate_full_detected_bandwidth(segment, sample_rate, nfft)
freq_lower = center_frequency + f_lower
freq_upper = center_frequency + f_upper
except (ValueError, IndexError):
# Fallback if calculation fails
freq_lower = center_frequency - (sample_rate / 2)
freq_upper = center_frequency + (sample_rate / 2)
else:
# Segment too short for FFT, use full bandwidth
freq_lower = center_frequency - (sample_rate / 2)
freq_upper = center_frequency + (sample_rate / 2)
return freq_lower, freq_upper

View File

@ -1,435 +0,0 @@
"""
Parallel signal separation for multi-component frequency-offset signals.
Provides methods to detect and separate overlapping frequency-domain signals
that occupy the same time window but different frequency bands.
This module implements **spectral peak detection** to identify distinct frequency
components and split single time-domain annotations into frequency-specific
sub-annotations.
**Key Design Decisions** (per Codex review):
1. **Complex IQ Support**: Uses `scipy.signal.welch` with `return_onesided=False`
for proper complex signal handling. Window length automatically adapts to
signal length via `nperseg=min(nfft, len(signal))` to handle bursts <nfft.
2. **Frequency Representation**: Components are detected in **relative** frequency
(baseband, centered at 0 Hz). Caller must add RF center_frequency_hz when
writing to SigMF annotations. This separation of concerns avoids the frequency
context bug where absolute Hz would be meaningless for baseband processing.
3. **Bandwidth Estimation**: Dual strategy avoids -3dB limitations:
- Primary: -3dB rolloff for typical narrowband signals
- Fallback: Cumulative power (99% like OBW) for wide/OFDM signals
- Auto-fallback when -3dB bandwidth is anomalous
4. **Noise Floor**: Auto-estimated via `np.percentile(psd_db, 10)` from data
to adapt across hardware (Pluto vs. ThinkRF). User can override if needed.
5. **Filter Sizing (Optional FIR extraction)**: When extracting components,
uses Kaiser window FIR with proper stopband specification. Auto-sizes
numtaps based on desired transition bandwidth. Includes downsampling
guidance for long captures.
6. **CLI Surface**: Single `separate` subcommand for all separation operations.
Can be chained after any detector or used standalone on existing annotations.
Example:
Two WiFi channels captured simultaneously:
>>> from ria_toolkit_oss.annotations import find_spectral_components
>>> # Detect the two distinct channels (returns relative frequencies)
>>> components = find_spectral_components(signal, sampling_rate=20e6)
>>> print(f"Found {len(components)} components")
Found 2 components
The module is designed to work with detected time-domain annotations,
allowing splitting of overlapping signals into separate training samples.
"""
import json
from typing import List, Optional, Tuple
import numpy as np
from scipy import ndimage
from scipy import signal as scipy_signal
from ria_toolkit_oss.data import Annotation, Recording
def find_spectral_components(
signal_data: np.ndarray,
sampling_rate: float,
nfft: int = 65536,
noise_threshold_db: Optional[float] = None,
min_component_bw: float = 50e3,
time_percentile: float = 70.0,
) -> List[Tuple[float, float, float]]:
"""
Find distinct frequency components using spectral peak detection.
Identifies separate frequency components in a signal by analyzing the power
spectral density and finding peaks corresponding to distinct signals. This is
useful for separating parallel signals that occupy different frequency bands.
**Frequency Representation**: Returns frequencies in **baseband/relative** Hz
(centered at 0). To get absolute RF frequencies, add center_frequency_hz from
recording metadata to all returned values.
Algorithm:
1. Compute power spectral density using Welch (properly handles complex IQ)
2. Auto-estimate noise floor from data if not specified
3. Smooth PSD to reduce spurious peaks
4. Find local maxima above noise floor
5. Estimate bandwidth per peak using -3dB (fallback: cumulative power)
6. Filter components below minimum bandwidth threshold
:param signal_data: Complex IQ signal samples (np.complex64/128)
:type signal_data: np.ndarray
:param sampling_rate: Sample rate in Hz
:type sampling_rate: float
:param nfft: FFT size / window length for Welch. Automatically capped at
signal length to handle bursts (default: 65536)
:type nfft: int
:param noise_threshold_db: Minimum SNR threshold in dB. If None (default),
auto-estimates as np.percentile(psd_db, 10).
Adapt this across hardware (Pluto: ~-100, ThinkRF: ~-60).
:type noise_threshold_db: Optional[float]
:param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz)
:type min_component_bw: float
:param power_threshold: Cumulative power threshold for fallback bandwidth
estimation (default: 0.99 = 99% power, like OBW)
:type power_threshold: float
:returns: List of (center_freq_hz, lower_freq_hz, upper_freq_hz) tuples.
**All frequencies are relative (baseband, 0-centered).**
Add recording metadata['center_frequency'] to get absolute RF frequencies.
:rtype: List[Tuple[float, float, float]]
:raises ValueError: If signal has fewer than 256 samples
**Example**::
>>> from ria.io import load_recording
>>> from ria_toolkit_oss.annotations import find_spectral_components
>>> recording = load_recording("capture.sigmf")
>>> segment = recording.data[0][start:end]
>>> # Components in relative (baseband) frequency
>>> components = find_spectral_components(segment, sampling_rate=20e6)
>>> for center_rel, lower_rel, upper_rel in components:
... # Convert to absolute RF frequency
... center_abs = recording.metadata['center_frequency'] + center_rel
... print(f"Component @ {center_abs/1e9:.3f} GHz")
"""
# Validate input
min_samples = 256
if len(signal_data) < min_samples:
raise ValueError(f"Signal too short: need at least {min_samples} samples, " f"got {len(signal_data)}.")
# Compute PSD using Welch method for complex IQ signals
# CRITICAL: return_onesided=False for proper complex signal handling
nperseg = min(nfft, len(signal_data))
noverlap = nperseg // 2
# --- STFT ---
freqs, times, Zxx = scipy_signal.stft(
signal_data,
fs=sampling_rate,
window="blackman",
nperseg=nperseg,
noverlap=noverlap,
return_onesided=False,
boundary=None,
)
# Shift zero freq to center
Zxx = np.fft.fftshift(Zxx, axes=0)
freqs = np.fft.fftshift(freqs)
# Power spectrogram
power = np.abs(Zxx) ** 2
power_db = 10 * np.log10(power + 1e-12)
# --- Aggregate across time robustly ---
# Using percentile instead of mean prevents short signals from being diluted
freq_profile_db = np.percentile(power_db, time_percentile, axis=1)
# --- Noise floor estimation ---
if noise_threshold_db is None:
noise_threshold_db = np.percentile(freq_profile_db, 20)
threshold = noise_threshold_db + 3 # 3 dB above noise floor
# --- Smooth lightly (avoid merging nearby signals) ---
freq_profile_db = ndimage.gaussian_filter1d(freq_profile_db, sigma=1.5)
# --- Binary mask of significant frequencies ---
mask = freq_profile_db > threshold
# --- Find contiguous frequency regions ---
labeled, num_features = ndimage.label(mask)
components = []
for region_label in range(1, num_features + 1):
region_indices = np.where(labeled == region_label)[0]
if len(region_indices) == 0:
continue
lower_idx = region_indices[0]
upper_idx = region_indices[-1]
lower_freq = freqs[lower_idx]
upper_freq = freqs[upper_idx]
bw = upper_freq - lower_freq
if bw < min_component_bw:
continue
center_freq = (lower_freq + upper_freq) / 2
components.append((center_freq, lower_freq, upper_freq))
return components
def split_annotation_by_components(
annotation: Annotation,
signal: np.ndarray,
sampling_rate: float,
center_frequency_hz: float = 0.0,
nfft: int = 65536,
noise_threshold_db: Optional[float] = None,
min_component_bw: float = 50e3,
) -> List[Annotation]:
"""
Split an annotation into multiple annotations by detected frequency components.
Takes an existing annotation spanning multiple frequency components and
analyzes the frequency content to create separate sub-annotations for
each distinct frequency component.
**Use case**: Energy detection found a time window with 2-3 parallel WiFi
channels. This function splits it into separate annotations per channel.
**Frequency Handling**: `find_spectral_components` returns relative (baseband)
frequencies. This function adds `center_frequency_hz` to convert to absolute
RF frequencies for SigMF annotation bounds. This ensures correct frequency
context across baseband and RF domains.
:param annotation: Original annotation to split
:type annotation: Annotation
:param signal: Full signal array (complex IQ)
:type signal: np.ndarray
:param sampling_rate: Sample rate in Hz
:type sampling_rate: float
:param center_frequency_hz: RF center frequency to add to relative frequencies
from peak detection (default: 0.0 = baseband)
:type center_frequency_hz: float
:param nfft: FFT size for analysis (default: 65536, auto-capped at signal length)
:type nfft: int
:param noise_threshold_db: Noise floor threshold in dB. If None (default),
auto-estimates from data.
:type noise_threshold_db: Optional[float]
:param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz)
:type min_component_bw: float
:returns: List of new annotations (one per detected component).
Returns empty list if no components found or segment too short.
:rtype: List[Annotation]
**Example**::
>>> from ria.io import load_recording
>>> from ria_toolkit_oss.annotations import split_annotation_by_components
>>> recording = load_recording("capture.sigmf")
>>> # Original annotation spans multiple channels
>>> original = recording.annotations[0]
>>> # Split using RF center frequency from metadata
>>> components = split_annotation_by_components(
... original,
... recording.data[0],
... recording.metadata['sample_rate'],
... center_frequency_hz=recording.metadata.get('center_frequency', 0.0)
... )
>>> print(f"Split into {len(components)} components")
Split into 2 components
**Algorithm**:
1. Extract segment corresponding to annotation time bounds
2. Find frequency components in that segment (returns relative frequencies)
3. Add center_frequency_hz to get absolute RF frequencies
4. Create new annotation for each component
5. Preserve original metadata (label, type, etc.)
6. Add component info to comment JSON
**Notes**:
- Original annotation is not modified
- Returns empty list if segment too short (<256 samples)
- Segments <nfft get auto-downsampled to nfft (see find_spectral_components)
- Each component inherits label from original
- Component frequencies in comment JSON are absolute (RF) frequencies
"""
# Extract segment corresponding to annotation time bounds
start_sample = annotation.sample_start
end_sample = min(start_sample + annotation.sample_count, len(signal))
segment = signal[start_sample:end_sample]
# Validate segment length is enough for spectral analysis
if len(segment) < 256:
return []
# Find components in this segment (returns relative/baseband frequencies)
try:
components = find_spectral_components(segment, sampling_rate, nfft, noise_threshold_db, min_component_bw)
except ValueError:
# Spectral analysis failed (e.g., not complex IQ)
return []
if not components:
# No components found
return []
# Create annotations for each component
new_annotations = []
for center_freq_rel, lower_freq_rel, upper_freq_rel in components:
# Convert relative (baseband) frequencies to absolute (RF) frequencies
center_freq_abs = center_frequency_hz + center_freq_rel
lower_freq_abs = center_frequency_hz + lower_freq_rel
upper_freq_abs = center_frequency_hz + upper_freq_rel
# Parse original annotation metadata
try:
comment_data = json.loads(annotation.comment)
except (json.JSONDecodeError, TypeError):
comment_data = {"type": "standalone"}
# Add component information (with absolute RF frequencies)
comment_data["split_from_annotation"] = True
comment_data["original_freq_bounds"] = {
"lower": float(annotation.freq_lower_edge),
"upper": float(annotation.freq_upper_edge),
}
comment_data["component_freq_bounds_rf"] = {
"center": float(center_freq_abs),
"lower": float(lower_freq_abs),
"upper": float(upper_freq_abs),
}
# Create new annotation with absolute RF frequency bounds
new_anno = Annotation(
sample_start=annotation.sample_start,
sample_count=annotation.sample_count,
freq_lower_edge=lower_freq_abs,
freq_upper_edge=upper_freq_abs,
label=annotation.label,
comment=json.dumps(comment_data),
detail={
"generator": "parallel_signal_separator",
"center_freq_hz": float(center_freq_abs),
},
)
new_annotations.append(new_anno)
return new_annotations
def split_recording_annotations(
recording: Recording,
indices: Optional[List[int]] = None,
nfft: int = 65536,
noise_threshold_db: Optional[float] = None,
min_component_bw: float = 50e3,
) -> Recording:
"""
Split multiple annotations in a recording by frequency components.
Processes specified annotations (or all if indices=None), replacing each
with its frequency-separated components. Uses RF center_frequency from
recording metadata for proper absolute frequency conversion.
:param recording: Recording to process
:type recording: Recording
:param indices: Annotation indices to split (None = all, default: None).
Use indices=[] to skip splitting (returns unchanged recording).
:type indices: Optional[List[int]]
:param nfft: FFT size for spectral analysis (default: 65536,
auto-capped at signal segment length)
:type nfft: int
:param noise_threshold_db: Noise floor threshold in dB. If None (default),
auto-estimates from each segment.
:type noise_threshold_db: Optional[float]
:param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz).
Components narrower than this are filtered out.
:type min_component_bw: float
:returns: New Recording with split annotations
:rtype: Recording
**Example**::
>>> from ria.io import load_recording
>>> from ria_toolkit_oss.annotations import split_recording_annotations
>>> recording = load_recording("capture.sigmf")
>>> # Split all annotations
>>> split_rec = split_recording_annotations(recording)
>>> print(f"Original: {len(recording.annotations)} annotations")
>>> print(f"Split: {len(split_rec.annotations)} annotations")
Original: 5 annotations
Split: 9 annotations
**Algorithm**:
1. For each annotation in indices (or all if None):
2. Call split_annotation_by_components with RF center_frequency
3. If components found, replace annotation with components
4. If no components found, keep original annotation
5. Annotations not in indices are kept unchanged
**Notes**:
- Original recording is not modified
- Returns empty Recording.annotations if recording has no annotations
- RF center_frequency from metadata ensures correct absolute frequencies
- If an annotation can't be split (too short, wrong format), original kept
"""
if indices is None:
# Split all annotations
indices = list(range(len(recording.annotations)))
if not recording.annotations:
# No annotations to split
return recording
signal = recording.data[0]
sample_rate = recording.metadata["sample_rate"]
center_frequency = recording.metadata.get("center_frequency", 0.0)
# Build new annotation list
new_annotations = []
for i, anno in enumerate(recording.annotations):
if i in indices:
# Attempt to split this annotation
try:
components = split_annotation_by_components(
anno,
signal,
sample_rate,
center_frequency_hz=center_frequency,
nfft=nfft,
noise_threshold_db=noise_threshold_db,
min_component_bw=min_component_bw,
)
if components:
# Split successful, use components
new_annotations.extend(components)
else:
# No components found, keep original
new_annotations.append(anno)
except Exception:
# Split failed for any reason, keep original
new_annotations.append(anno)
else:
# Not in split list, keep as-is
new_annotations.append(anno)
return Recording(data=recording.data, metadata=recording.metadata, annotations=new_annotations)

View File

@ -1,35 +0,0 @@
import numpy as np
from ria_toolkit_oss.data import Recording
def qualify_slice_from_annotations(recording: Recording, slice_length: int):
"""
Slice a recording into many smaller recordings,
discarding any slices which do not have annotations that apply to those samples.
Used together with an annotation based qualifier.
:param recording: The recording to slice.
:type recording: Recording
:param slice_length: The length in samples of a slice.
:type slice_length: int"""
if len(recording.annotations) == 0:
print("Warning, no annotations.")
annotation_mask = np.zeros(len(recording.data[0]))
for annotation in recording.annotations:
annotation_mask[annotation.sample_start : annotation.sample_start + annotation.sample_count] = 1
output_recordings = []
for i in range((len(recording.data[0]) // slice_length) - 1):
start_index = slice_length * i
end_index = slice_length * (i + 1)
if 1 in annotation_mask[start_index:end_index]:
sl = recording.data[:, start_index:end_index]
output_recordings.append(Recording(data=sl, metadata=recording.metadata))
return output_recordings

View File

@ -1,97 +0,0 @@
import numpy as np
from scipy.signal import butter, lfilter
from ria_toolkit_oss.data.annotation import Annotation
from ria_toolkit_oss.data.recording import Recording
def isolate_signal(recording: Recording, annotation: Annotation) -> Recording:
"""
Slice, filter and frequency shift the input recording according to the bounding box defined by the annotation.
:param recording: The input Recording to be sliced.
:type recording: Recording
:param annotation: The Annotation object defining the area of the recording to isolate.
:type annotation: Annotation
:param decimate: Decimate the input signal after filtering to reduce the sample rate.
:type decimate: bool
:returns: The subsection of the original recording defined by the annotation.
:rtype: Recording"""
sample_start = max(0, annotation.sample_start)
sample_stop = min(len(recording), annotation.sample_start + annotation.sample_count)
anno_base_center_freq = (annotation.freq_lower_edge + annotation.freq_upper_edge) / 2 - recording.metadata.get(
"center_frequency", 0
)
anno_bw = annotation.freq_upper_edge - annotation.freq_lower_edge
signal_slice = recording.data[0, sample_start:sample_stop]
# normalize
signal_slice = signal_slice / np.max(np.abs(signal_slice))
isolation_bw = anno_bw
# frequency shift the center of the box about zero
shifted_signal_slice = frequency_shift_iq_samples(
iq_samples=signal_slice,
sample_rate=recording.metadata["sample_rate"],
shift_frequency=-1 * anno_base_center_freq,
)
# filter
if isolation_bw < recording.metadata["sample_rate"] - 1:
filtered_signal = apply_complex_lowpass_filter(
signal=shifted_signal_slice, cutoff_frequency=isolation_bw, sample_rate=recording.metadata["sample_rate"]
)
else:
filtered_signal = shifted_signal_slice
output = Recording(data=[filtered_signal], metadata=recording.metadata)
return output
def frequency_shift_iq_samples(iq_samples, sample_rate, shift_frequency):
# Number of samples
num_samples = len(iq_samples)
# Create a time vector from 0 to the total duration in seconds
time_vector = np.arange(num_samples) / sample_rate
# Generate the complex exponential for the frequency shift
complex_exponential = np.exp(1j * 2 * np.pi * shift_frequency * time_vector)
# Apply the frequency shift to the IQ samples
shifted_samples = iq_samples * complex_exponential
return shifted_samples
# Function to apply a lowpass Butterworth filter to a complex signal
def apply_complex_lowpass_filter(signal, cutoff_frequency, sample_rate, order=5):
# Design the lowpass filter
b, a = design_complex_lowpass_filter(cutoff_frequency, sample_rate, order)
# Apply the lowpass filter
filtered_signal = lfilter(b, a, signal)
return filtered_signal
def design_complex_lowpass_filter(cutoff_frequency, sample_rate, order=5):
# Nyquist frequency for complex signals is the sample rate
nyquist = sample_rate
# Ensure the cutoff frequency is positive and within the Nyquist limit
if cutoff_frequency <= 0 or cutoff_frequency > nyquist:
raise ValueError("Cutoff frequency must be between 0 and the Nyquist frequency.")
# Normalize the cutoff frequency to the Nyquist frequency
cutoff_normalized = cutoff_frequency / nyquist
# Create a Butterworth lowpass filter
b, a = butter(order, cutoff_normalized, btype="low")
return b, a

View File

@ -1,359 +0,0 @@
"""
Temporal signal detection and boundary refinement via Hysteresis Thresholding.
Provides methods to detect signal bursts in the time domain by triggering on
smoothed power peaks and expanding boundaries to capture the full energy envelope.
This module implements a **dual-threshold trigger** to solve the 'chatter'
problem in noisy environments, ensuring that signal annotations encapsulate
the entire rise and fall of a burst rather than just the peak.
**Key Design Decisions**:
1. **Hysteresis Logic (Dual-Threshold)**:
- **Trigger**: High threshold (`threshold * max_power`) ensures high confidence
in signal presence.
- **Boundary**: Low threshold (`0.5 * trigger`) allows the annotation to
"crawl" outward, capturing the lower-energy start and end of the burst
often missed by simple single-threshold detectors.
2. **Temporal Smoothing**: Uses a moving average window (`window_size`) prior
- to thresholding. This prevents high-frequency noise spikes from causing
fragmented annotations and provides a more stable estimate of the
signal's power envelope.
3. **Spectral Profiling**: Once a temporal segment is isolated, the module
- performs an automated FFT analysis. It identifies the **90% spectral
occupancy** to define the frequency boundaries (`f_min`, `f_max`),
allowing the detector to work on narrowband and wideband signals without
manual frequency tuning.
4. **Baseband/RF Mapping**: Automatically handles the conversion from
- relative FFT bin frequencies to absolute RF frequencies by referencing
`recording.metadata["center_frequency"]`.
5. **False Positive Mitigation**: Implements a hard minimum duration check
- (10ms) to ignore transient hardware spikes or noise floor fluctuations
that do not constitute a valid signal burst.
The module is designed to be the primary "first-pass" detector for pulsed
waveforms (like ADS-B, Lora, or bursty FSK) before passing them to
classification or demodulation stages.
"""
import json
from typing import Optional
import numpy as np
from ria_toolkit_oss.data import Annotation, Recording
def _find_ranges(indices, max_gap):
"""
Groups individual indices into continuous temporal ranges.
Args:
indices: Array of indices where the signal exceeded a threshold.
max_gap: Maximum gap allowed between indices to consider them part
of the same range.
Returns:
A list of (start, stop) tuples representing detected signal segments.
"""
if len(indices) == 0:
return []
start = indices[0]
prev = indices[0]
ranges = []
for i in range(1, len(indices)):
if indices[i] - prev > max_gap:
ranges.append((start, prev))
start = indices[i]
prev = indices[i]
ranges.append((start, prev))
return ranges
def _expand_and_filter_ranges(
smoothed_power: np.ndarray,
initial_ranges: list[tuple[int, int]],
boundary_val: float,
min_duration_samples: int,
) -> list[tuple[int, int]]:
"""Apply hysteresis expansion and minimum-duration filtering."""
out: list[tuple[int, int]] = []
n = len(smoothed_power)
for start, stop in initial_ranges:
if (stop - start) < min_duration_samples:
continue
true_start = start
while true_start > 0 and smoothed_power[true_start] > boundary_val:
true_start -= 1
true_stop = stop
while true_stop < n - 1 and smoothed_power[true_stop] > boundary_val:
true_stop += 1
if (true_stop - true_start) >= min_duration_samples:
out.append((true_start, true_stop))
return out
def _merge_ranges(ranges: list[tuple[int, int]], max_gap: int) -> list[tuple[int, int]]:
"""Merge overlapping or near-adjacent ranges."""
if not ranges:
return []
ranges = sorted(ranges, key=lambda r: r[0])
merged = [ranges[0]]
for s, e in ranges[1:]:
last_s, last_e = merged[-1]
if s <= last_e + max_gap:
merged[-1] = (last_s, max(last_e, e))
else:
merged.append((s, e))
return merged
def _estimate_noise_floor(power: np.ndarray, quantile: float = 20.0) -> float:
"""Estimate baseline from the quieter portion of the envelope."""
return float(np.percentile(power, quantile))
def _estimate_group_gap(sample_rate: float) -> int:
"""Use a fixed temporal grouping gap instead of reusing the smoothing window."""
return max(1, int(0.001 * sample_rate))
def _estimate_spectral_bounds(signal_segment: np.ndarray, sample_rate: float) -> tuple[float, float]:
"""Estimate occupied bandwidth from a smoothed magnitude spectrum."""
if len(signal_segment) == 0:
return -sample_rate / 4, sample_rate / 4
window = np.hanning(len(signal_segment))
windowed = signal_segment * window
fft_data = np.abs(np.fft.fftshift(np.fft.fft(windowed)))
fft_freqs = np.fft.fftshift(np.fft.fftfreq(len(signal_segment), 1 / sample_rate))
# Smooth the spectrum so noise-like wideband bursts form a contiguous mask
# instead of thousands of tiny isolated runs.
spectral_smooth_bins = max(5, min(257, (len(signal_segment) // 512) | 1))
spectral_kernel = np.ones(spectral_smooth_bins, dtype=np.float64) / spectral_smooth_bins
smoothed_fft = np.convolve(fft_data, spectral_kernel, mode="same")
spectral_floor = float(np.percentile(smoothed_fft, 20))
spectral_peak = float(np.max(smoothed_fft))
spectral_ratio = spectral_peak / max(spectral_floor, 1e-12)
if spectral_ratio < 1.2:
return -sample_rate / 4, sample_rate / 4
spectral_thresh = spectral_floor + 0.1 * (spectral_peak - spectral_floor)
sig_indices = np.where(smoothed_fft > spectral_thresh)[0]
if len(sig_indices) == 0:
peak_idx = int(np.argmax(smoothed_fft))
bin_hz = sample_rate / len(signal_segment)
half_bins = max(1, int(np.ceil(10_000.0 / bin_hz)))
lo_idx = max(0, peak_idx - half_bins)
hi_idx = min(len(smoothed_fft) - 1, peak_idx + half_bins)
else:
runs = _find_ranges(sig_indices, max_gap=max(1, spectral_smooth_bins // 2))
peak_idx = int(np.argmax(smoothed_fft))
lo_idx, hi_idx = min(
runs,
key=lambda run: 0 if run[0] <= peak_idx <= run[1] else min(abs(run[0] - peak_idx), abs(run[1] - peak_idx)),
)
# Prevent extremely narrow tone boxes from collapsing to just a few bins.
min_total_bw_hz = 20_000.0
min_half_bins = max(1, int(np.ceil((min_total_bw_hz / 2) / (sample_rate / len(signal_segment)))))
center_idx = int(round((lo_idx + hi_idx) / 2))
lo_idx = max(0, min(lo_idx, center_idx - min_half_bins))
hi_idx = min(len(smoothed_fft) - 1, max(hi_idx, center_idx + min_half_bins))
return float(fft_freqs[lo_idx]), float(fft_freqs[hi_idx])
def threshold_qualifier(
recording: Recording,
threshold: float,
window_size: Optional[int] = None,
label: Optional[str] = None,
annotation_type: Optional[str] = "standalone",
channel: int = 0,
) -> Recording:
"""
Annotate a recording with bounding boxes for regions above a threshold.
Threshold is defined as a fraction of the maximum sample magnitude.
This algorithm searches for samples above the threshold and combines them into ranges if they
are within window_size of each other.
Detects and annotates signals using energy thresholding and spectral analysis.
The algorithm follows these steps:
1. Smooths power data using a moving average.
2. Identifies 'peak' regions exceeding a high trigger threshold.
3. Uses hysteresis to expand boundaries until power drops below a lower threshold.
4. Performs an FFT on each segment to determine frequency occupancy.
Args:
recording: The Recording object containing IQ or real signal data.
threshold: Sensitivity multiplier (0.0 to 1.0) applied to max power.
window_size: Size of the smoothing filter in samples. Defaults to 1ms worth of samples.
label: Custom string label for annotations.
annotation_type: Metadata string for the 'type' field in the annotation.
channel: Index of the channel to annotate. Defaults to 0.
Returns:
A new Recording object populated with detected Annotations.
"""
# Extract signal and metadata
sample_data = recording.data[channel]
sample_rate = recording.metadata["sample_rate"]
center_frequency = recording.metadata.get("center_frequency", 0)
if window_size is None:
window_size = max(64, int(sample_rate * 0.001))
# --- 1. SIGNAL CONDITIONING ---
# Convert to power (Magnitude squared)
power_data = np.abs(sample_data) ** 2
smoothing_window = np.ones(window_size) / window_size
smoothed_power = np.convolve(power_data, smoothing_window, mode="same")
group_gap_samples = _estimate_group_gap(sample_rate)
# Define thresholds using peak relative to baseline.
max_power = np.max(smoothed_power)
noise_floor = _estimate_noise_floor(smoothed_power)
dynamic_range_ratio = max_power / max(noise_floor, 1e-12)
# Soft early exit: keep a guard for low-contrast noise, but compute it from
# the quieter tail of the envelope so burst-heavy captures are not rejected.
if dynamic_range_ratio < 1.5:
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations)
trigger_val = noise_floor + threshold * (max_power - noise_floor)
boundary_val = noise_floor + 0.5 * threshold * (max_power - noise_floor)
# --- 2. INITIAL DETECTION ---
# Enforce an explicit minimum duration in seconds; this is stable across
# varying capture lengths and avoids over-fitting to recording length.
min_duration_samples = max(1, int(0.005 * sample_rate))
annotations = []
# Pass 1: Detect stronger bursts.
indices = np.where(smoothed_power > trigger_val)[0]
pass1_initial = _find_ranges(indices=indices, max_gap=group_gap_samples)
pass1_ranges = _expand_and_filter_ranges(
smoothed_power=smoothed_power,
initial_ranges=pass1_initial,
boundary_val=boundary_val,
min_duration_samples=min_duration_samples,
)
# Pass 2: Recover weaker bursts on residual power not already covered.
# This improves recall in mixed-amplitude captures.
# Expand each Pass-1 range by the smoothing window on both sides so the
# smoothing skirts of a strong burst are not re-detected as a weak burst
# immediately adjacent to it (mirrors the guard used in Pass 3).
mask = np.ones_like(smoothed_power, dtype=np.float32)
pass2_mask_expand = window_size
for s, e in pass1_ranges:
mask[max(0, s - pass2_mask_expand) : min(len(mask), e + pass2_mask_expand)] = 0.0
residual_power = smoothed_power * mask
residual_max = float(np.max(residual_power))
residual_ratio = residual_max / max(noise_floor, 1e-12)
pass2_ranges: list[tuple[int, int]] = []
if residual_ratio >= 2.0:
weak_threshold = max(0.3, threshold * 0.7)
weak_trigger = noise_floor + weak_threshold * (residual_max - noise_floor)
weak_boundary = noise_floor + 0.5 * weak_threshold * (residual_max - noise_floor)
weak_indices = np.where(residual_power > weak_trigger)[0]
pass2_initial = _find_ranges(indices=weak_indices, max_gap=group_gap_samples)
pass2_ranges = _expand_and_filter_ranges(
smoothed_power=residual_power,
initial_ranges=pass2_initial,
boundary_val=weak_boundary,
min_duration_samples=min_duration_samples,
)
# Pass 3: Detect sustained faint bursts via macro-window averaging.
# Targets bursts whose peak power is near the trigger level but whose
# *average* power is consistently elevated above the noise floor — these
# are missed by peak-based detection because only a few short spikes exceed
# the trigger, all too brief to pass the minimum-duration filter.
#
# The mask is applied to power_data *before* convolving so that bright
# burst energy does not bleed through the long window into adjacent regions,
# which would inflate macro_residual_max and push the trigger above the
# faint burst's average power.
macro_window_size = max(window_size * 16, int(sample_rate * 0.02))
macro_kernel = np.ones(macro_window_size, dtype=np.float64) / macro_window_size
# Expand each annotated range by half the macro window on both sides so that
# the long convolution cannot "see" the leading/trailing edges of already-
# annotated bursts, which would produce spurious short fragments in Pass 3.
macro_expand = macro_window_size * 2
masked_power_for_macro = power_data.copy()
n = len(masked_power_for_macro)
for s, e in pass1_ranges + pass2_ranges:
masked_power_for_macro[max(0, s - macro_expand) : min(n, e + macro_expand)] = 0.0
macro_residual = np.convolve(masked_power_for_macro, macro_kernel, mode="same")
macro_residual_max = float(np.max(macro_residual))
pass3_ranges: list[tuple[int, int]] = []
if macro_residual_max / max(noise_floor, 1e-12) >= 1.3:
macro_trigger = noise_floor + threshold * (macro_residual_max - noise_floor)
macro_boundary = noise_floor + 0.5 * threshold * (macro_residual_max - noise_floor)
macro_indices = np.where(macro_residual > macro_trigger)[0]
macro_initial = _find_ranges(indices=macro_indices, max_gap=group_gap_samples)
pass3_ranges = _expand_and_filter_ranges(
smoothed_power=macro_residual,
initial_ranges=macro_initial,
boundary_val=macro_boundary,
min_duration_samples=min_duration_samples,
)
all_ranges = _merge_ranges(pass1_ranges + pass2_ranges + pass3_ranges, max_gap=group_gap_samples)
for true_start, true_stop in all_ranges:
# --- 4. SPECTRAL ANALYSIS (Frequency Detection) ---
signal_segment = sample_data[true_start:true_stop]
f_min, f_max = _estimate_spectral_bounds(signal_segment, sample_rate)
# --- 5. ANNOTATION GENERATION ---
ann_label = label if label is not None else f"{int(threshold*100)}%"
# Pack metadata for the UI/Downstream processing
comment_data = {
"type": annotation_type,
"generator": "threshold_qualifier",
"params": {
"threshold": threshold,
"window_size": window_size,
},
}
anno = Annotation(
sample_start=true_start,
sample_count=true_stop - true_start,
freq_lower_edge=center_frequency + f_min,
freq_upper_edge=center_frequency + f_max,
label=ann_label,
comment=json.dumps(comment_data),
detail={"generator": "hysteresis_qualifier"},
)
annotations.append(anno)
# Return a new Recording object including the new annotations
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations)

View File

@ -1 +0,0 @@
"""App runner: pull and run containerized RIA applications."""

View File

@ -1,278 +0,0 @@
"""Unified ``ria-app`` CLI.
Subcommands:
- ``ria-app pull <app>[:tag]`` pull a RIA app image from the configured registry.
- ``ria-app run <app>[:tag]`` pull (if needed) and run, auto-configuring
GPU/USB/network flags from image labels set by CI.
- ``ria-app list`` list locally cached RIA app images.
- ``ria-app stop <app>`` stop a running app container.
- ``ria-app logs <app>`` tail logs of a running app container.
- ``ria-app configure`` set default registry/namespace.
Image references resolve as::
my-classifier -> {registry}/{namespace}/my-classifier:latest
group/my-classifier -> {registry}/group/my-classifier:latest
host/group/app:tag -> host/group/app:tag (fully-qualified passthrough)
"""
from __future__ import annotations
import argparse
import json
import os
import shutil
import subprocess
import sys
from . import config as _config
_LABEL_PROFILE = "ria.profile"
_LABEL_HARDWARE = "ria.hardware"
_LABEL_APP = "ria.app"
def _engine(cfg: _config.AppConfig, sudo_override: bool = False) -> list[str]:
for exe in ("docker", "podman"):
if shutil.which(exe):
use_sudo = sudo_override or cfg.sudo
return ["sudo", exe] if use_sudo else [exe]
print("error: neither 'docker' nor 'podman' found on PATH", file=sys.stderr)
sys.exit(2)
def _resolve_ref(app: str, cfg: _config.AppConfig) -> str:
ref = app if ":" in app.split("/")[-1] else f"{app}:latest"
slashes = ref.count("/")
if slashes >= 2:
return ref
if slashes == 1:
return f"{cfg.registry}/{ref}" if cfg.registry else ref
if not cfg.registry or not cfg.namespace:
print(
"error: app is not fully qualified and no default registry/namespace configured. "
"Run `ria-app configure` or pass a full image reference (registry/namespace/app:tag).",
file=sys.stderr,
)
sys.exit(2)
return f"{cfg.registry}/{cfg.namespace}/{ref}"
def _container_name(ref: str) -> str:
name = ref.rsplit("/", 1)[-1].split(":", 1)[0]
return f"ria-app-{name}"
def _inspect_labels(engine: list[str], ref: str) -> dict:
try:
out = subprocess.check_output(
[*engine, "image", "inspect", "--format", "{{json .Config.Labels}}", ref],
stderr=subprocess.DEVNULL,
)
except subprocess.CalledProcessError:
return {}
try:
return json.loads(out.decode().strip()) or {}
except json.JSONDecodeError:
return {}
def _gpu_available() -> bool:
if os.path.exists("/dev/nvidia0"):
return True
return shutil.which("nvidia-smi") is not None
def _hardware_flags(labels: dict, no_gpu: bool, no_usb: bool, no_host_net: bool) -> tuple[list[str], list[str]]:
flags: list[str] = []
notes: list[str] = []
profile = (labels.get(_LABEL_PROFILE) or "").lower()
hardware = (labels.get(_LABEL_HARDWARE) or "").lower()
hw_items = {h.strip() for h in hardware.split(",") if h.strip()}
wants_gpu = any(k in profile for k in ("nvidia", "holoscan", "cuda"))
if wants_gpu and not no_gpu:
if _gpu_available():
flags += ["--gpus", "all"]
else:
notes.append(
"image wants GPU but no NVIDIA runtime detected — skipping --gpus (use --force-gpu to override)"
)
if hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} and not no_usb:
flags += ["--device", "/dev/bus/usb"]
if hw_items & {"usrp", "thinkrf", "pluto"} and not no_host_net:
flags += ["--net", "host"]
return flags, notes
def _cmd_configure(args: argparse.Namespace) -> int:
cfg = _config.load()
if args.registry:
cfg.registry = args.registry
if args.namespace:
cfg.namespace = args.namespace
if args.sudo is not None:
cfg.sudo = args.sudo
path = _config.save(cfg)
print(f"Saved app config to {path}")
print(f" registry: {cfg.registry or '(unset)'}")
print(f" namespace: {cfg.namespace or '(unset)'}")
print(f" sudo: {cfg.sudo}")
return 0
def _cmd_pull(args: argparse.Namespace) -> int:
cfg = _config.load()
engine = _engine(cfg, args.sudo)
ref = _resolve_ref(args.app, cfg)
print(f"Pulling {ref}")
return subprocess.call([*engine, "pull", ref])
def _cmd_run(args: argparse.Namespace) -> int:
cfg = _config.load()
engine = _engine(cfg, args.sudo)
ref = _resolve_ref(args.app, cfg)
if not _inspect_labels(engine, ref):
rc = subprocess.call([*engine, "pull", ref])
if rc != 0:
return rc
labels = _inspect_labels(engine, ref)
no_gpu = args.no_gpu and not args.force_gpu
hw_flags, notes = _hardware_flags(labels, no_gpu=no_gpu, no_usb=args.no_usb, no_host_net=args.no_host_net)
if args.force_gpu and "--gpus" not in hw_flags:
hw_flags = ["--gpus", "all", *hw_flags]
cmd = [*engine, "run", "--rm"]
if not args.foreground:
cmd += ["-d"]
cmd += ["--name", args.name or _container_name(ref)]
cmd += hw_flags
if args.config:
cmd += ["-v", f"{args.config}:/config/config.yaml:ro", "-e", "RIA_CONFIG=/config/config.yaml"]
for env in args.env or []:
cmd += ["-e", env]
for vol in args.volume or []:
cmd += ["-v", vol]
for port in args.publish or []:
cmd += ["-p", port]
cmd += list(args.docker_args or [])
cmd += [ref]
cmd += list(args.app_args or [])
if args.dry_run:
print(" ".join(cmd))
return 0
label_str = ", ".join(f"{k}={v}" for k, v in labels.items() if k.startswith("ria.")) or "(no ria.* labels)"
print(f"Running {ref} [{label_str}]")
if hw_flags:
print(f" auto flags: {' '.join(hw_flags)}")
for note in notes:
print(f" note: {note}")
return subprocess.call(cmd)
def _cmd_list(args: argparse.Namespace) -> int:
cfg = _config.load()
engine = _engine(cfg, args.sudo)
return subprocess.call(
[
*engine,
"images",
"--filter",
f"label={_LABEL_APP}",
"--format",
"table {{.Repository}}:{{.Tag}}\t{{.ID}}\t{{.Size}}",
]
)
def _cmd_stop(args: argparse.Namespace) -> int:
cfg = _config.load()
engine = _engine(cfg, args.sudo)
name = args.name or _container_name(_resolve_ref(args.app, cfg))
return subprocess.call([*engine, "stop", name])
def _cmd_logs(args: argparse.Namespace) -> int:
cfg = _config.load()
engine = _engine(cfg, args.sudo)
name = args.name or _container_name(_resolve_ref(args.app, cfg))
cmd = [*engine, "logs"]
if args.follow:
cmd += ["-f"]
cmd += [name]
return subprocess.call(cmd)
def main() -> None:
parser = argparse.ArgumentParser(prog="ria-app")
parser.add_argument("--sudo", action="store_true", default=False, help="Run docker/podman via sudo")
sub = parser.add_subparsers(dest="command", required=True)
p_cfg = sub.add_parser("configure", help="Set default registry/namespace")
p_cfg.add_argument("--registry", default=None, help="Default container registry (e.g. registry.riahub.ai)")
p_cfg.add_argument("--namespace", default=None, help="Default namespace (e.g. qoherent)")
p_cfg.add_argument(
"--sudo",
dest="sudo",
action=argparse.BooleanOptionalAction,
default=None,
help="Persist sudo default (--sudo / --no-sudo)",
)
p_pull = sub.add_parser("pull", help="Pull an app image")
p_pull.add_argument("app", help="App name or image reference")
p_run = sub.add_parser("run", help="Run an app, auto-detecting hardware flags")
p_run.add_argument("app", help="App name or image reference")
p_run.add_argument("--name", default=None, help="Container name (default: ria-app-<app>)")
p_run.add_argument("--config", default=None, help="Path to config.yaml to mount into the container")
p_run.add_argument("-e", "--env", action="append", help="Extra env var (KEY=VALUE)")
p_run.add_argument("-v", "--volume", action="append", help="Extra volume mount")
p_run.add_argument("-p", "--publish", action="append", help="Publish port")
p_run.add_argument("--foreground", "-F", action="store_true", help="Run in foreground (no -d)")
p_run.add_argument("--no-gpu", action="store_true", help="Skip --gpus flag even if image wants GPU")
p_run.add_argument("--force-gpu", action="store_true", help="Force --gpus all even if no NVIDIA runtime detected")
p_run.add_argument("--no-usb", action="store_true", help="Skip --device /dev/bus/usb")
p_run.add_argument("--no-host-net", action="store_true", help="Skip --net host")
p_run.add_argument("--dry-run", action="store_true", help="Print the container command and exit")
p_run.add_argument("--docker-args", nargs=argparse.REMAINDER, help="Pass remaining args to docker/podman run")
p_run.add_argument("--app-args", nargs=argparse.REMAINDER, help="Pass remaining args to the app entrypoint")
sub.add_parser("list", help="List locally cached RIA app images")
p_stop = sub.add_parser("stop", help="Stop a running app")
p_stop.add_argument("app", help="App name or image reference")
p_stop.add_argument("--name", default=None, help="Container name override")
p_logs = sub.add_parser("logs", help="Tail logs of a running app")
p_logs.add_argument("app", help="App name or image reference")
p_logs.add_argument("--name", default=None, help="Container name override")
p_logs.add_argument("-f", "--follow", action="store_true", help="Follow log output")
args = parser.parse_args()
dispatch = {
"configure": _cmd_configure,
"pull": _cmd_pull,
"run": _cmd_run,
"list": _cmd_list,
"stop": _cmd_stop,
"logs": _cmd_logs,
}
sys.exit(dispatch[args.command](args))
if __name__ == "__main__":
main()

View File

@ -1,51 +0,0 @@
"""App runner configuration at ``~/.ria/toolkit.json``.
Schema::
{
"registry": "registry.riahub.ai",
"namespace": "qoherent"
}
"""
from __future__ import annotations
import json
import os
from dataclasses import asdict, dataclass
from pathlib import Path
_DEFAULT_PATH = Path(os.environ.get("RIA_TOOLKIT_CONFIG", str(Path.home() / ".ria" / "toolkit.json")))
@dataclass
class AppConfig:
registry: str = ""
namespace: str = ""
sudo: bool = False
def default_path() -> Path:
return _DEFAULT_PATH
def load(path: Path | None = None) -> AppConfig:
p = path or _DEFAULT_PATH
if not p.exists():
return AppConfig(
registry=os.environ.get("RIA_REGISTRY", ""),
namespace=os.environ.get("RIA_NAMESPACE", ""),
)
data = json.loads(p.read_text())
return AppConfig(
registry=data.get("registry", "") or os.environ.get("RIA_REGISTRY", ""),
namespace=data.get("namespace", "") or os.environ.get("RIA_NAMESPACE", ""),
sudo=bool(data.get("sudo", False)) or os.environ.get("RIA_DOCKER_SUDO", "") not in ("", "0", "false"),
)
def save(cfg: AppConfig, path: Path | None = None) -> Path:
p = path or _DEFAULT_PATH
p.parent.mkdir(parents=True, exist_ok=True)
p.write_text(json.dumps(asdict(cfg), indent=2))
return p

View File

@ -1,8 +0,0 @@
"""
The Data package contains abstract data types tailored for radio machine learning, such as ``Recording``, as well
as the abstract interfaces for the radio dataset and radio dataset builder framework.
"""
__all__ = ["Annotation", "Recording"]
from .annotation import Annotation
from .recording import Recording

View File

@ -0,0 +1,8 @@
"""
The datatypes package contains abstract data types tailored for radio machine learning.
"""
__all__ = ["Annotation", "Recording"]
from .annotation import Annotation
from .recording import Recording

View File

@ -7,8 +7,8 @@ from typing import Any, Optional
from packaging.version import Version from packaging.version import Version
from ria_toolkit_oss.data.datasets.license.dataset_license import DatasetLicense from ria_toolkit_oss.datatypes.datasets.license.dataset_license import DatasetLicense
from ria_toolkit_oss.data.datasets.radio_dataset import RadioDataset from ria_toolkit_oss.datatypes.datasets.radio_dataset import RadioDataset
from ria_toolkit_oss.utils.abstract_attribute import abstract_attribute from ria_toolkit_oss.utils.abstract_attribute import abstract_attribute

View File

@ -7,11 +7,11 @@ from typing import Optional
import h5py import h5py
import numpy as np import numpy as np
from ria_toolkit_oss.data.datasets.h5helpers import ( from ria_toolkit_oss.datatypes.datasets.h5helpers import (
append_entry_inplace, append_entry_inplace,
copy_dataset_entry_by_index, copy_dataset_entry_by_index,
) )
from ria_toolkit_oss.data.datasets.radio_dataset import RadioDataset from ria_toolkit_oss.datatypes.datasets.radio_dataset import RadioDataset
class IQDataset(RadioDataset, ABC): class IQDataset(RadioDataset, ABC):
@ -19,7 +19,7 @@ class IQDataset(RadioDataset, ABC):
radiofrequency (RF) signals represented as In-phase (I) and Quadrature (Q) samples. radiofrequency (RF) signals represented as In-phase (I) and Quadrature (Q) samples.
For machine learning tasks that involve processing spectrograms, please use For machine learning tasks that involve processing spectrograms, please use
ria_toolkit_oss.data.datasets.SpectDataset instead. ria_toolkit_oss.datatypes.datasets.SpectDataset instead.
This is an abstract interface defining common properties and behaviour of IQDatasets. Therefore, this class This is an abstract interface defining common properties and behaviour of IQDatasets. Therefore, this class
should not be instantiated directly. Instead, it is subclassed to define custom interfaces for specific machine should not be instantiated directly. Instead, it is subclassed to define custom interfaces for specific machine

View File

@ -12,7 +12,7 @@ import numpy as np
import pandas as pd import pandas as pd
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from ria_toolkit_oss.data.datasets.h5helpers import ( from ria_toolkit_oss.datatypes.datasets.h5helpers import (
append_entry_inplace, append_entry_inplace,
copy_file, copy_file,
copy_over_example, copy_over_example,
@ -29,7 +29,7 @@ class RadioDataset(ABC):
This is an abstract interface defining common properties and behavior of radio datasets. Therefore, this class This is an abstract interface defining common properties and behavior of radio datasets. Therefore, this class
should not be instantiated directly. Instead, it should be subclassed to define specific interfaces for different should not be instantiated directly. Instead, it should be subclassed to define specific interfaces for different
types of radio datasets. For example, see ria_toolkit_oss.data.datasets.IQDataset, which is a radio dataset types of radio datasets. For example, see ria_toolkit_oss.datatypes.datasets.IQDataset, which is a radio dataset
subclass tailored for tasks involving the processing of radio signals represented as IQ (In-phase and Quadrature) subclass tailored for tasks involving the processing of radio signals represented as IQ (In-phase and Quadrature)
samples. samples.

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import os import os
from abc import ABC from abc import ABC
from ria_toolkit_oss.data.datasets.radio_dataset import RadioDataset from ria_toolkit_oss.datatypes.datasets.radio_dataset import RadioDataset
class SpectDataset(RadioDataset, ABC): class SpectDataset(RadioDataset, ABC):
@ -13,7 +13,7 @@ class SpectDataset(RadioDataset, ABC):
radio signal spectrograms. radio signal spectrograms.
For machine learning tasks that involve processing on IQ samples, please use For machine learning tasks that involve processing on IQ samples, please use
ria_toolkit_oss.data.datasets.IQDataset instead. ria_toolkit_oss.datatypes.datasets.IQDataset instead.
This is an abstract interface defining common properties and behaviour of IQDatasets. Therefore, this class This is an abstract interface defining common properties and behaviour of IQDatasets. Therefore, this class
should not be instantiated directly. Instead, it is subclassed to define custom interfaces for specific machine should not be instantiated directly. Instead, it is subclassed to define custom interfaces for specific machine

View File

@ -6,8 +6,11 @@ from typing import Optional
import numpy as np import numpy as np
from numpy.random import Generator from numpy.random import Generator
from ria_toolkit_oss.data.datasets import RadioDataset from ria_toolkit_oss.datatypes.datasets import RadioDataset
from ria_toolkit_oss.data.datasets.h5helpers import copy_over_example, make_empty_clone from ria_toolkit_oss.datatypes.datasets.h5helpers import (
copy_over_example,
make_empty_clone,
)
def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDataset]: def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDataset]:
@ -28,7 +31,7 @@ def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDatase
cases. cases.
This function is deterministic, meaning it will always produce the same split. For a random split, see This function is deterministic, meaning it will always produce the same split. For a random split, see
ria_toolkit_oss.data.datasets.random_split. ria_toolkit_oss.datatypes.datasets.random_split.
:param dataset: Dataset to be split. :param dataset: Dataset to be split.
:type dataset: RadioDataset :type dataset: RadioDataset
@ -47,7 +50,7 @@ def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDatase
>>> import string >>> import string
>>> import numpy as np >>> import numpy as np
>>> import pandas as pd >>> import pandas as pd
>>> from ria_toolkit_oss.data.datasets import split >>> from ria_toolkit_oss.datatypes.datasets import split
First, let's generate some random data: First, let's generate some random data:
@ -123,7 +126,7 @@ def random_split(
training and test datasets. training and test datasets.
This restriction makes it unlikely that a random split will produce datasets with the exact lengths specified. This restriction makes it unlikely that a random split will produce datasets with the exact lengths specified.
If it is important to ensure the closest possible split, consider using ria_toolkit_oss.data.datasets.split If it is important to ensure the closest possible split, consider using ria_toolkit_oss.datatypes.datasets.split
instead. instead.
:param dataset: Dataset to be split. :param dataset: Dataset to be split.
@ -141,7 +144,7 @@ def random_split(
:rtype: list of RadioDataset :rtype: list of RadioDataset
See Also: See Also:
ria_toolkit_oss.data.datasets.split: Usage is the same as for ``random_split()``. ria_toolkit_oss.datatypes.datasets.split: Usage is the same as for ``random_split()``.
""" """
if not isinstance(dataset, RadioDataset): if not isinstance(dataset, RadioDataset):
raise ValueError(f"'dataset' must be RadioDataset or one of its subclasses, got {type(dataset)}.") raise ValueError(f"'dataset' must be RadioDataset or one of its subclasses, got {type(dataset)}.")

View File

@ -12,7 +12,7 @@ from typing import Any, Iterator, Optional
import numpy as np import numpy as np
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from ria_toolkit_oss.data.annotation import Annotation from ria_toolkit_oss.datatypes.annotation import Annotation
PROTECTED_KEYS = ["rec_id", "timestamp"] PROTECTED_KEYS = ["rec_id", "timestamp"]
@ -26,7 +26,7 @@ class Recording:
Metadata is stored in a dictionary of key value pairs, Metadata is stored in a dictionary of key value pairs,
to include information such as sample_rate and center_frequency. to include information such as sample_rate and center_frequency.
Annotations are a list of :class:`~ria_toolkit_oss.data.Annotation`, Annotations are a list of :class:`~ria_toolkit_oss.datatypes.Annotation`,
defining bounding boxes in time and frequency with labels and metadata. defining bounding boxes in time and frequency with labels and metadata.
Here, signal data is represented as a NumPy array. This class is then extended in the RIA Backends to provide Here, signal data is represented as a NumPy array. This class is then extended in the RIA Backends to provide
@ -46,7 +46,7 @@ class Recording:
:param metadata: Additional information associated with the recording. :param metadata: Additional information associated with the recording.
:type metadata: dict, optional :type metadata: dict, optional
:param annotations: A collection of :class:`~ria_toolkit_oss.data.Annotation` objects defining bounding boxes. :param annotations: A collection of :class:`~ria_toolkit_oss.datatypes.Annotation` objects defining bounding boxes.
:type annotations: list of Annotations, optional :type annotations: list of Annotations, optional
:param dtype: Explicitly specify the data-type of the complex samples. Must be a complex NumPy type, such as :param dtype: Explicitly specify the data-type of the complex samples. Must be a complex NumPy type, such as
@ -66,7 +66,7 @@ class Recording:
**Examples:** **Examples:**
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.data import Recording, Annotation >>> from ria_toolkit_oss.datatypes import Recording, Annotation
>>> # Create an array of complex samples, just 1s in this case. >>> # Create an array of complex samples, just 1s in this case.
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
@ -244,7 +244,7 @@ class Recording:
@property @property
def sample_rate(self) -> float | None: def sample_rate(self) -> float | None:
""" """
:return: Sample rate of the recording, or None if 'sample_rate' is not in metadata. :return: Sample rate of the recording, or None is 'sample_rate' is not in metadata.
:type: str :type: str
""" """
return self.metadata.get("sample_rate") return self.metadata.get("sample_rate")
@ -311,7 +311,7 @@ class Recording:
Create a recording and add metadata: Create a recording and add metadata:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.data import Recording >>> from ria_toolkit_oss.datatypes import Recording
>>> >>>
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -366,7 +366,7 @@ class Recording:
Create a recording and update metadata: Create a recording and update metadata:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.data import Recording >>> from ria_toolkit_oss.datatypes import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -421,7 +421,7 @@ class Recording:
Create a recording and add metadata: Create a recording and add metadata:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.data import Recording >>> from ria_toolkit_oss.datatypes import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -454,7 +454,7 @@ class Recording:
:param output_path: The output image path. Defaults to "images/signal.png". :param output_path: The output image path. Defaults to "images/signal.png".
:type output_path: str, optional :type output_path: str, optional
:param kwargs: Keyword arguments passed on to ria_toolkit_oss.view.view_sig. :param kwargs: Keyword arguments passed on to utils.view.view_sig.
:type: dict of keyword arguments :type: dict of keyword arguments
**Examples:** **Examples:**
@ -462,7 +462,7 @@ class Recording:
Create a recording and view it as a plot in a .png image: Create a recording and view it as a plot in a .png image:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.data import Recording >>> from utils.data import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -480,7 +480,7 @@ class Recording:
def simple_view(self, **kwargs) -> None: def simple_view(self, **kwargs) -> None:
"""Create a plot of various signal visualizations as a PNG or SVG image. """Create a plot of various signal visualizations as a PNG or SVG image.
:param kwargs: Keyword arguments passed on to ria_toolkit_oss.view.view_signal_simple.view_simple_sig. :param kwargs: Keyword arguments passed on to utils.view.view_signal_simple.create_plots.
:type: dict of keyword arguments :type: dict of keyword arguments
**Examples:** **Examples:**
@ -488,7 +488,7 @@ class Recording:
Create a recording and view it as a plot in a .png image: Create a recording and view it as a plot in a .png image:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.data import Recording >>> from utils.data import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -511,7 +511,7 @@ class Recording:
The SigMF io format is defined by the `SigMF Specification Project <https://github.com/sigmf/SigMF>`_ The SigMF io format is defined by the `SigMF Specification Project <https://github.com/sigmf/SigMF>`_
:param recording: The recording to be written to file. :param recording: The recording to be written to file.
:type recording: ria_toolkit_oss.data.Recording :type recording: ria_toolkit_oss.datatypes.Recording
:param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename. :param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename.
:type filename: os.PathLike or str, optional :type filename: os.PathLike or str, optional
:param path: The directory path to where the recording is to be saved. Defaults to recordings/. :param path: The directory path to where the recording is to be saved. Defaults to recordings/.
@ -545,7 +545,7 @@ class Recording:
Create a recording and save it to a .npy file: Create a recording and save it to a .npy file:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.data import Recording >>> from ria_toolkit_oss.datatypes import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -596,7 +596,7 @@ class Recording:
Create a recording and save it to a .wav file: Create a recording and save it to a .wav file:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.data import Recording >>> from utils.data import Recording
>>> samples = numpy.exp(1j * 2 * numpy.pi * 0.1 * numpy.arange(10000)) >>> samples = numpy.exp(1j * 2 * numpy.pi * 0.1 * numpy.arange(10000))
>>> metadata = {"sample_rate": 1e6, "center_frequency": 915e6} >>> metadata = {"sample_rate": 1e6, "center_frequency": 915e6}
>>> recording = Recording(data=samples, metadata=metadata) >>> recording = Recording(data=samples, metadata=metadata)
@ -646,7 +646,7 @@ class Recording:
Create a recording and save it to a .blue file: Create a recording and save it to a .blue file:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.data import Recording >>> from utils.data import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = {"sample_rate": 1e6, "center_frequency": 2.44e9} >>> metadata = {"sample_rate": 1e6, "center_frequency": 2.44e9}
>>> recording = Recording(data=samples, metadata=metadata) >>> recording = Recording(data=samples, metadata=metadata)
@ -674,7 +674,7 @@ class Recording:
Create a recording and trim it: Create a recording and trim it:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.data import Recording >>> from ria_toolkit_oss.datatypes import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -736,7 +736,7 @@ class Recording:
Create a recording with maximum amplitude 0.5 and normalize to a maximum amplitude of 1: Create a recording with maximum amplitude 0.5 and normalize to a maximum amplitude of 1:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.data import Recording >>> from ria_toolkit_oss.datatypes import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) * 0.5 >>> samples = numpy.ones(10000, dtype=numpy.complex64) * 0.5
>>> metadata = { >>> metadata = {

View File

@ -1,5 +1,5 @@
""" """
Utilities for input/output operations on the ria_toolkit_oss.data.Recording object. Utilities for input/output operations on the ria_toolkit_oss.datatypes.Recording object.
""" """
import datetime import datetime
@ -19,8 +19,8 @@ from quantiphy import Quantity
from sigmf import SigMFFile, sigmffile from sigmf import SigMFFile, sigmffile
from sigmf.utils import get_data_type_str from sigmf.utils import get_data_type_str
from ria_toolkit_oss.data import Annotation from ria_toolkit_oss.datatypes import Annotation
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
_BLUE_META_PREFIX = "META_" _BLUE_META_PREFIX = "META_"
_BLUE_META_TAG_MAX_LEN = 60 _BLUE_META_TAG_MAX_LEN = 60
@ -64,7 +64,7 @@ def to_npy(
"""Write recording to ``.npy`` binary file. """Write recording to ``.npy`` binary file.
:param recording: The recording to be written to file. :param recording: The recording to be written to file.
:type recording: ria_toolkit_oss.data.Recording :type recording: ria_toolkit_oss.datatypes.Recording
:param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename. :param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename.
:type filename: os.PathLike or str, optional :type filename: os.PathLike or str, optional
:param path: The directory path to where the recording is to be saved. Defaults to recordings/. :param path: The directory path to where the recording is to be saved. Defaults to recordings/.
@ -135,7 +135,7 @@ def from_npy(file: os.PathLike | str, legacy: bool = False) -> Recording:
:raises IOError: If there is an issue encountered during the file reading process. :raises IOError: If there is an issue encountered during the file reading process.
:return: The recording, as initialized from the ``.npy`` file. :return: The recording, as initialized from the ``.npy`` file.
:rtype: ria_toolkit_oss.data.Recording :rtype: ria_toolkit_oss.datatypes.Recording
""" """
filename, extension = os.path.splitext(file) filename, extension = os.path.splitext(file)
@ -161,7 +161,7 @@ def from_npy(file: os.PathLike | str, legacy: bool = False) -> Recording:
try: try:
raw_ann = np.load(f, allow_pickle=False) raw_ann = np.load(f, allow_pickle=False)
ann_list = json.loads(raw_ann.tobytes().decode()) ann_list = json.loads(raw_ann.tobytes().decode())
from ria_toolkit_oss.data.annotation import Annotation from ria_toolkit_oss.datatypes.annotation import Annotation
annotations = [Annotation(**a) for a in ann_list] annotations = [Annotation(**a) for a in ann_list]
except EOFError: except EOFError:
@ -198,7 +198,7 @@ def from_npy_legacy(file: os.PathLike | str) -> Recording:
:raises IOError: If there is an issue encountered during the file reading process. :raises IOError: If there is an issue encountered during the file reading process.
:return: The recording, as initialized from the legacy ``.npy`` file. :return: The recording, as initialized from the legacy ``.npy`` file.
:rtype: ria_toolkit_oss.data.Recording :rtype: ria_toolkit_oss.datatypes.Recording
**Examples:** **Examples:**
@ -270,7 +270,7 @@ def to_sigmf(
The SigMF io format is defined by the `SigMF Specification Project <https://github.com/sigmf/SigMF>`_ The SigMF io format is defined by the `SigMF Specification Project <https://github.com/sigmf/SigMF>`_
:param recording: The recording to be written to file. :param recording: The recording to be written to file.
:type recording: ria_toolkit_oss.data.Recording :type recording: ria_toolkit_oss.datatypes.Recording
:param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename. :param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename.
:type filename: os.PathLike or str, optional :type filename: os.PathLike or str, optional
:param path: The directory path to where the recording is to be saved. Defaults to recordings/. :param path: The directory path to where the recording is to be saved. Defaults to recordings/.
@ -367,7 +367,9 @@ def to_sigmf(
meta_dict = sigMF_metafile.ordered_metadata() meta_dict = sigMF_metafile.ordered_metadata()
meta_dict["ria"] = metadata meta_dict["ria"] = metadata
sigMF_metafile.tofile(meta_file_path, overwrite=overwrite) if overwrite and os.path.isfile(meta_file_path):
os.remove(meta_file_path)
sigMF_metafile.tofile(meta_file_path)
def from_sigmf(file: os.PathLike | str) -> Recording: def from_sigmf(file: os.PathLike | str) -> Recording:
@ -381,7 +383,7 @@ def from_sigmf(file: os.PathLike | str) -> Recording:
:raises IOError: If there is an issue encountered during the file reading process. :raises IOError: If there is an issue encountered during the file reading process.
:return: The recording, as initialized from the SigMF files. :return: The recording, as initialized from the SigMF files.
:rtype: ria_toolkit_oss.data.Recording :rtype: ria_toolkit_oss.datatypes.Recording
""" """
file = str(file) file = str(file)
@ -443,7 +445,7 @@ def to_wav(
in the ICMT (comment) field for human readability. in the ICMT (comment) field for human readability.
:param recording: The recording to be written to file. :param recording: The recording to be written to file.
:type recording: ria_toolkit_oss.data.Recording :type recording: ria_toolkit_oss.datatypes.Recording
:param filename: The name of the file where the recording is to be saved. :param filename: The name of the file where the recording is to be saved.
Defaults to auto-generated filename. Defaults to auto-generated filename.
:type filename: str, optional :type filename: str, optional
@ -553,7 +555,7 @@ def from_wav(file: os.PathLike | str) -> Recording:
:raises ValueError: If file is not stereo or has unsupported format. :raises ValueError: If file is not stereo or has unsupported format.
:return: The recording, as initialized from the WAV file. :return: The recording, as initialized from the WAV file.
:rtype: ria_toolkit_oss.data.Recording :rtype: ria_toolkit_oss.datatypes.Recording
""" """
import wave import wave
@ -635,7 +637,7 @@ def to_blue(
Commonly used with X-Midas and other RF/radar signal processing tools. Commonly used with X-Midas and other RF/radar signal processing tools.
:param recording: The recording to be written to file. :param recording: The recording to be written to file.
:type recording: ria_toolkit_oss.data.Recording :type recording: ria_toolkit_oss.datatypes.Recording
:param filename: The name of the file where the recording is to be saved. :param filename: The name of the file where the recording is to be saved.
Defaults to auto-generated filename. Defaults to auto-generated filename.
:type filename: str, optional :type filename: str, optional
@ -792,7 +794,7 @@ def from_blue(file: os.PathLike | str) -> Recording:
:raises ValueError: If file format is not valid or unsupported. :raises ValueError: If file format is not valid or unsupported.
:return: The recording, as initialized from the Blue file. :return: The recording, as initialized from the Blue file.
:rtype: ria_toolkit_oss.data.Recording :rtype: ria_toolkit_oss.datatypes.Recording
""" """
filename = str(file) filename = str(file)
if not filename.endswith(".blue"): if not filename.endswith(".blue"):
@ -917,7 +919,7 @@ def load_recording(file: os.PathLike) -> Recording:
:raises ValueError: If the inferred file extension is not supported. :raises ValueError: If the inferred file extension is not supported.
:return: The recording, as initialized from file(s). :return: The recording, as initialized from file(s).
:rtype: ria_toolkit_oss.data.Recording :rtype: ria_toolkit_oss.datatypes.Recording
""" """
_, extension = os.path.splitext(file) _, extension = os.path.splitext(file)
extension = extension.lstrip(".") extension = extension.lstrip(".")

View File

@ -223,19 +223,13 @@ class TransmitterConfig:
id: str id: str
type: str # "wifi", "bluetooth", "sdr", "external" type: str # "wifi", "bluetooth", "sdr", "external"
control_method: str # "external_script" | "sdr" | "sdr_remote" control_method: str # "external_script" | "sdr"
schedule: list[CaptureStep] schedule: list[CaptureStep]
# For external_script control # For external_script control
script: Optional[str] = None # path to control script script: Optional[str] = None # path to control script
device: Optional[str] = None # e.g. "/dev/wlan0" device: Optional[str] = None # e.g. "/dev/wlan0"
# For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port
sdr_remote: Optional[dict] = None
# For sdr_agent control — keys: modulation, order, symbol_rate, center_frequency, filter, rolloff
sdr_agent: Optional[dict] = None
@classmethod @classmethod
def from_dict(cls, d: dict) -> "TransmitterConfig": def from_dict(cls, d: dict) -> "TransmitterConfig":
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])] schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
@ -246,8 +240,6 @@ class TransmitterConfig:
schedule=schedule, schedule=schedule,
script=d.get("script"), script=d.get("script"),
device=d.get("device"), device=d.get("device"),
sdr_remote=d.get("sdr_remote"),
sdr_agent=d.get("sdr_agent"),
) )
@ -276,7 +268,6 @@ class OutputConfig:
path: str = "recordings" path: str = "recordings"
device_id: Optional[str] = None # for device-profile campaigns device_id: Optional[str] = None # for device-profile campaigns
repo: Optional[str] = None repo: Optional[str] = None
folder: Optional[str] = None # repo subfolder: None = use campaign name, "" = no subfolder, str = custom
@classmethod @classmethod
def from_dict(cls, d: dict) -> "OutputConfig": def from_dict(cls, d: dict) -> "OutputConfig":
@ -285,7 +276,6 @@ class OutputConfig:
path=str(d.get("path", "recordings")), path=str(d.get("path", "recordings")),
device_id=d.get("device_id"), device_id=d.get("device_id"),
repo=d.get("repo"), repo=d.get("repo"),
folder=d.get("folder"),
) )
@ -299,7 +289,6 @@ class CampaignConfig:
qa: QAConfig = field(default_factory=QAConfig) qa: QAConfig = field(default_factory=QAConfig)
output: OutputConfig = field(default_factory=OutputConfig) output: OutputConfig = field(default_factory=OutputConfig)
mode: str = "controlled_testbed" mode: str = "controlled_testbed"
loops: int = 1 # repeat full schedule this many times; labels get _run{N:02d} suffix
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Loaders # Loaders
@ -327,7 +316,6 @@ class CampaignConfig:
return cls( return cls(
name=safe_name, name=safe_name,
mode=str(campaign_meta.get("mode", "controlled_testbed")), mode=str(campaign_meta.get("mode", "controlled_testbed")),
loops=max(1, int(campaign_meta.get("loops", 1))),
recorder=RecorderConfig.from_dict(raw["recorder"]), recorder=RecorderConfig.from_dict(raw["recorder"]),
transmitters=transmitters, transmitters=transmitters,
qa=QAConfig.from_dict(raw.get("qa", {})), qa=QAConfig.from_dict(raw.get("qa", {})),
@ -392,7 +380,6 @@ class CampaignConfig:
return cls( return cls(
name=safe_name, name=safe_name,
mode=str(campaign_meta.get("mode", "controlled_testbed")), mode=str(campaign_meta.get("mode", "controlled_testbed")),
loops=max(1, int(campaign_meta.get("loops", 1))),
recorder=RecorderConfig.from_dict(raw["recorder"]), recorder=RecorderConfig.from_dict(raw["recorder"]),
transmitters=transmitters, transmitters=transmitters,
qa=QAConfig.from_dict(raw.get("qa", {})), qa=QAConfig.from_dict(raw.get("qa", {})),
@ -495,9 +482,9 @@ class CampaignConfig:
) )
def total_capture_time_s(self) -> float: def total_capture_time_s(self) -> float:
"""Sum of all step durations across all transmitters and loops.""" """Sum of all step durations across all transmitters."""
return sum(step.duration for tx in self.transmitters for step in tx.schedule) * self.loops return sum(step.duration for tx in self.transmitters for step in tx.schedule)
def total_steps(self) -> int: def total_steps(self) -> int:
"""Total number of capture steps across all transmitters and loops.""" """Total number of capture steps across all transmitters."""
return sum(len(tx.schedule) for tx in self.transmitters) * self.loops return sum(len(tx.schedule) for tx in self.transmitters)

View File

@ -5,19 +5,17 @@ from __future__ import annotations
import json import json
import logging import logging
import subprocess import subprocess
import threading
import time import time
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Callable, Optional from typing import Callable, Optional
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.io.recording import to_sigmf from ria_toolkit_oss.io.recording import to_sigmf
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
from .labeler import build_output_filename, label_recording from .labeler import build_output_filename, label_recording
from .qa import QAResult, check_recording from .qa import QAResult, check_recording
from .tx_executor import TxExecutor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -171,21 +169,6 @@ def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _extract_tx_params(transmitter: TransmitterConfig) -> dict | None:
"""Build a tx_params dict from a transmitter's signal config for SigMF labeling.
For sdr_agent transmitters, returns the synthetic generation parameters
(modulation, order, symbol_rate, etc.) so recordings capture what was
transmitted. Returns None for control methods without signal-level params.
"""
sdr_agent_cfg = getattr(transmitter, "sdr_agent", None)
if not sdr_agent_cfg:
return None
# Extract known signal-level fields; ignore infra fields
_INFRA_KEYS = {"node_id", "session_code"}
return {k: v for k, v in sdr_agent_cfg.items() if k not in _INFRA_KEYS and v is not None}
class CampaignExecutor: class CampaignExecutor:
"""Executes a :class:`CampaignConfig` end-to-end. """Executes a :class:`CampaignConfig` end-to-end.
@ -209,14 +192,10 @@ class CampaignExecutor:
config: CampaignConfig, config: CampaignConfig,
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None, progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
verbose: bool = False, verbose: bool = False,
skip_local_tx: bool = False,
): ):
self.config = config self.config = config
self.progress_cb = progress_cb self.progress_cb = progress_cb
self.skip_local_tx = skip_local_tx
self._sdr = None self._sdr = None
self._remote_tx_controllers: dict = {}
self._tx_executors: dict[str, tuple] = {} # tx_id → (TxExecutor, stop_event, thread)
if verbose: if verbose:
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -236,28 +215,21 @@ class CampaignExecutor:
""" """
result = CampaignResult(campaign_name=self.config.name) result = CampaignResult(campaign_name=self.config.name)
loops = self.config.loops
logger.info( logger.info(
f"Starting campaign '{self.config.name}': " f"Starting campaign '{self.config.name}': "
f"{self.config.total_steps()} steps" f"{self.config.total_steps()} steps, "
+ (f" ({self.config.total_steps() // loops} × {loops} loops)" if loops > 1 else "") f"~{self.config.total_capture_time_s():.0f}s capture time"
+ f", ~{self.config.total_capture_time_s():.0f}s capture time"
) )
self._init_sdr() self._init_sdr()
self._init_remote_tx_controllers()
try: try:
total = self.config.total_steps() total = self.config.total_steps()
step_index = 0 step_index = 0
for loop_idx in range(loops):
if loops > 1:
logger.info(f"Loop {loop_idx + 1}/{loops}")
for transmitter in self.config.transmitters: for transmitter in self.config.transmitters:
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)") logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
for step in transmitter.schedule: for step in transmitter.schedule:
looped_step = replace(step, label=f"{step.label}_run{loop_idx + 1:02d}") if loops > 1 else step step_result = self._execute_step(transmitter, step)
step_result = self._execute_step(transmitter, looped_step)
result.steps.append(step_result) result.steps.append(step_result)
step_index += 1 step_index += 1
@ -265,21 +237,17 @@ class CampaignExecutor:
self.progress_cb(step_index, total, step_result) self.progress_cb(step_index, total, step_result)
if step_result.error: if step_result.error:
logger.warning(f"Step '{looped_step.label}' error: {step_result.error}") logger.warning(f"Step '{step.label}' error: {step_result.error}")
elif step_result.qa.flagged: elif step_result.qa.flagged:
logger.warning( logger.warning(f"Step '{step.label}' flagged for review: " + "; ".join(step_result.qa.issues))
f"Step '{looped_step.label}' flagged for review: " + "; ".join(step_result.qa.issues)
)
else: else:
logger.info( logger.info(
f"Step '{looped_step.label}' OK " f"Step '{step.label}' OK "
f"(SNR {step_result.qa.snr_db:.1f} dB, " f"(SNR {step_result.qa.snr_db:.1f} dB, "
f"{step_result.qa.duration_s:.1f}s)" f"{step_result.qa.duration_s:.1f}s)"
) )
finally: finally:
self._close_sdr() self._close_sdr()
self._close_remote_tx_controllers()
self._close_tx_executors()
result.end_time = time.time() result.end_time = time.time()
logger.info( logger.info(
@ -319,47 +287,6 @@ class CampaignExecutor:
logger.warning(f"SDR close error: {e}") logger.warning(f"SDR close error: {e}")
self._sdr = None self._sdr = None
# ------------------------------------------------------------------
# Remote Tx controller management
# ------------------------------------------------------------------
def _init_remote_tx_controllers(self) -> None:
"""Open SSH+ZMQ connections for all sdr_remote transmitters."""
from ria_toolkit_oss.remote_control import RemoteTransmitterController
for tx in self.config.transmitters:
if tx.control_method != "sdr_remote":
continue
cfg = tx.sdr_remote
if not cfg:
raise RuntimeError(f"Transmitter '{tx.id}' uses sdr_remote but has no sdr_remote config")
logger.info(f"Connecting remote Tx controller for {tx.id}{cfg['host']}")
ctrl = RemoteTransmitterController(
host=cfg["host"],
ssh_user=cfg["ssh_user"],
ssh_key_path=cfg["ssh_key_path"],
zmq_port=int(cfg.get("zmq_port", 5556)),
)
ctrl.set_radio(
device_type=cfg["device_type"],
device_id=cfg.get("device_id", ""),
)
self._remote_tx_controllers[tx.id] = ctrl
def _close_remote_tx_controllers(self) -> None:
for tx_id, ctrl in list(self._remote_tx_controllers.items()):
try:
ctrl.close()
except Exception as exc:
logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}")
self._remote_tx_controllers.clear()
def _close_tx_executors(self) -> None:
for tx_id, (_, stop_event, t) in list(self._tx_executors.items()):
stop_event.set()
t.join(timeout=5.0)
self._tx_executors.clear()
def _record(self, duration_s: float) -> Recording: def _record(self, duration_s: float) -> Recording:
"""Capture ``duration_s`` seconds of IQ samples.""" """Capture ``duration_s`` seconds of IQ samples."""
num_samples = int(duration_s * self.config.recorder.sample_rate) num_samples = int(duration_s * self.config.recorder.sample_rate)
@ -404,7 +331,6 @@ class CampaignExecutor:
step=step, step=step,
capture_timestamp=capture_timestamp, capture_timestamp=capture_timestamp,
campaign_name=self.config.name, campaign_name=self.config.name,
tx_params=_extract_tx_params(transmitter),
) )
# QA # QA
@ -446,8 +372,7 @@ class CampaignExecutor:
traffic, etc. The script is responsible for applying the configuration traffic, etc. The script is responsible for applying the configuration
and returning promptly (i.e. not blocking for the capture duration). and returning promptly (i.e. not blocking for the capture duration).
For ``sdr_remote`` the remote ZMQ controller calls ``init_tx`` then For SDR transmitters this is a no-op placeholder (TX not yet implemented).
starts a background transmit thread that runs for the step duration.
""" """
if transmitter.control_method == "external_script": if transmitter.control_method == "external_script":
if not transmitter.script: if not transmitter.script:
@ -459,44 +384,6 @@ class CampaignExecutor:
elif transmitter.control_method == "sdr": elif transmitter.control_method == "sdr":
logger.debug("SDR TX not yet implemented — skipping start") logger.debug("SDR TX not yet implemented — skipping start")
elif transmitter.control_method == "sdr_remote":
ctrl = self._remote_tx_controllers.get(transmitter.id)
if ctrl is None:
raise RuntimeError(f"No remote Tx controller found for transmitter '{transmitter.id}'")
gain = step.power_dbm if step.power_dbm is not None else 0.0
ctrl.init_tx(
center_frequency=self.config.recorder.center_freq,
sample_rate=self.config.recorder.sample_rate,
gain=gain,
channel=step.channel or 0,
)
# Start transmission in background; _record() runs concurrently
ctrl.transmit_async(step.duration + 1.0)
elif transmitter.control_method == "sdr_agent":
if self.skip_local_tx:
logger.debug(f"skip_local_tx — TX for '{transmitter.id}' delegated to TX agent node")
return
if not transmitter.sdr_agent:
logger.warning(f"Transmitter '{transmitter.id}' has no sdr_agent config — skipping")
return
step_dict: dict = {"label": step.label, "duration": step.duration + 1.0}
if step.power_dbm is not None:
step_dict["power_dbm"] = step.power_dbm
tx_config = {
"id": transmitter.id,
"sdr_agent": transmitter.sdr_agent,
"schedule": [step_dict],
}
rec = self.config.recorder
tx_device = transmitter.device or rec.device
sdr_device = _DEVICE_ALIASES.get(tx_device.lower(), tx_device.lower())
stop_event = threading.Event()
executor = TxExecutor(tx_config, sdr_device=sdr_device, stop_event=stop_event)
t = threading.Thread(target=executor.run, daemon=True, name=f"tx-{transmitter.id}")
self._tx_executors[transmitter.id] = (executor, stop_event, t)
t.start()
else: else:
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping") logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
@ -504,7 +391,6 @@ class CampaignExecutor:
"""Signal the transmitter to stop. """Signal the transmitter to stop.
Calls ``<script> stop`` for external_script transmitters. Calls ``<script> stop`` for external_script transmitters.
For ``sdr_remote``, waits for the background transmit thread to finish.
""" """
if transmitter.control_method == "external_script": if transmitter.control_method == "external_script":
if not transmitter.script: if not transmitter.script:
@ -514,18 +400,6 @@ class CampaignExecutor:
except Exception as e: except Exception as e:
logger.warning(f"Script stop failed for {transmitter.id}: {e}") logger.warning(f"Script stop failed for {transmitter.id}: {e}")
elif transmitter.control_method == "sdr_remote":
ctrl = self._remote_tx_controllers.get(transmitter.id)
if ctrl is not None:
ctrl.wait_transmit(timeout=step.duration + 10.0)
elif transmitter.control_method == "sdr_agent":
entry = self._tx_executors.pop(transmitter.id, None)
if entry is not None:
_, stop_event, t = entry
stop_event.set()
t.join(timeout=step.duration + 10.0)
@staticmethod @staticmethod
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str: def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
"""Serialise step parameters to a JSON string for the control script.""" """Serialise step parameters to a JSON string for the control script."""

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Optional from typing import Optional
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from .campaign import CaptureStep from .campaign import CaptureStep
@ -15,7 +15,6 @@ def label_recording(
step: CaptureStep, step: CaptureStep,
capture_timestamp: float, capture_timestamp: float,
campaign_name: Optional[str] = None, campaign_name: Optional[str] = None,
tx_params: Optional[dict] = None,
) -> Recording: ) -> Recording:
"""Apply device identity and capture configuration labels to a recording's metadata. """Apply device identity and capture configuration labels to a recording's metadata.
@ -28,9 +27,6 @@ def label_recording(
step: The capture step that was active during this recording. step: The capture step that was active during this recording.
capture_timestamp: Unix timestamp (float) of when capture started. capture_timestamp: Unix timestamp (float) of when capture started.
campaign_name: Optional campaign name for cross-recording reference. campaign_name: Optional campaign name for cross-recording reference.
tx_params: Optional dict of transmitter signal parameters (e.g. modulation,
order, symbol_rate) written as ``ria:tx_<key>`` fields so downstream
training pipelines know what was transmitted into the recording.
Returns: Returns:
The same recording with updated metadata. The same recording with updated metadata.
@ -61,11 +57,6 @@ def label_recording(
if step.power_dbm is not None: if step.power_dbm is not None:
recording.update_metadata("tx_power_dbm", step.power_dbm) recording.update_metadata("tx_power_dbm", step.power_dbm)
# Transmitter signal parameters (e.g. from sdr_agent synthetic generation)
if tx_params:
for key, value in tx_params.items():
recording.update_metadata(f"tx_{key}", value)
return recording return recording

View File

@ -6,7 +6,7 @@ from dataclasses import dataclass, field
import numpy as np import numpy as np
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from .campaign import QAConfig from .campaign import QAConfig

View File

@ -1,299 +0,0 @@
"""TX campaign executor — synthesises and transmits signals via a local SDR.
The TxExecutor receives a transmitter config dict (matching the
``sdr_agent`` control method's schema) and a step schedule, then for each
step builds a signal chain with the block generator and transmits it via
the local SDR device.
Supported modulations (``modulation`` field in config):
BPSK, QPSK, 8PSK, 16QAM, 64QAM, 256QAM, FSK, OOK, GMSK, OQPSK
Example config dict (matches CampaignConfig transmitter with
``control_method: sdr_agent``)::
{
"id": "synthetic-tx",
"type": "sdr",
"control_method": "sdr_agent",
"sdr_agent": {
"modulation": "QPSK",
"order": 4,
"symbol_rate": 1000000,
"center_frequency": 0.0,
"filter": "rrc",
"rolloff": 0.35
},
"schedule": [
{"label": "step1", "duration": 10, "power_dbm": -10}
]
}
"""
from __future__ import annotations
import logging
import threading
from typing import Any
logger = logging.getLogger(__name__)
def _parse_hz(val: object) -> float:
"""Parse a frequency value that may be a float (Hz) or a string like '2.45GHz'."""
if isinstance(val, (int, float)):
return float(val)
s = str(val).strip()
for suffix, mult in (("GHz", 1e9), ("MHz", 1e6), ("kHz", 1e3), ("Hz", 1.0)):
if s.endswith(suffix):
return float(s[: -len(suffix)]) * mult
return float(s)
def _parse_seconds(val: object) -> float:
"""Parse a duration value that may be a float (seconds) or a string like '5s'."""
if isinstance(val, (int, float)):
return float(val)
s = str(val).strip()
return float(s[:-1]) if s.endswith("s") else float(s)
# Mapping from modulation name → (PSK/QAM order, generator_type)
# 'psk' uses PSKGenerator, 'qam' uses QAMGenerator
_MOD_TABLE: dict[str, tuple[int, str]] = {
"BPSK": (1, "psk"),
"QPSK": (2, "psk"),
"8PSK": (3, "psk"),
"16QAM": (4, "qam"),
"64QAM": (6, "qam"),
"256QAM": (8, "qam"),
}
_SPECIAL_MODS = {"FSK", "OOK", "GMSK", "OQPSK"}
# usrp-uhd-client's tx_recording() streams 2 000-sample chunks and loops the
# source buffer for the full tx_time, so only this many samples ever need to
# be in RAM regardless of step duration or sample rate.
# 50 000 complex64 samples ≈ 400 kB — enough spectral diversity for looping.
_SYNTH_BLOCK_SAMPLES = 50_000
class TxExecutor:
"""Synthesise and transmit a signal campaign via a local SDR.
Args:
config: Transmitter config dict (must have ``sdr_agent`` sub-dict with
modulation params, and ``schedule`` list of step dicts).
sdr_device: SDR device name to open in TX mode (e.g. "pluto", "usrp").
stop_event: External event that aborts the TX loop mid-step.
"""
def __init__(
self,
config: dict,
sdr_device: str = "unknown",
stop_event: threading.Event | None = None,
) -> None:
self.config = config
self.sdr_device = sdr_device
self.stop_event = stop_event or threading.Event()
self._sdr: Any = None
def run(self) -> None:
"""Execute all steps in the schedule, transmitting for each step duration."""
agent_cfg: dict = self.config.get("sdr_agent") or {}
schedule: list[dict] = self.config.get("schedule") or []
if not schedule:
logger.warning("TxExecutor: no schedule steps — nothing to transmit")
return
modulation: str = agent_cfg.get("modulation", "QPSK").upper()
symbol_rate: float = float(agent_cfg.get("symbol_rate", 1e6))
center_freq: float = _parse_hz(agent_cfg.get("center_frequency", 0.0))
filter_type: str = agent_cfg.get("filter", "rrc").lower()
rolloff: float = float(agent_cfg.get("rolloff", 0.35))
loops: int = max(1, int(self.config.get("loops", 1)))
# Upsampling factor: samples_per_symbol, fixed at 8 for SDR compatibility.
sps = 8
sample_rate = symbol_rate * sps
self._init_sdr(sample_rate, center_freq)
try:
for loop_idx in range(loops):
if self.stop_event.is_set():
break
if loops > 1:
logger.info("TX loop %d/%d", loop_idx + 1, loops)
for step in schedule:
if self.stop_event.is_set():
break
looped_step = (
{**step, "label": f"{step.get('label', 'step')}_run{loop_idx + 1:02d}"} if loops > 1 else step
)
self._execute_step(looped_step, modulation, sps, symbol_rate, filter_type, rolloff)
finally:
self._close_sdr()
def _execute_step(
self,
step: dict,
modulation: str,
sps: int,
symbol_rate: float,
filter_type: str,
rolloff: float,
) -> None:
duration: float = _parse_seconds(step.get("duration", 10.0))
label: str = step.get("label", "step")
gain: float = float(step.get("power_dbm") or 0.0)
sample_rate = symbol_rate * sps
logger.info(
"TX step '%s': %.0f s, %s @ %.3f MHz (sps=%d, filter=%s)",
label,
duration,
modulation,
symbol_rate / 1e6,
sps,
filter_type,
)
num_samples = int(duration * sample_rate)
# Synthesise a short representative block. tx_recording() loops this
# buffer for the full tx_time using a 2 000-sample streaming callback,
# so peak memory is O(_SYNTH_BLOCK_SAMPLES) regardless of duration.
block_size = min(num_samples, _SYNTH_BLOCK_SAMPLES)
signal = self._synthesise(modulation, sps, block_size, filter_type, rolloff)
if self._sdr is not None:
try:
# Apply gain update if SDR supports it
if hasattr(self._sdr, "set_tx_gain"):
self._sdr.set_tx_gain(gain)
self._sdr.tx_recording(signal, tx_time=duration)
except Exception as exc:
logger.error("TX step '%s' SDR error: %s", label, exc)
else:
# No SDR available — simulate by sleeping for the step duration.
logger.warning("TX step '%s': no SDR — simulating %.0f s delay", label, duration)
self.stop_event.wait(timeout=duration)
def _synthesise(
self,
modulation: str,
sps: int,
num_samples: int,
filter_type: str,
rolloff: float,
):
"""Build a block-generator chain and return IQ samples as a numpy array."""
try:
import numpy as np
from ria_toolkit_oss.signal.block_generator import (
BinarySource,
GMSKModulator,
Mapper,
OOKModulator,
OQPSKModulator,
RaisedCosineFilter,
RootRaisedCosineFilter,
Upsampling,
)
from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import (
FSKModulator,
)
except ImportError as exc:
raise RuntimeError(f"ria_toolkit_oss block generator not available: {exc}") from exc
# ── Special modulations with their own source-connected modulator ──
if modulation in ("OOK", "GMSK", "OQPSK"):
src = BinarySource()
if modulation == "OOK":
mod = OOKModulator(src, samples_per_symbol=sps)
elif modulation == "GMSK":
mod = GMSKModulator(src, samples_per_symbol=sps)
else:
mod = OQPSKModulator(src, samples_per_symbol=sps)
recording = mod.record(num_samples)
flat = np.asarray(recording.data).flatten().astype(np.complex64)
if len(flat) < num_samples:
flat = np.tile(flat, num_samples // len(flat) + 1)
return flat[:num_samples]
if modulation == "FSK":
symbol_rate = num_samples / sps
bits_per_sym = 1 # 2-FSK
num_bits = max(num_samples // sps, 128) * bits_per_sym
bits = BinarySource()((1, num_bits))
mod = FSKModulator(
num_bits_per_symbol=bits_per_sym,
frequency_spacing=symbol_rate * 0.5,
symbol_duration=1.0 / max(symbol_rate, 1.0),
sampling_frequency=symbol_rate * sps,
)
flat = np.asarray(mod(bits)).flatten().astype(np.complex64)
if len(flat) < num_samples:
flat = np.tile(flat, num_samples // len(flat) + 1)
return flat[:num_samples]
# ── PSK / QAM via Mapper → Upsampling → pulse filter ──────────────
if modulation not in _MOD_TABLE:
logger.warning("Unknown modulation %r — defaulting to QPSK", modulation)
modulation = "QPSK"
bits_per_sym, gen_type = _MOD_TABLE[modulation]
mod_family = "QAM" if gen_type == "qam" else "PSK"
source = BinarySource()
mapper = Mapper(constellation_type=mod_family, num_bits_per_symbol=bits_per_sym)
upsampler = Upsampling(factor=sps)
mapper.connect_input([source])
upsampler.connect_input([mapper])
if filter_type in ("rrc",):
pulse_filter = RootRaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
pulse_filter.connect_input([upsampler])
recording = pulse_filter.record(num_samples)
elif filter_type in ("rc",):
pulse_filter = RaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
pulse_filter.connect_input([upsampler])
recording = pulse_filter.record(num_samples)
else:
# "none", "rect", "gaussian" — use upsampler output directly
recording = upsampler.record(num_samples)
flat = np.asarray(recording.data).flatten().astype(np.complex64)
if len(flat) < num_samples:
flat = np.tile(flat, num_samples // len(flat) + 1)
return flat[:num_samples]
def _init_sdr(self, sample_rate: float, center_freq: float) -> None:
try:
from ria_toolkit_oss.sdr import get_sdr_device
self._sdr = get_sdr_device(self.sdr_device)
self._sdr.init_tx(
sample_rate=sample_rate,
center_frequency=center_freq,
gain=0,
channel=0,
gain_mode="manual",
)
logger.info(
"TX SDR initialised: %s @ %.3f MHz, %.1f Msps", self.sdr_device, center_freq / 1e6, sample_rate / 1e6
)
except Exception as exc:
logger.warning("TX SDR init failed (%s) — will simulate: %s", self.sdr_device, exc)
self._sdr = None
def _close_sdr(self) -> None:
if self._sdr is not None:
try:
self._sdr.close()
except Exception as exc:
logger.debug("TX SDR close error: %s", exc)
self._sdr = None

View File

@ -1,6 +0,0 @@
"""Remote SDR transmitter control via SSH + ZMQ."""
from .remote_transmitter import RemoteTransmitter
from .remote_transmitter_controller import RemoteTransmitterController
__all__ = ["RemoteTransmitter", "RemoteTransmitterController"]

View File

@ -1,152 +0,0 @@
"""Server-side ZMQ RPC receiver for SDR transmission.
Run this script on the Tx machine. The script binds a ZMQ REP socket and
waits for JSON-RPC commands from a :class:`RemoteTransmitterController`.
Requires: zmq, and ria-toolkit or utils installed for SDR support.
"""
from __future__ import annotations
import argparse
import io
import json
import logging
from contextlib import redirect_stderr, redirect_stdout
import zmq
logger = logging.getLogger(__name__)
class RemoteTransmitter:
"""Executes SDR Tx commands received over ZMQ.
Loads the appropriate SDR driver dynamically so the script can run on
machines that have only a subset of SDR libraries installed.
"""
def __init__(self) -> None:
self._sdr = None
def set_radio(self, radio_str: str, identifier: str = "") -> None:
"""Initialise the SDR radio.
Args:
radio_str: SDR type pluto | usrp | hackrf | bladerf.
identifier: Device-specific identifier (IP, serial, etc.).
"""
radio_str = radio_str.lower()
try:
if radio_str in ("pluto", "plutosdr"):
from ria_toolkit_oss.sdr.pluto import Pluto
self._sdr = Pluto(identifier)
elif radio_str in ("usrp",):
from ria_toolkit_oss.sdr.usrp import USRP
self._sdr = USRP(identifier)
elif radio_str in ("hackrf", "hackrf_one"):
from ria_toolkit_oss.sdr.hackrf import HackRF
self._sdr = HackRF(identifier)
elif radio_str in ("bladerf", "blade"):
from ria_toolkit_oss.sdr.blade import Blade
self._sdr = Blade(identifier)
else:
raise ValueError(f"Unknown SDR type: {radio_str!r}")
except ImportError as exc:
raise RuntimeError(f"SDR driver for '{radio_str}' is not installed: {exc}") from exc
def init_tx(
self,
center_frequency: float,
sample_rate: float,
gain: float,
channel: int = 0,
gain_mode: str = "absolute",
) -> None:
if self._sdr is None:
raise RuntimeError("Call set_radio() before init_tx()")
self._sdr.init_tx(
center_frequency=center_frequency,
sample_rate=sample_rate,
gain=gain,
channel=channel,
)
def transmit(self, duration_s: float) -> None:
"""Transmit a continuous wave for ``duration_s`` seconds."""
if self._sdr is None:
raise RuntimeError("Call set_radio() and init_tx() before transmit()")
import time
# Transmit in a loop until duration has elapsed
end = time.monotonic() + duration_s
while time.monotonic() < end:
try:
self._sdr.tx_cw()
except AttributeError:
time.sleep(0.01)
def stop(self) -> None:
"""Stop transmission and close the SDR."""
if self._sdr is not None:
try:
self._sdr.close()
except Exception:
pass
self._sdr = None
def run_function(self, command_dict: dict) -> dict:
"""Dispatch a JSON-RPC command and return a response dict."""
out_buf = io.StringIO()
err_buf = io.StringIO()
fn = command_dict.get("function_name", "")
try:
with redirect_stdout(out_buf), redirect_stderr(err_buf):
if fn == "set_radio":
self.set_radio(
radio_str=command_dict["radio_str"],
identifier=command_dict.get("identifier", ""),
)
elif fn == "init_tx":
self.init_tx(
center_frequency=command_dict["center_frequency"],
sample_rate=command_dict["sample_rate"],
gain=command_dict["gain"],
channel=command_dict.get("channel", 0),
gain_mode=command_dict.get("gain_mode", "absolute"),
)
elif fn == "transmit":
self.transmit(duration_s=command_dict.get("duration_s", 1.0))
elif fn == "stop":
self.stop()
else:
raise ValueError(f"Unknown function: {fn!r}")
return {"status": True, "message": out_buf.getvalue(), "error_message": err_buf.getvalue()}
except Exception as exc:
logger.exception("Error executing %s", fn)
return {"status": False, "message": out_buf.getvalue(), "error_message": str(exc)}
def _serve(port: int) -> None:
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{port}")
logger.info("RemoteTransmitter listening on port %d", port)
tx = RemoteTransmitter()
while True:
raw = socket.recv()
cmd = json.loads(raw.decode())
response = tx.run_function(cmd)
socket.send(json.dumps(response).encode())
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="SDR Tx ZMQ server")
parser.add_argument("--port", type=int, default=5556)
args = parser.parse_args()
_serve(args.port)

View File

@ -1,218 +0,0 @@
"""Client-side SSH + ZMQ controller for a remote SDR transmitter.
Run this on the Rx machine (or hub). It SSH-es into the Tx machine,
starts :mod:`remote_transmitter` there, then sends JSON-RPC commands over
ZMQ.
Requires: paramiko, zmq.
"""
from __future__ import annotations
import json
import logging
import threading
import time
import paramiko
import zmq
logger = logging.getLogger(__name__)
_STARTUP_WAIT_S = 2.0 # seconds to wait for remote ZMQ server to bind
class RemoteTransmitterController:
"""SSH into a Tx machine, start the ZMQ server, and send commands.
Args:
host: IP or hostname of the Tx machine.
ssh_user: SSH username.
ssh_key_path: Path to SSH private key file.
zmq_port: ZMQ port that the remote transmitter will bind on.
"""
def __init__(
self,
host: str,
ssh_user: str,
ssh_key_path: str,
zmq_port: int = 5556,
) -> None:
self._host = host
self._zmq_port = zmq_port
self._ssh: paramiko.SSHClient | None = None
self._ssh_stdout = None
self._context: zmq.Context | None = None
self._socket: zmq.Socket | None = None
self._tx_thread: threading.Thread | None = None
self._lock = threading.Lock()
self._connect(host, ssh_user, ssh_key_path, zmq_port)
# ------------------------------------------------------------------
# Connection management
# ------------------------------------------------------------------
def _connect(self, host: str, ssh_user: str, ssh_key_path: str, zmq_port: int) -> None:
"""Open SSH tunnel, start remote server, connect ZMQ socket."""
try:
import paramiko
except ImportError as exc:
raise RuntimeError("paramiko is required for remote SDR control: pip install paramiko") from exc
try:
import zmq
except ImportError as exc:
raise RuntimeError("pyzmq is required for remote SDR control: pip install pyzmq") from exc
logger.info("SSH connecting to %s@%s", ssh_user, host)
self._ssh = paramiko.SSHClient()
self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self._ssh.connect(hostname=host, username=ssh_user, key_filename=ssh_key_path)
cmd = f"python -m ria_toolkit_oss.remote_control.remote_transmitter --port {zmq_port}"
logger.info("Starting remote Tx server: %s", cmd)
_, self._ssh_stdout, _ = self._ssh.exec_command(cmd)
time.sleep(_STARTUP_WAIT_S)
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REQ)
self._socket.connect(f"tcp://{host}:{zmq_port}")
logger.info("ZMQ connected to tcp://%s:%d", host, zmq_port)
def close(self) -> None:
"""Tear down ZMQ and SSH connections."""
if self._socket is not None:
try:
self._socket.close(linger=0)
except Exception:
pass
self._socket = None
if self._context is not None:
try:
self._context.term()
except Exception:
pass
self._context = None
if self._ssh_stdout is not None:
try:
self._ssh_stdout.channel.close()
except Exception:
pass
self._ssh_stdout = None
if self._ssh is not None:
try:
self._ssh.close()
except Exception:
pass
self._ssh = None
logger.info("RemoteTransmitterController closed")
# ------------------------------------------------------------------
# ZMQ dispatch
# ------------------------------------------------------------------
def _send(self, command: dict) -> dict:
"""Send a JSON-RPC command and return the response dict (thread-safe)."""
with self._lock:
if self._socket is None:
raise RuntimeError("Controller is closed")
self._socket.send(json.dumps(command).encode())
raw = self._socket.recv()
reply: dict = json.loads(raw.decode())
if not reply.get("status"):
raise RuntimeError(
f"Remote command '{command.get('function_name')}' failed: "
f"{reply.get('error_message', 'unknown error')}"
)
return reply
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def set_radio(self, device_type: str, device_id: str = "") -> None:
"""Initialise the SDR radio on the Tx machine.
Args:
device_type: SDR type ``pluto``, ``usrp``, ``hackrf``, ``bladerf``.
device_id: Device-specific identifier (IP, serial, etc.).
"""
logger.info("set_radio(%s, %r)", device_type, device_id)
self._send({"function_name": "set_radio", "radio_str": device_type, "identifier": device_id})
def init_tx(
self,
center_frequency: float,
sample_rate: float,
gain: float,
channel: int = 0,
gain_mode: str = "absolute",
) -> None:
"""Configure Tx parameters on the remote SDR.
Args:
center_frequency: Center frequency in Hz.
sample_rate: Sample rate in Hz.
gain: Tx gain in dB.
channel: RF channel index (default 0).
gain_mode: ``"absolute"`` (default) or ``"relative"``.
"""
logger.info(
"init_tx: fc=%.3f MHz, fs=%.3f MHz, gain=%.1f dB, ch=%d",
center_frequency / 1e6,
sample_rate / 1e6,
gain,
channel,
)
self._send(
{
"function_name": "init_tx",
"center_frequency": center_frequency,
"sample_rate": sample_rate,
"gain": gain,
"channel": channel,
"gain_mode": gain_mode,
}
)
def transmit_async(self, duration_s: float) -> None:
"""Start a timed CW transmission in a background thread.
Returns immediately. Call :meth:`wait_transmit` after recording to
ensure the transmit thread has finished before the next step.
Args:
duration_s: Transmission duration in seconds.
"""
logger.info("transmit_async: %.1f s", duration_s)
def _run() -> None:
try:
self._send({"function_name": "transmit", "duration_s": duration_s})
except Exception as exc:
logger.warning("Background transmit error: %s", exc)
self._tx_thread = threading.Thread(target=_run, daemon=True, name="remote-tx")
self._tx_thread.start()
def wait_transmit(self, timeout: float | None = None) -> None:
"""Wait for the background transmit thread to finish.
Args:
timeout: Maximum seconds to wait. ``None`` = wait indefinitely.
"""
if self._tx_thread is not None:
self._tx_thread.join(timeout=timeout)
self._tx_thread = None
def stop(self) -> None:
"""Stop transmission and release the remote SDR, then close connections."""
logger.info("Sending stop to remote Tx")
try:
self._send({"function_name": "stop"})
except Exception as exc:
logger.warning("stop command error (may be normal if connection closed): %s", exc)
finally:
self.close()

View File

@ -15,13 +15,8 @@ __all__ = [
] ]
from .mock import MockSDR from .mock import MockSDR
from .sdr import ( # noqa: F401 from .sdr import SDR, SDRError, SdrDisconnectedError, SDRParameterError, translate_disconnect # noqa: F401
SDR,
SdrDisconnectedError,
SDRError,
SDRParameterError,
translate_disconnect,
)
_DRIVER_CANDIDATES: tuple[tuple[str, str, str], ...] = ( _DRIVER_CANDIDATES: tuple[tuple[str, str, str], ...] = (
("mock", "ria_toolkit_oss.sdr.mock", "MockSDR"), ("mock", "ria_toolkit_oss.sdr.mock", "MockSDR"),

View File

@ -5,7 +5,7 @@ from typing import Optional
import numpy as np import numpy as np
from bladerf import _bladerf from bladerf import _bladerf
from ria_toolkit_oss.data import Recording from ria_toolkit_oss.datatypes import Recording
from ria_toolkit_oss.sdr import SDR, SDRError, SDRParameterError from ria_toolkit_oss.sdr import SDR, SDRError, SDRParameterError

View File

@ -4,7 +4,7 @@ from typing import Optional
import numpy as np import numpy as np
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.sdr._external.libhackrf import HackRF as hrf from ria_toolkit_oss.sdr._external.libhackrf import HackRF as hrf
from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError

View File

@ -7,13 +7,8 @@ from typing import Optional
import adi import adi
import numpy as np import numpy as np
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.sdr.sdr import ( from ria_toolkit_oss.sdr.sdr import SDR, SDRError, SDRParameterError, translate_disconnect
SDR,
SDRError,
SDRParameterError,
translate_disconnect,
)
class Pluto(SDR): class Pluto(SDR):
@ -389,10 +384,7 @@ class Pluto(SDR):
self._enable_tx = True self._enable_tx = True
while self._enable_tx is True: while self._enable_tx is True:
buffer = self._convert_tx_samples(callback(self.tx_buffer_size)) buffer = self._convert_tx_samples(callback(self.tx_buffer_size))
# pyadi-iio's ``radio.tx`` auto-wraps single-channel 1-D input. self.radio.tx(buffer[0])
# Indexing ``buffer[0]`` was a latent bug for callbacks that
# returned 1-D samples (scalar → TypeError inside pyadi).
self.radio.tx(buffer)
def set_rx_center_frequency(self, center_frequency): def set_rx_center_frequency(self, center_frequency):
""" """
@ -522,11 +514,6 @@ class Pluto(SDR):
raise SDRError(e) raise SDRError(e)
def set_tx_center_frequency(self, center_frequency): def set_tx_center_frequency(self, center_frequency):
# ``adi.Pluto`` exposes one radio handle shared between RX and TX; concurrent
# RX + TX sessions (see the agent ``_SdrRegistry``) may call RX and TX
# setters at the same time. Serialize with ``_param_lock`` — RX setters hold
# the same reentrant lock — so native attribute writes don't interleave.
with self._param_lock:
if center_frequency < 70e6 or center_frequency > 6e9: if center_frequency < 70e6 or center_frequency > 6e9:
raise SDRParameterError( raise SDRParameterError(
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz " f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
@ -547,10 +534,6 @@ class Pluto(SDR):
) )
def set_tx_sample_rate(self, sample_rate): def set_tx_sample_rate(self, sample_rate):
# ``self.radio.sample_rate`` is shared between RX and TX on Pluto — RX's
# ``set_rx_sample_rate`` writes the same native attribute. Hold ``_param_lock``
# so full-duplex sessions can't interleave writes.
with self._param_lock:
min_rate, max_rate = 65.1e3, 61.44e6 min_rate, max_rate = 65.1e3, 61.44e6
if sample_rate < min_rate or sample_rate > max_rate: if sample_rate < min_rate or sample_rate > max_rate:
raise SDRParameterError( raise SDRParameterError(
@ -570,8 +553,6 @@ class Pluto(SDR):
) )
def set_tx_gain(self, gain, channel=0, gain_mode="absolute"): def set_tx_gain(self, gain, channel=0, gain_mode="absolute"):
# Serialize with RX setters: see ``set_tx_sample_rate`` above.
with self._param_lock:
tx_gain_min = -89 tx_gain_min = -89
tx_gain_max = 0 tx_gain_max = 0

View File

@ -11,7 +11,7 @@ try:
except ImportError as exc: # pragma: no cover - dependency provided by end user except ImportError as exc: # pragma: no cover - dependency provided by end user
raise ImportError("pyrtlsdr is required to use the RTLSDR class") from exc raise ImportError("pyrtlsdr is required to use the RTLSDR class") from exc
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError

View File

@ -8,7 +8,7 @@ from typing import Optional
import numpy as np import numpy as np
import zmq import zmq
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
class SDR(ABC): class SDR(ABC):
@ -43,13 +43,6 @@ class SDR(ABC):
self.tx_gain = None self.tx_gain = None
self._param_lock = threading.RLock() # Reentrant lock self._param_lock = threading.RLock() # Reentrant lock
# Pending config consumed by rx() on first call and by _apply_sdr_config
# in the agent inference loop. Subclasses that need different defaults
# (e.g. MockSDR) can overwrite these in their own __init__.
self.center_freq: float = 2.4e9
self.sample_rate: float = 10e6
self.gain: float = 40.0
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording: def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
""" """
Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided. Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided.
@ -107,32 +100,6 @@ class SDR(ABC):
self._num_buffers_processed = 0 self._num_buffers_processed = 0
return recording return recording
def rx(self, num_samples: int) -> "np.ndarray":
"""Return *num_samples* complex IQ samples as a 1-D complex64 array.
This is the interface used by the agent inference loop. On first call,
``init_rx()`` is invoked automatically using the values stored in
``center_freq``, ``sample_rate``, and ``gain`` (set beforehand by
``_apply_sdr_config``). Subsequent calls stream directly.
Subclasses may override this for hardware-native capture APIs (e.g.
``MockSDR`` uses AWGN generation; ``PlutoSDR`` could use
``self.radio.rx()``).
"""
if not self._rx_initialized:
gain = self.gain if isinstance(self.gain, (int, float)) else 40.0
self.init_rx(
sample_rate=self.sample_rate,
center_frequency=self.center_freq,
gain=gain,
channel=0,
)
recording = self.record(num_samples=num_samples)
# Recording.data is either a list of 1-D arrays (one per channel) or a
# 2-D ndarray (channels × samples). Either way, index 0 is channel 0.
data = recording.data
return data[0] if hasattr(data, "__getitem__") else data
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000): def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
""" """
Stream iq samples as interleaved bytes via zmq. Stream iq samples as interleaved bytes via zmq.

View File

@ -6,7 +6,7 @@ from typing import Optional
import numpy as np import numpy as np
import uhd import uhd
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError

View File

@ -3,7 +3,7 @@
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
from .auth import require_api_key from .auth import require_api_key
from .routers import conductor, inference from .routers import inference, orchestrator
def create_app(api_key: str = "") -> FastAPI: def create_app(api_key: str = "") -> FastAPI:
@ -28,9 +28,9 @@ def create_app(api_key: str = "") -> FastAPI:
app.state.api_key = api_key app.state.api_key = api_key
app.include_router( app.include_router(
conductor.router, orchestrator.router,
prefix="/conductor", prefix="/orchestrator",
tags=["Conductor"], tags=["Orchestrator"],
dependencies=[Depends(require_api_key)], dependencies=[Depends(require_api_key)],
) )
app.include_router( app.include_router(

View File

@ -7,7 +7,7 @@ from pathlib import Path
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Conductor # Orchestrator
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@ -1,4 +1,4 @@
"""Conductor routes: campaign deployment, status, and cancellation.""" """Orchestrator routes: campaign deployment, status, and cancellation."""
from __future__ import annotations from __future__ import annotations

View File

@ -11,7 +11,7 @@ from scipy.signal import butter
from scipy.signal import chirp as sci_chirp from scipy.signal import chirp as sci_chirp
from scipy.signal import hilbert, lfilter from scipy.signal import hilbert, lfilter
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
def sine( def sine(

View File

@ -1,4 +1,4 @@
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.signal.block_generator.generators.signal_generator import ( from ria_toolkit_oss.signal.block_generator.generators.signal_generator import (
SignalGenerator, SignalGenerator,
) )

View File

@ -1,4 +1,4 @@
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.signal.block_generator.generators.signal_generator import ( from ria_toolkit_oss.signal.block_generator.generators.signal_generator import (
SignalGenerator, SignalGenerator,
) )

View File

@ -1,4 +1,4 @@
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.signal.block_generator.generators.signal_generator import ( from ria_toolkit_oss.signal.block_generator.generators.signal_generator import (
SignalGenerator, SignalGenerator,
) )

View File

@ -1,4 +1,4 @@
from ria_toolkit_oss.data import Recording from ria_toolkit_oss.datatypes import Recording
from ria_toolkit_oss.signal import Recordable from ria_toolkit_oss.signal import Recordable
from ria_toolkit_oss.signal.block_generator.block import Block from ria_toolkit_oss.signal.block_generator.block import Block

View File

@ -4,7 +4,7 @@ from datetime import datetime
import click import click
import numpy as np import numpy as np
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.signal.block_generator.mapping.mapper import Mapper from ria_toolkit_oss.signal.block_generator.mapping.mapper import Mapper
from ria_toolkit_oss.signal.block_generator.multirate.upsampling import Upsampling from ria_toolkit_oss.signal.block_generator.multirate.upsampling import Upsampling
from ria_toolkit_oss.signal.block_generator.pulse_shaping.raised_cosine_filter import ( from ria_toolkit_oss.signal.block_generator.pulse_shaping.raised_cosine_filter import (

View File

@ -1,4 +1,4 @@
from ria_toolkit_oss.data import Recording from ria_toolkit_oss.datatypes import Recording
from ria_toolkit_oss.signal.block_generator.data_types import DataType from ria_toolkit_oss.signal.block_generator.data_types import DataType
from ria_toolkit_oss.signal.block_generator.recordable_block import RecordableBlock from ria_toolkit_oss.signal.block_generator.recordable_block import RecordableBlock
from ria_toolkit_oss.signal.block_generator.source_block import SourceBlock from ria_toolkit_oss.signal.block_generator.source_block import SourceBlock

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ria_toolkit_oss.data import Recording from ria_toolkit_oss.datatypes import Recording
class Recordable(ABC): class Recordable(ABC):

View File

@ -11,7 +11,7 @@ from typing import Optional
import numpy as np import numpy as np
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.utils.array_conversion import convert_to_2xn from ria_toolkit_oss.utils.array_conversion import convert_to_2xn
# TODO: For round 2 of index generation, should j be at min 2 spots away from where it was to prevent adjacent patches. # TODO: For round 2 of index generation, should j be at min 2 spots away from where it was to prevent adjacent patches.
@ -29,7 +29,7 @@ def generate_awgn(signal: ArrayLike | Recording, snr: Optional[float] = 1) -> np
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:param snr: The signal-to-noise ratio in dB. Default is 1. :param snr: The signal-to-noise ratio in dB. Default is 1.
:type snr: float, optional :type snr: float, optional
@ -37,7 +37,7 @@ def generate_awgn(signal: ArrayLike | Recording, snr: Optional[float] = 1) -> np
:return: A numpy array representing the generated noise which matches the SNR of `signal`. If `signal` is a :return: A numpy array representing the generated noise which matches the SNR of `signal`. If `signal` is a
Recording, returns a Recording object with its `data` attribute containing the generated noise array. Recording, returns a Recording object with its `data` attribute containing the generated noise array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[2 + 5j, 1 + 8j]]) >>> rec = Recording(data=[[2 + 5j, 1 + 8j]])
>>> new_rec = generate_awgn(rec) >>> new_rec = generate_awgn(rec)
@ -80,14 +80,14 @@ def time_reversal(signal: ArrayLike | Recording) -> np.ndarray | Recording:
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:raises ValueError: If `signal` is not CxN complex. :raises ValueError: If `signal` is not CxN complex.
:return: A numpy array containing the reversed I and Q data samples if `signal` is an array. :return: A numpy array containing the reversed I and Q data samples if `signal` is an array.
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the
reversed array. reversed array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[1+2j, 3+4j, 5+6j]]) >>> rec = Recording(data=[[1+2j, 3+4j, 5+6j]])
>>> new_rec = time_reversal(rec) >>> new_rec = time_reversal(rec)
@ -123,14 +123,14 @@ def spectral_inversion(signal: ArrayLike | Recording) -> np.ndarray | Recording:
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:raises ValueError: If `signal` is not CxN complex. :raises ValueError: If `signal` is not CxN complex.
:return: A numpy array containing the original I and negated Q data samples if `signal` is an array. :return: A numpy array containing the original I and negated Q data samples if `signal` is an array.
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the
inverted array. inverted array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[0+45j, 2-10j]]) >>> rec = Recording(data=[[0+45j, 2-10j]])
>>> new_rec = spectral_inversion(rec) >>> new_rec = spectral_inversion(rec)
@ -165,14 +165,14 @@ def channel_swap(signal: ArrayLike | Recording) -> np.ndarray | Recording:
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:raises ValueError: If `signal` is not CxN complex. :raises ValueError: If `signal` is not CxN complex.
:return: A numpy array containing the swapped I and Q data samples if `signal` is an array. :return: A numpy array containing the swapped I and Q data samples if `signal` is an array.
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the
swapped array. swapped array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[10+20j, 7+35j]]) >>> rec = Recording(data=[[10+20j, 7+35j]])
>>> new_rec = channel_swap(rec) >>> new_rec = channel_swap(rec)
@ -207,14 +207,14 @@ def amplitude_reversal(signal: ArrayLike | Recording) -> np.ndarray | Recording:
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:raises ValueError: If `signal` is not CxN complex. :raises ValueError: If `signal` is not CxN complex.
:return: A numpy array containing the negated I and Q data samples if `signal` is an array. :return: A numpy array containing the negated I and Q data samples if `signal` is an array.
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the
negated array. negated array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[4-3j, -5-2j, -9+1j]]) >>> rec = Recording(data=[[4-3j, -5-2j, -9+1j]])
>>> new_rec = amplitude_reversal(rec) >>> new_rec = amplitude_reversal(rec)
@ -253,7 +253,7 @@ def drop_samples( # noqa: C901 # TODO: Simplify function
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:param max_section_size: Maximum allowable size of the section to be dropped and replaced. Default is 2. :param max_section_size: Maximum allowable size of the section to be dropped and replaced. Default is 2.
:type max_section_size: int, optional :type max_section_size: int, optional
:param fill_type: Fill option used to replace dropped section of data (back-fill, front-fill, mean, zeros). :param fill_type: Fill option used to replace dropped section of data (back-fill, front-fill, mean, zeros).
@ -275,7 +275,7 @@ def drop_samples( # noqa: C901 # TODO: Simplify function
:return: A numpy array containing the I and Q data samples with replaced subsections if :return: A numpy array containing the I and Q data samples with replaced subsections if
`signal` is an array. If `signal` is a `Recording`, returns a `Recording` object with its `data` `signal` is an array. If `signal` is a `Recording`, returns a `Recording` object with its `data`
attribute containing the array with dropped samples. attribute containing the array with dropped samples.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]]) >>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
>>> new_rec = drop_samples(rec) >>> new_rec = drop_samples(rec)
@ -346,7 +346,7 @@ def quantize_tape(
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:param bin_number: The number of bins the signal should be divided into. Default is 4. :param bin_number: The number of bins the signal should be divided into. Default is 4.
:type bin_number: int, optional :type bin_number: int, optional
:param rounding_type: The type of rounding applied during processing. Default is "floor". :param rounding_type: The type of rounding applied during processing. Default is "floor".
@ -362,7 +362,7 @@ def quantize_tape(
:return: A numpy array containing the quantized I and Q data samples if `signal` is an array. :return: A numpy array containing the quantized I and Q data samples if `signal` is an array.
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing
the quantized array. the quantized array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[1+1j, 4+4j, 1+2j, 1+4j]]) >>> rec = Recording(data=[[1+1j, 4+4j, 1+2j, 1+4j]])
>>> new_rec = quantize_tape(rec) >>> new_rec = quantize_tape(rec)
@ -421,7 +421,7 @@ def quantize_parts(
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:param max_section_size: Maximum allowable size of the section to be quantized. Default is 2. :param max_section_size: Maximum allowable size of the section to be quantized. Default is 2.
:type max_section_size: int, optional :type max_section_size: int, optional
:param bin_number: The number of bins the signal should be divided into. Default is 4. :param bin_number: The number of bins the signal should be divided into. Default is 4.
@ -439,7 +439,7 @@ def quantize_parts(
:return: A numpy array containing the I and Q data samples with quantized subsections if `signal` :return: A numpy array containing the I and Q data samples with quantized subsections if `signal`
is an array. If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute is an array. If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute
containing the partially quantized array. containing the partially quantized array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]]) >>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
>>> new_rec = quantize_parts(rec) >>> new_rec = quantize_parts(rec)
@ -510,7 +510,7 @@ def magnitude_rescale(
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:param starting_bounds: The bounds (inclusive) as indices in which the starting position of the rescaling occurs. :param starting_bounds: The bounds (inclusive) as indices in which the starting position of the rescaling occurs.
Default is None, but if user does not assign any bounds, the bounds become (random index, N-1). Default is None, but if user does not assign any bounds, the bounds become (random index, N-1).
:type starting_bounds: tuple, optional :type starting_bounds: tuple, optional
@ -522,7 +522,7 @@ def magnitude_rescale(
:return: A numpy array containing the I and Q data samples with the rescaled magnitude after the random :return: A numpy array containing the I and Q data samples with the rescaled magnitude after the random
starting point if `signal` is an array. If `signal` is a `Recording`, returns a `Recording` starting point if `signal` is an array. If `signal` is a `Recording`, returns a `Recording`
object with its `data` attribute containing the rescaled array. object with its `data` attribute containing the rescaled array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]]) >>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
>>> new_rec = magniute_rescale(rec) >>> new_rec = magniute_rescale(rec)
@ -571,7 +571,7 @@ def cut_out( # noqa: C901 # TODO: Simplify function
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:param max_section_size: Maximum allowable size of the section to be quantized. Default is 3. :param max_section_size: Maximum allowable size of the section to be quantized. Default is 3.
:type max_section_size: int, optional :type max_section_size: int, optional
:param fill_type: Fill option used to replace cutout section of data (zeros, ones, low-snr, avg-snr-1, avg-snr-2). :param fill_type: Fill option used to replace cutout section of data (zeros, ones, low-snr, avg-snr-1, avg-snr-2).
@ -596,7 +596,7 @@ def cut_out( # noqa: C901 # TODO: Simplify function
:return: A numpy array containing the I and Q data samples with random sections cut out and replaced according to :return: A numpy array containing the I and Q data samples with random sections cut out and replaced according to
`fill_type` if `signal` is an array. If `signal` is a `Recording`, returns a `Recording` object `fill_type` if `signal` is an array. If `signal` is a `Recording`, returns a `Recording` object
with its `data` attribute containing the cut out and replaced array. with its `data` attribute containing the cut out and replaced array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]]) >>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
>>> new_rec = cut_out(rec) >>> new_rec = cut_out(rec)
@ -666,7 +666,7 @@ def patch_shuffle(signal: ArrayLike | Recording, max_patch_size: Optional[int] =
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:param max_patch_size: Maximum allowable patch size of the data that can be shuffled. Default is 3. :param max_patch_size: Maximum allowable patch size of the data that can be shuffled. Default is 3.
:type max_patch_size: int, optional :type max_patch_size: int, optional
@ -676,7 +676,7 @@ def patch_shuffle(signal: ArrayLike | Recording, max_patch_size: Optional[int] =
:return: A numpy array containing the I and Q data samples with randomly shuffled regions if `signal` is :return: A numpy array containing the I and Q data samples with randomly shuffled regions if `signal` is
an array. If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing an array. If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing
the shuffled array. the shuffled array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]]) >>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
>>> new_rec = patch_shuffle(rec) >>> new_rec = patch_shuffle(rec)

View File

@ -16,7 +16,7 @@ import numpy as np
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from scipy.signal import resample_poly from scipy.signal import resample_poly
from ria_toolkit_oss.data import Recording from ria_toolkit_oss.datatypes import Recording
from ria_toolkit_oss.transforms import iq_augmentations from ria_toolkit_oss.transforms import iq_augmentations
@ -31,7 +31,7 @@ def add_awgn_to_signal(signal: ArrayLike | Recording, snr: Optional[float] = 1)
:param signal: Input IQ data as a complex ``C x N`` array or `Recording`, where ``C`` is the number of channels :param signal: Input IQ data as a complex ``C x N`` array or `Recording`, where ``C`` is the number of channels
and ``N`` is the length of the IQ examples. and ``N`` is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:param snr: The signal-to-noise ratio in dB. Default is 1. :param snr: The signal-to-noise ratio in dB. Default is 1.
:type snr: float, optional :type snr: float, optional
@ -39,7 +39,7 @@ def add_awgn_to_signal(signal: ArrayLike | Recording, snr: Optional[float] = 1)
:return: A numpy array which is the sum of the noise (which matches the SNR) and the original signal. If `signal` :return: A numpy array which is the sum of the noise (which matches the SNR) and the original signal. If `signal`
is a `Recording`, returns a `Recording object` with its `data` attribute containing the noisy signal array. is a `Recording`, returns a `Recording object` with its `data` attribute containing the noisy signal array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[1+1j, 2+2j]]) >>> rec = Recording(data=[[1+1j, 2+2j]])
>>> new_rec = add_awgn_to_signal(rec) >>> new_rec = add_awgn_to_signal(rec)
@ -71,7 +71,7 @@ def time_shift(signal: ArrayLike | Recording, shift: Optional[int] = 1) -> np.nd
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:param shift: The number of indices to shift by. Default is 1. :param shift: The number of indices to shift by. Default is 1.
:type shift: int, optional :type shift: int, optional
@ -80,7 +80,7 @@ def time_shift(signal: ArrayLike | Recording, shift: Optional[int] = 1) -> np.nd
:return: A numpy array which represents the time-shifted signal. If `signal` is a `Recording`, :return: A numpy array which represents the time-shifted signal. If `signal` is a `Recording`,
returns a `Recording object` with its `data` attribute containing the time-shifted array. returns a `Recording object` with its `data` attribute containing the time-shifted array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j, 5+5j]]) >>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j, 5+5j]])
>>> new_rec = time_shift(rec, -2) >>> new_rec = time_shift(rec, -2)
@ -134,7 +134,7 @@ def frequency_shift(signal: ArrayLike | Recording, shift: Optional[float] = 0.5)
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:param shift: The frequency shift relative to the sample rate. Must be in the range ``[-0.5, 0.5]``. :param shift: The frequency shift relative to the sample rate. Must be in the range ``[-0.5, 0.5]``.
Default is 0.5. Default is 0.5.
:type shift: float, optional :type shift: float, optional
@ -144,7 +144,7 @@ def frequency_shift(signal: ArrayLike | Recording, shift: Optional[float] = 0.5)
:return: A numpy array which represents the frequency-shifted signal. If `signal` is a `Recording`, :return: A numpy array which represents the frequency-shifted signal. If `signal` is a `Recording`,
returns a `Recording object` with its `data` attribute containing the frequency-shifted array. returns a `Recording object` with its `data` attribute containing the frequency-shifted array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]]) >>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]])
>>> new_rec = frequency_shift(rec, -0.4) >>> new_rec = frequency_shift(rec, -0.4)
@ -189,7 +189,7 @@ def phase_shift(signal: ArrayLike | Recording, phase: Optional[float] = np.pi) -
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:param phase: The phase angle by which to rotate the IQ samples, in radians. Must be in the range ``[-π, π]``. :param phase: The phase angle by which to rotate the IQ samples, in radians. Must be in the range ``[-π, π]``.
Default is π. Default is π.
:type phase: float, optional :type phase: float, optional
@ -199,7 +199,7 @@ def phase_shift(signal: ArrayLike | Recording, phase: Optional[float] = np.pi) -
:return: A numpy array which represents the phase-shifted signal. If `signal` is a `Recording`, :return: A numpy array which represents the phase-shifted signal. If `signal` is a `Recording`,
returns a `Recording object` with its `data` attribute containing the phase-shifted array. returns a `Recording object` with its `data` attribute containing the phase-shifted array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]]) >>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]])
>>> new_rec = phase_shift(rec, np.pi/2) >>> new_rec = phase_shift(rec, np.pi/2)
@ -246,7 +246,7 @@ def iq_imbalance(
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:param amplitude_imbalance: The IQ amplitude imbalance to apply, in dB. Default is 1.5. :param amplitude_imbalance: The IQ amplitude imbalance to apply, in dB. Default is 1.5.
:type amplitude_imbalance: float, optional :type amplitude_imbalance: float, optional
:param phase_imbalance: The IQ phase imbalance to apply, in radians. Default is π. :param phase_imbalance: The IQ phase imbalance to apply, in radians. Default is π.
@ -260,7 +260,7 @@ def iq_imbalance(
:return: A numpy array which is the original signal with an applied IQ imbalance. If `signal` is a `Recording`, :return: A numpy array which is the original signal with an applied IQ imbalance. If `signal` is a `Recording`,
returns a `Recording object` with its `data` attribute containing the IQ imbalanced signal array. returns a `Recording object` with its `data` attribute containing the IQ imbalanced signal array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[2+18j, -34+2j, 3+9j]]) >>> rec = Recording(data=[[2+18j, -34+2j, 3+9j]])
>>> new_rec = iq_imbalance(rec, 1, np.pi, 2) >>> new_rec = iq_imbalance(rec, 1, np.pi, 2)
@ -315,7 +315,7 @@ def resample(signal: ArrayLike | Recording, up: Optional[int] = 4, down: Optiona
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N :param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
is the length of the IQ examples. is the length of the IQ examples.
:type signal: array_like or ria_toolkit_oss.data.Recording :type signal: array_like or ria_toolkit_oss.datatypes.Recording
:param up: The upsampling factor. Default is 4. :param up: The upsampling factor. Default is 4.
:type up: int, optional :type up: int, optional
:param down: The downsampling factor. Default is 2. :param down: The downsampling factor. Default is 2.
@ -325,7 +325,7 @@ def resample(signal: ArrayLike | Recording, up: Optional[int] = 4, down: Optiona
:return: A numpy array which represents the resampled signal If `signal` is a `Recording`, :return: A numpy array which represents the resampled signal If `signal` is a `Recording`,
returns a `Recording object` with its `data` attribute containing the resampled array. returns a `Recording object` with its `data` attribute containing the resampled array.
:rtype: np.ndarray or ria_toolkit_oss.data.Recording :rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
>>> rec = Recording(data=[[1+1j, 2+2j]]) >>> rec = Recording(data=[[1+1j, 2+2j]])
>>> new_rec = resample(rec, 2, 1) >>> new_rec = resample(rec, 2, 1)

View File

@ -4,14 +4,14 @@ import scipy.signal as signal
from plotly.graph_objs import Figure from plotly.graph_objs import Figure
from scipy.fft import fft, fftshift from scipy.fft import fft, fftshift
from ria_toolkit_oss.data import Recording from ria_toolkit_oss.datatypes import Recording
def spectrogram(rec: Recording, thumbnail: bool = False) -> Figure: def spectrogram(rec: Recording, thumbnail: bool = False) -> Figure:
"""Create a spectrogram for the recording. """Create a spectrogram for the recording.
:param rec: Signal to plot. :param rec: Signal to plot.
:type rec: ria_toolkit_oss.data.Recording :type rec: utils.data.Recording
:param thumbnail: Whether to return a small thumbnail version or full plot. :param thumbnail: Whether to return a small thumbnail version or full plot.
:type thumbnail: bool :type thumbnail: bool
@ -95,7 +95,7 @@ def iq_time_series(rec: Recording) -> Figure:
"""Create a time series plot of the real and imaginary parts of signal. """Create a time series plot of the real and imaginary parts of signal.
:param rec: Signal to plot. :param rec: Signal to plot.
:type rec: ria_toolkit_oss.data.Recording :type rec: utils.data.Recording
:return: Time series plot as a Plotly figure. :return: Time series plot as a Plotly figure.
""" """
@ -125,7 +125,7 @@ def frequency_spectrum(rec: Recording) -> Figure:
"""Create a frequency spectrum plot from the recording. """Create a frequency spectrum plot from the recording.
:param rec: Input signal to plot. :param rec: Input signal to plot.
:type rec: ria_toolkit_oss.data.Recording :type rec: utils.data.Recording
:return: Frequency spectrum as a Plotly figure. :return: Frequency spectrum as a Plotly figure.
""" """
@ -160,7 +160,7 @@ def constellation(rec: Recording) -> Figure:
"""Create a constellation plot from the recording. """Create a constellation plot from the recording.
:param rec: Input signal to plot. :param rec: Input signal to plot.
:type rec: ria_toolkit_oss.data.Recording :type rec: utils.data.Recording
:return: Constellation as a Plotly figure. :return: Constellation as a Plotly figure.
""" """

View File

@ -6,13 +6,12 @@ from typing import Optional
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from matplotlib import gridspec from matplotlib import gridspec
from matplotlib.patches import Patch
from PIL import Image from PIL import Image
from scipy.fft import fft, fftshift from scipy.fft import fft, fftshift
from scipy.signal import spectrogram from scipy.signal import spectrogram
from scipy.signal.windows import hann from scipy.signal.windows import hann
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.view.tools import ( from ria_toolkit_oss.view.tools import (
COLORS, COLORS,
decimate, decimate,
@ -40,76 +39,6 @@ def set_spines(ax, spines):
ax.spines["left"].set_visible(False) ax.spines["left"].set_visible(False)
def view_annotations(
recording: Recording,
channel: Optional[int] = 0,
output_path: Optional[str] = "images/annotations.png",
title: Optional[str] = "Annotated Spectrogram",
dpi: Optional[int] = 300,
title_fontsize: Optional[int] = 15,
dark: Optional[bool] = True,
) -> None:
# 1. Setup Plotting Environment
plt.close("all")
if dark:
plt.style.use("dark_background")
else:
plt.style.use("default")
fig, ax = plt.subplots(figsize=(12, 8))
complex_signal = recording.data[channel]
sample_rate, center_frequency, _ = extract_metadata_fields(recording.metadata)
annotations = recording.annotations
# 2. Setup Color Mapping
palette = ["#2196F3", "#9C27B0", "#64B5F6", "#7B1FA2", "#5C6BC0", "#CE93D8", "#1565C0", "#7C4DFF"]
unique_labels = sorted(list(set(ann.label for ann in annotations if ann.label)))
label_to_color = {label: palette[i % len(palette)] for i, label in enumerate(unique_labels)}
# 3. Generate Spectrogram
Pxx, freqs, times, im = ax.specgram(
complex_signal, NFFT=256, Fs=sample_rate, Fc=center_frequency, noverlap=128, cmap="twilight"
)
# 4. Draw Annotations (highest threshold % first so lower % renders on top)
def _threshold_sort_key(ann):
try:
return int(ann.label.rstrip("%"))
except (ValueError, AttributeError):
return 0
for annotation in sorted(annotations, key=_threshold_sort_key, reverse=True):
t_start = annotation.sample_start / sample_rate
t_width = annotation.sample_count / sample_rate
f_start = annotation.freq_lower_edge
f_height = annotation.freq_upper_edge - annotation.freq_lower_edge
ann_color = label_to_color.get(annotation.label, "gray")
rect = plt.Rectangle(
(t_start, f_start), t_width, f_height, linewidth=1.5, edgecolor=ann_color, facecolor="none", alpha=0.8
)
ax.add_patch(rect)
if unique_labels:
legend_elements = [
Patch(facecolor=label_to_color[label], alpha=0.3, edgecolor=label_to_color[label], label=label)
for label in unique_labels
]
ax.legend(handles=legend_elements, loc="upper right", framealpha=0.2)
ax.set_title(title, fontsize=title_fontsize, pad=20)
ax.set_xlabel("Time (s)", fontsize=12)
ax.set_ylabel("Frequency (MHz)", fontsize=12)
ax.grid(alpha=0.1)
output_path, _ = set_path(output_path=output_path)
plt.savefig(output_path, dpi=dpi, bbox_inches="tight")
plt.close(fig)
print(f"Professional annotation plot saved to {output_path}")
def view_channels( def view_channels(
recording: Recording, recording: Recording,
output_path: Optional[str] = "images/signal.png", output_path: Optional[str] = "images/signal.png",
@ -280,7 +209,9 @@ def view_sig(
) )
set_spines(spec_ax, spines) set_spines(spec_ax, spines)
spec_ax.set_title("Spectrogram", loc="center", fontsize=subtitle_fontsize) spec_ax.set_title("Spectrogram", fontsize=subtitle_fontsize)
spec_ax.set_ylabel("Frequency (Hz)")
spec_ax.set_xlabel("Time (s)")
if iq: if iq:
iq_ax = plt.subplot(gs[plot_y_indx : plot_y_indx + 2, :]) iq_ax = plt.subplot(gs[plot_y_indx : plot_y_indx + 2, :])
@ -364,11 +295,7 @@ def view_sig(
set_spines(meta_ax, spines) set_spines(meta_ax, spines)
if logo and os.path.isfile(logo_path): if logo and os.path.isfile(logo_path):
# logo_ax = plt.subplot(gs[plot_y_indx:, 2]) logo_ax = plt.subplot(gs[plot_y_indx + 2 :, 2])
logo_pos = [0.75, 0.05, 0.2, 0.08]
logo_ax = fig.add_axes(logo_pos, anchor="SE", zorder=10)
plot_x_indx = plot_x_indx + 1
logo_ax.axis("off") logo_ax.axis("off")
try: try:
@ -387,6 +314,7 @@ def view_sig(
hspace=2.5, # Vertical space between subplots hspace=2.5, # Vertical space between subplots
) )
# save path handling
output_path, _ = set_path(output_path=output_path) output_path, _ = set_path(output_path=output_path)
plt.savefig(output_path, dpi=dpi) plt.savefig(output_path, dpi=dpi)
print(f"Saved signal plot to {output_path}") print(f"Saved signal plot to {output_path}")

View File

@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import gc import gc
import json
from typing import Optional from typing import Optional
import matplotlib import matplotlib
@ -12,7 +11,7 @@ import numpy as np
from scipy.fft import fft, fftshift from scipy.fft import fft, fftshift
from scipy.signal.windows import hann from scipy.signal.windows import hann
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.view.tools import ( from ria_toolkit_oss.view.tools import (
COLORS, COLORS,
decimate, decimate,
@ -21,52 +20,6 @@ from ria_toolkit_oss.view.tools import (
) )
def _add_annotations(annotations, compact_mode, show_labels, sample_rate_hz, center_freq_hz, ax2):
if annotations and not compact_mode:
for annotation in annotations:
start_idx = annotation.get("core:sample_start", 0)
length = annotation.get("core:sample_count", 0)
start_time = start_idx / sample_rate_hz
end_time = (start_idx + length) / sample_rate_hz
freq_low = annotation.get("core:freq_lower_edge", center_freq_hz - sample_rate_hz / 4)
freq_high = annotation.get("core:freq_upper_edge", center_freq_hz + sample_rate_hz / 4)
comment = annotation.get("core:comment", "{}")
try:
comment_data = json.loads(comment) if isinstance(comment, str) else comment
ann_type = comment_data.get("type", "unknown")
if ann_type == "intersection":
color = COLORS["success"]
elif ann_type == "parallel":
color = COLORS["primary"]
elif ann_type == "standalone":
color = COLORS["warning"]
else:
color = COLORS["error"]
except Exception:
color = COLORS["error"]
rect = plt.Rectangle(
(start_time, freq_low),
end_time - start_time,
freq_high - freq_low,
color=color,
alpha=0.4,
linewidth=2,
)
ax2.add_patch(rect)
if show_labels:
label = annotation.get("core:label", "Signal")
ax2.text(
start_time,
freq_high,
label,
color=COLORS["light"],
fontsize=10,
bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7),
)
def _get_nfft_size(signal, fast_mode): def _get_nfft_size(signal, fast_mode):
if len(signal) < 1000: if len(signal) < 1000:
nfft = 128 nfft = 128
@ -185,7 +138,6 @@ def detect_constellation_symbols(signal: np.ndarray, method: str = "differential
def view_simple_sig( def view_simple_sig(
recording: Recording, recording: Recording,
annotations: Optional[list] = None,
output_path: Optional[str] = "images/signal.png", output_path: Optional[str] = "images/signal.png",
saveplot: Optional[bool] = True, saveplot: Optional[bool] = True,
fast_mode: Optional[bool] = False, fast_mode: Optional[bool] = False,
@ -309,15 +261,6 @@ def view_simple_sig(
ax2.set_title("Spectrogram", loc="left", pad=10) ax2.set_title("Spectrogram", loc="left", pad=10)
_add_annotations(
annotations=annotations,
compact_mode=compact_mode,
show_labels=show_labels,
sample_rate_hz=sample_rate_hz,
center_freq_hz=center_freq_hz,
ax2=ax2,
)
if ax_constellation is not None: if ax_constellation is not None:
constellation_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=50_000, fast_max=20_000) constellation_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=50_000, fast_max=20_000)
method = "differential" if fast_mode else "combined" method = "differential" if fast_mode else "combined"
@ -367,7 +310,7 @@ def view_simple_sig(
else: else:
plt.tight_layout() plt.tight_layout()
if show_title: if show_title:
plt.subplots_adjust(top=0.92) plt.subplots_adjust(top=0.90)
if saveplot: if saveplot:
output_path, extension = set_path(output_path=output_path) output_path, extension = set_path(output_path=output_path)

View File

@ -4,14 +4,14 @@ import scipy.signal as signal
from plotly.graph_objs import Figure from plotly.graph_objs import Figure
from scipy.fft import fft, fftshift from scipy.fft import fft, fftshift
from ria_toolkit_oss.data import Recording from ria_toolkit_oss.datatypes import Recording
def spectrogram(rec: Recording, thumbnail: bool = False) -> Figure: def spectrogram(rec: Recording, thumbnail: bool = False) -> Figure:
"""Create a spectrogram for the recording. """Create a spectrogram for the recording.
:param rec: Signal to plot. :param rec: Signal to plot.
:type rec: ria_toolkit_oss.data.Recording :type rec: ria_toolkit_oss.datatypes.Recording
:param thumbnail: Whether to return a small thumbnail version or full plot. :param thumbnail: Whether to return a small thumbnail version or full plot.
:type thumbnail: bool :type thumbnail: bool
@ -107,7 +107,7 @@ def iq_time_series(rec: Recording) -> Figure:
"""Create a time series plot of the real and imaginary parts of signal. """Create a time series plot of the real and imaginary parts of signal.
:param rec: Signal to plot. :param rec: Signal to plot.
:type rec: ria_toolkit_oss.data.Recording :type rec: ria_toolkit_oss.datatypes.Recording
:return: Time series plot, as a Plotly Figure. :return: Time series plot, as a Plotly Figure.
""" """
@ -145,7 +145,7 @@ def frequency_spectrum(rec: Recording) -> Figure:
"""Create a frequency spectrum plot from the recording. """Create a frequency spectrum plot from the recording.
:param rec: Input signal to plot. :param rec: Input signal to plot.
:type rec: ria_toolkit_oss.data.Recording :type rec: ria_toolkit_oss.datatypes.Recording
:return: Frequency spectrum, as a Plotly figure. :return: Frequency spectrum, as a Plotly figure.
""" """
@ -187,7 +187,7 @@ def constellation(rec: Recording) -> Figure:
"""Create a constellation plot from the recording. """Create a constellation plot from the recording.
:param rec: Input signal to plot. :param rec: Input signal to plot.
:type rec: ria_toolkit_oss.data.Recording :type rec: ria_toolkit_oss.datatypes.Recording
:return: Constellation, as a Plotly Figure. :return: Constellation, as a Plotly Figure.
""" """
@ -222,7 +222,7 @@ def power_spectral_density(rec: Recording) -> Figure:
"""Create a Power Spectral Density (PSD) plot from the recording. """Create a Power Spectral Density (PSD) plot from the recording.
:param rec: Input signal to plot. :param rec: Input signal to plot.
:type rec: ria_toolkit_oss.data.Recording :type rec: ria_toolkit_oss.datatypes.Recording
:return: PSD plot, as a Plotly Figure. :return: PSD plot, as a Plotly Figure.
""" """
@ -268,7 +268,7 @@ def fft_plot(rec: Recording) -> Figure:
"""Create an FFT magnitude plot from the recording. """Create an FFT magnitude plot from the recording.
:param rec: Input signal to plot. :param rec: Input signal to plot.
:type rec: ria_toolkit_oss.data.Recording :type rec: ria_toolkit_oss.datatypes.Recording
:return: FFT plot, as a Plotly Figure. :return: FFT plot, as a Plotly Figure.
""" """
@ -312,7 +312,7 @@ def spectrogram_3d(rec: Recording) -> Figure:
"""Create a 3D spectrogram plot from the recording. """Create a 3D spectrogram plot from the recording.
:param rec: Input signal to plot. :param rec: Input signal to plot.
:type rec: ria_toolkit_oss.data.Recording :type rec: ria_toolkit_oss.datatypes.Recording
:return: 3D Spectrogram, as a Plotly Figure. :return: 3D Spectrogram, as a Plotly Figure.
""" """

View File

@ -1,828 +0,0 @@
"""Annotate command - Automatic detection and manual annotation management."""
import json
from pathlib import Path
import click
from ria_toolkit_oss.annotations import (
annotate_with_cusum,
detect_signals_energy,
split_recording_annotations,
threshold_qualifier,
)
from ria_toolkit_oss.data import Annotation
from ria_toolkit_oss.data.recording import Recording
from ria_toolkit_oss.io import load_recording, to_blue, to_npy, to_sigmf, to_wav
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
format_frequency,
format_sample_count,
)
def normalize_sigmf_path(filepath):
"""Normalize SigMF path to base name without extension."""
path = Path(filepath)
# Handle .sigmf-data, .sigmf-meta, or .sigmf
if ".sigmf" in path.suffix:
# Remove the suffix to get base name
return path.with_suffix("")
else:
return path
def detect_input_format(filepath):
"""Detect file format from extension."""
path = Path(filepath)
ext = path.suffix.lower()
if ext in [".sigmf-data", ".sigmf-meta"]:
return "sigmf"
elif path.name.endswith(".sigmf"):
return "sigmf"
elif ext == ".npy":
return "npy"
elif ext == ".wav":
return "wav"
elif ext == ".blue":
return "blue"
else:
raise click.ClickException(f"Unknown format for '{filepath}'. Supported: .sigmf, .npy, .wav, .blue")
def determine_output_path(input_path, output_path, fmt, quiet, overwrite):
input_path = Path(input_path)
input_is_annotated = input_path.stem.endswith("_annotated")
if output_path:
target = Path(output_path)
elif overwrite and input_is_annotated:
# Write back in-place only when the input is already an _annotated file
target = input_path
else:
target = input_path.with_name(f"{input_path.stem}_annotated{input_path.suffix}")
if fmt == "sigmf":
final_path = normalize_sigmf_path(target)
if not quiet:
click.echo(f"Saving SigMF metadata to: {final_path}")
else:
final_path = target
if not quiet:
click.echo(f"Saving to: {final_path}")
# Always allow writing to _annotated files; guard against overwriting originals
target_is_annotated = final_path.stem.endswith("_annotated")
if final_path.exists() and not target_is_annotated and final_path != input_path:
click.echo(f"Error: {final_path} is not an annotated file and cannot be overwritten.", err=True)
return None
return final_path
def save_recording_auto(recording, output_path, input_path, quiet=False, overwrite=False):
"""Save recording, auto-detecting format from extension.
For SigMF: Only overwrites metadata file, data file is unchanged
For other formats: Creates _annotated copy by default, unless overwrite=True
"""
input_path = Path(input_path)
fmt = detect_input_format(input_path)
# Determine output path
output_path = determine_output_path(
input_path=input_path, output_path=output_path, fmt=fmt, quiet=quiet, overwrite=overwrite
)
if fmt == "sigmf":
# Normalize path for SigMF
base_path = output_path
stem = base_path.name
parent = base_path.parent
# For SigMF: only save metadata, copy data if needed
meta_path = parent / f"{stem}.sigmf-meta"
data_path = parent / f"{stem}.sigmf-data"
# If output is different from input, copy data file
input_base = normalize_sigmf_path(input_path)
if input_base != base_path:
import shutil
# Construct input data path correctly
# input_base is like /path/to/recording or /path/to/recording.sigmf
# We need /path/to/recording.sigmf-data
if str(input_base).endswith(".sigmf"):
input_data = Path(str(input_base).replace(".sigmf", ".sigmf-data"))
else:
input_data = input_base.parent / f"{input_base.name}.sigmf-data"
if not quiet:
click.echo(f" Copying: {data_path}")
shutil.copy2(input_data, data_path)
# Always save metadata (this is the whole point)
to_sigmf(recording, filename=stem, path=parent, overwrite=True)
if not quiet:
click.echo(f" Updated: {meta_path}")
if input_base != base_path:
click.echo(f" Created: {data_path}")
elif fmt == "npy":
to_npy(recording, filename=output_path.stem, path=output_path.parent, overwrite=True)
if not quiet:
click.echo(f" Created: {output_path}")
elif fmt == "wav":
to_wav(recording, filename=output_path.stem, path=output_path.parent, overwrite=True)
if not quiet:
click.echo(f" Created: {output_path}")
elif fmt == "blue":
to_blue(recording, filename=output_path.stem, path=output_path.parent, overwrite=True)
if not quiet:
click.echo(f" Created: {output_path}")
def determine_frequency_bounds(recording: Recording, freq_lower, freq_upper):
# Handle frequency bounds
if (freq_lower is None) != (freq_upper is None):
raise click.ClickException("Must specify both --freq-lower and --freq-upper, or neither")
if freq_lower is None:
# Default to full bandwidth
sample_rate = recording.metadata.get("sample_rate", 1)
center_freq = recording.metadata.get("center_frequency", 0)
freq_lower = center_freq - (sample_rate / 2)
freq_upper = center_freq + (sample_rate / 2)
freq_default = True
else:
freq_default = False
if freq_lower >= freq_upper:
raise click.ClickException(
f"Invalid frequency range: lower ({format_frequency(freq_lower)}) "
f"must be < upper ({format_frequency(freq_upper)})"
)
return freq_lower, freq_upper, freq_default
def get_indices_list(indices, recording: Recording):
if indices:
try:
indices_list = [int(idx.strip()) for idx in indices.split(",")]
# Validate indices
for idx in indices_list:
if idx < 0 or idx >= len(recording.annotations):
raise click.ClickException(
f"Invalid index {idx}. Recording has {len(recording.annotations)} annotation(s)"
)
except ValueError as e:
raise click.ClickException(f"Invalid indices format. Expected comma-separated integers: {e}")
return indices_list
else:
return None
# ============================================================================
# Main command group
# ============================================================================
@click.group()
def annotate():
"""Manage and auto-detect annotations on RF recordings.
\b
MANUAL MANAGEMENT:
list - List all current annotations
add - Manually add a specific annotation
remove - Delete an annotation by its index
clear - Remove all annotations from the recording
\b
DETECTION & SEPARATION:
energy - Auto-detect using energy-based thresholding
cusum - Auto-detect segments using signal state changes
threshold - Auto-detect samples above magnitude percentage
separate - Auto-detect parallel frequency-offset signals, split into sub-bands
\b
File Path Handling:
- SigMF files: Pass .sigmf-data, .sigmf-meta, or base name
- Other formats: .npy, .wav, .blue files
\b
Output Behavior:
- SigMF: Updates .sigmf-meta only (data unchanged), in-place
- Other: Creates _annotated copy unless --overwrite specified
"""
pass
# ============================================================================
# List subcommand
# ============================================================================
@annotate.command()
@click.argument("input", type=click.Path(exists=True))
@click.option("--verbose", is_flag=True, help="Show detailed annotation info")
def list(input, verbose):
"""List all annotations in a recording.
\b
Examples:
ria annotate list recording.sigmf-data
ria annotate list signal.npy --verbose
"""
try:
recording = load_recording(input)
except Exception as e:
raise click.ClickException(f"Failed to load recording: {e}")
if len(recording.annotations) == 0:
click.echo(f"No annotations in {Path(input).name}")
return
click.echo(f"\nAnnotations in {Path(input).name}:")
for i, ann in enumerate(recording.annotations):
# Parse type from comment JSON
try:
comment_data = json.loads(ann.comment)
ann_type = comment_data.get("type", "unknown")
user_comment = comment_data.get("user_comment", "")
except (json.JSONDecodeError, TypeError):
ann_type = "unknown"
user_comment = ann.comment or ""
# Basic info
freq_range = f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
click.echo(
f" [{i}] Samples {format_sample_count(ann.sample_start)}-"
f"{format_sample_count(ann.sample_start + ann.sample_count)}: {ann.label}"
)
click.echo(f" Type: {ann_type}")
if verbose:
if user_comment:
click.echo(f" Comment: {user_comment}")
click.echo(f" Frequency: {freq_range}")
if ann.detail:
click.echo(f" Detail: {ann.detail}")
click.echo(f"\nTotal: {len(recording.annotations)} annotation(s)")
# ============================================================================
# Add subcommand
# ============================================================================
@annotate.command(context_settings={"max_content_width": 200})
@click.argument("input", type=click.Path(exists=True))
@click.option("--start", type=int, required=True, help="Start sample index")
@click.option("--count", type=int, required=True, help="Sample count")
@click.option("--label", type=str, required=True, help="Annotation label")
@click.option("--freq-lower", type=float, help="Lower frequency edge (Hz)")
@click.option("--freq-upper", type=float, help="Upper frequency edge (Hz)")
@click.option("--comment", type=str, help="Human-readable comment")
@click.option(
"--type",
"annotation_type",
type=click.Choice(["standalone", "parallel", "intersection"]),
default="standalone",
help="Annotation type",
)
@click.option("--output", "-o", type=click.Path(), help="Output file path")
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
@click.option("--quiet", is_flag=True, help="Quiet mode")
def add(input, start, count, label, freq_lower, freq_upper, comment, annotation_type, output, overwrite, quiet):
"""Add a manual annotation.
\b
Examples:
ria annotate add file.npy --start 1000 --count 500 --label wifi
ria annotate add signal.sigmf-data --start 0 --count 1000 --label burst --comment "Strong signal"
"""
try:
recording = load_recording(input)
if not quiet:
click.echo(f"Loaded: {input}")
except Exception as e:
raise click.ClickException(f"Failed to load recording: {e}")
# Validate sample range
n_samples = len(recording.data[0])
if start < 0:
raise click.ClickException(f"--start must be >= 0, got {start}")
if count <= 0:
raise click.ClickException(f"--count must be > 0, got {count}")
if start + count > n_samples:
raise click.ClickException(
f"Invalid annotation range:\n"
f" Start: {start:,}\n"
f" Count: {count:,}\n"
f" End: {start + count:,}\n"
f"Recording only has {n_samples:,} samples"
)
# Handle frequency bounds
freq_lower, freq_upper, freq_default = determine_frequency_bounds(
recording=recording, freq_lower=freq_lower, freq_upper=freq_upper
)
# Build comment JSON
comment_data = {"type": annotation_type}
if comment:
comment_data["user_comment"] = comment
# Create annotation
ann = Annotation(
sample_start=start,
sample_count=count,
freq_lower_edge=freq_lower,
freq_upper_edge=freq_upper,
label=label,
comment=json.dumps(comment_data),
detail={},
)
recording._annotations.append(ann)
if not quiet:
click.echo("\nAdding annotation:")
click.echo(f" Start: {format_sample_count(start)}")
click.echo(f" Count: {format_sample_count(count)} samples")
freq_str = (
"full bandwidth" if freq_default else f"{format_frequency(freq_lower)} - {format_frequency(freq_upper)}"
)
click.echo(f" Frequency: {freq_str}")
click.echo(f" Label: {label}")
click.echo(f" Type: {annotation_type}")
if comment:
click.echo(f" Comment: {comment}")
try:
save_recording_auto(recording, output, input, quiet, overwrite)
if not quiet:
click.echo(" ✓ Saved")
except Exception as e:
raise click.ClickException(f"Failed to save: {e}")
# ============================================================================
# Remove subcommand
# ============================================================================
@annotate.command(context_settings={"max_content_width": 200})
@click.argument("input", type=click.Path(exists=True))
@click.argument("index", type=int)
@click.option("--output", "-o", type=click.Path(), help="Output file path")
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
@click.option("--quiet", is_flag=True, help="Quiet mode")
def remove(input, index, output, overwrite, quiet):
"""Remove annotation by index.
Use 'ria annotate list' to see annotation indices.
\b
Examples:
ria annotate remove signal.sigmf-data 2
ria annotate remove file.npy 0
"""
try:
recording = load_recording(input)
if not quiet:
click.echo(f"Loaded: {input}")
except Exception as e:
raise click.ClickException(f"Failed to load recording: {e}")
if index < 0 or index >= len(recording.annotations):
raise click.ClickException(
f"Cannot remove annotation at index {index}\n"
f"Recording has {len(recording.annotations)} annotation(s) (indices 0-{len(recording.annotations)-1})"
)
removed_ann = recording.annotations[index]
recording._annotations.pop(index)
if not quiet:
click.echo(f"\nRemoving annotation [{index}]:")
click.echo(
f" Removed: samples {format_sample_count(removed_ann.sample_start)}-"
f"{format_sample_count(removed_ann.sample_start + removed_ann.sample_count)} ({removed_ann.label})"
)
try:
save_recording_auto(recording, output_path=input, input_path=input, quiet=quiet, overwrite=True)
if not quiet:
click.echo(" ✓ Saved")
except Exception as e:
raise click.ClickException(f"Failed to save: {e}")
# ============================================================================
# Clear subcommand
# ============================================================================
@annotate.command(context_settings={"max_content_width": 175})
@click.argument("input", type=click.Path(exists=True))
@click.option("--output", "-o", type=click.Path(), help="Output file path")
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
@click.option("--force", is_flag=True, help="Skip confirmation")
@click.option("--quiet", is_flag=True, help="Quiet mode")
def clear(input, output, overwrite, force, quiet):
"""Clear all annotations.
\b
Examples:
ria annotate clear signal.sigmf-data
ria annotate clear file.npy --force
"""
try:
recording = load_recording(input)
if not quiet:
click.echo(f"Loaded: {input}")
except Exception as e:
raise click.ClickException(f"Failed to load recording: {e}")
count_before = len(recording.annotations)
if count_before == 0:
if not quiet:
click.echo("No annotations to clear")
return
# Confirm unless --force
if not force and not quiet:
click.echo(f"\nWarning: This will remove all {count_before} annotation(s)")
click.confirm("Continue?", abort=True)
recording._annotations = []
if not quiet:
click.echo(f"\nCleared {count_before} annotation(s)")
recording._annotations = []
try:
save_recording_auto(recording, output_path=input, input_path=input, quiet=quiet, overwrite=True)
if not quiet:
click.echo(" ✓ Saved")
except Exception as e:
raise click.ClickException(f"Failed to save: {e}")
# ============================================================================
# Energy detection subcommand
# ============================================================================
@annotate.command(context_settings={"max_content_width": 200})
@click.argument("input", type=click.Path(exists=True))
@click.option("--label", type=str, default="signal", help="Annotation label")
@click.option("--threshold", type=float, default=1.2, help="Threshold multiplier above noise floor")
@click.option("--segments", type=int, default=10, help="Number of segments for noise estimation")
@click.option("--window-size", type=int, default=200, help="Smoothing window size")
@click.option("--min-distance", type=int, default=5000, help="Min distance between detections")
@click.option(
"--freq-method",
type=click.Choice(["nbw", "obw", "full-detected", "full-bandwidth"]),
default="nbw",
help="Frequency bounding method",
)
@click.option("--nfft", type=int, default=None, help="FFT size for frequency calculation")
@click.option("--obw-power", type=float, default=0.99, help="Power percentage for OBW/NBW (0.98-0.9999)")
@click.option(
"--type",
"annotation_type",
type=click.Choice(["standalone", "parallel", "intersection"]),
default="standalone",
help="Annotation type",
)
@click.option("--output", "-o", type=click.Path(), help="Output file path")
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
@click.option("--quiet", is_flag=True, help="Quiet mode")
def energy(
input,
label,
threshold,
segments,
window_size,
min_distance,
freq_method,
nfft,
obw_power,
annotation_type,
output,
overwrite,
quiet,
):
"""Auto-detect signals using energy-based method.
Detects bursts based on energy above noise floor. Best for bursty signals
and intermittent transmissions.
\b
Frequency Bounding Methods:
nbw - Nominal bandwidth (default, best for real signals)
obw - Occupied bandwidth (more conservative, includes sidelobes)
full-detected - Lowest to highest spectral component
full-bandwidth - Entire Nyquist span
\b
Examples:
ria annotate energy capture.sigmf-data --label burst
ria annotate energy signal.npy --threshold 1.5 --min-distance 10000
ria annotate energy signal.sigmf-data --freq-method obw
ria annotate energy signal.sigmf-data --freq-method full-detected
"""
try:
recording = load_recording(input)
if not quiet:
click.echo(f"Loaded: {input}")
except Exception as e:
raise click.ClickException(f"Failed to load recording: {e}")
if not quiet:
click.echo("\nDetecting signals using energy-based method...")
click.echo(" Time detection:")
click.echo(f" Segments: {segments}")
click.echo(f" Threshold: {threshold}x noise floor")
click.echo(f" Window size: {window_size} samples")
click.echo(f" Min distance: {min_distance} samples")
click.echo(f" Frequency bounds: {freq_method}")
try:
initial_count = len(recording.annotations)
recording = detect_signals_energy(
recording,
k=segments,
threshold_factor=threshold,
window_size=window_size,
min_distance=min_distance,
label=label,
annotation_type=annotation_type,
freq_method=freq_method,
nfft=nfft,
obw_power=obw_power,
)
added = len(recording.annotations) - initial_count
if not quiet:
click.echo(f" ✓ Added {added} annotation(s)")
save_recording_auto(recording, output, input, quiet, overwrite)
if not quiet:
click.echo(" ✓ Saved")
except Exception as e:
raise click.ClickException(f"Energy detection failed: {e}")
# ============================================================================
# CUSUM detection subcommand
# ============================================================================
@annotate.command()
@click.argument("input", type=click.Path(exists=True))
@click.option("--label", type=str, default="segment", help="Annotation label")
@click.option("--min-duration", type=float, default=5.0, help="Min duration in ms (prevents over-segmentation)")
@click.option("--window-size", type=int, default=1, help="Smoothing window size")
@click.option("--tolerance", type=int, default=-1, help="Sample tolerance for merging")
@click.option(
"--type",
"annotation_type",
type=click.Choice(["standalone", "parallel", "intersection"]),
default="standalone",
help="Annotation type",
)
@click.option("--output", "-o", type=click.Path(), help="Output file path")
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
@click.option("--quiet", is_flag=True, help="Quiet mode")
def cusum(input, label, min_duration, window_size, tolerance, annotation_type, output, overwrite, quiet):
"""Auto-detect segments using CUSUM method.
Detects signal state changes (on/off, amplitude transitions). Best for
segmenting continuous signals.
IMPORTANT: Always specify --min-duration to prevent excessive segmentation.
\b
Examples:
ria annotate cusum signal.sigmf-data --min-duration 5.0
ria annotate cusum data.npy --min-duration 10.0 --label state
"""
try:
recording = load_recording(input)
if not quiet:
click.echo(f"Loaded: {input}")
except Exception as e:
raise click.ClickException(f"Failed to load recording: {e}")
if not quiet:
click.echo("\nDetecting segments using CUSUM...")
click.echo(f" Min duration: {min_duration} ms")
if window_size != 1:
click.echo(f" Window size: {window_size} samples")
try:
initial_count = len(recording.annotations)
recording = annotate_with_cusum(
recording,
label=label,
window_size=window_size,
min_duration=min_duration,
tolerance=tolerance,
annotation_type=annotation_type,
)
added = len(recording.annotations) - initial_count
if not quiet:
click.echo(f" ✓ Added {added} annotation(s)")
save_recording_auto(recording, output, input, quiet, overwrite)
if not quiet:
click.echo(" ✓ Saved")
except Exception as e:
raise click.ClickException(f"CUSUM detection failed: {e}")
# ============================================================================
# Threshold detection subcommand
# ============================================================================
@annotate.command()
@click.argument("input", type=click.Path(exists=True))
@click.option("--threshold", type=float, required=True, help="Threshold (0.0-1.0, fraction of max magnitude)")
@click.option("--label", type=str, default=None, help="Annotation label")
@click.option(
"--window-size",
type=int,
default=None,
help="Smoothing window size in samples (default: 1ms at recording sample rate)",
)
@click.option(
"--type",
"annotation_type",
type=click.Choice(["standalone", "parallel", "intersection"]),
default="standalone",
help="Annotation type",
)
@click.option("--channel", type=int, default=0, help="Channel index to annotate (default: 0)")
@click.option("--output", "-o", type=click.Path(), help="Output file path")
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
@click.option("--quiet", is_flag=True, help="Quiet mode")
def threshold(input, threshold, label, window_size, annotation_type, channel, output, overwrite, quiet):
"""Auto-detect signals using threshold method.
Detects samples above a percentage of maximum magnitude. Best for simple
power-based detection.
\b
Examples:
ria annotate threshold signal.sigmf-data --threshold 0.7 --label wifi
ria annotate threshold data.npy --threshold 0.5 --window-size 2048
"""
if not (0.0 <= threshold <= 1.0):
raise click.ClickException(f"--threshold must be between 0.0 and 1.0, got {threshold}")
try:
recording = load_recording(input)
if not quiet:
click.echo(f"Loaded: {input}")
except Exception as e:
raise click.ClickException(f"Failed to load recording: {e}")
if not quiet:
click.echo("\nDetecting signals using threshold qualifier...")
click.echo(f" Threshold: {threshold * 100:.1f}% of max magnitude")
click.echo(f" Window size: {'auto (1ms)' if window_size is None else f'{window_size} samples'}")
click.echo(f" Channel: {channel}")
try:
initial_count = len(recording.annotations)
recording = threshold_qualifier(
recording,
threshold=threshold,
window_size=window_size,
label=label,
annotation_type=annotation_type,
channel=channel,
)
added = len(recording.annotations) - initial_count
if not quiet:
click.echo(f" ✓ Added {added} annotation(s)")
save_recording_auto(recording, output, input, quiet, overwrite)
if not quiet:
click.echo(" ✓ Saved")
except Exception as e:
raise click.ClickException(f"Threshold detection failed: {e}")
# ============================================================================
# Separate subcommand (Phase 2: Parallel signal separation)
# ============================================================================
@annotate.command()
@click.argument("input", type=click.Path(exists=True))
@click.option("--indices", type=str, help="Comma-separated annotation indices to split (default: all)")
@click.option("--nfft", type=int, default=65536, help="FFT size for spectral analysis")
@click.option("--noise-threshold-db", type=float, help="Noise floor threshold in dB (auto-estimated if not specified)")
@click.option("--min-component-bw", type=float, default=50e3, help="Min component bandwidth in Hz")
@click.option("--output", "-o", type=click.Path(), help="Output file path")
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
@click.option("--quiet", is_flag=True, help="Quiet mode")
@click.option("--verbose", is_flag=True, help="Verbose output (show detected components)")
def separate(input, indices, nfft, noise_threshold_db, min_component_bw, output, overwrite, quiet, verbose):
"""
Auto-detect parallel frequency-offset signals and split into sub-bands.
Provides methods to detect and separate overlapping frequency-domain signals
that occupy the same time window but different frequency bands.
Detects multiple frequency components within single annotations and splits
them into separate annotations. Uses spectral peak detection with dual
bandwidth estimation.
\b
Key Features:
- Spectral peak detection for frequency components
- Auto noise floor estimation (or user-specified)
- Dual bandwidth estimation: -3dB primary, cumulative power fallback
- Handles narrowband and wide signals (OFDM)
\b
Examples:
ria annotate separate capture.sigmf-data
ria annotate separate signal.npy --indices 0,1,2
ria annotate separate data.sigmf-data --noise-threshold-db -70
ria annotate separate signal.npy --min-component-bw 100000
"""
try:
recording = load_recording(input)
if not quiet:
click.echo(f"Loaded: {input}")
except Exception as e:
raise click.ClickException(f"Failed to load recording: {e}")
# Parse indices if specified
indices_list = get_indices_list(indices=indices, recording=recording)
if len(recording.annotations) == 0:
if not quiet:
click.echo("No annotations to split")
return
if not quiet:
click.echo("\nSplitting annotations by frequency components...")
click.echo(f" Input annotations: {len(recording.annotations)}")
if indices_list:
click.echo(f" Splitting indices: {indices_list}")
click.echo(f" FFT size: {nfft}")
if noise_threshold_db is not None:
click.echo(f" Noise threshold: {noise_threshold_db} dB")
else:
click.echo(" Noise threshold: auto-estimated")
click.echo(f" Min component BW: {format_frequency(min_component_bw)}")
try:
initial_count = len(recording.annotations)
recording = split_recording_annotations(
recording,
indices=indices_list,
nfft=nfft,
noise_threshold_db=noise_threshold_db,
min_component_bw=min_component_bw,
)
final_count = len(recording.annotations)
added = final_count - initial_count
if not quiet:
click.echo(f" ✓ Output annotations: {final_count} ({'+' if added >= 0 else ''}{added} change)")
if verbose and added > 0:
click.echo("\n Details:")
for i in range(initial_count, final_count):
ann = recording.annotations[i]
freq_range = f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
click.echo(
f" [{i}] samples {format_sample_count(ann.sample_start)}-"
f"{format_sample_count(ann.sample_start + ann.sample_count)}: {freq_range}"
)
save_recording_auto(recording, output, input, quiet, overwrite)
if not quiet:
click.echo(" ✓ Saved")
except Exception as e:
raise click.ClickException(f"Spectral separation failed: {e}")

View File

@ -7,7 +7,7 @@ from pathlib import Path
import click import click
import numpy as np import numpy as np
from ria_toolkit_oss.data import Recording from ria_toolkit_oss.datatypes import Recording
from ria_toolkit_oss.io import from_npy_legacy, load_recording from ria_toolkit_oss.io import from_npy_legacy, load_recording
from ria_toolkit_oss_cli.ria_toolkit_oss.common import ( from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
echo_progress, echo_progress,

View File

@ -3,7 +3,6 @@
This module contains all the CLI bindings for the ria package. This module contains all the CLI bindings for the ria package.
""" """
from .annotate import annotate
from .campaign import campaign from .campaign import campaign
from .capture import capture from .capture import capture
from .combine import combine from .combine import combine

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional
import click import click
import yaml import yaml
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.io.recording import to_blue, to_npy, to_sigmf, to_wav from ria_toolkit_oss.io.recording import to_blue, to_npy, to_sigmf, to_wav

View File

@ -8,7 +8,7 @@ import numpy as np
import yaml import yaml
import ria_toolkit_oss.signal.basic_signal_generator as basic_gen import ria_toolkit_oss.signal.basic_signal_generator as basic_gen
from ria_toolkit_oss.data import Recording from ria_toolkit_oss.datatypes import Recording
from ria_toolkit_oss.signal.block_generator.basic import FrequencyShift from ria_toolkit_oss.signal.block_generator.basic import FrequencyShift
from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import ( from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import (
FSKModulator, FSKModulator,
@ -232,8 +232,8 @@ def generate():
\b \b
Examples: Examples:
ria synth chirp -b 1e6 -p 0.01 -s 10e6 -o chirp_basic.sigmf utils synth chirp -b 1e6 -p 0.01 -s 10e6 -o chirp_basic.sigmf
ria synth fsk -M 2 -r 100e3 -s 2e6 -o fsk2_basic.sigmf utils synth fsk -M 2 -r 100e3 -s 2e6 -o fsk2_basic.sigmf
""" """
pass pass

View File

@ -23,9 +23,9 @@ def serve(host: str, port: int, api_key: str, log_level: str):
\b \b
Endpoints: Endpoints:
POST /conductor/deploy POST /orchestrator/deploy
GET /conductor/status/{campaign_id} GET /orchestrator/status/{campaign_id}
POST /conductor/cancel/{campaign_id} POST /orchestrator/cancel/{campaign_id}
POST /inference/load POST /inference/load
POST /inference/start POST /inference/start
POST /inference/stop POST /inference/stop

View File

@ -8,7 +8,7 @@ from pathlib import Path
import click import click
from ria_toolkit_oss.data.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.io.recording import load_recording from ria_toolkit_oss.io.recording import load_recording
from ria_toolkit_oss.transforms import iq_augmentations, iq_impairments from ria_toolkit_oss.transforms import iq_augmentations, iq_impairments
from ria_toolkit_oss_cli.ria_toolkit_oss.common import ( from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
@ -270,13 +270,13 @@ def transform():
Examples:\n Examples:\n
\b \b
# List available augmentations # List available augmentations
ria transform augment --list utils transform augment --list
\b \b
# Apply channel swap # Apply channel swap
ria transform augment channel_swap input.npy utils transform augment channel_swap input.npy
\b \b
# Apply AWGN impairment # Apply AWGN impairment
ria transform impair awgn input.npy --snr-db 15 utils transform impair awgn input.npy --snr-db 15
""" """
pass pass

View File

@ -6,7 +6,7 @@ import time
import click import click
from ria_toolkit_oss.data import Recording from ria_toolkit_oss.datatypes import Recording
from ria_toolkit_oss.io import from_npy_legacy, load_recording from ria_toolkit_oss.io import from_npy_legacy, load_recording
from .common import ( from .common import (

View File

@ -7,7 +7,7 @@ from typing import Optional
import click import click
from ria_toolkit_oss.io.recording import from_npy, load_recording from ria_toolkit_oss.io.recording import from_npy, load_recording
from ria_toolkit_oss.view.view_signal import view_annotations, view_channels, view_sig from ria_toolkit_oss.view.view_signal import view_channels, view_sig
from ria_toolkit_oss.view.view_signal_simple import view_simple_sig from ria_toolkit_oss.view.view_signal_simple import view_simple_sig
from .common import echo_progress, echo_verbose, load_yaml_config from .common import echo_progress, echo_verbose, load_yaml_config
@ -34,11 +34,6 @@ VISUALIZATION_TYPES = {
"spines", "spines",
], ],
}, },
"annotations": {
"function": view_annotations,
"description": "Annotation-focused spectrogram view",
"options": ["channel", "dark"],
},
"channels": {"function": view_channels, "description": "Multi-channel IQ and spectrogram view", "options": []}, "channels": {"function": view_channels, "description": "Multi-channel IQ and spectrogram view", "options": []},
} }
@ -199,7 +194,7 @@ def print_metadata(recording, quiet):
@click.option( @click.option(
"--type", "--type",
"viz_type", "viz_type",
type=click.Choice(list(VISUALIZATION_TYPES.keys()) + ["annotate", "annotation"]), type=click.Choice(list(VISUALIZATION_TYPES.keys())),
default="simple", default="simple",
show_default=True, show_default=True,
help="Visualization type", help="Visualization type",
@ -243,7 +238,7 @@ def print_metadata(recording, quiet):
@click.option("--verbose", "-v", is_flag=True, help="Verbose output") @click.option("--verbose", "-v", is_flag=True, help="Verbose output")
@click.option("--quiet", "-q", is_flag=True, help="Suppress output") @click.option("--quiet", "-q", is_flag=True, help="Suppress output")
@click.option("--overwrite", is_flag=True, help="Overwrite existing output file") @click.option("--overwrite", is_flag=True, help="Overwrite existing output file")
def view( # noqa: C901 def view(
input, input,
viz_type, viz_type,
output, output,
@ -302,9 +297,6 @@ def view( # noqa: C901
# Legacy NPY file # Legacy NPY file
ria view old_capture.npy --legacy --type simple ria view old_capture.npy --legacy --type simple
""" """
if viz_type in ["annotate", "annotation"]:
viz_type = "annotations"
# Load config file if specified # Load config file if specified
if config: if config:
_ = load_yaml_config(config) _ = load_yaml_config(config)

View File

@ -1,115 +0,0 @@
"""CLI flags for TX opt-in and interlocks."""
from __future__ import annotations
import json
import sys
from unittest.mock import patch
from ria_toolkit_oss.agent import cli as agent_cli
from ria_toolkit_oss.agent import config as agent_config
class _FakeResp:
def __init__(self, payload: dict):
self._payload = payload
def read(self) -> bytes:
return json.dumps(self._payload).encode()
def __enter__(self):
return self
def __exit__(self, *_a):
return False
def _run_register(argv: list[str], cfg_path) -> int:
fake_resp = _FakeResp({"agent_id": "agent-1", "token": "tok-abc"})
with (
patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False),
patch("urllib.request.urlopen", return_value=fake_resp),
patch.object(sys, "argv", ["ria-agent", *argv]),
):
try:
agent_cli.main()
except SystemExit as exc:
return int(exc.code or 0)
return 0
def test_register_without_allow_tx_keeps_tx_disabled(tmp_path):
cfg_path = tmp_path / "agent.json"
_run_register(
["register", "--hub", "http://hub:3005", "--api-key", "K"],
cfg_path,
)
cfg = agent_config.load(path=cfg_path)
assert cfg.agent_id == "agent-1"
assert cfg.tx_enabled is False
assert cfg.tx_max_gain_db is None
def test_register_with_allow_tx_and_caps(tmp_path):
cfg_path = tmp_path / "agent.json"
_run_register(
[
"register",
"--hub",
"http://hub:3005",
"--api-key",
"K",
"--allow-tx",
"--tx-max-gain-db",
"-10",
"--tx-max-duration-s",
"60",
"--tx-freq-range",
"2.4e9",
"2.5e9",
"--tx-freq-range",
"5.7e9",
"5.8e9",
],
cfg_path,
)
cfg = agent_config.load(path=cfg_path)
assert cfg.tx_enabled is True
assert cfg.tx_max_gain_db == -10.0
assert cfg.tx_max_duration_s == 60.0
assert cfg.tx_allowed_freq_ranges == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]]
def test_stream_allow_tx_does_not_persist(tmp_path):
# Pre-register with tx_enabled=False, then simulate `stream --allow-tx`.
# The on-disk config must remain unchanged; the runtime flag is process-local.
cfg_path = tmp_path / "agent.json"
base = agent_config.AgentConfig(
hub_url="http://hub:3005",
agent_id="agent-1",
token="tok-abc",
tx_enabled=False,
)
agent_config.save(base, path=cfg_path)
captured: dict = {}
async def _fake_run_streamer(url, token, *, cfg):
captured["cfg"] = cfg
return None
with (
patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False),
patch("ria_toolkit_oss.agent.streamer.run_streamer", new=_fake_run_streamer),
patch.object(sys, "argv", ["ria-agent", "stream", "--allow-tx"]),
):
try:
agent_cli.main()
except SystemExit:
pass
# Runtime cfg had TX flipped on
assert captured["cfg"].tx_enabled is True
# But the persisted file is untouched
on_disk = agent_config.load(path=cfg_path)
assert on_disk.tx_enabled is False

View File

@ -20,36 +20,6 @@ def test_load_missing_returns_empty(tmp_path):
assert loaded == agent_config.AgentConfig() assert loaded == agent_config.AgentConfig()
def test_tx_fields_round_trip(tmp_path):
p = tmp_path / "agent.json"
cfg = agent_config.AgentConfig(
hub_url="https://hub.example.com",
agent_id="agent-1",
token="t",
tx_enabled=True,
tx_max_gain_db=-10.0,
tx_max_duration_s=60.0,
tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]],
)
agent_config.save(cfg, path=p)
loaded = agent_config.load(path=p)
assert loaded.tx_enabled is True
assert loaded.tx_max_gain_db == -10.0
assert loaded.tx_max_duration_s == 60.0
assert loaded.tx_allowed_freq_ranges == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]]
def test_tx_fields_default_when_absent(tmp_path):
# Old configs written before TX existed should load cleanly with safe defaults.
p = tmp_path / "agent.json"
p.write_text('{"hub_url": "x", "agent_id": "a", "token": "t"}')
cfg = agent_config.load(path=p)
assert cfg.tx_enabled is False
assert cfg.tx_max_gain_db is None
assert cfg.tx_max_duration_s is None
assert cfg.tx_allowed_freq_ranges is None
def test_extra_keys_preserved(tmp_path): def test_extra_keys_preserved(tmp_path):
p = tmp_path / "agent.json" p = tmp_path / "agent.json"
p.write_text('{"hub_url": "x", "custom": 42}') p.write_text('{"hub_url": "x", "custom": 42}')

Some files were not shown because too many files have changed in this diff Show More