Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0f246e9c69 | |||
| 4ba0dc170e | |||
| 762eda9426 | |||
| 172f8e0ca3 | |||
| c035b990ef | |||
| 98407604ef | |||
| 1cf5a723e6 | |||
| 0bd1b6e288 | |||
| 8105b829be | |||
| 450fab6df2 | |||
| a0b46a35e2 | |||
| 4872eea116 | |||
| 24730850b0 | |||
| 5074e8f32a | |||
| 4420ae76c9 | |||
| ddf445fd4d | |||
| d68b9727ad | |||
| c06e58f5d6 | |||
| c7c7100d46 | |||
| e863040e19 | |||
| 77d1773370 | |||
| 79aa1fc0a4 | |||
| 2721ed866c | |||
| e84cb16e77 | |||
| c2b47ead95 | |||
| 34faa57ea4 | |||
| f430e626a6 | |||
| 1fb55607a2 | |||
|
Aash
|
2e0378ff9d | ||
|
Aash
|
3f8506f222 | ||
|
Aash
|
5c9e50fa48 | ||
|
Aash
|
25e5a4c6a6 | ||
|
Aash
|
a9f8ad4bee | ||
|
Aash
|
8a3c80b33f | ||
| d919e4666c | |||
| 42af1a2c1e |
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -48,6 +48,7 @@ coverage.xml
|
|||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
tests/sdr/
|
||||
|
||||
# Sphinx documentation
|
||||
docs/build/
|
||||
|
|
|
|||
|
|
@ -168,6 +168,10 @@ Additional usage information is provided in the project documentation: [RIA Tool
|
|||
|
||||
Kindly report any issues on RIA Hub: [RIA Toolkit OSS Issues Board](https://riahub.ai/qoherent/ria-toolkit-oss/issues).
|
||||
|
||||
### Upcoming Changes
|
||||
The ThinkRF package is currently pending further testing and potential updates.
|
||||
|
||||
|
||||
## 🤝 Contribution
|
||||
|
||||
Contributions are always welcome! Whether it's an enhancement, bug fix, or new example, your input is valuable. If you'd like to contribute to the project, please reach out to the project maintainers.
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
|
|||
project = 'ria-toolkit-oss'
|
||||
copyright = '2025, Qoherent Inc'
|
||||
author = 'Qoherent Inc.'
|
||||
release = '0.1.2'
|
||||
release = '0.1.3'
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ their key capabilities and limitations, as well as additional information needed
|
|||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
USRP <usrp>
|
||||
BladeRF <blade>
|
||||
PlutoSDR <pluto>
|
||||
HackRF <hackrf>
|
||||
PlutoSDR <pluto>
|
||||
RTL-SDR <rtlsdr>
|
||||
ThinkRF <thinkrf>
|
||||
USRP <usrp>
|
||||
|
|
|
|||
87
docs/source/sdr_guides/rtlsdr.rst
Normal file
87
docs/source/sdr_guides/rtlsdr.rst
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
.. _rtl:
|
||||
|
||||
RTLSDR
|
||||
======
|
||||
|
||||
RTL-SDR (RTL2832U Software Defined Radio) is a low-cost USB dongle originally designed for digital TV reception
|
||||
that has been repurposed as a wideband software-defined radio. RTL-SDR devices are popular for hobbyist use due to
|
||||
their affordability and wide range of applications.
|
||||
|
||||
The RTL-SDR is based on the Realtek RTL2832U chipset, which features direct sampling and demodulation of RF
|
||||
signals. These devices are commonly used for tasks such as listening to FM radio, monitoring aircraft traffic
|
||||
(ADS-B), receiving weather satellite images, and more.
|
||||
|
||||
Supported Models
|
||||
----------------
|
||||
- Generic RTL-SDR Dongle: The most common variant, usually featuring an R820T or R820T2 tuner.
|
||||
- RTL-SDR Blog V3: An enhanced version with additional features like direct sampling mode and a bias tee for
|
||||
powering external devices.
|
||||
|
||||
Key Features
|
||||
------------
|
||||
- Frequency Range: Typically from 24 MHz to 1.7 GHz, depending on the tuner chip.
|
||||
- Bandwidth: Limited to about 2.4 MHz, making it suitable for narrowband applications.
|
||||
- Connectivity: USB 2.0 interface, plug-and-play on most platforms.
|
||||
- Software Support: Compatible with SDR software like SDR#, GQRX, and GNU Radio.
|
||||
|
||||
Limitations
|
||||
-----------
|
||||
- Narrow bandwidth compared to more expensive SDRs, which may limit some applications.
|
||||
- Sensitivity and performance can vary depending on the specific model and components.
|
||||
- Requires external software for signal processing and analysis.
|
||||
|
||||
Set up instructions (Linux, Radioconda)
|
||||
---------------------------------------
|
||||
|
||||
1. Activate your Radioconda environment:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
conda activate <your-env-name>
|
||||
|
||||
2. Purge drivers:
|
||||
|
||||
If you already have other drivers installed, purge them from your system.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
sudo apt purge ^librtlsdr
|
||||
sudo rm -rvf /usr/lib/librtlsdr*
|
||||
sudo rm -rvf /usr/include/rtl-sdr*
|
||||
sudo rm -rvf /usr/local/lib/librtlsdr*
|
||||
sudo rm -rvf /usr/local/include/rtl-sdr*
|
||||
sudo rm -rvf /usr/local/include/rtl_*
|
||||
sudo rm -rvf /usr/local/bin/rtl_*
|
||||
|
||||
3. Install RTL-SDR Blog drivers:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
sudo apt-get install libusb-1.0-0-dev git cmake pkg-config build-essential
|
||||
git clone https://github.com/osmocom/rtl-sdr
|
||||
cd rtl-sdr
|
||||
mkdir build
|
||||
cd build
|
||||
cmake ../ -DINSTALL_UDEV_RULES=ON
|
||||
make
|
||||
sudo make install
|
||||
sudo cp ../rtl-sdr.rules /etc/udev/rules.d/
|
||||
sudo ldconfig
|
||||
|
||||
4. Blacklist the DVB-T modules that would otherwise claim the device:
|
||||
|
||||
.. code-block:: bash
|
||||
sudo ln -s $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf /etc/modprobe.d/radioconda-rtl-sdr-blacklist.conf
|
||||
sudo modprobe -r $(cat $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf | sed -n -e 's/^blacklist //p')
|
||||
|
||||
5. Install a udev rule by creating a link into your radioconda installation:
|
||||
|
||||
.. code-block:: bash
|
||||
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/rtl-sdr.rules /etc/udev/rules.d/radioconda-rtl-sdr.rules
|
||||
sudo udevadm control --reload
|
||||
sudo udevadm trigger
|
||||
|
||||
Further Information
|
||||
-------------------
|
||||
- `RTL-SDR Official Website <https://www.rtl-sdr.com/>`_
|
||||
- `RTL-SDR Documentation <https://www.rtl-sdr.com/rtl-sdr-quick-start-guide/>`_
|
||||
59
docs/source/sdr_guides/thinkrf.rst
Normal file
59
docs/source/sdr_guides/thinkrf.rst
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
.. _thinkrf:
|
||||
|
||||
ThinkRF
|
||||
====================
|
||||
|
||||
The ThinkRF series of spectrum analyzers and software-defined radio platforms are designed for advanced
|
||||
RF signal monitoring, analysis, and wireless research. These devices
|
||||
combine high-performance RF front ends with flexible software interfaces for a wide range of applications,
|
||||
including spectrum monitoring, signal intelligence, and wireless testing.
|
||||
|
||||
ThinkRF devices offer wide frequency coverage, deep dynamic range, and real-time analysis capabilities.
|
||||
They are built for professional and research-grade environments, offering Ethernet-based connectivity and
|
||||
software APIs for remote control and integration into automated systems.
|
||||
|
||||
Supported Models
|
||||
----------------
|
||||
- **ThinkRF R5550**: A real-time spectrum analyzer with frequency coverage from 9 kHz to 27 GHz, 160 MHz real-time bandwidth,
|
||||
and 100 MHz instantaneous FFT bandwidth.
|
||||
|
||||
Key Features
|
||||
------------
|
||||
- Frequency Range: 9 kHz to 27 GHz (depending on model).
|
||||
- Bandwidth: Up to 160 MHz real-time bandwidth.
|
||||
- Connectivity: Gigabit Ethernet interface for high-throughput streaming and remote control.
|
||||
- Software Support: Compatible with ThinkRF APIs, GNU Radio, MATLAB, and third-party spectrum analysis software.
|
||||
- Real-Time Analysis:
|
||||
- Enables full-band, real-time spectral visibility for dynamic signal environments.
|
||||
- Supports trigger-based capture and event-driven recording.
|
||||
- Remote Operation:
|
||||
- Designed for distributed deployments and networked operation through Ethernet.
|
||||
- Can be integrated into automated RF monitoring systems or deployed for field data collection.
|
||||
|
||||
Limitations
|
||||
-----------
|
||||
- Requires external host for processing (no onboard CPU for user applications).
|
||||
- Dependent on ThinkRF software drivers and API for device control.
|
||||
- High data rate operation may require optimized network settings or storage systems.
|
||||
|
||||
Set up instructions (Linux)
|
||||
---------------------------------
|
||||
|
||||
Install PyRF
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install 'pyrf>=2.8.0'
|
||||
|
||||
Convert PyRF scripts to Python 3
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd ../scripts
|
||||
./convert_pyrf_to_python3.sh
|
||||
|
||||
Further Information
|
||||
-------------------
|
||||
- `ThinkRF Documentation <https://thinkrf.com/resources/>`_
|
||||
- `ThinkRF Product Page <https://thinkrf.com/products/>`_
|
||||
- `Pyrf Github Page <https://github.com/pyrf/pyrf>`_
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "ria-toolkit-oss"
|
||||
version = "0.1.2"
|
||||
version = "0.1.3"
|
||||
description = "An open-source version of the RIA Toolkit, including the fundamental tools to get started developing, testing, and deploying radio intelligence applications"
|
||||
license = { text = "AGPL-3.0-only" }
|
||||
readme = "README.md"
|
||||
|
|
@ -11,6 +11,7 @@ authors = [
|
|||
maintainers = [
|
||||
{ name = "Benjamin Chinnery", email = "ben@qoherent.ai" },
|
||||
{ name = "Ashkan Beigi", email = "ash@qoherent.ai" },
|
||||
{ name = "Madrigal Weersink", email = "madrigal@qoherent.ai" },
|
||||
]
|
||||
keywords = [
|
||||
"radio",
|
||||
|
|
@ -47,6 +48,23 @@ dependencies = [
|
|||
"pyzmq (>=27.1.0,<28.0.0)",
|
||||
]
|
||||
|
||||
# [project.optional-dependencies] Commented out to prevent Tox tests from failing
|
||||
# # SDR hardware-specific dependencies (optional installs)
|
||||
# rtlsdr = ["pyrtlsdr>=0.2.9"]
|
||||
# pluto = ["pyadi-iio>=0.0.14"]
|
||||
# usrp = [] # Requires system UHD installation
|
||||
# hackrf = ["pyhackrf>=0.2.0"]
|
||||
# bladerf = [] # Requires system libbladerf installation
|
||||
# thinkrf = ["pyrf>=2.8.0"] # NOTE: Requires lib2to3 post-install fix (see docs/)
|
||||
|
||||
# All SDR hardware support
|
||||
all-sdr = [
|
||||
"pyrtlsdr>=0.2.9",
|
||||
"pyadi-iio>=0.0.14",
|
||||
"pyhackrf>=0.2.0",
|
||||
"pyrf>=2.8.0",
|
||||
]
|
||||
|
||||
[tool.poetry]
|
||||
packages = [
|
||||
{ include = "ria_toolkit_oss", from = "src" }
|
||||
|
|
|
|||
45
scripts/convert_pyrf_to_python3.sh
Executable file
45
scripts/convert_pyrf_to_python3.sh
Executable file
|
|
@ -0,0 +1,45 @@
|
|||
#!/bin/bash
|
||||
# Fix pyrf Python 3 compatibility
|
||||
# Run this after: pip install pyrf
|
||||
|
||||
set -e
|
||||
|
||||
VENV_DIR="${1:-venv}"
|
||||
PYRF_BASE="$VENV_DIR/lib/python3.12/site-packages/pyrf"
|
||||
|
||||
if [ ! -d "$PYRF_BASE" ]; then
|
||||
echo "❌ pyrf not found at $PYRF_BASE"
|
||||
echo "Usage: $0 [venv_directory]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "🔧 Fixing pyrf for Python 3..."
|
||||
|
||||
# Backup originals
|
||||
cp "$PYRF_BASE/devices/thinkrf.py" "$PYRF_BASE/devices/thinkrf.py.bak" 2>/dev/null || true
|
||||
cp "$PYRF_BASE/devices/thinkrf_properties.py" "$PYRF_BASE/devices/thinkrf_properties.py.bak" 2>/dev/null || true
|
||||
cp "$PYRF_BASE/connectors/blocking.py" "$PYRF_BASE/connectors/blocking.py.bak" 2>/dev/null || true
|
||||
|
||||
# Fix thinkrf.py
|
||||
echo " Fixing thinkrf.py..."
|
||||
sed -i 's/\.iteritems()/.items()/g' "$PYRF_BASE/devices/thinkrf.py"
|
||||
sed -i 's/raw_input/input/g' "$PYRF_BASE/devices/thinkrf.py"
|
||||
# Fix print statements (carefully to handle the format string)
|
||||
sed -i '884s/.*/ print(fmt % (index, wsa["HOST"], modelstring, wsa["SERIAL"]))/' "$PYRF_BASE/devices/thinkrf.py"
|
||||
sed -i 's/print "r) Refresh"/print("r) Refresh")/g' "$PYRF_BASE/devices/thinkrf.py"
|
||||
sed -i 's/print "q) Abort"/print("q) Abort")/g' "$PYRF_BASE/devices/thinkrf.py"
|
||||
sed -i 's/print "error: invalid selection: '\''%s'\''" % choice/print("error: invalid selection: '\''%s'\''" % choice)/g' "$PYRF_BASE/devices/thinkrf.py"
|
||||
|
||||
# Fix thinkrf_properties.py
|
||||
echo " Fixing thinkrf_properties.py..."
|
||||
sed -i 's/\.iteritems()/.items()/g' "$PYRF_BASE/devices/thinkrf_properties.py"
|
||||
|
||||
# Fix blocking.py (socket bytes issue)
|
||||
echo " Fixing blocking.py..."
|
||||
sed -i '29s/self._sock_scpi.send(cmd)/self._sock_scpi.send(cmd.encode())/' "$PYRF_BASE/connectors/blocking.py"
|
||||
sed -i '34s/self._sock_scpi.send(cmd)/self._sock_scpi.send(cmd.encode())/' "$PYRF_BASE/connectors/blocking.py"
|
||||
# Fix line 37 - replace entire line to avoid double decode
|
||||
sed -i '37s/.*/ return buf.decode()/' "$PYRF_BASE/connectors/blocking.py"
|
||||
|
||||
echo "✅ pyrf fixed for Python 3!"
|
||||
echo " Backups saved with .bak extension"
|
||||
44
scripts/fix_pyrf_python3.py
Normal file
44
scripts/fix_pyrf_python3.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fix pyrf Python 3 compatibility.
|
||||
|
||||
The pyrf library ships with Python 2 syntax in pyrf/devices/thinkrf.py.
|
||||
This script uses lib2to3 to automatically convert it to Python 3.
|
||||
|
||||
Usage:
|
||||
python scripts/fix_pyrf_python3.py
|
||||
|
||||
Run this after installing pyrf:
|
||||
pip install ria-toolkit-oss[thinkrf]
|
||||
python scripts/fix_pyrf_python3.py
|
||||
"""
|
||||
|
||||
from lib2to3.refactor import RefactoringTool, get_fixers_from_package
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import pyrf
|
||||
except ImportError:
|
||||
print("ERROR: pyrf is not installed.")
|
||||
print("Install with: pip install pyrf")
|
||||
print("Or install ria with ThinkRF support: pip install ria-toolkit-oss[thinkrf]")
|
||||
exit(1)
|
||||
|
||||
# Find the thinkrf.py file in the pyrf package
|
||||
thinkrf_path = Path(pyrf.__file__).resolve().parent / "devices" / "thinkrf.py"
|
||||
|
||||
if not thinkrf_path.exists():
|
||||
print(f"ERROR: Could not find {thinkrf_path}")
|
||||
print("Is pyrf installed correctly?")
|
||||
exit(1)
|
||||
|
||||
print(f"Found pyrf ThinkRF module at: {thinkrf_path}")
|
||||
|
||||
# Apply lib2to3 fixes
|
||||
print("Applying Python 3 compatibility fixes...")
|
||||
fixers = get_fixers_from_package("lib2to3.fixes")
|
||||
tool = RefactoringTool(fixers)
|
||||
tool.refactor_file(str(thinkrf_path), write=True)
|
||||
|
||||
print(f"✅ Successfully patched {thinkrf_path} for Python 3 compatibility.")
|
||||
print("\nYou can now use ria_toolkit_oss.sdr.thinkrf.ThinkRF")
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
|
@ -12,7 +11,6 @@ from typing import Any, Iterator, Optional
|
|||
|
||||
import numpy as np
|
||||
from numpy.typing import ArrayLike
|
||||
from quantiphy import Quantity
|
||||
|
||||
from ria_toolkit_oss.datatypes.annotation import Annotation
|
||||
|
||||
|
|
@ -450,7 +448,63 @@ class Recording:
|
|||
else:
|
||||
raise ValueError(f"Key {key} is protected and cannot be modified or removed.")
|
||||
|
||||
def to_sigmf(self, filename: Optional[str] = None, path: Optional[os.PathLike | str] = None) -> None:
|
||||
def view(self, output_path: Optional[str] = "images/signal.png", **kwargs) -> None:
|
||||
"""Create a plot of various signal visualizations as a PNG image.
|
||||
|
||||
:param output_path: The output image path. Defaults to "images/signal.png".
|
||||
:type output_path: str, optional
|
||||
:param kwargs: Keyword arguments passed on to utils.view.view_sig.
|
||||
:type: dict of keyword arguments
|
||||
|
||||
**Examples:**
|
||||
|
||||
Create a recording and view it as a plot in a .png image:
|
||||
|
||||
>>> import numpy
|
||||
>>> from utils.data import Recording
|
||||
|
||||
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
||||
>>> metadata = {
|
||||
>>> "sample_rate": 1e6,
|
||||
>>> "center_frequency": 2.44e9,
|
||||
>>> }
|
||||
|
||||
>>> recording = Recording(data=samples, metadata=metadata)
|
||||
>>> recording.view()
|
||||
"""
|
||||
from ria_toolkit_oss.view.view_signal import view_sig
|
||||
|
||||
view_sig(recording=self, output_path=output_path, **kwargs)
|
||||
|
||||
def simple_view(self, **kwargs) -> None:
|
||||
"""Create a plot of various signal visualizations as a PNG or SVG image.
|
||||
|
||||
:param kwargs: Keyword arguments passed on to utils.view.view_signal_simple.create_plots.
|
||||
:type: dict of keyword arguments
|
||||
|
||||
**Examples:**
|
||||
|
||||
Create a recording and view it as a plot in a .png image:
|
||||
|
||||
>>> import numpy
|
||||
>>> from utils.data import Recording
|
||||
|
||||
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
||||
>>> metadata = {
|
||||
>>> "sample_rate": 1e6,
|
||||
>>> "center_frequency": 2.44e9,
|
||||
>>> }
|
||||
|
||||
>>> recording = Recording(data=samples, metadata=metadata)
|
||||
>>> recording.simple_view()
|
||||
"""
|
||||
from ria_toolkit_oss.view.view_signal_simple import view_simple_sig
|
||||
|
||||
view_simple_sig(recording=self, **kwargs)
|
||||
|
||||
def to_sigmf(
|
||||
self, filename: Optional[str] = None, path: Optional[os.PathLike | str] = None, overwrite: bool = False
|
||||
) -> None:
|
||||
"""Write recording to a set of SigMF files.
|
||||
|
||||
The SigMF io format is defined by the `SigMF Specification Project <https://github.com/sigmf/SigMF>`_
|
||||
|
|
@ -468,9 +522,11 @@ class Recording:
|
|||
"""
|
||||
from ria_toolkit_oss.io.recording import to_sigmf
|
||||
|
||||
to_sigmf(filename=filename, path=path, recording=self)
|
||||
to_sigmf(filename=filename, path=path, recording=self, overwrite=overwrite)
|
||||
|
||||
def to_npy(self, filename: Optional[str] = None, path: Optional[os.PathLike | str] = None) -> str:
|
||||
def to_npy(
|
||||
self, filename: Optional[str] = None, path: Optional[os.PathLike | str] = None, overwrite: bool = False
|
||||
) -> str:
|
||||
"""Write recording to ``.npy`` binary file.
|
||||
|
||||
:param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename.
|
||||
|
|
@ -501,7 +557,7 @@ class Recording:
|
|||
"""
|
||||
from ria_toolkit_oss.io.recording import to_npy
|
||||
|
||||
to_npy(recording=self, filename=filename, path=path)
|
||||
to_npy(recording=self, filename=filename, path=path, overwrite=overwrite)
|
||||
|
||||
def trim(self, num_samples: int, start_sample: Optional[int] = 0) -> Recording:
|
||||
"""Trim Recording samples to a desired length, shifting annotations to maintain alignment.
|
||||
|
|
@ -594,40 +650,6 @@ class Recording:
|
|||
scaled_data = self.data / np.max(abs(self.data))
|
||||
return Recording(data=scaled_data, metadata=self.metadata, annotations=self.annotations)
|
||||
|
||||
def generate_filename(self, tag: Optional[str] = "rec"):
|
||||
"""Generate a filename from metadata.
|
||||
|
||||
:param tag: The string at the beginning of the generated filename. Default is "rec".
|
||||
:type tag: str, optional
|
||||
|
||||
:return: A filename without an extension.
|
||||
:rtype: str
|
||||
"""
|
||||
# TODO: This method should be refactored to use the first 7 characters of the 'rec_id' field.
|
||||
|
||||
tag = tag + "_"
|
||||
source = self.metadata.get("source", "")
|
||||
if source != "":
|
||||
source = source + "_"
|
||||
|
||||
# converts 1000 to 1k for example
|
||||
center_frequency = str(Quantity(self.metadata.get("center_frequency", 0)))
|
||||
if center_frequency != "0":
|
||||
num = center_frequency[:-1]
|
||||
suffix = center_frequency[-1]
|
||||
num = int(np.round(float(num)))
|
||||
else:
|
||||
num = 0
|
||||
suffix = ""
|
||||
center_frequency = str(num) + suffix + "Hz_"
|
||||
|
||||
timestamp = int(self.timestamp)
|
||||
timestamp = datetime.datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d_%H-%M-%S") + "_"
|
||||
|
||||
# Add first seven characters of rec_id for uniqueness
|
||||
rec_id = self.rec_id[0:7]
|
||||
return tag + source + center_frequency + timestamp + rec_id
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The length of a recording is defined by the number of complex samples in each channel of the recording."""
|
||||
return self.shape[1]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
Utilities for input/output operations on the ria_toolkit_oss.datatypes.Recording object.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import datetime as dt
|
||||
import os
|
||||
from datetime import timezone
|
||||
|
|
@ -9,6 +10,7 @@ from typing import Optional
|
|||
|
||||
import numpy as np
|
||||
import sigmf
|
||||
from quantiphy import Quantity
|
||||
from sigmf import SigMFFile, sigmffile
|
||||
from sigmf.utils import get_data_type_str
|
||||
|
||||
|
|
@ -92,7 +94,12 @@ def convert_to_serializable(obj):
|
|||
raise TypeError(f"Value of type {type(obj)} is not JSON serializable: {obj}")
|
||||
|
||||
|
||||
def to_sigmf(recording: Recording, filename: Optional[str] = None, path: Optional[os.PathLike | str] = None) -> None:
|
||||
def to_sigmf(
|
||||
recording: Recording,
|
||||
filename: Optional[str] = None,
|
||||
path: Optional[os.PathLike | str] = None,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Write recording to a set of SigMF files.
|
||||
|
||||
The SigMF io format is defined by the `SigMF Specification Project <https://github.com/sigmf/SigMF>`_
|
||||
|
|
@ -121,7 +128,7 @@ def to_sigmf(recording: Recording, filename: Optional[str] = None, path: Optiona
|
|||
if filename is not None:
|
||||
filename, _ = os.path.splitext(filename)
|
||||
else:
|
||||
filename = recording.generate_filename()
|
||||
filename = generate_filename(recording=recording)
|
||||
|
||||
if path is None:
|
||||
path = "recordings"
|
||||
|
|
@ -140,6 +147,13 @@ def to_sigmf(recording: Recording, filename: Optional[str] = None, path: Optiona
|
|||
samples = multichannel_samples[0]
|
||||
|
||||
data_file_path = os.path.join(path, f"{filename}.sigmf-data")
|
||||
meta_file_path = os.path.join(path, f"{filename}.sigmf-meta")
|
||||
|
||||
if not overwrite:
|
||||
if os.path.isfile(data_file_path):
|
||||
raise IOError("File already exists")
|
||||
if os.path.isfile(meta_file_path):
|
||||
raise IOError("File already exists")
|
||||
|
||||
samples.tofile(data_file_path)
|
||||
global_info = {
|
||||
|
|
@ -188,7 +202,7 @@ def to_sigmf(recording: Recording, filename: Optional[str] = None, path: Optiona
|
|||
meta_dict = sigMF_metafile.ordered_metadata()
|
||||
meta_dict["ria"] = metadata
|
||||
|
||||
sigMF_metafile.tofile(f"{os.path.join(path, filename)}.sigmf-meta")
|
||||
sigMF_metafile.tofile(meta_file_path)
|
||||
|
||||
|
||||
def from_sigmf(file: os.PathLike | str) -> Recording:
|
||||
|
|
@ -250,7 +264,12 @@ def from_sigmf(file: os.PathLike | str) -> Recording:
|
|||
return output_recording
|
||||
|
||||
|
||||
def to_npy(recording: Recording, filename: Optional[str] = None, path: Optional[os.PathLike | str] = None) -> str:
|
||||
def to_npy(
|
||||
recording: Recording,
|
||||
filename: Optional[str] = None,
|
||||
path: Optional[os.PathLike | str] = None,
|
||||
overwrite: bool = False,
|
||||
) -> str:
|
||||
"""Write recording to ``.npy`` binary file.
|
||||
|
||||
:param recording: The recording to be written to file.
|
||||
|
|
@ -277,7 +296,7 @@ def to_npy(recording: Recording, filename: Optional[str] = None, path: Optional[
|
|||
if filename is not None:
|
||||
filename, _ = os.path.splitext(filename)
|
||||
else:
|
||||
filename = recording.generate_filename()
|
||||
filename = generate_filename(recording=recording)
|
||||
filename = filename + ".npy"
|
||||
|
||||
if path is None:
|
||||
|
|
@ -287,6 +306,10 @@ def to_npy(recording: Recording, filename: Optional[str] = None, path: Optional[
|
|||
os.makedirs(path)
|
||||
fullpath = os.path.join(path, filename)
|
||||
|
||||
if not overwrite:
|
||||
if os.path.isfile(fullpath):
|
||||
raise IOError("File already exists")
|
||||
|
||||
data = np.array(recording.data)
|
||||
metadata = recording.metadata
|
||||
annotations = recording.annotations
|
||||
|
|
@ -330,3 +353,37 @@ def from_npy(file: os.PathLike | str) -> Recording:
|
|||
|
||||
recording = Recording(data=data, metadata=metadata, annotations=annotations)
|
||||
return recording
|
||||
|
||||
|
||||
def generate_filename(recording: Recording, tag: Optional[str] = "rec"):
|
||||
"""Generate a filename from metadata.
|
||||
|
||||
:param tag: The string at the beginning of the generated filename. Default is "rec".
|
||||
:type tag: str, optional
|
||||
|
||||
:return: A filename without an extension.
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
tag = tag + "_"
|
||||
source = recording.metadata.get("source", "")
|
||||
if source != "":
|
||||
source = source + "_"
|
||||
|
||||
# converts 1000 to 1k for example
|
||||
center_frequency = str(Quantity(recording.metadata.get("center_frequency", 0)))
|
||||
if center_frequency != "0":
|
||||
num = center_frequency[:-1]
|
||||
suffix = center_frequency[-1]
|
||||
num = int(np.round(float(num)))
|
||||
else:
|
||||
num = 0
|
||||
suffix = ""
|
||||
center_frequency = str(num) + suffix + "Hz_"
|
||||
|
||||
timestamp = int(recording.timestamp)
|
||||
timestamp = datetime.datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d_%H-%M-%S") + "_"
|
||||
|
||||
# Add first seven characters of rec_id for uniqueness
|
||||
rec_id = recording.rec_id[0:7]
|
||||
return tag + source + center_frequency + timestamp + rec_id
|
||||
|
|
|
|||
|
|
@ -325,8 +325,8 @@ f.argtypes = [p_hackrf_device, POINTER(read_partid_serialno_t)]
|
|||
# libhackrf.hackrf_set_txvga_gain.argtypes = [POINTER(hackrf_device), c_uint32]
|
||||
## extern ADDAPI int ADDCALL hackrf_set_antenna_enable(hackrf_device*
|
||||
## device, const uint8_t value);
|
||||
# libhackrf.hackrf_set_antenna_enable.restype = c_int
|
||||
# libhackrf.hackrf_set_antenna_enable.argtypes = [POINTER(hackrf_device), c_uint8]
|
||||
libhackrf.hackrf_set_antenna_enable.restype = c_int
|
||||
libhackrf.hackrf_set_antenna_enable.argtypes = [p_hackrf_device, c_uint8]
|
||||
#
|
||||
## extern ADDAPI const char* ADDCALL hackrf_error_name(enum hackrf_error errcode);
|
||||
## libhackrf.hackrf_error_name.restype = POINTER(c_char)
|
||||
|
|
@ -537,6 +537,16 @@ class HackRF(object):
|
|||
raise IOError("error disabling amp")
|
||||
return 0
|
||||
|
||||
def set_antenna_enable(self, enable):
|
||||
value = 1 if enable else 0
|
||||
result = libhackrf.hackrf_set_antenna_enable(self.dev_p, value)
|
||||
if result != 0:
|
||||
error_name = get_error_name(result)
|
||||
raise IOError(f"Error setting antenna bias tee: {error_name} (Code {result})")
|
||||
state = "enabled" if enable else "disabled"
|
||||
print(f"HackRF antenna bias tee {state}.")
|
||||
return 0
|
||||
|
||||
# rounds down to multiple of 8 (15 -> 8, 39 -> 32), etc.
|
||||
# internally, hackrf_set_lna_gain does the same thing
|
||||
# But we take care of it so we can keep track of the correct gain
|
||||
|
|
@ -582,6 +592,75 @@ class HackRF(object):
|
|||
if result != 0:
|
||||
raise IOError("stop_rx failure")
|
||||
|
||||
def _rx_capture_callback(self, hackrf_transfer):
|
||||
"""Instance method callback for RX capture - prevents garbage collection"""
|
||||
try:
|
||||
c = hackrf_transfer.contents
|
||||
|
||||
# Append bytes to buffer using string_at
|
||||
from ctypes import string_at
|
||||
byte_chunk = string_at(c.buffer, c.valid_length)
|
||||
self._capture_buffer.extend(byte_chunk)
|
||||
|
||||
# Check if we have enough
|
||||
if len(self._capture_buffer) >= self._capture_target:
|
||||
self._capture_done = True
|
||||
return 1 # Stop streaming
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(f"Error in RX capture callback: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
self._capture_done = True
|
||||
return 1
|
||||
|
||||
def read_samples(self, num_samples):
|
||||
"""
|
||||
Block capture mode for HackRF - captures exactly num_samples.
|
||||
This is safer than streaming for USB2 and avoids buffer overflow issues.
|
||||
|
||||
:param num_samples: Number of complex samples to capture
|
||||
:return: numpy array of complex64 samples
|
||||
"""
|
||||
# Initialize capture state as instance variables
|
||||
self._capture_buffer = bytearray()
|
||||
self._capture_target = num_samples * 2 # 2 bytes per complex sample (I+Q as int8)
|
||||
self._capture_done = False
|
||||
|
||||
# Store callback as instance variable to prevent garbage collection (like TX does)
|
||||
self._rx_cb = _callback(self._rx_capture_callback)
|
||||
|
||||
# Start RX with the callback
|
||||
result = libhackrf.hackrf_start_rx(self.dev_p, self._rx_cb, None)
|
||||
if result != 0:
|
||||
raise IOError("start_rx failure during read_samples")
|
||||
|
||||
# Wait for capture to complete
|
||||
import time
|
||||
timeout = num_samples / self.sample_rate + 5.0 # Add 5 second buffer
|
||||
start_time = time.time()
|
||||
|
||||
while not self._capture_done:
|
||||
if time.time() - start_time > timeout:
|
||||
print("HackRF capture timeout!")
|
||||
break
|
||||
time.sleep(0.01)
|
||||
|
||||
# Stop RX
|
||||
self.stop_rx()
|
||||
|
||||
# Convert bytes to complex samples
|
||||
byte_data = bytes(self._capture_buffer[:self._capture_target])
|
||||
all_samples = np.frombuffer(byte_data, dtype=np.int8).astype(np.float32).view(np.complex64)
|
||||
|
||||
# Clean up instance variables
|
||||
del self._capture_buffer
|
||||
del self._capture_target
|
||||
del self._capture_done
|
||||
del self._rx_cb
|
||||
|
||||
return all_samples[:num_samples]
|
||||
|
||||
# Add transmit gain property
|
||||
def set_txvga_gain(self, gain):
|
||||
if gain < 0 or gain > 47:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import time
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -67,9 +69,6 @@ class Blade(SDR):
|
|||
print("FPGA version:\t\t" + str(device.get_fpga_version()))
|
||||
return 0
|
||||
|
||||
def close(self):
|
||||
self.device.close()
|
||||
|
||||
def init_rx(
|
||||
self,
|
||||
sample_rate: int | float,
|
||||
|
|
@ -92,6 +91,9 @@ class Blade(SDR):
|
|||
:type channel: int
|
||||
:param buffer_size: The buffer size during receive. Defaults to 8192.
|
||||
:type buffer_size: int
|
||||
:param gain_mode: 'absolute' passes gain directly to the sdr,
|
||||
'relative' means that gain should be a negative value, and it will be subtracted from the max gain (60).
|
||||
:type gain_mode: str
|
||||
"""
|
||||
print("Initializing RX")
|
||||
|
||||
|
|
@ -112,6 +114,93 @@ class Blade(SDR):
|
|||
self._rx_initialized = True
|
||||
self._tx_initialized = False
|
||||
|
||||
def _stream_rx(self, callback):
|
||||
if not self._rx_initialized:
|
||||
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
|
||||
|
||||
# Setup synchronous stream
|
||||
self.device.sync_config(
|
||||
layout=_bladerf.ChannelLayout.RX_X1,
|
||||
fmt=_bladerf.Format.SC16_Q11,
|
||||
num_buffers=16,
|
||||
buffer_size=self.rx_buffer_size,
|
||||
num_transfers=8,
|
||||
stream_timeout=3500000000,
|
||||
)
|
||||
|
||||
self.rx_ch.enable = True
|
||||
self.bytes_per_sample = 4
|
||||
|
||||
print("Blade Starting RX...")
|
||||
self._enable_rx = True
|
||||
|
||||
while self._enable_rx:
|
||||
# Create receive buffer and read in samples to buffer
|
||||
# Add them to a list to convert and save after stream is finished
|
||||
buffer = bytearray(self.rx_buffer_size * self.bytes_per_sample)
|
||||
self.device.sync_rx(buffer, self.rx_buffer_size)
|
||||
signal = self._convert_rx_samples(buffer)
|
||||
self.buffer = buffer
|
||||
# send callback complex signal
|
||||
callback(buffer=signal, metadata=None)
|
||||
|
||||
# Disable module
|
||||
print("Blade RX Completed.")
|
||||
self.rx_ch.enable = False
|
||||
|
||||
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None):
|
||||
if not self._rx_initialized:
|
||||
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
|
||||
|
||||
if num_samples is not None and rx_time is not None:
|
||||
raise ValueError("Only input one of num_samples or rx_time")
|
||||
elif num_samples is not None:
|
||||
self._num_samples_to_record = num_samples
|
||||
elif rx_time is not None:
|
||||
self._num_samples_to_record = int(rx_time * self.rx_sample_rate)
|
||||
else:
|
||||
raise ValueError("Must provide input of one of num_samples or rx_time")
|
||||
|
||||
# Setup synchronous stream
|
||||
self.device.sync_config(
|
||||
layout=_bladerf.ChannelLayout.RX_X1,
|
||||
fmt=_bladerf.Format.SC16_Q11,
|
||||
num_buffers=16,
|
||||
buffer_size=self.rx_buffer_size,
|
||||
num_transfers=8,
|
||||
stream_timeout=3500000000,
|
||||
)
|
||||
|
||||
self.rx_ch.enable = True
|
||||
self.bytes_per_sample = 4
|
||||
|
||||
print("Blade Starting RX...")
|
||||
self._enable_rx = True
|
||||
|
||||
store_array = np.zeros(
|
||||
(1, (self._num_samples_to_record // self.rx_buffer_size + 1) * self.rx_buffer_size), dtype=np.complex64
|
||||
)
|
||||
|
||||
for i in range(self._num_samples_to_record // self.rx_buffer_size + 1):
|
||||
# Create receive buffer and read in samples to buffer
|
||||
# Add them to a list to convert and save after stream is finished
|
||||
buffer = bytearray(self.rx_buffer_size * self.bytes_per_sample)
|
||||
self.device.sync_rx(buffer, self.rx_buffer_size)
|
||||
signal = self._convert_rx_samples(buffer)
|
||||
store_array[:, i * self.rx_buffer_size : (i + 1) * self.rx_buffer_size] = signal
|
||||
|
||||
# Disable module
|
||||
print("Blade RX Completed.")
|
||||
self.rx_ch.enable = False
|
||||
metadata = {
|
||||
"source": self.__class__.__name__,
|
||||
"sample_rate": self.rx_sample_rate,
|
||||
"center_frequency": self.rx_center_frequency,
|
||||
"gain": self.rx_gain,
|
||||
}
|
||||
|
||||
return Recording(data=store_array[:, : self._num_samples_to_record], metadata=metadata)
|
||||
|
||||
def init_tx(
|
||||
self,
|
||||
sample_rate: int | float,
|
||||
|
|
@ -134,6 +223,9 @@ class Blade(SDR):
|
|||
:type channel: int
|
||||
:param buffer_size: The buffer size during transmission. Defaults to 8192.
|
||||
:type buffer_size: int
|
||||
:param gain_mode: 'absolute' passes gain directly to the sdr,
|
||||
'relative' means that gain should be a negative value, and it will be subtracted from the max gain (60).
|
||||
:type gain_mode: str
|
||||
"""
|
||||
|
||||
# Configure BladeRF
|
||||
|
|
@ -162,84 +254,6 @@ class Blade(SDR):
|
|||
self._rx_initialized = False
|
||||
return 0
|
||||
|
||||
def _stream_rx(self, callback):
|
||||
if not self._rx_initialized:
|
||||
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
|
||||
|
||||
# Setup synchronous stream
|
||||
self.device.sync_config(
|
||||
layout=_bladerf.ChannelLayout.RX_X1,
|
||||
fmt=_bladerf.Format.SC16_Q11,
|
||||
num_buffers=16,
|
||||
buffer_size=self.rx_buffer_size,
|
||||
num_transfers=8,
|
||||
stream_timeout=3500000000,
|
||||
)
|
||||
|
||||
self.rx_ch.enable = True
|
||||
self.bytes_per_sample = 4
|
||||
|
||||
print("Blade Starting RX...")
|
||||
self._enable_rx = True
|
||||
|
||||
while self._enable_rx:
|
||||
# Create receive buffer and read in samples to buffer
|
||||
# Add them to a list to convert and save after stream is finished
|
||||
buffer = bytearray(self.rx_buffer_size * self.bytes_per_sample)
|
||||
self.device.sync_rx(buffer, self.rx_buffer_size)
|
||||
signal = self._convert_rx_samples(buffer)
|
||||
# samples = convert_to_2xn(signal)
|
||||
self.buffer = buffer
|
||||
# send callback complex signal
|
||||
callback(buffer=signal, metadata=None)
|
||||
|
||||
# Disable module
|
||||
print("Blade RX Completed.")
|
||||
self.rx_ch.enable = False
|
||||
|
||||
def record(self, num_samples):
|
||||
if not self._rx_initialized:
|
||||
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
|
||||
|
||||
# Setup synchronous stream
|
||||
self.device.sync_config(
|
||||
layout=_bladerf.ChannelLayout.RX_X1,
|
||||
fmt=_bladerf.Format.SC16_Q11,
|
||||
num_buffers=16,
|
||||
buffer_size=self.rx_buffer_size,
|
||||
num_transfers=8,
|
||||
stream_timeout=3500000000,
|
||||
)
|
||||
|
||||
self.rx_ch.enable = True
|
||||
self.bytes_per_sample = 4
|
||||
|
||||
print("Blade Starting RX...")
|
||||
self._enable_rx = True
|
||||
|
||||
store_array = np.zeros((1, (num_samples // self.rx_buffer_size + 1) * self.rx_buffer_size), dtype=np.complex64)
|
||||
|
||||
for i in range(num_samples // self.rx_buffer_size + 1):
|
||||
# Create receive buffer and read in samples to buffer
|
||||
# Add them to a list to convert and save after stream is finished
|
||||
buffer = bytearray(self.rx_buffer_size * self.bytes_per_sample)
|
||||
self.device.sync_rx(buffer, self.rx_buffer_size)
|
||||
signal = self._convert_rx_samples(buffer)
|
||||
# samples = convert_to_2xn(signal)
|
||||
store_array[:, i * self.rx_buffer_size : (i + 1) * self.rx_buffer_size] = signal
|
||||
|
||||
# Disable module
|
||||
print("Blade RX Completed.")
|
||||
self.rx_ch.enable = False
|
||||
metadata = {
|
||||
"source": self.__class__.__name__,
|
||||
"sample_rate": self.rx_sample_rate,
|
||||
"center_frequency": self.rx_center_frequency,
|
||||
"gain": self.rx_gain,
|
||||
}
|
||||
|
||||
return Recording(data=store_array[:, :num_samples], metadata=metadata)
|
||||
|
||||
def _stream_tx(self, callback):
|
||||
|
||||
# Setup stream
|
||||
|
|
@ -267,6 +281,88 @@ class Blade(SDR):
|
|||
print("Blade TX Completed.")
|
||||
self.tx_ch.enable = False
|
||||
|
||||
def tx_recording(
|
||||
self,
|
||||
recording: Recording | np.ndarray,
|
||||
num_samples: Optional[int] = None,
|
||||
tx_time: Optional[int | float] = None,
|
||||
):
|
||||
"""
|
||||
Transmit the given IQ samples from the provided recording.
|
||||
init_tx() must be called before this function.
|
||||
|
||||
:param recording: The recording to transmit.
|
||||
:type recording: Recording or np.ndarray
|
||||
:param num_samples: The number of samples to transmit, will repeat or
|
||||
truncate the recording to this length. Defaults to None.
|
||||
:type num_samples: int, optional
|
||||
:param tx_time: The time to transmit, will repeat or truncate the
|
||||
recording to this length. Defaults to None.
|
||||
:type tx_time: int or float, optional
|
||||
"""
|
||||
|
||||
if num_samples is not None and tx_time is not None:
|
||||
raise ValueError("Only input one of num_samples or tx_time")
|
||||
elif num_samples is not None:
|
||||
tx_time = num_samples / self.tx_sample_rate
|
||||
elif tx_time is not None:
|
||||
pass
|
||||
else:
|
||||
tx_time = len(recording) / self.tx_sample_rate
|
||||
|
||||
if isinstance(recording, np.ndarray):
|
||||
samples = recording
|
||||
elif isinstance(recording, Recording):
|
||||
if len(recording.data) > 1:
|
||||
warnings.warn("Recording object is multichannel, only channel 0 data was used for transmission")
|
||||
samples = recording.data[0]
|
||||
else:
|
||||
raise TypeError("recording must be np.ndarray or Recording")
|
||||
|
||||
samples = samples.astype(np.complex64, copy=False)
|
||||
|
||||
# Setup stream
|
||||
self.device.sync_config(
|
||||
layout=_bladerf.ChannelLayout.TX_X1,
|
||||
fmt=_bladerf.Format.SC16_Q11,
|
||||
num_buffers=16,
|
||||
buffer_size=self.tx_buffer_size,
|
||||
num_transfers=8,
|
||||
stream_timeout=3500,
|
||||
)
|
||||
|
||||
# Enable module
|
||||
self.tx_ch.enable = True
|
||||
|
||||
print("Blade Starting TX...")
|
||||
|
||||
# Transmit samples - repeat as needed for the duration
|
||||
start_time = time.time()
|
||||
sample_index = 0
|
||||
|
||||
try:
|
||||
while time.time() - start_time < tx_time:
|
||||
# Get next chunk
|
||||
chunk_size = min(self.tx_buffer_size, len(samples) - sample_index)
|
||||
if chunk_size == 0:
|
||||
# Reached end, loop back
|
||||
sample_index = 0
|
||||
chunk_size = min(self.tx_buffer_size, len(samples))
|
||||
|
||||
chunk = samples[sample_index : sample_index + chunk_size]
|
||||
sample_index += chunk_size
|
||||
|
||||
# Convert and transmit
|
||||
byte_array = self._convert_tx_samples(chunk)
|
||||
self.device.sync_tx(byte_array, len(chunk))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nTransmission interrupted by user")
|
||||
|
||||
# Disable module
|
||||
print("Blade TX Completed.")
|
||||
self.tx_ch.enable = False
|
||||
|
||||
def _convert_rx_samples(self, samples):
|
||||
samples = np.frombuffer(samples, dtype=np.int16).astype(np.float32)
|
||||
samples /= 2048
|
||||
|
|
@ -274,14 +370,18 @@ class Blade(SDR):
|
|||
return samples
|
||||
|
||||
def _convert_tx_samples(self, samples):
|
||||
tx_samples = np.empty(samples.size * 2, dtype=np.float32)
|
||||
tx_samples[::2] = np.real(samples) # Real part
|
||||
tx_samples[1::2] = np.imag(samples) # Imaginary part
|
||||
# Normalize to maximum amplitude to prevent overflow
|
||||
max_val = np.max(np.abs(samples))
|
||||
if max_val > 0:
|
||||
samples = samples / max_val # Normalize to [-1, 1]
|
||||
|
||||
# Scale to Q11 format (use 2047 instead of 2048 to avoid overflow)
|
||||
# and interleave I/Q samples
|
||||
tx_samples = np.zeros(len(samples) * 2, dtype=np.int16)
|
||||
tx_samples[0::2] = (np.real(samples) * 2047).astype(np.int16) # I samples
|
||||
tx_samples[1::2] = (np.imag(samples) * 2047).astype(np.int16) # Q samples
|
||||
|
||||
tx_samples *= 2048
|
||||
tx_samples = tx_samples.astype(np.int16)
|
||||
byte_array = tx_samples.tobytes()
|
||||
|
||||
return byte_array
|
||||
|
||||
def _set_rx_channel(self, channel):
|
||||
|
|
@ -381,3 +481,22 @@ class Blade(SDR):
|
|||
|
||||
print(f"Clock source set to {self.device.get_clock_select()}")
|
||||
print(f"PLL Reference set to {self.device.get_pll_refclk()}")
|
||||
|
||||
def supports_bias_tee(self) -> bool:
|
||||
return True
|
||||
|
||||
def set_bias_tee(self, enable: bool, channel: Optional[int] = None):
|
||||
if channel is None:
|
||||
channel = getattr(self, "rx_channel", getattr(self, "tx_channel", 0))
|
||||
|
||||
try:
|
||||
bladerf_channel = _bladerf.CHANNEL_RX(channel)
|
||||
self.device.set_bias_tee(bladerf_channel, bool(enable))
|
||||
except AttributeError as exc: # pragma: no cover - depends on libbladeRF version
|
||||
raise NotImplementedError("bladeRF binding lacks bias-tee control") from exc
|
||||
|
||||
state = "enabled" if enable else "disabled"
|
||||
print(f"BladeRF bias tee {state} on channel {channel}.")
|
||||
|
||||
def close(self):
|
||||
self.device.close()
|
||||
|
|
|
|||
|
|
@ -35,10 +35,120 @@ class HackRF(SDR):
|
|||
|
||||
super().__init__()
|
||||
|
||||
def init_rx(self, sample_rate, center_frequency, gain, channel, gain_mode):
|
||||
def init_rx(
|
||||
self,
|
||||
sample_rate: int | float,
|
||||
center_frequency: int | float,
|
||||
gain: int,
|
||||
channel: int,
|
||||
gain_mode: Optional[str] = "absolute",
|
||||
):
|
||||
"""
|
||||
Initializes the HackRF for receiving.
|
||||
|
||||
HackRF has 3 gain stages:
|
||||
- 14 dB front-end amplifier (on/off)
|
||||
- LNA gain: 0-40 dB in 8 dB steps
|
||||
- VGA gain: 0-62 dB in 2 dB steps
|
||||
|
||||
:param sample_rate: The sample rate for receiving.
|
||||
:type sample_rate: int or float
|
||||
:param center_frequency: The center frequency of the recording.
|
||||
:type center_frequency: int or float
|
||||
:param gain: The LNA gain set for receiving on the HackRF
|
||||
:type gain: int
|
||||
:param channel: The channel the HackRF is set to. (Not actually used)
|
||||
:type channel: int
|
||||
:param gain_mode: 'absolute' passes gain directly to the sdr,
|
||||
'relative' means that gain should be a negative value, and it will be subtracted from the max gain (40).
|
||||
:type gain_mode: str
|
||||
"""
|
||||
print("Initializing RX")
|
||||
|
||||
self.rx_sample_rate = sample_rate
|
||||
self.radio.sample_rate = int(sample_rate)
|
||||
print(f"HackRF sample rate = {self.radio.sample_rate}")
|
||||
|
||||
self.rx_center_frequency = center_frequency
|
||||
self.radio.center_freq = int(center_frequency)
|
||||
print(f"HackRF center frequency = {self.radio.center_freq}")
|
||||
|
||||
# Distribute gain across amplifier stages
|
||||
rx_gain_min = 0
|
||||
rx_gain_max = 40 # (LNA)
|
||||
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise ValueError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This "
|
||||
"sets the gain relative to the maximum possible gain."
|
||||
)
|
||||
else:
|
||||
abs_gain = rx_gain_max + gain
|
||||
else:
|
||||
abs_gain = gain
|
||||
|
||||
if abs_gain < rx_gain_min or abs_gain > rx_gain_max:
|
||||
abs_gain = min(max(abs_gain, rx_gain_min), rx_gain_max)
|
||||
print(f"Gain {gain} out of range for HackRF.")
|
||||
print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB")
|
||||
|
||||
self.set_gain_amp(False)
|
||||
self.set_rx_vga_gain(45)
|
||||
self.set_rx_lna_gain(abs_gain)
|
||||
self.rx_gain = abs_gain
|
||||
|
||||
print(f"HackRF gain distribution: Amp={self.amp_enabled}, LNA={self.rx_lna_gain}dB, VGA={self.rx_vga_gain}dB")
|
||||
print("To individually modify the HackRF gains, use set_gain_amp(), set_rx_lna_gain(), and set_rx_vga_gain().")
|
||||
|
||||
self._tx_initialized = False
|
||||
self._rx_initialized = True
|
||||
return NotImplementedError("RX not yet implemented for HackRF")
|
||||
|
||||
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None):
|
||||
"""
|
||||
Create a radio recording (iq samples and metadata) of a given length from the SDR.
|
||||
HackRF uses block capture mode, which is more reliable than streaming for USB2 connections.
|
||||
Either num_samples or rx_time must be provided.
|
||||
init_rx() must be called before record()
|
||||
|
||||
:param num_samples: The number of samples to record.
|
||||
:type num_samples: int, optional
|
||||
:param rx_time: The time to record.
|
||||
:type rx_time: int or float, optional
|
||||
|
||||
returns: Recording object (iq samples and metadata)
|
||||
"""
|
||||
if not self._rx_initialized:
|
||||
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
|
||||
|
||||
if num_samples is not None and rx_time is not None:
|
||||
raise ValueError("Only input one of num_samples or rx_time")
|
||||
elif num_samples is not None:
|
||||
self._num_samples_to_record = num_samples
|
||||
elif rx_time is not None:
|
||||
self._num_samples_to_record = int(rx_time * self.rx_sample_rate)
|
||||
else:
|
||||
raise ValueError("Must provide input of one of num_samples or rx_time")
|
||||
|
||||
print("HackRF Starting RX...")
|
||||
|
||||
# Use libhackrf's block capture method
|
||||
all_samples = self.radio.read_samples(self._num_samples_to_record)
|
||||
|
||||
print("HackRF RX Completed.")
|
||||
|
||||
# Create 1xN array for single-channel recording
|
||||
store_array = np.zeros((1, self._num_samples_to_record), dtype=np.complex64)
|
||||
store_array[0, :] = all_samples
|
||||
|
||||
metadata = {
|
||||
"source": self.__class__.__name__,
|
||||
"sample_rate": self.rx_sample_rate,
|
||||
"center_frequency": self.rx_center_frequency,
|
||||
"gain": self.rx_gain,
|
||||
}
|
||||
|
||||
return Recording(data=store_array, metadata=metadata)
|
||||
|
||||
def init_tx(
|
||||
self,
|
||||
|
|
@ -72,8 +182,6 @@ class HackRF(SDR):
|
|||
self.radio.center_freq = int(center_frequency)
|
||||
print(f"HackRF center frequency = {self.radio.center_freq}")
|
||||
|
||||
self.radio.enable_amp()
|
||||
|
||||
tx_gain_min = 0
|
||||
tx_gain_max = 47
|
||||
if gain_mode == "relative":
|
||||
|
|
@ -92,8 +200,11 @@ class HackRF(SDR):
|
|||
print(f"Gain {gain} out of range for Pluto.")
|
||||
print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB")
|
||||
|
||||
self.radio.txvga_gain = abs_gain
|
||||
print(f"HackRF gain = {self.radio.txvga_gain}")
|
||||
self.set_gain_amp(True)
|
||||
self.set_tx_vga_gain(abs_gain)
|
||||
self.tx_gain = abs_gain
|
||||
print(f"HackRF gain distribution: Amp={self.amp_enabled}, VGA={self.tx_vga_gain}dB")
|
||||
print("To individually modify the HackRF gains, use set_gain_amp() or set_tx_vga_gain().")
|
||||
|
||||
self._tx_initialized = True
|
||||
self._rx_initialized = False
|
||||
|
|
@ -144,17 +255,90 @@ class HackRF(SDR):
|
|||
self.radio.stop_tx()
|
||||
print("HackRF Tx Completed.")
|
||||
|
||||
def set_clock_source(self, source):
|
||||
def set_gain_amp(self, enable):
|
||||
if enable:
|
||||
self.radio.enable_amp()
|
||||
self.amp_enabled = True
|
||||
else:
|
||||
self.radio.disable_amp()
|
||||
self.amp_enabled = False
|
||||
|
||||
def set_rx_lna_gain(self, lna_gain):
|
||||
self.radio.set_lna_gain(lna_gain)
|
||||
self.rx_lna_gain = lna_gain
|
||||
|
||||
def set_rx_vga_gain(self, vga_gain):
|
||||
self.radio.set_vga_gain(vga_gain)
|
||||
self.rx_vga_gain = vga_gain
|
||||
|
||||
def set_tx_vga_gain(self, vga_gain):
|
||||
self.radio.set_txvga_gain(vga_gain)
|
||||
self.tx_vga_gain = vga_gain
|
||||
|
||||
def set_clock_source(self, source):
|
||||
self.radio.set_clock_source(source)
|
||||
|
||||
def supports_bias_tee(self) -> bool:
|
||||
return True
|
||||
|
||||
def set_bias_tee(self, enable: bool):
|
||||
try:
|
||||
self.radio.set_antenna_enable(bool(enable))
|
||||
except AttributeError as exc: # pragma: no cover - defensive
|
||||
raise NotImplementedError("Underlying HackRF interface lacks bias-tee control") from exc
|
||||
|
||||
def close(self):
|
||||
self.radio.close()
|
||||
|
||||
def _stream_rx(self, callback):
|
||||
"""
|
||||
Stream samples from the HackRF using a callback function.
|
||||
|
||||
:param callback: Function to call for each buffer of samples
|
||||
:type callback: callable
|
||||
"""
|
||||
if not self._rx_initialized:
|
||||
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
|
||||
return NotImplementedError("RX not yet implemented for HackRF")
|
||||
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx()")
|
||||
|
||||
print("HackRF Starting RX stream...")
|
||||
|
||||
self._enable_rx = True
|
||||
|
||||
def rx_callback(hackrf_transfer):
|
||||
"""Internal callback that wraps the user's callback"""
|
||||
try:
|
||||
if not self._enable_rx:
|
||||
return 1 # Stop
|
||||
|
||||
c = hackrf_transfer.contents
|
||||
|
||||
# Use ctypes string_at to safely copy the buffer
|
||||
from ctypes import string_at
|
||||
|
||||
byte_data = string_at(c.buffer, c.valid_length)
|
||||
|
||||
# Convert bytes to int8, then to float32, then view as complex64
|
||||
samples = np.frombuffer(byte_data, dtype=np.int8).astype(np.float32).view(np.complex64)
|
||||
|
||||
# Call user's callback
|
||||
callback(buffer=samples, metadata=None)
|
||||
|
||||
return 0 if self._enable_rx else 1
|
||||
except Exception as e:
|
||||
print(f"Error in rx_callback: {e}")
|
||||
return 1 # Stop on error
|
||||
|
||||
# Start RX
|
||||
self.radio.start_rx(rx_callback)
|
||||
|
||||
# Wait while streaming
|
||||
while self._enable_rx:
|
||||
time.sleep(0.1)
|
||||
|
||||
# Stop RX
|
||||
self.radio.stop_rx()
|
||||
|
||||
print("HackRF RX stream completed.")
|
||||
|
||||
def _stream_tx(self, callback):
|
||||
return super()._stream_tx(callback)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class Pluto(SDR):
|
|||
"""
|
||||
Initialize a Pluto SDR device object and connect to the SDR hardware.
|
||||
|
||||
This software supports the ADALAM Pluto SDR created by Analog Devices.
|
||||
This software supports the ADALM Pluto SDR created by Analog Devices.
|
||||
|
||||
:param identifier: The value of the parameter that identifies the device.
|
||||
:type identifier: str = "192.168.3.1", "pluto.local", etc
|
||||
|
|
@ -34,8 +34,25 @@ class Pluto(SDR):
|
|||
else:
|
||||
uri = f"ip:{identifier}"
|
||||
|
||||
self.radio = adi.ad9361(uri)
|
||||
print(f"Successfully found Pluto radio with identifier [{identifier}].")
|
||||
# Detect MIMO capability by checking IIO channels (one-time, during init)
|
||||
# Rev B: 2 channels (voltage0, voltage1) - single RX/TX only
|
||||
# Rev C/D: 4 channels (voltage0-3) - dual RX/TX capable
|
||||
test_radio = adi.ad9361(uri)
|
||||
ctx = test_radio.ctx
|
||||
dev = ctx.find_device("cf-ad9361-lpc")
|
||||
|
||||
if dev and len(dev.channels) >= 4:
|
||||
# MIMO-capable hardware (Rev C/D)
|
||||
self.radio = test_radio
|
||||
self._mimo_capable = True
|
||||
print(f"Successfully found MIMO-capable Pluto (Rev C/D) with identifier [{identifier}].")
|
||||
else:
|
||||
# Non-MIMO hardware (Rev B) - use standard Pluto driver
|
||||
del test_radio
|
||||
self.radio = adi.Pluto(uri)
|
||||
self._mimo_capable = False
|
||||
print(f"Successfully found Pluto (Rev B) with identifier [{identifier}].")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to find Pluto radio with identifier [{identifier}].")
|
||||
raise e
|
||||
|
|
@ -59,8 +76,9 @@ class Pluto(SDR):
|
|||
:type gain: int
|
||||
:param channel: The channel the Pluto is set to. Must be 0 or 1. 0 enables channel 1, 1 enables both channels.
|
||||
:type channel: int
|
||||
:param buffer_size: The buffer size during receive. Defaults to 10000.
|
||||
:type buffer_size: int
|
||||
:param gain_mode: 'absolute' passes gain directly to the sdr,
|
||||
'relative' means that gain should be a negative value, and it will be subtracted from the max gain (74).
|
||||
:type gain_mode: str
|
||||
"""
|
||||
print("Initializing RX")
|
||||
|
||||
|
|
@ -74,41 +92,30 @@ class Pluto(SDR):
|
|||
self.radio.rx_enabled_channels = [0]
|
||||
print(f"Pluto channel(s) = {self.radio.rx_enabled_channels}")
|
||||
elif channel == 1:
|
||||
if not self._mimo_capable:
|
||||
raise ValueError(
|
||||
"Dual RX channel requested (channel=1) but hardware is not MIMO-capable. "
|
||||
"Dual RX/TX requires Pluto Rev C/D. Detected hardware: Rev B (single channel only)."
|
||||
)
|
||||
self.radio.rx_enabled_channels = [0, 1]
|
||||
print(f"Pluto channel(s) = {self.radio.rx_enabled_channels}")
|
||||
else:
|
||||
raise ValueError("Channel must be either 0 or 1.")
|
||||
|
||||
rx_gain_min = 0
|
||||
rx_gain_max = 74
|
||||
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise ValueError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This sets \
|
||||
the gain relative to the maximum possible gain."
|
||||
)
|
||||
else:
|
||||
abs_gain = rx_gain_max + gain
|
||||
else:
|
||||
abs_gain = gain
|
||||
|
||||
if abs_gain < rx_gain_min or abs_gain > rx_gain_max:
|
||||
abs_gain = min(max(gain, rx_gain_min), rx_gain_max)
|
||||
print(f"Gain {gain} out of range for Pluto.")
|
||||
print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB")
|
||||
|
||||
self.set_rx_gain(gain=abs_gain, channel=channel)
|
||||
self.set_rx_gain(gain=gain, channel=channel, gain_mode=gain_mode)
|
||||
if channel == 0:
|
||||
print(f"Pluto gain = {self.radio.rx_hardwaregain_chan0}")
|
||||
elif channel == 1:
|
||||
self.set_rx_gain(gain=abs_gain, channel=0)
|
||||
self.set_rx_gain(gain=gain, channel=0, gain_mode=gain_mode)
|
||||
print(f"Pluto gain = {self.radio.rx_hardwaregain_chan0}, {self.radio.rx_hardwaregain_chan1}")
|
||||
|
||||
self.radio.rx_buffer_size = 1024 # TODO deal with this for zmq
|
||||
self.set_rx_buffer_size(getattr(self, "rx_buffer_size", 1024))
|
||||
|
||||
self._rx_initialized = True
|
||||
self._tx_initialized = False
|
||||
|
||||
return {"sample_rate": self.rx_sample_rate, "center_frequency": self.rx_center_frequency, "gain": self.rx_gain}
|
||||
|
||||
def init_tx(
|
||||
self,
|
||||
sample_rate: int | float,
|
||||
|
|
@ -129,8 +136,9 @@ class Pluto(SDR):
|
|||
:type gain: int
|
||||
:param channel: The channel the Pluto is set to. Must be 0 or 1. 0 enables channel 1, 1 enables both channels.
|
||||
:type channel: int
|
||||
:param buffer_size: The buffer size during transmit. Defaults to 10000.
|
||||
:type buffer_size: int
|
||||
:param gain_mode: 'absolute' passes gain directly to the sdr,
|
||||
'relative' means that gain should be a negative value, and it will be subtracted from the max gain (0).
|
||||
:type gain_mode: str
|
||||
"""
|
||||
|
||||
print("Initializing TX")
|
||||
|
|
@ -141,44 +149,32 @@ class Pluto(SDR):
|
|||
self.set_tx_center_frequency(center_frequency=int(center_frequency))
|
||||
print(f"Pluto center frequency = {self.radio.tx_lo}")
|
||||
|
||||
if channel == 1:
|
||||
self.radio.tx_enabled_channels = [0, 1]
|
||||
print(f"Pluto channel(s) = {self.radio.tx_enabled_channels}")
|
||||
elif channel == 0:
|
||||
if channel == 0:
|
||||
self.radio.tx_enabled_channels = [0]
|
||||
print(f"Pluto channel(s) = {self.radio.tx_enabled_channels}")
|
||||
elif channel == 1:
|
||||
if not self._mimo_capable:
|
||||
raise ValueError(
|
||||
"Dual TX channel requested (channel=1) but hardware is not MIMO-capable. "
|
||||
"Dual RX/TX requires Pluto Rev C/D. Detected hardware: Rev B (single channel only)."
|
||||
)
|
||||
self.radio.tx_enabled_channels = [0, 1]
|
||||
print(f"Pluto channel(s) = {self.radio.tx_enabled_channels}")
|
||||
else:
|
||||
raise ValueError("Channel must be either 0 or 1.")
|
||||
|
||||
tx_gain_min = -89
|
||||
tx_gain_max = 0
|
||||
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise ValueError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain."
|
||||
)
|
||||
else:
|
||||
abs_gain = tx_gain_max + gain
|
||||
else:
|
||||
abs_gain = gain
|
||||
|
||||
if abs_gain < tx_gain_min or abs_gain > tx_gain_max:
|
||||
abs_gain = min(max(gain, tx_gain_min), tx_gain_max)
|
||||
print(f"Gain {gain} out of range for Pluto.")
|
||||
print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB")
|
||||
|
||||
self.set_tx_gain(gain=abs_gain, channel=channel)
|
||||
self.set_tx_gain(gain=gain, channel=channel, gain_mode=gain_mode)
|
||||
if channel == 0:
|
||||
print(f"Pluto gain = {self.radio.tx_hardwaregain_chan0}")
|
||||
elif channel == 1:
|
||||
self.set_tx_gain(gain=abs_gain, channel=0)
|
||||
self.set_tx_gain(gain=gain, channel=0, gain_mode=gain_mode)
|
||||
print(f"Pluto gain = {self.radio.tx_hardwaregain_chan0}, {self.radio.tx_hardwaregain_chan1}")
|
||||
|
||||
self._tx_initialized = True
|
||||
self._rx_initialized = False
|
||||
|
||||
return {"sample_rate": self.tx_sample_rate, "center_frequency": self.tx_center_frequency, "gain": self.tx_gain}
|
||||
|
||||
def _stream_rx(self, callback):
|
||||
if not self._rx_initialized:
|
||||
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
|
||||
|
|
@ -297,11 +293,6 @@ class Pluto(SDR):
|
|||
self.radio.tx_cyclic_buffer = False
|
||||
print("Pluto TX Completed.")
|
||||
|
||||
def close(self):
|
||||
if self.radio.tx_cyclic_buffer:
|
||||
self.radio.tx_destroy_buffer()
|
||||
del self.radio
|
||||
|
||||
def tx_recording(self, recording: Recording | np.ndarray | list, num_samples=None, tx_time=None, mode="timed"):
|
||||
"""
|
||||
Transmit the given iq samples from the provided recording.
|
||||
|
|
@ -381,28 +372,47 @@ class Pluto(SDR):
|
|||
except ValueError as e:
|
||||
_handle_OSError(e)
|
||||
|
||||
def set_rx_gain(self, gain, channel=0):
|
||||
self.rx_gain = gain
|
||||
def set_rx_gain(self, gain, channel=0, gain_mode="absolute"):
|
||||
rx_gain_min = 0
|
||||
rx_gain_max = 74
|
||||
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise ValueError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This sets \
|
||||
the gain relative to the maximum possible gain."
|
||||
)
|
||||
else:
|
||||
abs_gain = rx_gain_max + gain
|
||||
else:
|
||||
abs_gain = gain
|
||||
|
||||
if abs_gain < rx_gain_min or abs_gain > rx_gain_max:
|
||||
abs_gain = min(max(gain, rx_gain_min), rx_gain_max)
|
||||
print(f"Gain {gain} out of range for Pluto.")
|
||||
print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB")
|
||||
|
||||
self.rx_gain = abs_gain
|
||||
try:
|
||||
if channel == 0:
|
||||
|
||||
if gain is None:
|
||||
if abs_gain is None:
|
||||
self.radio.gain_control_mode_chan0 = "automatic"
|
||||
print("Using Pluto Automatic Gain Control.")
|
||||
|
||||
else:
|
||||
self.radio.gain_control_mode_chan0 = "manual"
|
||||
self.radio.rx_hardwaregain_chan0 = gain # dB
|
||||
self.radio.rx_hardwaregain_chan0 = abs_gain # dB
|
||||
|
||||
elif channel == 1:
|
||||
try:
|
||||
if gain is None:
|
||||
if abs_gain is None:
|
||||
self.radio.gain_control_mode_chan1 = "automatic"
|
||||
print("Using Pluto Automatic Gain Control.")
|
||||
|
||||
else:
|
||||
self.radio.gain_control_mode_chan1 = "manual"
|
||||
self.radio.rx_hardwaregain_chan1 = gain # dB
|
||||
self.radio.rx_hardwaregain_chan1 = abs_gain # dB
|
||||
|
||||
except Exception as e:
|
||||
print("Failed to use channel 1 on the PlutoSDR. \nThis is only available for revC versions.")
|
||||
|
|
@ -417,10 +427,31 @@ class Pluto(SDR):
|
|||
_handle_OSError(e)
|
||||
|
||||
def set_rx_channel(self, channel):
|
||||
self.rx_channel = channel
|
||||
if channel == 0:
|
||||
self.radio.rx_enabled_channels = [0]
|
||||
print(f"Pluto channel(s) = {self.radio.rx_enabled_channels}")
|
||||
elif channel == 1:
|
||||
self.radio.rx_enabled_channels = [0, 1]
|
||||
print(f"Pluto channel(s) = {self.radio.rx_enabled_channels}")
|
||||
else:
|
||||
raise ValueError("Channel must be either 0 or 1.")
|
||||
|
||||
def set_rx_buffer_size(self, buffer_size):
|
||||
raise NotImplementedError
|
||||
if buffer_size is None:
|
||||
raise ValueError("Buffer_size must be provided.")
|
||||
buffer_size = int(buffer_size)
|
||||
if buffer_size <= 0:
|
||||
raise ValueError("Buffer_size must be a positive integer.")
|
||||
|
||||
self.rx_buffer_size = buffer_size
|
||||
|
||||
if hasattr(self, "radio"):
|
||||
try:
|
||||
self.radio.rx_buffer_size = buffer_size
|
||||
except OSError as e:
|
||||
_handle_OSError(e)
|
||||
except ValueError as e:
|
||||
_handle_OSError(e)
|
||||
|
||||
def set_tx_center_frequency(self, center_frequency):
|
||||
try:
|
||||
|
|
@ -442,14 +473,33 @@ class Pluto(SDR):
|
|||
except ValueError as e:
|
||||
_handle_OSError(e)
|
||||
|
||||
def set_tx_gain(self, gain, channel=0):
|
||||
def set_tx_gain(self, gain, channel=0, gain_mode="absolute"):
|
||||
tx_gain_min = -89
|
||||
tx_gain_max = 0
|
||||
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise ValueError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain."
|
||||
)
|
||||
else:
|
||||
abs_gain = tx_gain_max + gain
|
||||
else:
|
||||
abs_gain = gain
|
||||
|
||||
if abs_gain < tx_gain_min or abs_gain > tx_gain_max:
|
||||
abs_gain = min(max(gain, tx_gain_min), tx_gain_max)
|
||||
print(f"Gain {gain} out of range for Pluto.")
|
||||
print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB")
|
||||
|
||||
try:
|
||||
self.tx_gain = gain
|
||||
self.tx_gain = abs_gain
|
||||
|
||||
if channel == 0:
|
||||
self.radio.tx_hardwaregain_chan0 = int(gain)
|
||||
self.radio.tx_hardwaregain_chan0 = int(abs_gain)
|
||||
elif channel == 1:
|
||||
self.radio.tx_hardwaregain_chan1 = int(gain)
|
||||
self.radio.tx_hardwaregain_chan1 = int(abs_gain)
|
||||
else:
|
||||
raise ValueError(f"Pluto channel must be 0 or 1 but was {channel}.")
|
||||
|
||||
|
|
@ -459,11 +509,23 @@ class Pluto(SDR):
|
|||
_handle_OSError(e)
|
||||
|
||||
def set_tx_channel(self, channel):
|
||||
raise NotImplementedError
|
||||
if channel == 1:
|
||||
self.radio.tx_enabled_channels = [0, 1]
|
||||
print(f"Pluto channel(s) = {self.radio.tx_enabled_channels}")
|
||||
elif channel == 0:
|
||||
self.radio.tx_enabled_channels = [0]
|
||||
print(f"Pluto channel(s) = {self.radio.tx_enabled_channels}")
|
||||
else:
|
||||
raise ValueError("Channel must be either 0 or 1.")
|
||||
|
||||
def set_tx_buffer_size(self, buffer_size):
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
if self.radio.tx_cyclic_buffer:
|
||||
self.radio.tx_destroy_buffer()
|
||||
del self.radio
|
||||
|
||||
def shutdown(self):
|
||||
del self.radio
|
||||
|
||||
|
|
|
|||
237
src/ria_toolkit_oss/sdr/rtlsdr.py
Normal file
237
src/ria_toolkit_oss/sdr/rtlsdr.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
"""RTL-SDR device integration for the RIA Toolkit."""
|
||||
|
||||
import time
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from rtlsdr import RtlSdr
|
||||
except ImportError as exc: # pragma: no cover - dependency provided by end user
|
||||
raise ImportError("pyrtlsdr is required to use the RTLSDR class") from exc
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
from ria_toolkit_oss.sdr.sdr import SDR
|
||||
|
||||
|
||||
class RTLSDR(SDR):
|
||||
"""SDR interface for RTL-SDR dongles using pyrtlsdr."""
|
||||
|
||||
def __init__(self, identifier: Optional[str] = None):
|
||||
"""
|
||||
Initialize a Pluto SDR device object and connect to the SDR hardware.
|
||||
|
||||
This software supports the ADALM Pluto SDR created by Analog Devices.
|
||||
|
||||
:param identifier: The value of the parameter that identifies the device.
|
||||
:type identifier: str = "192.168.3.1", "pluto.local", etc
|
||||
|
||||
If no identifier is provided, it will select the first device found, with a warning.
|
||||
If more than one device is found with the identifier, it will select the first of those devices.
|
||||
"""
|
||||
print(f"Initializing Pluto radio with identifier [{identifier}].")
|
||||
try:
|
||||
super().__init__()
|
||||
|
||||
if identifier is None:
|
||||
self.radio = RtlSdr()
|
||||
else:
|
||||
self.radio = RtlSdr(identifier)
|
||||
|
||||
self.rx_buffer_size = 256_000
|
||||
self.rx_channel = 0
|
||||
|
||||
print(f"Initialized RTL-SDR with identifier [{identifier}].")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to find RTL-SDR with identifier [{identifier}].")
|
||||
raise e
|
||||
|
||||
def init_rx(
|
||||
self,
|
||||
sample_rate: int | float,
|
||||
center_frequency: int | float,
|
||||
gain: Optional[int],
|
||||
channel: int,
|
||||
gain_mode: Optional[str] = "absolute",
|
||||
buffer_size: Optional[int] = 256_000,
|
||||
bias_t: bool = False,
|
||||
):
|
||||
if channel not in (0, None):
|
||||
raise ValueError("RTL-SDR supports only channel 0 for RX.")
|
||||
|
||||
self.set_rx_sample_rate(sample_rate=sample_rate)
|
||||
self.set_rx_center_frequency(center_frequency=center_frequency)
|
||||
self.set_rx_gain(gain=gain, gain_mode=gain_mode)
|
||||
|
||||
self.rx_buffer_size = int(buffer_size or self.rx_buffer_size)
|
||||
self.rx_channel = 0
|
||||
|
||||
if bias_t:
|
||||
self.set_bias_tee(True)
|
||||
time.sleep(1)
|
||||
|
||||
self._rx_initialized = True
|
||||
self._tx_initialized = False
|
||||
|
||||
return {"sample_rate": self.rx_sample_rate, "center_frequency": self.rx_center_frequency, "gain": self.rx_gain}
|
||||
|
||||
def set_rx_sample_rate(self, sample_rate):
|
||||
self.radio.sample_rate = float(sample_rate)
|
||||
self.rx_sample_rate = self.radio.sample_rate
|
||||
print(f"RTL RX Sample Rate = {self.radio.get_sample_rate()}")
|
||||
|
||||
def set_rx_center_frequency(self, center_frequency):
|
||||
self.radio.center_freq = float(center_frequency)
|
||||
self.rx_center_frequency = self.radio.center_freq
|
||||
print(f"RTL RX Center Frequency = {self.radio.get_center_freq()}")
|
||||
|
||||
def set_rx_gain(self, gain, gain_mode="absolute"):
|
||||
available_gains = self.radio.get_gains()
|
||||
|
||||
if gain is None:
|
||||
self.radio.gain = "auto"
|
||||
self.rx_gain = "auto"
|
||||
else:
|
||||
if not available_gains:
|
||||
warnings.warn(
|
||||
"No gain table reported by RTL-SDR; applying requested gain directly.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
target_gain = gain
|
||||
else:
|
||||
min_gain = min(available_gains)
|
||||
max_gain = max(available_gains)
|
||||
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise ValueError(
|
||||
"When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain."
|
||||
)
|
||||
target_gain = max_gain + gain
|
||||
else:
|
||||
target_gain = gain
|
||||
|
||||
if target_gain < min_gain or target_gain > max_gain:
|
||||
print(
|
||||
f"Requested gain {target_gain} dB out of range;\
|
||||
clamping to valid span {min_gain}-{max_gain} dB."
|
||||
)
|
||||
target_gain = min(max(target_gain, min_gain), max_gain)
|
||||
|
||||
target_gain = min(available_gains, key=lambda g: abs(g - target_gain))
|
||||
|
||||
self.radio.set_gain(target_gain)
|
||||
self.rx_gain = self.radio.get_gain()
|
||||
|
||||
print(f"RTL RX Gain = {self.radio.get_gain()}")
|
||||
print(f"Available RTL RX Gains: {available_gains}")
|
||||
|
||||
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None):
|
||||
"""
|
||||
Create a radio recording (iq samples and metadata) of a given length from the RTL-SDR.
|
||||
Either num_samples or rx_time must be provided.
|
||||
init_rx() must be called before record()
|
||||
|
||||
:param num_samples: The number of samples to record.
|
||||
:type num_samples: int, optional
|
||||
:param rx_time: The time to record.
|
||||
:type rx_time: int or float, optional
|
||||
|
||||
returns: Recording object (iq samples and metadata)
|
||||
"""
|
||||
|
||||
if not self._rx_initialized:
|
||||
raise RuntimeError("RX was not initialized. init_rx() must be called before record().")
|
||||
|
||||
if num_samples is not None and rx_time is not None:
|
||||
raise ValueError("Only input one of num_samples or rx_time")
|
||||
elif num_samples is not None:
|
||||
pass
|
||||
elif rx_time is not None:
|
||||
num_samples = int(rx_time * self.rx_sample_rate)
|
||||
else:
|
||||
raise ValueError("Must provide input of one of num_samples or rx_time")
|
||||
|
||||
# RTL-SDR has USB buffer limitations - use consistent 256k chunks
|
||||
# Always read full chunks to avoid USB overflow issues with partial reads
|
||||
max_samples_per_read = 262144 # 256k samples = stable chunk size
|
||||
num_full_reads = num_samples // max_samples_per_read
|
||||
remainder = num_samples % max_samples_per_read
|
||||
signal = np.array([], dtype=np.complex64)
|
||||
|
||||
print("RTL-SDR Starting RX...")
|
||||
|
||||
# Read full chunks
|
||||
for _ in range(num_full_reads):
|
||||
try:
|
||||
chunk = self.radio.read_samples(max_samples_per_read)
|
||||
signal = np.append(signal, chunk)
|
||||
except Exception as e:
|
||||
print(f"Error while reading samples: {e}")
|
||||
break
|
||||
|
||||
# Read remainder if needed (round up to power of 2 for USB compatibility)
|
||||
if remainder > 0 and len(signal) == num_full_reads * max_samples_per_read:
|
||||
# Round up to next 16k boundary for USB stability
|
||||
padded_remainder = ((remainder + 16383) // 16384) * 16384
|
||||
try:
|
||||
chunk = self.radio.read_samples(padded_remainder)
|
||||
signal = np.append(signal, chunk[:remainder]) # Only keep what we need
|
||||
except Exception as e:
|
||||
print(f"Error while reading final chunk: {e}")
|
||||
|
||||
print("RTL-SDR RX Completed.")
|
||||
|
||||
metadata = {
|
||||
"source": self.__class__.__name__,
|
||||
"sample_rate": self.rx_sample_rate,
|
||||
"center_frequency": self.rx_center_frequency,
|
||||
"gain": self.rx_gain,
|
||||
}
|
||||
|
||||
return Recording(data=signal, metadata=metadata)
|
||||
|
||||
def _stream_rx(self, callback):
|
||||
if not self._rx_initialized:
|
||||
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record().")
|
||||
|
||||
print("RTL-SDR Starting RX...")
|
||||
self._enable_rx = True
|
||||
try:
|
||||
while self._enable_rx:
|
||||
samples = self.radio.read_samples(self.rx_buffer_size)
|
||||
callback(buffer=np.asarray(samples, dtype=np.complex64), metadata=None)
|
||||
finally:
|
||||
print("RTL-SDR RX Completed.")
|
||||
|
||||
def _stream_tx(self, callback): # pragma: no cover - RTL-SDR is RX only
|
||||
raise NotImplementedError("RTL-SDR does not support transmit operations")
|
||||
|
||||
def init_tx(self, *args, **kwargs): # pragma: no cover - RTL-SDR is RX only
|
||||
raise NotImplementedError("RTL-SDR does not support transmit operations")
|
||||
|
||||
def tx_recording(
|
||||
self, recording: Recording | np.ndarray | list, num_samples=None, tx_time=None
|
||||
): # pragma: no cover - RTL-SDR is RX only
|
||||
raise NotImplementedError("RTL-SDR does not support transmit operations")
|
||||
|
||||
def supports_bias_tee(self) -> bool:
|
||||
return True
|
||||
|
||||
def set_bias_tee(self, enable: bool):
|
||||
self.radio.set_bias_tee(bool(enable))
|
||||
state = "enabled" if enable else "disabled"
|
||||
print(f"RTL-SDR bias tee {state}.")
|
||||
|
||||
def set_clock_source(self, source): # pragma: no cover - not applicable to RTL-SDR
|
||||
raise NotImplementedError("RTL-SDR does not support external clock configuration")
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self.radio.close()
|
||||
finally:
|
||||
self._enable_rx = False
|
||||
self._enable_tx = False
|
||||
|
|
@ -32,6 +32,12 @@ class SDR(ABC):
|
|||
self._num_buffers_processed = 0
|
||||
self._accumulated_buffer = None
|
||||
self._last_buffer = None
|
||||
self.rx_sample_rate = None
|
||||
self.rx_center_frequency = None
|
||||
self.rx_gain = None
|
||||
self.tx_sample_rate = None
|
||||
self.tx_center_frequency = None
|
||||
self.tx_gain = None
|
||||
|
||||
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
|
||||
"""
|
||||
|
|
@ -295,6 +301,14 @@ class SDR(ABC):
|
|||
|
||||
return samples
|
||||
|
||||
def supports_bias_tee(self) -> bool:
|
||||
"""Return True when the radio supports bias-tee control."""
|
||||
return False
|
||||
|
||||
def set_bias_tee(self, enable: bool):
|
||||
"""Enable or disable bias-tee power when supported by the radio."""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not support bias-tee control")
|
||||
|
||||
def pause_rx(self):
|
||||
self._enable_rx = False
|
||||
|
||||
|
|
@ -303,6 +317,61 @@ class SDR(ABC):
|
|||
|
||||
def stop(self):
|
||||
self.pause_rx()
|
||||
self.pause_tx()
|
||||
|
||||
def get_rx_sample_rate(self):
|
||||
"""
|
||||
Retrieve the current sample rate of the receiver.
|
||||
|
||||
Returns:
|
||||
float: The receiver's sample rate in samples per second (Hz).
|
||||
"""
|
||||
return self.rx_sample_rate
|
||||
|
||||
def get_rx_center_frequency(self):
|
||||
"""
|
||||
Retrieve the current center frequency of the receiver.
|
||||
|
||||
Returns:
|
||||
float: The receiver's center frequency in Hertz (Hz).
|
||||
"""
|
||||
return self.rx_center_frequency
|
||||
|
||||
def get_rx_gain(self):
|
||||
"""
|
||||
Retrieve the current gain setting of the receiver.
|
||||
|
||||
Returns:
|
||||
float: The receiver's gain in decibels (dB).
|
||||
"""
|
||||
return self.rx_gain
|
||||
|
||||
def get_tx_sample_rate(self):
|
||||
"""
|
||||
Retrieve the current sample rate of the transmitter.
|
||||
|
||||
Returns:
|
||||
float: The transmitter's sample rate in samples per second (Hz).
|
||||
"""
|
||||
return self.tx_sample_rate
|
||||
|
||||
def get_tx_center_frequency(self):
|
||||
"""
|
||||
Retrieve the current center frequency of the transmitter.
|
||||
|
||||
Returns:
|
||||
float: The transmitter's center frequency in Hertz (Hz).
|
||||
"""
|
||||
return self.tx_center_frequency
|
||||
|
||||
def get_tx_gain(self):
|
||||
"""
|
||||
Retrieve the current gain setting of the transmitter.
|
||||
|
||||
Returns:
|
||||
float: The transmitter's gain in decibels (dB).
|
||||
"""
|
||||
return self.tx_gain
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
|
|
|
|||
448
src/ria_toolkit_oss/sdr/thinkrf.py
Normal file
448
src/ria_toolkit_oss/sdr/thinkrf.py
Normal file
|
|
@ -0,0 +1,448 @@
|
|||
"""ThinkRF integration for the RIA toolkit."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from pyrf.devices.thinkrf import WSA
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
raise ImportError(
|
||||
"pyrf is required to use the ThinkRF integration. " "Install with: pip install ria-toolkit-oss[thinkrf]"
|
||||
) from exc
|
||||
except SyntaxError as exc: # pragma: no cover - Python 2/3 compatibility issue
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# pyrf ships with Python 2 syntax - try to auto-fix it
|
||||
print("\033[93mWARNING: pyrf has Python 2 syntax. Attempting automatic fix...\033[0m")
|
||||
try:
|
||||
from lib2to3.refactor import RefactoringTool, get_fixers_from_package
|
||||
|
||||
import pyrf
|
||||
|
||||
thinkrf_path = Path(pyrf.__file__).resolve().parent / "devices" / "thinkrf.py"
|
||||
print(f"Fixing: {thinkrf_path}")
|
||||
|
||||
fixers = get_fixers_from_package("lib2to3.fixes")
|
||||
tool = RefactoringTool(fixers)
|
||||
tool.refactor_file(str(thinkrf_path), write=True)
|
||||
|
||||
print("\033[92m✅ Fixed pyrf for Python 3. Please restart Python/reload the module.\033[0m")
|
||||
print("Or run: python -m ria_toolkit_oss.sdr.thinkrf_fix")
|
||||
sys.exit(1) # Exit so user can reload
|
||||
except Exception as fix_exc:
|
||||
print(f"\033[91m❌ Auto-fix failed: {fix_exc}\033[0m")
|
||||
print("Manual fix: Run `python scripts/fix_pyrf_python3.py` from ria-toolkit-oss directory")
|
||||
raise exc
|
||||
|
||||
from ria_toolkit_oss.sdr.sdr import SDR
|
||||
|
||||
|
||||
class ThinkRF(SDR):
|
||||
"""SDR adapter for ThinkRF analyzers using the PyRF API."""
|
||||
|
||||
BASE_SAMPLE_RATE = 125_000_000
|
||||
SUPPORTED_DECIMATIONS = (1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024)
|
||||
MAX_ONBOARD_SAMPLES = 33_500_000 # Confirmed: 512 packets @ dec 1 = 33.5M samples (268ms)
|
||||
DEFAULT_SPP = 65504 # VRT packet size (samples per packet)
|
||||
|
||||
def __init__(self, identifier: Optional[str] = None):
|
||||
super().__init__()
|
||||
|
||||
if identifier is None:
|
||||
raise ValueError("ThinkRF requires an IP address or hostname identifier")
|
||||
|
||||
self.identifier = identifier
|
||||
try:
|
||||
self.radio = WSA()
|
||||
self.radio.connect(identifier)
|
||||
self.radio.request_read_perm()
|
||||
print(f"Connected to ThinkRF at [{identifier}].")
|
||||
except Exception as exc:
|
||||
print(f"Failed to connect to ThinkRF at [{identifier}].")
|
||||
raise exc
|
||||
|
||||
self.configure_frontend()
|
||||
self._last_context: Optional[Any] = None
|
||||
|
||||
def configure_frontend(
|
||||
self,
|
||||
*,
|
||||
rfe_mode: str = "ZIF",
|
||||
attenuation: int = 0,
|
||||
gain_profile: str = "HIGH",
|
||||
trigger_config: Optional[Dict[str, Any]] = None,
|
||||
samples_per_packet: int = 65504,
|
||||
packets_per_block: int = 1,
|
||||
capture_mode: str = "block",
|
||||
stream_id: int = 1,
|
||||
min_stream_decimation: int = 16,
|
||||
) -> None:
|
||||
"""Persist settings applied during the next RX initialisation.
|
||||
|
||||
``capture_mode`` selects between buffered ``"block"`` captures that use
|
||||
the analyser's onboard RAM and ``"stream"`` captures that push data over
|
||||
GigE in real time. Streaming requires a sufficiently large decimation to
|
||||
keep within the link budget; ``min_stream_decimation`` forms the lower
|
||||
bound.
|
||||
"""
|
||||
|
||||
mode = capture_mode.lower()
|
||||
if mode not in {"block", "stream"}:
|
||||
raise ValueError("capture_mode must be either 'block' or 'stream'")
|
||||
|
||||
self._rfe_mode = rfe_mode
|
||||
self._attenuation = int(max(0, min(attenuation, 30)))
|
||||
self._gain_profile = gain_profile.upper()
|
||||
self._trigger_config = trigger_config
|
||||
self._samples_per_packet = int(samples_per_packet)
|
||||
self._packets_per_block = max(1, int(packets_per_block))
|
||||
self._capture_mode = mode
|
||||
self._stream_id = int(stream_id)
|
||||
self._min_stream_decimation = max(1, int(min_stream_decimation))
|
||||
self._streaming_active = False
|
||||
|
||||
def init_rx(
|
||||
self,
|
||||
sample_rate: int | float,
|
||||
center_frequency: int | float,
|
||||
gain: int,
|
||||
channel: int,
|
||||
gain_mode: Optional[str] = "absolute",
|
||||
decimation: Optional[int] = None,
|
||||
):
|
||||
if channel not in (0, None):
|
||||
raise ValueError("ThinkRF devices expose a single receive channel")
|
||||
|
||||
stream_mode = getattr(self, "_capture_mode", "block") == "stream"
|
||||
actual_decimation, actual_sample_rate = self.set_rx_sample_rate(sample_rate=sample_rate, decimation=decimation)
|
||||
|
||||
self.radio.reset()
|
||||
self.radio.scpiset(":SYSTEM:FLUSH")
|
||||
try:
|
||||
self.radio.scpiset(":TRACE:STREAM:STOP")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.radio.rfe_mode(self._rfe_mode)
|
||||
self.set_rx_center_frequency(center_frequency=center_frequency)
|
||||
|
||||
attenuation = self._attenuation if gain is None else int(gain) # gain
|
||||
attenuation = max(0, min(attenuation, 30))
|
||||
self.radio.attenuator(attenuation)
|
||||
|
||||
gain_profile = self._gain_profile
|
||||
if gain_mode and isinstance(gain_mode, str) and gain_mode.upper() in {"LOW", "MEDIUM", "HIGH", "VLOW"}:
|
||||
gain_profile = gain_mode.upper()
|
||||
self.radio.gain(gain_profile.lower()) # WSA.gain() expects lowercase
|
||||
|
||||
self.radio.decimation(actual_decimation)
|
||||
if stream_mode:
|
||||
self.radio.scpiset(f":SENSE:DECIMATION {actual_decimation}")
|
||||
trigger = self._trigger_config or self._default_trigger(center_frequency)
|
||||
self.radio.trigger(trigger)
|
||||
|
||||
self.radio.scpiset(f":TRACE:SPP {self._samples_per_packet}")
|
||||
if stream_mode:
|
||||
self._streaming_active = False
|
||||
else:
|
||||
print(
|
||||
f"ThinkRF: Configuring block capture - SPP={self._samples_per_packet}, PPB={self._packets_per_block}"
|
||||
)
|
||||
self.radio.scpiset(f":TRACE:BLOCK:PACKETS {self._packets_per_block}")
|
||||
self.radio.scpiset(":TRACE:BLOCK:DATA?")
|
||||
|
||||
self.rx_gain = {
|
||||
"attenuation_dB": attenuation,
|
||||
"profile": gain_profile,
|
||||
"decimation": actual_decimation,
|
||||
"rfe_mode": self._rfe_mode,
|
||||
"spp": self._samples_per_packet,
|
||||
"ppb": self._packets_per_block,
|
||||
}
|
||||
self.rx_buffer_size = self._samples_per_packet
|
||||
self.rx_channel = 0
|
||||
|
||||
self._rx_initialized = True
|
||||
self._tx_initialized = False
|
||||
|
||||
def set_rx_sample_rate(self, sample_rate, decimation, stream_mode):
|
||||
# Enforce sample rate / decimation
|
||||
# Note: decimation parameter takes precedence if provided
|
||||
actual_decimation, actual_sample_rate = self.enforce_sample_rate(sample_rate, decimation)
|
||||
|
||||
if stream_mode and actual_decimation < self._min_stream_decimation:
|
||||
enforced = self._min_stream_decimation
|
||||
print(
|
||||
"Requested ThinkRF sample rate exceeds typical GigE throughput; "
|
||||
f"enforcing decimation {enforced} for streaming."
|
||||
)
|
||||
actual_decimation = enforced
|
||||
actual_sample_rate = self.BASE_SAMPLE_RATE / actual_decimation
|
||||
|
||||
self._decimation = actual_decimation
|
||||
self.rx_sample_rate = actual_sample_rate
|
||||
print(f"ThinkRF RX Sample Rate = {actual_sample_rate}")
|
||||
|
||||
return actual_decimation, actual_sample_rate
|
||||
|
||||
def set_rx_center_frequency(self, center_frequency):
|
||||
self.radio.freq(int(center_frequency))
|
||||
self.rx_center_frequency = self.radio.freq
|
||||
print(f"ThinkRF RX Center Frequency = {self.radio.freq}")
|
||||
|
||||
def _stream_rx(self, callback):
|
||||
if not self._rx_initialized:
|
||||
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record().")
|
||||
|
||||
self._enable_rx = True
|
||||
packets_processed = 0
|
||||
stream_mode = getattr(self, "_capture_mode", "block") == "stream"
|
||||
|
||||
if stream_mode and not self._streaming_active:
|
||||
try:
|
||||
self.radio.scpiset(f":TRACE:STREAM:START {self._stream_id}")
|
||||
self._streaming_active = True
|
||||
except Exception as exc:
|
||||
print(f"Failed to start ThinkRF stream: {exc}")
|
||||
return
|
||||
|
||||
print("ThinkRF Starting RX...")
|
||||
while self._enable_rx:
|
||||
packet = self._safe_read(stream_mode, packets_processed)
|
||||
|
||||
if packet is None:
|
||||
# No more packets available
|
||||
if not stream_mode and packets_processed >= self._packets_per_block:
|
||||
# Finished reading block
|
||||
break
|
||||
continue
|
||||
|
||||
if packet.is_context_packet():
|
||||
self._last_context = packet
|
||||
continue
|
||||
|
||||
if not packet.is_data_packet():
|
||||
# Unknown packet type - skip
|
||||
continue
|
||||
|
||||
metadata = metadata = self._extract_metadata(packet)
|
||||
complex_buffer = self._extract_iq(packet)
|
||||
if complex_buffer is None:
|
||||
continue
|
||||
|
||||
# Send packet data to callback (accumulation handled by parent)
|
||||
callback(buffer=complex_buffer, metadata=metadata)
|
||||
packets_processed += 1
|
||||
|
||||
# In block mode, stop after receiving all packets in the block
|
||||
if not stream_mode and packets_processed >= self._packets_per_block:
|
||||
# Got all packets for this block
|
||||
break
|
||||
|
||||
print("ThinkRF RX Completed.")
|
||||
if stream_mode and self._streaming_active:
|
||||
self._stop_stream()
|
||||
|
||||
self.radio.scpiset(":SYSTEM:FLUSH")
|
||||
|
||||
def _safe_read(self, stream_mode, packets_processed):
|
||||
packet = None
|
||||
try:
|
||||
packet = self.radio.read()
|
||||
except Exception as e:
|
||||
# In block mode, reaching end of block can cause exceptions
|
||||
if not stream_mode and packets_processed > 0:
|
||||
# We got some packets in block mode, so finish gracefully
|
||||
print(f"ThinkRF: Block read complete ({packets_processed} packets received)")
|
||||
else:
|
||||
print(f"ThinkRF read error: {e}")
|
||||
return packet
|
||||
|
||||
def _extract_iq(self, packet):
|
||||
# packet.data is an iterable IQData object that yields (I, Q) tuples
|
||||
# Convert to numpy array: collect all [I, Q] pairs
|
||||
try:
|
||||
iq_pairs = list(packet.data)
|
||||
if not iq_pairs:
|
||||
return None
|
||||
iq_array = np.array(iq_pairs, dtype=np.float32)
|
||||
return (iq_array[:, 0] + 1j * iq_array[:, 1]).astype(np.complex64)
|
||||
except Exception as e:
|
||||
print(f"Error extracting IQ from packet.data: {e}")
|
||||
return None
|
||||
|
||||
def _extract_metadata(self, packet):
|
||||
if not hasattr(packet, "fields"):
|
||||
return None
|
||||
metadata = packet.fields
|
||||
if metadata.get("sample_loss"):
|
||||
print("\033[93mWarning: ThinkRF sample overflow detected\033[0m")
|
||||
return metadata
|
||||
|
||||
def _stop_stream(self):
|
||||
try:
|
||||
self.radio.scpiset(":TRACE:STREAM:STOP")
|
||||
except Exception:
|
||||
pass
|
||||
self._streaming_active = False
|
||||
|
||||
def init_tx(
|
||||
self,
|
||||
sample_rate: int | float,
|
||||
center_frequency: int | float,
|
||||
gain: int,
|
||||
channel: int,
|
||||
gain_mode: Optional[str] = "absolute",
|
||||
):
|
||||
raise NotImplementedError("ThinkRF devices do not support transmit operations")
|
||||
|
||||
def _stream_tx(self, callback):
|
||||
raise NotImplementedError("ThinkRF devices do not support transmit operations")
|
||||
|
||||
def set_clock_source(self, source):
|
||||
raise NotImplementedError("ThinkRF clock configuration is not implemented")
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self.radio.scpiset(":TRACE:STREAM:STOP")
|
||||
except Exception: # pragma: no cover - best effort cleanup
|
||||
pass
|
||||
try:
|
||||
self.radio.scpiset(":SYSTEM:FLUSH")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self.radio.disconnect()
|
||||
finally:
|
||||
self._enable_rx = False
|
||||
self._enable_tx = False
|
||||
print(f"Disconnected from ThinkRF at [{self.identifier}].")
|
||||
|
||||
def supports_bias_tee(self) -> bool:
|
||||
return False
|
||||
|
||||
def set_bias_tee(self, enable: bool): # pragma: no cover - interface compliance
|
||||
raise NotImplementedError("ThinkRF radios do not expose a controllable bias-tee")
|
||||
|
||||
def _derive_decimation(self, target_sample_rate: int | float) -> int:
|
||||
"""
|
||||
Derive decimation from target sample rate.
|
||||
Always rounds DOWN decimation (UP sample rate) to meet or exceed user's requested rate.
|
||||
|
||||
Example: 30 MS/s requested → dec 4 (31.25 MS/s), NOT dec 8 (15.625 MS/s)
|
||||
"""
|
||||
if not target_sample_rate:
|
||||
return 1
|
||||
requested = float(target_sample_rate)
|
||||
if requested >= self.BASE_SAMPLE_RATE:
|
||||
return 1
|
||||
|
||||
desired_decimation = self.BASE_SAMPLE_RATE / requested
|
||||
|
||||
# Round DOWN decimation (UP sample rate) to meet or exceed requested rate
|
||||
# Find largest decimation that gives sample rate >= requested
|
||||
valid_decimations = [d for d in self.SUPPORTED_DECIMATIONS if d <= desired_decimation]
|
||||
|
||||
if valid_decimations:
|
||||
# Use largest valid decimation (gives sample rate >= requested)
|
||||
best = max(valid_decimations)
|
||||
else:
|
||||
# Requested rate too low, use minimum decimation (max sample rate)
|
||||
best = self.SUPPORTED_DECIMATIONS[0]
|
||||
|
||||
return int(best)
|
||||
|
||||
def enforce_sample_rate(
|
||||
self, requested_sample_rate: int | float, decimation: Optional[int] = None
|
||||
) -> tuple[int, float]:
|
||||
"""
|
||||
Enforce valid sample rate and decimation.
|
||||
|
||||
If decimation is provided, it takes precedence.
|
||||
Otherwise, derive decimation from requested sample rate.
|
||||
|
||||
Returns:
|
||||
(decimation, actual_sample_rate)
|
||||
"""
|
||||
if decimation is not None:
|
||||
# Decimation provided - validate and use it
|
||||
if decimation not in self.SUPPORTED_DECIMATIONS:
|
||||
# Round to nearest supported
|
||||
decimation = min(self.SUPPORTED_DECIMATIONS, key=lambda d: abs(d - decimation))
|
||||
print(f"ThinkRF: Requested decimation not supported. Using decimation={decimation}")
|
||||
else:
|
||||
# Derive from sample rate
|
||||
decimation = self._derive_decimation(requested_sample_rate)
|
||||
|
||||
actual_sample_rate = self.BASE_SAMPLE_RATE / decimation
|
||||
|
||||
if abs(actual_sample_rate - requested_sample_rate) > 1e3: # More than 1 kHz difference
|
||||
print(
|
||||
f"ThinkRF: Requested {requested_sample_rate/1e6:.2f} MS/s → \
|
||||
Using decimation={decimation} ({actual_sample_rate/1e6:.2f} MS/s)"
|
||||
)
|
||||
|
||||
return decimation, actual_sample_rate
|
||||
|
||||
def calculate_spp_ppb(self, num_samples: int, spp: Optional[int] = None) -> tuple[int, int]:
|
||||
"""
|
||||
Calculate optimal SPP (samples per packet) and PPB (packets per block).
|
||||
|
||||
Strategy:
|
||||
- Maximize SPP (use DEFAULT_SPP) unless num_samples < DEFAULT_SPP
|
||||
- Calculate PPB to get as close as possible to num_samples
|
||||
- Actual captured samples = SPP * PPB (may exceed num_samples slightly)
|
||||
|
||||
Args:
|
||||
num_samples: Desired number of samples
|
||||
spp: Override SPP (for advanced users, not recommended)
|
||||
|
||||
Returns:
|
||||
(spp, ppb)
|
||||
"""
|
||||
if spp is not None:
|
||||
# User override - use as-is
|
||||
actual_spp = max(1, int(spp))
|
||||
else:
|
||||
# Maximize SPP unless samples requested is smaller
|
||||
if num_samples < self.DEFAULT_SPP:
|
||||
actual_spp = num_samples
|
||||
else:
|
||||
actual_spp = self.DEFAULT_SPP
|
||||
|
||||
# Calculate PPB to get close to num_samples
|
||||
ppb = max(1, int(np.ceil(num_samples / actual_spp)))
|
||||
|
||||
actual_samples = actual_spp * ppb
|
||||
if actual_samples != num_samples:
|
||||
print(
|
||||
f"ThinkRF: Requested {num_samples} samples → Capturing {actual_samples} (SPP={actual_spp}, PPB={ppb})"
|
||||
)
|
||||
|
||||
return actual_spp, ppb
|
||||
|
||||
def check_ram_limit(self, num_samples: int, decimation: int) -> None:
|
||||
"""
|
||||
Check if requested capture exceeds onboard RAM limits.
|
||||
|
||||
Raises warning if exceeds MAX_ONBOARD_SAMPLES at low decimations.
|
||||
For decimation 1 or 2, block captures are limited by onboard RAM.
|
||||
"""
|
||||
if decimation <= 2 and num_samples > self.MAX_ONBOARD_SAMPLES:
|
||||
raise ValueError(
|
||||
f"ThinkRF: Cannot capture {num_samples} samples at decimation {decimation}. "
|
||||
f"Onboard RAM limit is ~{self.MAX_ONBOARD_SAMPLES} samples for dec 1/2. "
|
||||
f"Either reduce num_samples or use stream mode (increase decimation to >=4)."
|
||||
)
|
||||
|
||||
def _default_trigger(self, center_frequency: int | float) -> Dict[str, Any]:
|
||||
span = 40_000_000
|
||||
half = span // 2
|
||||
return {
|
||||
"type": "NONE",
|
||||
"fstart": int(center_frequency) - half,
|
||||
"fstop": int(center_frequency) + half,
|
||||
"amplitude": -100,
|
||||
}
|
||||
|
|
@ -17,11 +17,11 @@ class USRP(SDR):
|
|||
|
||||
This software supports all USRP SDRs created by Ettus Research.
|
||||
|
||||
:param identifier: Identifier of the device. Can be an IP address (e.g. "192.168.0.0"),
|
||||
a device name (e.g. "MyB210"), or any name/address found via ``uhd_find_devices``.
|
||||
If not provided, the first available device is selected with a warning.
|
||||
If multiple devices match the identifier, the first one is selected.
|
||||
:type identifier: str, optional
|
||||
:param identifier: The value of the parameter that identifies the device.
|
||||
:type identifier: str = "192.168.0.0", "MyB210", name or address found in uhd_find_devices
|
||||
|
||||
If no identifier is provided, it will select the first device found, with a warning.
|
||||
If more than one device is found with the identifier, it will select the first of those devices.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -43,29 +43,23 @@ class USRP(SDR):
|
|||
rx_buffer_size: int = 960000,
|
||||
):
|
||||
"""
|
||||
Initialize the USRP for receiving.
|
||||
Initializes the USRP for receiving.
|
||||
|
||||
:param sample_rate: The sample rate for receiving.
|
||||
:type sample_rate: int or float
|
||||
|
||||
:param center_frequency: The center frequency of the recording.
|
||||
:type center_frequency: int or float
|
||||
|
||||
:param gain: The gain set for receiving on the USRP
|
||||
:type gain: int
|
||||
:param channel: The channel the USRP is set to.
|
||||
:type channel: int
|
||||
|
||||
:param gain: The gain set for receiving on the USRP.
|
||||
:type gain: int
|
||||
|
||||
:param gain_mode: Gain mode setting. ``"absolute"`` passes gain directly to the SDR.
|
||||
``"relative"`` means gain should be a negative value, which will be subtracted
|
||||
from the maximum gain.
|
||||
:param gain_mode: 'absolute' passes gain directly to the sdr,
|
||||
'relative' means that gain should be a negative value, and it will be subtracted from the max gain.
|
||||
:type gain_mode: str
|
||||
|
||||
:param rx_buffer_size: Internal buffer size for receiving samples. Defaults to 960000.
|
||||
:type rx_buffer_size: int
|
||||
|
||||
:return: Dictionary with the actual RX parameters after configuration.
|
||||
:return: A dictionary with the actual RX parameters after configuration.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
|
|
@ -80,6 +74,50 @@ class USRP(SDR):
|
|||
if channel + 1 > max_num_channels:
|
||||
raise IOError(f"Channel {channel} not valid for device with {max_num_channels} channels.")
|
||||
|
||||
self.set_rx_sample_rate(sample_rate=sample_rate, channel=channel)
|
||||
self.set_rx_center_frequency(center_frequency=center_frequency, channel=channel)
|
||||
self.set_rx_gain(gain=gain, gain_mode=gain_mode, channel=channel)
|
||||
|
||||
self.rx_channel = channel
|
||||
print(f"USRP RX Channel = {self.rx_channel}")
|
||||
|
||||
# flag to prevent user from calling certain functions before this one.
|
||||
self._rx_initialized = True
|
||||
self._tx_initialized = False
|
||||
|
||||
return {"sample_rate": self.rx_sample_rate, "center_frequency": self.rx_center_frequency, "gain": self.rx_gain}
|
||||
|
||||
def set_rx_sample_rate(self, sample_rate, channel=0):
|
||||
# check if sample rate arg is valid
|
||||
# Note: B200/B210 devices auto-adjust master clock rate, so get_rx_rates() returns
|
||||
# the range for the CURRENT master clock, not the maximum possible range.
|
||||
# Skip validation for B-series devices and let UHD handle it.
|
||||
device_type = self.device_dict.get("type", "").lower()
|
||||
if device_type not in ["b200", "b210"]:
|
||||
sample_rate_range = self.usrp.get_rx_rates()
|
||||
if sample_rate < sample_rate_range.start() or sample_rate > sample_rate_range.stop():
|
||||
raise IOError(
|
||||
f"Sample rate {sample_rate} not valid for this USRP.\nValid\
|
||||
range is {sample_rate_range.start()}\
|
||||
to {sample_rate_range.stop()}."
|
||||
)
|
||||
self.usrp.set_rx_rate(sample_rate, channel)
|
||||
self.rx_sample_rate = self.usrp.get_rx_rate(channel)
|
||||
print(f"USRP RX Sample Rate = {self.rx_sample_rate}")
|
||||
|
||||
def set_rx_center_frequency(self, center_frequency, channel=0):
|
||||
center_frequency_range = self.usrp.get_rx_freq_range()
|
||||
if center_frequency < center_frequency_range.start() or center_frequency > center_frequency_range.stop():
|
||||
raise IOError(
|
||||
f"Center frequency {center_frequency} out of range for USRP.\
|
||||
\nValid range is {center_frequency_range.start()} \
|
||||
to {center_frequency_range.stop()}."
|
||||
)
|
||||
self.usrp.set_rx_freq(uhd.libpyuhd.types.tune_request(center_frequency), channel)
|
||||
self.rx_center_frequency = self.usrp.get_rx_freq(channel)
|
||||
print(f"USRP RX Center Frequency = {self.rx_center_frequency}")
|
||||
|
||||
def set_rx_gain(self, gain, gain_mode="absolute", channel=0):
|
||||
# check if gain arg is valid
|
||||
gain_range = self.usrp.get_rx_gain_range()
|
||||
if gain_mode == "relative":
|
||||
|
|
@ -98,70 +136,9 @@ class USRP(SDR):
|
|||
print(f"Gain range: {gain_range.start()} to {gain_range.stop()} dB")
|
||||
abs_gain = min(max(abs_gain, gain_range.start()), gain_range.stop())
|
||||
self.usrp.set_rx_gain(abs_gain, channel)
|
||||
|
||||
# check if sample rate arg is valid
|
||||
sample_rate_range = self.usrp.get_rx_rates()
|
||||
if sample_rate < sample_rate_range.start() or sample_rate > sample_rate_range.stop():
|
||||
raise IOError(
|
||||
f"Sample rate {sample_rate} not valid for this USRP.\nValid\
|
||||
range is {sample_rate_range.start()}\
|
||||
to {sample_rate_range.stop()}."
|
||||
)
|
||||
self.usrp.set_rx_rate(sample_rate, channel)
|
||||
|
||||
center_frequency_range = self.usrp.get_rx_freq_range()
|
||||
if center_frequency < center_frequency_range.start() or center_frequency > center_frequency_range.stop():
|
||||
raise IOError(
|
||||
f"Center frequency {center_frequency} out of range for USRP.\
|
||||
\nValid range is {center_frequency_range.start()} \
|
||||
to {center_frequency_range.stop()}."
|
||||
)
|
||||
self.usrp.set_rx_freq(uhd.libpyuhd.types.tune_request(center_frequency), channel)
|
||||
|
||||
# set internal variables for metadata
|
||||
self.rx_sample_rate = self.usrp.get_rx_rate(channel)
|
||||
self.rx_gain = self.usrp.get_rx_gain(channel)
|
||||
self.rx_center_frequency = self.usrp.get_rx_freq(channel)
|
||||
self.rx_channel = channel
|
||||
|
||||
print(f"USRP RX Sample Rate = {self.rx_sample_rate}")
|
||||
print(f"USRP RX Center Frequency = {self.rx_center_frequency}")
|
||||
print(f"USRP RX Channel = {self.rx_channel}")
|
||||
print(f"USRP RX Gain = {self.rx_gain}")
|
||||
|
||||
# flag to prevent user from calling certain functions before this one.
|
||||
self._rx_initialized = True
|
||||
self._tx_initialized = False
|
||||
|
||||
return {"sample_rate": self.rx_sample_rate, "center_frequency": self.rx_center_frequency, "gain": self.rx_gain}
|
||||
|
||||
def get_rx_sample_rate(self):
|
||||
"""
|
||||
Retrieve the current sample rate of the receiver.
|
||||
|
||||
Returns:
|
||||
float: The receiver's sample rate in samples per second (Hz).
|
||||
"""
|
||||
return self.rx_sample_rate
|
||||
|
||||
def get_rx_center_frequency(self):
|
||||
"""
|
||||
Retrieve the current center frequency of the receiver.
|
||||
|
||||
Returns:
|
||||
float: The receiver's center frequency in Hertz (Hz).
|
||||
"""
|
||||
return self.rx_center_frequency
|
||||
|
||||
def get_rx_gain(self):
|
||||
"""
|
||||
Retrieve the current gain setting of the receiver.
|
||||
|
||||
Returns:
|
||||
float: The receiver's gain in decibels (dB).
|
||||
"""
|
||||
return self.rx_gain
|
||||
|
||||
def _stream_rx(self, callback):
|
||||
|
||||
if not self._rx_initialized:
|
||||
|
|
@ -206,10 +183,31 @@ class USRP(SDR):
|
|||
del self.rx_stream
|
||||
print("USRP RX Completed.")
|
||||
|
||||
def record(self, num_samples):
|
||||
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None):
|
||||
"""
|
||||
Create a radio recording (iq samples and metadata) of a given length from the USRP.
|
||||
Either num_samples or rx_time must be provided.
|
||||
init_rx() must be called before record()
|
||||
|
||||
:param num_samples: The number of samples to record.
|
||||
:type num_samples: int, optional
|
||||
:param rx_time: The time to record.
|
||||
:type rx_time: int or float, optional
|
||||
|
||||
returns: Recording object (iq samples and metadata)
|
||||
"""
|
||||
if not self._rx_initialized:
|
||||
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
|
||||
|
||||
if num_samples is not None and rx_time is not None:
|
||||
raise ValueError("Only input one of num_samples or rx_time")
|
||||
elif num_samples is not None:
|
||||
pass
|
||||
elif rx_time is not None:
|
||||
num_samples = int(rx_time * self.rx_sample_rate)
|
||||
else:
|
||||
raise ValueError("Must provide input of one of num_samples or rx_time")
|
||||
|
||||
stream_args = uhd.usrp.StreamArgs("fc32", "sc16")
|
||||
stream_args.channels = [self.rx_channel]
|
||||
|
||||
|
|
@ -264,23 +262,18 @@ class USRP(SDR):
|
|||
gain_mode: Optional[str] = "absolute",
|
||||
):
|
||||
"""
|
||||
Initialize the USRP for transmitting.
|
||||
Initializes the USRP for transmitting.
|
||||
|
||||
:param sample_rate: The sample rate for transmitting.
|
||||
:type sample_rate: int or float
|
||||
|
||||
:param center_frequency: The center frequency of the recording.
|
||||
:type center_frequency: int or float
|
||||
|
||||
:param gain: The gain set for transmitting on the USRP.
|
||||
:param gain: The gain set for transmitting on the USRP
|
||||
:type gain: int
|
||||
|
||||
:param channel: The channel the USRP is set to.
|
||||
:type channel: int
|
||||
|
||||
:param gain_mode: Gain mode setting. ``"absolute"`` passes gain directly to the SDR.
|
||||
``"relative"`` means gain should be a negative value, which will be subtracted
|
||||
from the maximum gain.
|
||||
:param gain_mode: 'absolute' passes gain directly to the sdr,
|
||||
'relative' means that gain should be a negative value, and it will be subtracted from the max gain.
|
||||
:type gain_mode: str
|
||||
"""
|
||||
|
||||
|
|
@ -296,6 +289,52 @@ class USRP(SDR):
|
|||
if channel + 1 > max_num_channels:
|
||||
raise IOError(f"Channel {channel} not valid for device with {max_num_channels} channels.")
|
||||
|
||||
self.set_tx_sample_rate(sample_rate=sample_rate, channel=channel)
|
||||
self.set_tx_center_frequency(center_frequency=center_frequency, channel=channel)
|
||||
self.set_tx_gain(gain=gain, gain_mode=gain_mode, channel=channel)
|
||||
|
||||
self.tx_channel = channel
|
||||
print(f"USRP TX Channel = {self.tx_channel}")
|
||||
|
||||
self.usrp.set_clock_source("internal")
|
||||
self.usrp.set_time_source("internal")
|
||||
self.usrp.set_tx_antenna("TX/RX", channel)
|
||||
|
||||
self._tx_initialized = True
|
||||
self._rx_initialized = False
|
||||
|
||||
return {"sample_rate": self.tx_sample_rate, "center_frequency": self.tx_center_frequency, "gain": self.tx_gain}
|
||||
|
||||
def set_tx_sample_rate(self, sample_rate, channel=0):
|
||||
# check if sample rate arg is valid
|
||||
# Note: B200/B210 devices auto-adjust master clock rate, so get_tx_rates() returns
|
||||
# the range for the CURRENT master clock, not the maximum possible range.
|
||||
# Skip validation for B-series devices and let UHD handle it.
|
||||
device_type = self.device_dict.get("type", "").lower()
|
||||
if device_type not in ["b200", "b210"]:
|
||||
sample_rate_range = self.usrp.get_tx_rates()
|
||||
if sample_rate < sample_rate_range.start() or sample_rate > sample_rate_range.stop():
|
||||
raise IOError(
|
||||
f"Sample rate {sample_rate} not valid for this USRP.\nValid\
|
||||
range is {sample_rate_range.start()} to {sample_rate_range.stop()}."
|
||||
)
|
||||
self.usrp.set_tx_rate(sample_rate, channel)
|
||||
self.tx_sample_rate = self.usrp.get_tx_rate(channel)
|
||||
print(f"USRP TX Sample Rate = {self.tx_sample_rate}")
|
||||
|
||||
def set_tx_center_frequency(self, center_frequency, channel=0):
|
||||
center_frequency_range = self.usrp.get_tx_freq_range()
|
||||
if center_frequency < center_frequency_range.start() or center_frequency > center_frequency_range.stop():
|
||||
raise IOError(
|
||||
f"Center frequency {center_frequency} out of range for USRP.\
|
||||
\nValid range is {center_frequency_range.start()}\
|
||||
to {center_frequency_range.stop()}."
|
||||
)
|
||||
self.usrp.set_tx_freq(uhd.types.TuneRequest(center_frequency), channel)
|
||||
self.tx_center_frequency = self.usrp.get_tx_freq(channel)
|
||||
print(f"USRP TX Center Frequency = {self.tx_center_frequency}")
|
||||
|
||||
def set_tx_gain(self, gain, gain_mode="absolute", channel=0):
|
||||
# Ensure gain is within valid range
|
||||
gain_range = self.usrp.get_tx_gain_range()
|
||||
if gain_mode == "relative":
|
||||
|
|
@ -315,45 +354,9 @@ class USRP(SDR):
|
|||
abs_gain = min(max(abs_gain, gain_range.start()), gain_range.stop())
|
||||
|
||||
self.usrp.set_tx_gain(abs_gain, channel)
|
||||
|
||||
# check if sample rate arg is valid
|
||||
sample_rate_range = self.usrp.get_tx_rates()
|
||||
if sample_rate < sample_rate_range.start() or sample_rate > sample_rate_range.stop():
|
||||
raise IOError(
|
||||
f"Sample rate {sample_rate} not valid for this USRP.\nValid\
|
||||
range is {sample_rate_range.start()} to {sample_rate_range.stop()}."
|
||||
)
|
||||
self.usrp.set_tx_rate(sample_rate, channel)
|
||||
|
||||
center_frequency_range = self.usrp.get_tx_freq_range()
|
||||
if center_frequency < center_frequency_range.start() or center_frequency > center_frequency_range.stop():
|
||||
raise IOError(
|
||||
f"Center frequency {center_frequency} out of range for USRP.\
|
||||
\nValid range is {center_frequency_range.start()}\
|
||||
to {center_frequency_range.stop()}."
|
||||
)
|
||||
self.usrp.set_tx_freq(uhd.libpyuhd.types.tune_request(center_frequency), channel)
|
||||
|
||||
self.usrp.set_clock_source("internal")
|
||||
self.usrp.set_time_source("internal")
|
||||
self.usrp.set_tx_rate(sample_rate)
|
||||
self.usrp.set_tx_freq(uhd.types.TuneRequest(center_frequency), channel)
|
||||
self.usrp.set_tx_antenna("TX/RX", channel)
|
||||
|
||||
# set internal variables for metadata
|
||||
self.tx_sample_rate = self.usrp.get_tx_rate(channel)
|
||||
self.tx_gain = self.usrp.get_tx_gain(channel)
|
||||
self.tx_center_frequency = self.usrp.get_tx_freq(channel)
|
||||
self.tx_channel = channel
|
||||
|
||||
print(f"USRP TX Sample Rate = {self.tx_sample_rate}")
|
||||
print(f"USRP TX Center Frequency = {self.tx_center_frequency}")
|
||||
print(f"USRP TX Channel = {self.tx_channel}")
|
||||
print(f"USRP TX Gain = {self.tx_gain}")
|
||||
|
||||
self._tx_initialized = True
|
||||
self._rx_initialized = False
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
|
|
|||
57
src/ria_toolkit_oss/view/tools.py
Normal file
57
src/ria_toolkit_oss/view/tools.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import pathlib
|
||||
|
||||
MAX_PLOT_POINTS = 100_000
|
||||
COLORS = {
|
||||
"primary": "#6366f1",
|
||||
"secondary": "#8b5cf6",
|
||||
"accent": "#06b6d4",
|
||||
"dark": "#1e293b",
|
||||
"light": "#f8fafc",
|
||||
"text": "#334155",
|
||||
"muted": "#64748b",
|
||||
"success": "#10b981",
|
||||
"warning": "#f59e0b",
|
||||
"error": "#ef4444",
|
||||
"purple": "#8b5cf6",
|
||||
"magenta": "#d946ef",
|
||||
}
|
||||
|
||||
|
||||
def decimate(x, max_points=MAX_PLOT_POINTS):
|
||||
if len(x) <= max_points:
|
||||
return x
|
||||
step = len(x) // max_points
|
||||
return x[::step]
|
||||
|
||||
|
||||
def extract_metadata_fields(metadata):
|
||||
sample_rate = next((v for k, v in metadata.items() if "sample_rate" in k), 1)
|
||||
center_freq = next((v for k, v in metadata.items() if "center_freq" in k), 0)
|
||||
sdr = next((v for k, v in metadata.items() if "sdr" in k), "Unknown")
|
||||
return sample_rate, center_freq, sdr
|
||||
|
||||
|
||||
def set_path(output_path):
|
||||
split_path = output_path.split("/")
|
||||
|
||||
if len(split_path) == 1:
|
||||
folder = "images"
|
||||
file = split_path[0]
|
||||
elif len(split_path) > 2:
|
||||
file = split_path[-1]
|
||||
folder = "/".join(split_path[:-1])
|
||||
else:
|
||||
folder, file = split_path
|
||||
|
||||
split_file = file.split(".")
|
||||
if len(split_file) == 2:
|
||||
extension = split_file[1]
|
||||
else:
|
||||
extension = "no extension"
|
||||
if extension != "png" and extension != "svg":
|
||||
print(f"{extension} not supported, saving as .png.")
|
||||
extension = "png"
|
||||
file = file + ".png"
|
||||
|
||||
pathlib.Path(folder).mkdir(parents=True, exist_ok=True)
|
||||
return "/".join([folder, file]), extension
|
||||
257
src/ria_toolkit_oss/view/view_signal.py
Normal file
257
src/ria_toolkit_oss/view/view_signal.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
import os
|
||||
import textwrap
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from matplotlib import gridspec
|
||||
from PIL import Image
|
||||
from scipy.fft import fft, fftshift
|
||||
from scipy.signal import spectrogram
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
from ria_toolkit_oss.view.tools import (
|
||||
COLORS,
|
||||
decimate,
|
||||
extract_metadata_fields,
|
||||
set_path,
|
||||
)
|
||||
|
||||
|
||||
def get_fft_size(plot_length):
|
||||
if plot_length < 2000:
|
||||
return int(64)
|
||||
elif plot_length < 10000:
|
||||
return int(256)
|
||||
elif plot_length < 1000000:
|
||||
return int(1024)
|
||||
else:
|
||||
return int(2048)
|
||||
|
||||
|
||||
def set_spines(ax, spines):
|
||||
if not spines:
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.spines["right"].set_visible(False)
|
||||
ax.spines["bottom"].set_visible(False)
|
||||
ax.spines["left"].set_visible(False)
|
||||
|
||||
|
||||
def view_sig(
|
||||
recording: Recording,
|
||||
output_path: Optional[str] = "images/signal.png",
|
||||
title: Optional[str] = "Signal Plot",
|
||||
dpi: Optional[int] = 250,
|
||||
plot_length: Optional[int] = None,
|
||||
plot_spectrogram: Optional[bool] = True,
|
||||
iq: Optional[bool] = True,
|
||||
frequency: Optional[bool] = True,
|
||||
constellation: Optional[bool] = True,
|
||||
metadata: Optional[bool] = True,
|
||||
logo: Optional[bool] = True,
|
||||
dark: Optional[bool] = True,
|
||||
spines: Optional[bool] = False,
|
||||
title_fontsize: Optional[int] = 40,
|
||||
subtitle_fontsize: Optional[int] = 20,
|
||||
) -> None:
|
||||
"""
|
||||
Create a plot of various signal visualizations as a png or svg image.
|
||||
|
||||
:param recording: The recording object to plot.
|
||||
:type recording: Recording
|
||||
:param output_path: The output image path. Defaults to "images/signal.png"
|
||||
:type output_path: str, optional
|
||||
:param title: The display title. Defaults to "Signal Plot"
|
||||
:type title: str, optional
|
||||
:param dpi: The dots per inch resolution. Defaults to 250
|
||||
:type dpi: int, optional
|
||||
:param plot_length: The number of samples to plot, default is the whole recording. Defaults to None
|
||||
:type plot_length: int, optional
|
||||
:param plot_spectrogram: Display the spectrogram. Defaults to True
|
||||
:type plot_spectrogram: bool, optional
|
||||
:param iq: Display the iq sample plot. Defaults to True
|
||||
:type iq: bool, optional
|
||||
:param frequency: Display the fft of the recording. Defaults to True
|
||||
:type frequency: bool, optional
|
||||
:param constellation: Display the constellation plot. Defaults to True
|
||||
:type constellation: bool, optional
|
||||
:param metadata: Display the metadata text. Defaults to True
|
||||
:type metadata: bool, optional
|
||||
:param logo: Display the Qoherent logo. Defaults to True
|
||||
:type logo: bool, optional
|
||||
:param dark: Use dark mode. Defaults to True
|
||||
:type dark: bool, optional
|
||||
:param spines: Display spines (bounding lines) around plots. Defaults to False
|
||||
:type spines: bool, optional
|
||||
:param title_fontsize: The font size of the main title text. Defaults to 40
|
||||
:type title_fontsize: int, optional
|
||||
:param subtitle_fontsize: The fontsize of the subplot titles. Defaults to 20
|
||||
:type subtitle_fontsize: int, optional
|
||||
|
||||
**Examples:**
|
||||
|
||||
.. todo:: Usage examples coming soon.
|
||||
"""
|
||||
|
||||
complex_signal = recording.data[0]
|
||||
sample_rate, center_frequency, _ = extract_metadata_fields(recording.metadata)
|
||||
|
||||
subplot_height = 2 * (plot_spectrogram + iq + frequency) + 3 * (constellation or metadata or logo)
|
||||
subplot_width = max((constellation + metadata or 1), logo * 3)
|
||||
|
||||
if dark:
|
||||
plt.style.use("dark_background")
|
||||
logo_path = os.path.dirname(__file__) + "/graphics/Qoherent-logo-white-transparent.png"
|
||||
else:
|
||||
plt.style.use("default")
|
||||
logo_path = os.path.dirname(__file__) + "/graphics/Qoherent-logo-black-transparent.png"
|
||||
|
||||
if plot_length is None:
|
||||
plot_length = len(recording.data[0])
|
||||
|
||||
# Plot preparation
|
||||
fig = plt.figure(figsize=(14, 12))
|
||||
fig.suptitle(title, fontsize=title_fontsize)
|
||||
gs = gridspec.GridSpec(subplot_height, subplot_width)
|
||||
|
||||
plot_y_indx = 0
|
||||
plot_x_indx = 0
|
||||
|
||||
if plot_spectrogram:
|
||||
spec_ax = plt.subplot(gs[plot_y_indx : plot_y_indx + 2, :])
|
||||
plot_y_indx = plot_y_indx + 2
|
||||
fft_size = get_fft_size(plot_length=plot_length)
|
||||
|
||||
f, t_spec, Sxx = spectrogram(
|
||||
complex_signal[:plot_length],
|
||||
fs=sample_rate,
|
||||
nperseg=fft_size,
|
||||
noverlap=fft_size // 8,
|
||||
mode="magnitude",
|
||||
return_onesided=False,
|
||||
)
|
||||
|
||||
# shift frequencies so zero is centered
|
||||
Sxx = np.fft.fftshift(Sxx, axes=0)
|
||||
f = np.fft.fftshift(f) - sample_rate / 2 + center_frequency
|
||||
|
||||
spec_ax.imshow(
|
||||
10 * np.log10(Sxx + 1e-12),
|
||||
aspect="auto",
|
||||
origin="lower",
|
||||
extent=[t_spec[0], t_spec[-1], f[0], f[-1]],
|
||||
cmap="twilight",
|
||||
)
|
||||
|
||||
set_spines(spec_ax, spines)
|
||||
spec_ax.set_title("Spectrogram", loc="center", fontsize=subtitle_fontsize)
|
||||
spec_ax.set_ylabel("Frequency (Hz)")
|
||||
spec_ax.set_xlabel("Time (s)")
|
||||
|
||||
if iq:
|
||||
iq_ax = plt.subplot(gs[plot_y_indx : plot_y_indx + 2, :])
|
||||
plot_y_indx = plot_y_indx + 2
|
||||
|
||||
plot_iq = decimate(complex_signal[:plot_length])
|
||||
t = np.arange(len(plot_iq)) / sample_rate * (len(complex_signal[:plot_length]) / len(plot_iq))
|
||||
|
||||
iq_ax.plot(t, plot_iq.real, color=COLORS["purple"], linewidth=0.6, alpha=0.8, label="I")
|
||||
iq_ax.plot(t, plot_iq.imag, color=COLORS["magenta"], linewidth=0.6, alpha=0.8, label="Q")
|
||||
iq_ax.grid(False)
|
||||
|
||||
iq_ax.set_ylabel("Amplitude")
|
||||
iq_ax.set_xlim([min(t), max(t)])
|
||||
iq_ax.set_xlabel("Time (s)")
|
||||
iq_ax.set_title("IQ Sample Plot", fontsize=subtitle_fontsize)
|
||||
set_spines(iq_ax, spines)
|
||||
|
||||
if frequency:
|
||||
freq_ax = plt.subplot(gs[plot_y_indx : plot_y_indx + 2, :])
|
||||
plot_y_indx = plot_y_indx + 2
|
||||
|
||||
epsilon = 1e-10
|
||||
spectrum = np.abs(fftshift(fft(complex_signal[0:plot_length])))
|
||||
freqs = (
|
||||
np.linspace(-1 * (sample_rate / 2), (sample_rate / 2), len(complex_signal[0:plot_length]))
|
||||
+ center_frequency
|
||||
)
|
||||
|
||||
# Use semi-log for the y-axis
|
||||
freq_ax.semilogy(freqs, spectrum + epsilon, color=COLORS["accent"], linewidth=0.8)
|
||||
freq_ax.set_xlabel("Frequency")
|
||||
freq_ax.set_ylabel("Magnitude")
|
||||
freq_ax.set_title("Frequency Spectrum", fontsize=subtitle_fontsize)
|
||||
set_spines(freq_ax, spines)
|
||||
|
||||
if constellation:
|
||||
const_ax = plt.subplot(gs[plot_y_indx:, plot_x_indx])
|
||||
plot_x_indx = plot_x_indx + 1
|
||||
plot_const = decimate(complex_signal[:plot_length], 50_000)
|
||||
const_ax.scatter(plot_const.real, plot_const.imag, c=COLORS["purple"], s=1, linewidths=0.1)
|
||||
dimension = max(abs(complex_signal)) * 1.1
|
||||
const_ax.set_xlim([-1 * dimension, dimension])
|
||||
const_ax.set_ylim([-1 * dimension, dimension])
|
||||
const_ax.set_xlabel("In-phase (I)")
|
||||
const_ax.set_ylabel("Quadrature (Q)")
|
||||
const_ax.set_title("Constellation", fontsize=subtitle_fontsize)
|
||||
const_ax.set_aspect("equal")
|
||||
|
||||
if not spines:
|
||||
const_ax.spines["top"].set_visible(False)
|
||||
const_ax.spines["right"].set_visible(False)
|
||||
const_ax.spines["bottom"].set_visible(False)
|
||||
const_ax.spines["left"].set_visible(False)
|
||||
|
||||
# metadata text box
|
||||
if metadata:
|
||||
meta_ax = plt.subplot(gs[plot_y_indx:, plot_x_indx])
|
||||
plot_x_indx = plot_x_indx + 1
|
||||
metadata_text = "\n".join(
|
||||
[
|
||||
f"{key}: {textwrap.shorten(str(value), width=80, placeholder='...')}"
|
||||
for key, value in recording.metadata.items()
|
||||
]
|
||||
)
|
||||
|
||||
meta_ax.text(
|
||||
0.05,
|
||||
0.95,
|
||||
metadata_text,
|
||||
fontsize=10,
|
||||
va="top",
|
||||
ha="left",
|
||||
bbox=dict(facecolor="none", alpha=0.5, edgecolor="none"),
|
||||
)
|
||||
meta_ax.set_title("Metadata", fontsize=subtitle_fontsize)
|
||||
# Remove the tick labels
|
||||
meta_ax.xaxis.set_ticklabels([]) # Remove x-axis tick labels
|
||||
meta_ax.yaxis.set_ticklabels([]) # Remove y-axis tick labels
|
||||
meta_ax.set_xticks([])
|
||||
meta_ax.set_yticks([])
|
||||
set_spines(meta_ax, spines)
|
||||
|
||||
if logo and os.path.isfile(logo_path):
|
||||
logo_ax = plt.subplot(gs[plot_y_indx:, 2])
|
||||
plot_x_indx = plot_x_indx + 1
|
||||
logo_ax.axis("off")
|
||||
|
||||
try:
|
||||
image = Image.open(logo_path) # Open the PNG image using PIL
|
||||
logo_ax.imshow(image)
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Warning, {logo_path} not found.")
|
||||
|
||||
fig.subplots_adjust(
|
||||
left=0.1, # Left margin
|
||||
right=0.9, # Right margin
|
||||
top=0.9, # Top margin
|
||||
bottom=0.1, # Bottom margin
|
||||
wspace=0.4, # Horizontal space between subplots
|
||||
hspace=2.5, # Vertical space between subplots
|
||||
)
|
||||
|
||||
# save path handling
|
||||
output_path, _ = set_path(output_path=output_path)
|
||||
plt.savefig(output_path, dpi=dpi)
|
||||
print(f"Saved signal plot to {output_path}")
|
||||
328
src/ria_toolkit_oss/view/view_signal_simple.py
Normal file
328
src/ria_toolkit_oss/view/view_signal_simple.py
Normal file
|
|
@ -0,0 +1,328 @@
|
|||
"""Shared plotting primitives for signal visualization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from scipy.fft import fft, fftshift
|
||||
from scipy.signal.windows import hann
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
from ria_toolkit_oss.view.tools import (
|
||||
COLORS,
|
||||
decimate,
|
||||
extract_metadata_fields,
|
||||
set_path,
|
||||
)
|
||||
|
||||
|
||||
def _get_nfft_size(signal, fast_mode):
|
||||
if len(signal) < 1000:
|
||||
nfft = 128
|
||||
elif len(signal) < 10_000:
|
||||
nfft = 256
|
||||
elif len(signal) < 100_000:
|
||||
nfft = 512
|
||||
elif len(signal) < 1_000_000:
|
||||
nfft = 1024
|
||||
else:
|
||||
nfft = 2048
|
||||
|
||||
if fast_mode:
|
||||
nfft = min(nfft, 512)
|
||||
overlap = nfft // 8 if fast_mode else nfft // 4
|
||||
return nfft, overlap
|
||||
|
||||
|
||||
def _get_plot_samples(signal, fast_mode, slow_max, fast_max):
|
||||
max_samples = fast_max if fast_mode else slow_max
|
||||
if len(signal) > max_samples:
|
||||
start_idx = len(signal) // 2 - max_samples // 2
|
||||
return signal[start_idx : start_idx + max_samples]
|
||||
else:
|
||||
return signal
|
||||
|
||||
|
||||
def _set_dpi(fast_mode, labels_mode, extension):
|
||||
if fast_mode:
|
||||
dpi = 75
|
||||
elif labels_mode:
|
||||
dpi = 200
|
||||
else:
|
||||
dpi = 150
|
||||
return dpi if extension == "png" else None
|
||||
|
||||
|
||||
def setup_style(*, labels_mode: bool = False, compact_mode: bool = False) -> None:
|
||||
"""Configure matplotlib with the signal-testbed styling."""
|
||||
|
||||
plt.style.use("dark_background")
|
||||
|
||||
if compact_mode:
|
||||
base_font = 8
|
||||
title_font = 10
|
||||
label_font = 8
|
||||
elif labels_mode:
|
||||
base_font = 12
|
||||
title_font = 16
|
||||
label_font = 14
|
||||
else:
|
||||
base_font = 10
|
||||
title_font = 12
|
||||
label_font = 10
|
||||
|
||||
matplotlib.rcParams.update(
|
||||
{
|
||||
"figure.facecolor": "#0f172a",
|
||||
"axes.facecolor": "#1e293b",
|
||||
"axes.edgecolor": COLORS["muted"],
|
||||
"axes.labelcolor": COLORS["light"],
|
||||
"text.color": COLORS["light"],
|
||||
"xtick.color": COLORS["muted"],
|
||||
"ytick.color": COLORS["muted"],
|
||||
"grid.color": COLORS["muted"],
|
||||
"grid.alpha": 0.3,
|
||||
"font.size": base_font,
|
||||
"axes.titlesize": title_font,
|
||||
"axes.labelsize": label_font,
|
||||
"figure.titlesize": title_font + 2,
|
||||
"legend.frameon": False,
|
||||
"legend.facecolor": "none",
|
||||
"xtick.labelsize": base_font,
|
||||
"ytick.labelsize": base_font,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def detect_constellation_symbols(signal: np.ndarray, method: str = "differential") -> np.ndarray:
|
||||
"""Heuristic symbol detector used for constellation highlighting."""
|
||||
|
||||
if len(signal) < 100:
|
||||
return np.ones(len(signal), dtype=bool)
|
||||
|
||||
if method == "differential":
|
||||
di = np.diff(signal.imag)
|
||||
dq = np.diff(signal.real)
|
||||
derivative_magnitude = np.sqrt(di**2 + dq**2)
|
||||
derivative_magnitude = np.append(derivative_magnitude, 0)
|
||||
threshold = np.percentile(derivative_magnitude, 15)
|
||||
return derivative_magnitude < threshold
|
||||
|
||||
if method == "amplitude":
|
||||
amplitude = np.abs(signal)
|
||||
amplitude_change = np.abs(np.diff(amplitude))
|
||||
amplitude_change = np.append(amplitude_change, 0)
|
||||
threshold = np.percentile(amplitude_change, 20)
|
||||
return amplitude_change < threshold
|
||||
|
||||
if method == "phase":
|
||||
phase = np.angle(signal)
|
||||
phase_diff = np.diff(np.unwrap(phase))
|
||||
phase_diff = np.append(phase_diff, 0)
|
||||
threshold = np.percentile(np.abs(phase_diff), 20)
|
||||
return np.abs(phase_diff) < threshold
|
||||
|
||||
if method == "combined":
|
||||
diff_stable = detect_constellation_symbols(signal, "differential")
|
||||
amp_stable = detect_constellation_symbols(signal, "amplitude")
|
||||
phase_stable = detect_constellation_symbols(signal, "phase")
|
||||
stability_count = diff_stable.astype(int) + amp_stable.astype(int) + phase_stable.astype(int)
|
||||
return stability_count >= 2
|
||||
|
||||
raise ValueError(f"Unknown method: {method}")
|
||||
|
||||
|
||||
def view_simple_sig(
|
||||
recording: Recording,
|
||||
output_path: Optional[str] = "images/signal.png",
|
||||
saveplot: Optional[bool] = True,
|
||||
fast_mode: Optional[bool] = False,
|
||||
compact_mode: Optional[bool] = False,
|
||||
horizontal_mode: Optional[bool] = False,
|
||||
constellation_mode: Optional[bool] = False,
|
||||
labels_mode: Optional[bool] = False,
|
||||
slice: Optional[tuple] = None,
|
||||
title: Optional[str] = "Signal",
|
||||
):
|
||||
"""
|
||||
Create a simple plot of various signal visualizations as a png or svg image.
|
||||
|
||||
:param recording: The recording object to plot.
|
||||
:type recording: Recording
|
||||
:param output_path: The output image path. Defaults to "images/signal.png"
|
||||
:type output_path: str, optional
|
||||
:param saveplot: Whether or not to save the plot. Defaults to True.
|
||||
:type saveplot: bool, optional
|
||||
:param fast_mode: Use fast mode for faster render. Defaults to False.
|
||||
:type fast_mode: bool, optional
|
||||
:param compact_mode: Use compact mode for compact plot. Defaults to False.
|
||||
:type compact_mode: bool, optional
|
||||
:param horizontal_mode: Display plots horizontally. Defaults to False.
|
||||
:type horizontal_mode: bool, optional
|
||||
:param constellation_mode: Display constellation plot and PSD if not using compact mode. Defaults to False.
|
||||
:type constellation_mode: bool, optional
|
||||
:param labels_mode: Display more thorough labels. Defaults to False.
|
||||
:type labels_mode: bool, optional
|
||||
:param slice: Slice of signal to display. Defaults to None.
|
||||
:type slice: tuple[int, int], optional
|
||||
:param title: Title of plot. Defaults to "Signal".
|
||||
:type title: str, optional
|
||||
|
||||
"""
|
||||
|
||||
signal = recording.data[0]
|
||||
sample_rate_hz, center_freq_hz, sdr = extract_metadata_fields(recording.metadata)
|
||||
|
||||
setup_style(labels_mode=labels_mode, compact_mode=compact_mode)
|
||||
|
||||
if slice:
|
||||
start_idx, end_idx = slice
|
||||
signal = signal[start_idx:end_idx]
|
||||
print(f"Using slice: samples {start_idx} to {end_idx} ({len(signal):,} samples)")
|
||||
|
||||
max_display_pixels = 100_000 if fast_mode else 250_000
|
||||
display_signal = decimate(signal, max_display_pixels) if len(signal) > max_display_pixels else signal
|
||||
spec_signal = signal
|
||||
|
||||
if compact_mode:
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6), gridspec_kw={"height_ratios": [1, 5]})
|
||||
show_title = False
|
||||
show_labels = False
|
||||
ax_constellation = ax_psd = None
|
||||
elif horizontal_mode:
|
||||
if constellation_mode:
|
||||
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
|
||||
ax_constellation = ax3
|
||||
else:
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
|
||||
ax_constellation = None
|
||||
show_title = True
|
||||
show_labels = labels_mode
|
||||
ax_psd = None
|
||||
else:
|
||||
if constellation_mode:
|
||||
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
|
||||
ax_constellation, ax_psd = ax3, ax4
|
||||
else:
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
|
||||
ax_constellation = ax_psd = None
|
||||
show_title = True
|
||||
show_labels = labels_mode
|
||||
|
||||
if show_title:
|
||||
fig.suptitle(title, fontsize=16, color=COLORS["light"], y=0.96)
|
||||
fig.patch.set_facecolor("#0f172a")
|
||||
|
||||
total_duration_s = len(signal) / sample_rate_hz if sample_rate_hz else 0.0
|
||||
t_s = np.linspace(0, total_duration_s, len(display_signal)) if len(display_signal) else np.array([])
|
||||
|
||||
ax1.plot(t_s, display_signal.real, color=COLORS["purple"], linewidth=0.8, alpha=0.8, label="I")
|
||||
ax1.plot(t_s, display_signal.imag, color=COLORS["magenta"], linewidth=0.8, alpha=0.8, label="Q")
|
||||
ax1.set_xlim(0, total_duration_s)
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
nfft, overlap = _get_nfft_size(signal=signal, fast_mode=fast_mode)
|
||||
|
||||
_, freqs, _, _ = ax2.specgram(
|
||||
spec_signal,
|
||||
NFFT=nfft,
|
||||
Fc=center_freq_hz,
|
||||
Fs=sample_rate_hz,
|
||||
noverlap=overlap,
|
||||
cmap="twilight",
|
||||
)
|
||||
|
||||
ax2.set_ylim(center_freq_hz - sample_rate_hz / 2, center_freq_hz + sample_rate_hz / 2)
|
||||
ax2.set_xlim(0, total_duration_s)
|
||||
|
||||
if show_labels:
|
||||
if horizontal_mode:
|
||||
ax1.set_xlabel("Time (s)")
|
||||
else:
|
||||
ax2.set_xlabel("Time (s)")
|
||||
|
||||
ax1.set_ylabel("Amplitude")
|
||||
ax1.set_title(f"Time Series - {sdr} SDR")
|
||||
ax1.legend(loc="upper right")
|
||||
|
||||
ax2.set_ylabel("Frequency (Hz)")
|
||||
ax2.set_title(f"Spectrogram - {center_freq_hz / 1e6:.1f} MHz ± {sample_rate_hz / 2e6:.1f} MHz")
|
||||
yticks = ax2.get_yticks()
|
||||
ax2.set_yticklabels([f"{y / 1e6:.1f}" for y in yticks])
|
||||
elif not compact_mode:
|
||||
ax1.set_title("Time Series")
|
||||
ax1.legend(loc="upper right", fontsize=8)
|
||||
|
||||
ax2.set_title("Spectrogram")
|
||||
|
||||
if ax_constellation is not None:
|
||||
constellation_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=50_000, fast_max=20_000)
|
||||
method = "differential" if fast_mode else "combined"
|
||||
stable_points = detect_constellation_symbols(constellation_samples, method=method)
|
||||
|
||||
ax_constellation.scatter(
|
||||
constellation_samples.real[~stable_points],
|
||||
constellation_samples.imag[~stable_points],
|
||||
c=COLORS["muted"],
|
||||
s=0.5,
|
||||
alpha=0.2,
|
||||
)
|
||||
ax_constellation.scatter(
|
||||
constellation_samples.real[stable_points],
|
||||
constellation_samples.imag[stable_points],
|
||||
c=COLORS["purple"],
|
||||
s=3,
|
||||
alpha=0.8,
|
||||
)
|
||||
ax_constellation.set_xlabel("In-phase (I)")
|
||||
ax_constellation.set_ylabel("Quadrature (Q)")
|
||||
ax_constellation.set_title("Constellation")
|
||||
ax_constellation.grid(True, alpha=0.3)
|
||||
ax_constellation.set_aspect("equal")
|
||||
|
||||
if ax_psd is not None:
|
||||
psd_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=65_536, fast_max=16_384)
|
||||
window = hann(len(psd_samples))
|
||||
spectrum = np.abs(fftshift(fft(psd_samples * window))) ** 2
|
||||
freqs = np.linspace(-sample_rate_hz / 2, sample_rate_hz / 2, len(psd_samples))
|
||||
freqs = freqs + center_freq_hz
|
||||
spectrum_db = 10 * np.log10(spectrum + 1e-12)
|
||||
|
||||
ax_psd.plot(freqs / 1e6, spectrum_db, color=COLORS["accent"], linewidth=1.0)
|
||||
ax_psd.set_xlabel("Frequency (MHz)")
|
||||
ax_psd.set_ylabel("Power (dB)")
|
||||
ax_psd.set_title("Power Spectral Density")
|
||||
ax_psd.grid(True, alpha=0.3)
|
||||
|
||||
if compact_mode:
|
||||
ax1.set_xticks([])
|
||||
ax1.set_yticks([])
|
||||
ax2.set_xticks([])
|
||||
ax2.set_yticks([])
|
||||
|
||||
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, hspace=0)
|
||||
else:
|
||||
plt.tight_layout()
|
||||
if show_title:
|
||||
plt.subplots_adjust(top=0.92)
|
||||
|
||||
if saveplot:
|
||||
output_path, extension = set_path(output_path=output_path)
|
||||
dpi_value = _set_dpi(fast_mode=fast_mode, labels_mode=labels_mode, extension=extension)
|
||||
|
||||
plt.savefig(output_path, dpi=dpi_value, bbox_inches="tight", facecolor="#0f172a", edgecolor="none")
|
||||
print(f"Saved signal plot to {output_path}")
|
||||
return output_path
|
||||
|
||||
plt.show()
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"setup_style",
|
||||
"detect_constellation_symbols",
|
||||
"view_simple_sig",
|
||||
]
|
||||
562
src/ria_toolkit_oss/viz/onnx.py
Normal file
562
src/ria_toolkit_oss/viz/onnx.py
Normal file
|
|
@ -0,0 +1,562 @@
|
|||
"""
|
||||
ONNX model visualization utilities.
|
||||
|
||||
This module provides visualization functions for ONNX models following the same pattern
|
||||
as other ria-toolkit-oss visualization modules.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
try:
|
||||
import onnx
|
||||
import onnx.helper
|
||||
import onnx.numpy_helper
|
||||
|
||||
ONNX_AVAILABLE = True
|
||||
except ImportError:
|
||||
ONNX_AVAILABLE = False
|
||||
|
||||
|
||||
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure:
|
||||
"""Create a professional error figure with Qoherent dark theme styling."""
|
||||
fig = go.Figure()
|
||||
|
||||
# Create a clean, centered text display using Plotly's text formatting
|
||||
main_text = f"<b style='color:#f56565;font-size:18px'>⚠️ {title}</b><br><br>"
|
||||
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
|
||||
|
||||
if suggestion:
|
||||
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
|
||||
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
|
||||
|
||||
# Add the main text annotation
|
||||
fig.add_annotation(
|
||||
text=main_text,
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=0.5,
|
||||
y=0.5,
|
||||
xanchor="center",
|
||||
yanchor="middle",
|
||||
showarrow=False,
|
||||
align="center",
|
||||
borderwidth=2,
|
||||
bordercolor="#4a5568",
|
||||
bgcolor="#2d3748",
|
||||
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
|
||||
)
|
||||
|
||||
# Update layout with dark theme
|
||||
fig.update_layout(
|
||||
title="",
|
||||
height=400,
|
||||
template="plotly_dark",
|
||||
margin=dict(l=40, r=40, t=40, b=40),
|
||||
plot_bgcolor="#1a202c",
|
||||
paper_bgcolor="#1a202c",
|
||||
font=dict(color="#e2e8f0"),
|
||||
)
|
||||
|
||||
# Remove axes and grid
|
||||
fig.update_xaxes(visible=False)
|
||||
fig.update_yaxes(visible=False)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def graph_structure(file_path: Path) -> go.Figure:
|
||||
"""
|
||||
Visualize the ONNX model graph structure showing nodes and connections.
|
||||
Matches layout ID: graph_structure
|
||||
"""
|
||||
if not ONNX_AVAILABLE:
|
||||
return create_styled_error_figure(
|
||||
"ONNX Not Available", "ONNX library is required for model analysis.", "Install with: pip install onnx"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load ONNX model
|
||||
model = onnx.load(str(file_path))
|
||||
graph = model.graph
|
||||
nodes = graph.node
|
||||
|
||||
if len(nodes) == 0:
|
||||
return create_styled_error_figure(
|
||||
"Empty Model", "This ONNX model contains no operators.", "Please check if the model file is valid."
|
||||
)
|
||||
|
||||
# Create network diagram data
|
||||
node_info = []
|
||||
for i, node in enumerate(nodes):
|
||||
node_info.append(
|
||||
{
|
||||
"id": i,
|
||||
"name": node.name or f"{node.op_type}_{i}",
|
||||
"op_type": node.op_type,
|
||||
"inputs": len(node.input),
|
||||
"outputs": len(node.output),
|
||||
}
|
||||
)
|
||||
|
||||
# Create visualization
|
||||
fig = go.Figure()
|
||||
|
||||
# Simple linear layout for now
|
||||
x_positions = list(range(len(node_info)))
|
||||
y_positions = [0] * len(node_info)
|
||||
|
||||
# Add nodes as scatter points
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=x_positions,
|
||||
y=y_positions,
|
||||
mode="markers+text",
|
||||
marker=dict(
|
||||
size=[min(max(info["inputs"] + info["outputs"] + 15, 20), 50) for info in node_info],
|
||||
color=px.colors.qualitative.Set3[: len(node_info)],
|
||||
opacity=0.8,
|
||||
line=dict(width=2, color="white"),
|
||||
),
|
||||
text=[f"{info['op_type']}" for info in node_info],
|
||||
textposition="middle center",
|
||||
textfont=dict(size=10, color="white"),
|
||||
hovertemplate="<b>%{text}</b><br>"
|
||||
+ "Name: %{customdata[0]}<br>"
|
||||
+ "Inputs: %{customdata[1]}<br>"
|
||||
+ "Outputs: %{customdata[2]}<br>"
|
||||
+ "<extra></extra>",
|
||||
customdata=[[info["name"], info["inputs"], info["outputs"]] for info in node_info],
|
||||
name="Operators",
|
||||
)
|
||||
)
|
||||
|
||||
# Add connecting lines
|
||||
for i in range(len(node_info) - 1):
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[x_positions[i], x_positions[i + 1]],
|
||||
y=[y_positions[i], y_positions[i + 1]],
|
||||
mode="lines",
|
||||
line=dict(color="gray", width=1, dash="dot"),
|
||||
showlegend=False,
|
||||
hoverinfo="skip",
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title={
|
||||
"text": (
|
||||
"ONNX Graph Structure<br>"
|
||||
f"<span style='font-size:14px; color:#a0a0a0;'>{len(nodes)} Operators</span>"
|
||||
),
|
||||
"x": 0.5,
|
||||
"xanchor": "center",
|
||||
"font": {"size": 22},
|
||||
},
|
||||
xaxis_title="Execution Order",
|
||||
yaxis_title="",
|
||||
showlegend=False,
|
||||
height=500,
|
||||
template="plotly_dark",
|
||||
yaxis=dict(showticklabels=False, showgrid=False),
|
||||
xaxis=dict(showgrid=False),
|
||||
margin=dict(l=50, r=50, t=80, b=50),
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
return create_styled_error_figure(
|
||||
"Graph Analysis Error", "Could not analyze ONNX model structure.", f"Error: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def operator_analysis(file_path: Path) -> go.Figure:
|
||||
"""
|
||||
Analyze the distribution and types of operators in the ONNX model.
|
||||
Matches layout ID: operator_analysis
|
||||
"""
|
||||
if not ONNX_AVAILABLE:
|
||||
return create_styled_error_figure(
|
||||
"ONNX Not Available", "ONNX library is required for operator analysis.", "Install with: pip install onnx"
|
||||
)
|
||||
|
||||
try:
|
||||
model = onnx.load(str(file_path))
|
||||
graph = model.graph
|
||||
|
||||
# Count operators
|
||||
op_counts = {}
|
||||
for node in graph.node:
|
||||
op_type = node.op_type
|
||||
op_counts[op_type] = op_counts.get(op_type, 0) + 1
|
||||
|
||||
if not op_counts:
|
||||
return create_styled_error_figure(
|
||||
"No Operators",
|
||||
"This ONNX model contains no operators to analyze.",
|
||||
"Please verify the model file is valid.",
|
||||
)
|
||||
|
||||
# Sort by frequency
|
||||
sorted_ops = sorted(op_counts.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Create pie chart and bar chart
|
||||
fig = make_subplots(
|
||||
rows=2,
|
||||
cols=1,
|
||||
subplot_titles=("Operator Distribution", "Operator Frequency"),
|
||||
specs=[[{"type": "pie"}], [{"type": "bar"}]],
|
||||
)
|
||||
|
||||
# Pie chart for operator distribution
|
||||
op_names, op_values = zip(*sorted_ops) if sorted_ops else ([], [])
|
||||
|
||||
fig.add_trace(
|
||||
go.Pie(
|
||||
labels=list(op_names),
|
||||
values=list(op_values),
|
||||
textinfo="label+percent",
|
||||
textposition="auto",
|
||||
showlegend=False,
|
||||
),
|
||||
row=1,
|
||||
col=1,
|
||||
)
|
||||
|
||||
# Bar chart for frequency
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=list(op_names),
|
||||
y=list(op_values),
|
||||
marker_color=px.colors.qualitative.Set3[: len(op_names)],
|
||||
showlegend=False,
|
||||
),
|
||||
row=2,
|
||||
col=1,
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title={
|
||||
"text": (
|
||||
"ONNX Operator Analysis<br>"
|
||||
f"<span style='font-size:14px; color:#a0a0a0;'>{len(op_counts)} Unique Types</span>"
|
||||
),
|
||||
"x": 0.5,
|
||||
"xanchor": "center",
|
||||
"font": {"size": 22},
|
||||
},
|
||||
height=700,
|
||||
template="plotly_dark",
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
return create_styled_error_figure(
|
||||
"Operator Analysis Error", "Could not analyze ONNX operators.", f"Error: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def model_metadata(file_path: Path) -> go.Figure:
|
||||
"""
|
||||
Display comprehensive metadata about the ONNX model.
|
||||
Matches layout ID: model_metadata
|
||||
"""
|
||||
if not ONNX_AVAILABLE:
|
||||
return create_styled_error_figure(
|
||||
"ONNX Not Available", "ONNX library is required for metadata analysis.", "Install with: pip install onnx"
|
||||
)
|
||||
|
||||
try:
|
||||
model = onnx.load(str(file_path))
|
||||
graph = model.graph
|
||||
|
||||
# Calculate basic statistics
|
||||
total_nodes = len(graph.node)
|
||||
total_inputs = len(graph.input)
|
||||
total_outputs = len(graph.output)
|
||||
total_initializers = len(graph.initializer)
|
||||
|
||||
# Calculate parameter count
|
||||
total_params = 0
|
||||
for initializer in graph.initializer:
|
||||
try:
|
||||
tensor = onnx.numpy_helper.to_array(initializer)
|
||||
total_params += tensor.size
|
||||
except Exception:
|
||||
pass # Skip if tensor can't be loaded
|
||||
|
||||
# Get model file size
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
# Create metadata display
|
||||
fig = make_subplots(
|
||||
rows=2,
|
||||
cols=2,
|
||||
subplot_titles=("Model Size", "Architecture", "Inputs/Outputs", "Parameters"),
|
||||
specs=[[{"type": "indicator"}, {"type": "bar"}], [{"type": "table"}, {"type": "indicator"}]],
|
||||
)
|
||||
|
||||
# Model size indicator
|
||||
fig.add_trace(
|
||||
go.Indicator(
|
||||
mode="number+gauge",
|
||||
value=file_size_mb,
|
||||
title={"text": "Model Size (MB)"},
|
||||
number={"suffix": " MB", "valueformat": ".2f"},
|
||||
gauge={
|
||||
"axis": {"range": [0, max(100, file_size_mb * 1.5)]},
|
||||
"bar": {"color": "darkblue"},
|
||||
"steps": [
|
||||
{"range": [0, 10], "color": "lightgreen"},
|
||||
{"range": [10, 50], "color": "yellow"},
|
||||
{"range": [50, 100], "color": "orange"},
|
||||
],
|
||||
},
|
||||
),
|
||||
row=1,
|
||||
col=1,
|
||||
)
|
||||
|
||||
# Architecture components
|
||||
arch_data = ["Nodes", "Inputs", "Outputs", "Initializers"]
|
||||
arch_values = [total_nodes, total_inputs, total_outputs, total_initializers]
|
||||
|
||||
fig.add_trace(
|
||||
go.Bar(x=arch_data, y=arch_values, marker_color=["blue", "green", "orange", "red"], showlegend=False),
|
||||
row=1,
|
||||
col=2,
|
||||
)
|
||||
|
||||
# I/O Table
|
||||
io_data = []
|
||||
|
||||
# Add input info
|
||||
for inp in graph.input[:5]: # Limit to first 5
|
||||
shape = "Unknown"
|
||||
dtype = "Unknown"
|
||||
if inp.type and inp.type.tensor_type:
|
||||
# Get shape
|
||||
if inp.type.tensor_type.shape:
|
||||
dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in inp.type.tensor_type.shape.dim]
|
||||
shape = f"[{', '.join(dims)}]"
|
||||
|
||||
# Get data type
|
||||
elem_type = inp.type.tensor_type.elem_type
|
||||
type_map = {
|
||||
1: "float32",
|
||||
2: "uint8",
|
||||
3: "int8",
|
||||
6: "int32",
|
||||
7: "int64",
|
||||
9: "bool",
|
||||
10: "float16",
|
||||
11: "double",
|
||||
}
|
||||
dtype = type_map.get(elem_type, f"type_{elem_type}")
|
||||
|
||||
io_data.append(["Input", inp.name[:20], shape, dtype])
|
||||
|
||||
# Add output info
|
||||
for out in graph.output[:5]: # Limit to first 5
|
||||
shape = "Unknown"
|
||||
dtype = "Unknown"
|
||||
if out.type and out.type.tensor_type:
|
||||
if out.type.tensor_type.shape:
|
||||
dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in out.type.tensor_type.shape.dim]
|
||||
shape = f"[{', '.join(dims)}]"
|
||||
|
||||
elem_type = out.type.tensor_type.elem_type
|
||||
type_map = {
|
||||
1: "float32",
|
||||
2: "uint8",
|
||||
3: "int8",
|
||||
6: "int32",
|
||||
7: "int64",
|
||||
9: "bool",
|
||||
10: "float16",
|
||||
11: "double",
|
||||
}
|
||||
dtype = type_map.get(elem_type, f"type_{elem_type}")
|
||||
|
||||
io_data.append(["Output", out.name[:20], shape, dtype])
|
||||
|
||||
if io_data:
|
||||
fig.add_trace(
|
||||
go.Table(
|
||||
header=dict(values=["Type", "Name", "Shape", "Data Type"], fill_color="lightblue", align="left"),
|
||||
cells=dict(values=list(zip(*io_data)), fill_color="white", align="left"),
|
||||
),
|
||||
row=2,
|
||||
col=1,
|
||||
)
|
||||
|
||||
# Parameters indicator
|
||||
fig.add_trace(
|
||||
go.Indicator(
|
||||
mode="number",
|
||||
value=total_params,
|
||||
title={"text": "Total Parameters"},
|
||||
number={"suffix": "M", "valueformat": ".2f"},
|
||||
number_font_size=30,
|
||||
),
|
||||
row=2,
|
||||
col=2,
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title={
|
||||
"text": (
|
||||
"ONNX Model Metadata<br>"
|
||||
f"<span style='font-size:14px; color:#a0a0a0;'>{total_params/1e6:.2f}M Parameters</span>"
|
||||
),
|
||||
"x": 0.5,
|
||||
"xanchor": "center",
|
||||
"font": {"size": 22},
|
||||
},
|
||||
height=600,
|
||||
template="plotly_dark",
|
||||
showlegend=False,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
return create_styled_error_figure(
|
||||
"Metadata Analysis Error", "Could not extract ONNX model metadata.", f"Error: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def performance_metrics(file_path: Path) -> go.Figure:
|
||||
"""
|
||||
Display performance and computational metrics for the ONNX model.
|
||||
Matches layout ID: performance_metrics
|
||||
"""
|
||||
if not ONNX_AVAILABLE:
|
||||
return create_styled_error_figure(
|
||||
"ONNX Not Available",
|
||||
"ONNX library is required for performance analysis.",
|
||||
"Install with: pip install onnx",
|
||||
)
|
||||
|
||||
try:
|
||||
model = onnx.load(str(file_path))
|
||||
graph = model.graph
|
||||
|
||||
# Calculate metrics
|
||||
model_size_bytes = file_path.stat().st_size
|
||||
model_size_mb = model_size_bytes / (1024 * 1024)
|
||||
|
||||
# Count parameters
|
||||
total_params = 0
|
||||
for initializer in graph.initializer:
|
||||
try:
|
||||
tensor = onnx.numpy_helper.to_array(initializer)
|
||||
total_params += tensor.size
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Estimate memory usage (rough approximation)
|
||||
param_memory_mb = (total_params * 4) / (1024 * 1024) # Assume float32
|
||||
|
||||
# Count operations by complexity
|
||||
compute_ops = ["Conv", "MatMul", "Gemm", "LSTM", "GRU"]
|
||||
efficient_ops = ["Relu", "Add", "Mul", "BatchNormalization", "Dropout"]
|
||||
|
||||
compute_count = sum(1 for node in graph.node if any(op in node.op_type for op in compute_ops))
|
||||
efficient_count = sum(1 for node in graph.node if any(op in node.op_type for op in efficient_ops))
|
||||
total_ops = len(graph.node)
|
||||
other_count = total_ops - compute_count - efficient_count
|
||||
|
||||
# Create performance dashboard
|
||||
fig = make_subplots(
|
||||
rows=2,
|
||||
cols=2,
|
||||
subplot_titles=("Model Efficiency", "Memory Usage", "Operation Types", "Complexity Score"),
|
||||
specs=[[{"type": "bar"}, {"type": "bar"}], [{"type": "pie"}, {"type": "indicator"}]],
|
||||
)
|
||||
|
||||
# Model efficiency metrics
|
||||
efficiency_metrics = ["Model Size (MB)", "Parameters (M)", "Total Ops"]
|
||||
efficiency_values = [model_size_mb, total_params / 1e6, total_ops]
|
||||
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=efficiency_metrics, y=efficiency_values, marker_color=["blue", "green", "orange"], showlegend=False
|
||||
),
|
||||
row=1,
|
||||
col=1,
|
||||
)
|
||||
|
||||
# Memory usage
|
||||
memory_types = ["Parameters", "Est. Inference"]
|
||||
memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate
|
||||
|
||||
fig.add_trace(
|
||||
go.Bar(x=memory_types, y=memory_values, marker_color=["purple", "red"], showlegend=False),
|
||||
row=1,
|
||||
col=2,
|
||||
)
|
||||
|
||||
# Operation types pie chart
|
||||
fig.add_trace(
|
||||
go.Pie(
|
||||
labels=["Compute Ops", "Efficient Ops", "Other Ops"],
|
||||
values=[compute_count, efficient_count, other_count],
|
||||
marker_colors=["red", "green", "gray"],
|
||||
),
|
||||
row=2,
|
||||
col=1,
|
||||
)
|
||||
|
||||
# Complexity score (simple heuristic)
|
||||
complexity_score = min(100, (model_size_mb * 10 + total_params / 1e6 * 20 + compute_count))
|
||||
|
||||
fig.add_trace(
|
||||
go.Indicator(
|
||||
mode="gauge+number",
|
||||
value=complexity_score,
|
||||
title={"text": "Complexity Score"},
|
||||
gauge={
|
||||
"axis": {"range": [0, 100]},
|
||||
"bar": {
|
||||
"color": "darkred" if complexity_score > 70 else "orange" if complexity_score > 40 else "green"
|
||||
},
|
||||
"steps": [
|
||||
{"range": [0, 40], "color": "lightgreen"},
|
||||
{"range": [40, 70], "color": "yellow"},
|
||||
{"range": [70, 100], "color": "lightcoral"},
|
||||
],
|
||||
},
|
||||
),
|
||||
row=2,
|
||||
col=2,
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title={
|
||||
"text": (
|
||||
"ONNX Performance Metrics<br>"
|
||||
f"<span style='font-size:14px; color:#a0a0a0;'>"
|
||||
f"Complexity Score: {complexity_score:.0f}/100</span>"
|
||||
),
|
||||
"x": 0.5,
|
||||
"xanchor": "center",
|
||||
"font": {"size": 22},
|
||||
},
|
||||
height=600,
|
||||
template="plotly_dark",
|
||||
showlegend=False,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
return create_styled_error_figure(
|
||||
"Performance Analysis Error", "Could not analyze ONNX model performance.", f"Error: {str(e)}"
|
||||
)
|
||||
194
src/ria_toolkit_oss/viz/pytorch_state_dict.py
Normal file
194
src/ria_toolkit_oss/viz/pytorch_state_dict.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
from plotly.graph_objects import Figure
|
||||
|
||||
|
||||
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure:
|
||||
"""Create a professional error figure with Qoherent dark theme styling."""
|
||||
fig = go.Figure()
|
||||
|
||||
# Create a clean, centered text display using Plotly's text formatting
|
||||
main_text = f"<b style='color:#f56565;font-size:18px'>⚠️ {title}</b><br><br>"
|
||||
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
|
||||
|
||||
if suggestion:
|
||||
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
|
||||
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
|
||||
|
||||
# Add the main text annotation
|
||||
fig.add_annotation(
|
||||
text=main_text,
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=0.5,
|
||||
y=0.5,
|
||||
xanchor="center",
|
||||
yanchor="middle",
|
||||
showarrow=False,
|
||||
align="center",
|
||||
borderwidth=2,
|
||||
bordercolor="#4a5568",
|
||||
bgcolor="#2d3748",
|
||||
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
|
||||
)
|
||||
|
||||
# Update layout with dark theme
|
||||
fig.update_layout(
|
||||
title="",
|
||||
height=400,
|
||||
template="plotly_dark",
|
||||
margin=dict(l=40, r=40, t=40, b=40),
|
||||
plot_bgcolor="#1a202c",
|
||||
paper_bgcolor="#1a202c",
|
||||
font=dict(color="#e2e8f0"),
|
||||
)
|
||||
|
||||
# Remove axes and grid
|
||||
fig.update_xaxes(visible=False)
|
||||
fig.update_yaxes(visible=False)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def model_summary_plot(state_dict: dict) -> Figure:
|
||||
"""Generate a summary plot of the PyTorch model state dict."""
|
||||
if not state_dict:
|
||||
return create_styled_error_figure(
|
||||
"Empty State Dict",
|
||||
"No parameters found in state dict",
|
||||
"Ensure the model state dictionary contains weight parameters",
|
||||
)
|
||||
# Count parameters by layer type
|
||||
layer_info = []
|
||||
for key, tensor in state_dict.items():
|
||||
if "weight" in key:
|
||||
try:
|
||||
layer_name = key.replace(".weight", "")
|
||||
param_count = (
|
||||
tensor.numel()
|
||||
if hasattr(tensor, "numel")
|
||||
else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0
|
||||
)
|
||||
shape = (
|
||||
list(tensor.shape)
|
||||
if hasattr(tensor, "shape")
|
||||
else [len(tensor)] if hasattr(tensor, "__len__") else []
|
||||
)
|
||||
layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape})
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process layer {key}: {e}")
|
||||
continue
|
||||
if not layer_info:
|
||||
return create_styled_error_figure(
|
||||
"No Weight Layers Found",
|
||||
"No weight layers found in state dict",
|
||||
"Ensure the state dictionary contains layers with '.weight' parameters",
|
||||
)
|
||||
# Create bar chart of parameter counts
|
||||
fig = go.Figure(
|
||||
data=[
|
||||
go.Bar(
|
||||
x=[info["layer"] for info in layer_info],
|
||||
y=[info["parameters"] for info in layer_info],
|
||||
text=[f"Shape: {info['shape']}" for info in layer_info],
|
||||
textposition="auto",
|
||||
)
|
||||
]
|
||||
)
|
||||
fig.update_layout(
|
||||
title="Model Layer Parameter Counts",
|
||||
xaxis_title="Layer",
|
||||
yaxis_title="Number of Parameters",
|
||||
template="plotly_dark",
|
||||
)
|
||||
return fig
|
||||
|
||||
|
||||
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
|
||||
"""Visualize weights for a specific layer."""
|
||||
if not state_dict:
|
||||
return create_styled_error_figure(
|
||||
"Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data"
|
||||
)
|
||||
if layer_name is None:
|
||||
# Get first weight tensor
|
||||
weight_keys = [k for k in state_dict.keys() if "weight" in k]
|
||||
if not weight_keys:
|
||||
return create_styled_error_figure(
|
||||
"No Weight Tensors Found",
|
||||
"No weight tensors found in state dict",
|
||||
"Ensure the state dictionary contains layers with '.weight' parameters",
|
||||
)
|
||||
layer_name = weight_keys[0]
|
||||
try:
|
||||
weights = state_dict[layer_name]
|
||||
# Convert to numpy if it's a torch tensor
|
||||
if hasattr(weights, "numpy"):
|
||||
weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy()
|
||||
elif hasattr(weights, "cpu"):
|
||||
weights_np = weights.cpu().detach().numpy()
|
||||
else:
|
||||
weights_np = np.array(weights)
|
||||
# For 2D weights, create heatmap
|
||||
if len(weights_np.shape) == 2:
|
||||
fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0))
|
||||
fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark")
|
||||
else:
|
||||
# For other shapes, flatten and show histogram
|
||||
flat_weights = weights_np.flatten()
|
||||
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
|
||||
fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark")
|
||||
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
return create_styled_error_figure(
|
||||
"Layer Processing Error",
|
||||
f"Error processing layer {layer_name}: {str(e)}",
|
||||
"Check that the layer name exists and contains valid tensor data",
|
||||
)
|
||||
|
||||
|
||||
def weight_distribution_plot(state_dict: dict) -> Figure:
|
||||
"""Show distribution of weights across all layers."""
|
||||
if not state_dict:
|
||||
return create_styled_error_figure(
|
||||
"Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data"
|
||||
)
|
||||
|
||||
all_weights = []
|
||||
layer_names = []
|
||||
|
||||
for key, tensor in state_dict.items():
|
||||
if "weight" in key:
|
||||
try:
|
||||
# Convert to numpy if it's a torch tensor
|
||||
if hasattr(tensor, "numpy"):
|
||||
weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy()
|
||||
elif hasattr(tensor, "cpu"):
|
||||
weights_np = tensor.cpu().detach().numpy()
|
||||
else:
|
||||
weights_np = np.array(tensor)
|
||||
flat_weights = weights_np.flatten()
|
||||
all_weights.extend(flat_weights)
|
||||
layer_names.extend([key] * len(flat_weights))
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process weights for layer {key}: {e}")
|
||||
continue
|
||||
|
||||
if not all_weights:
|
||||
return create_styled_error_figure(
|
||||
"No Weight Data Found",
|
||||
"No weight data found in state dict",
|
||||
"Ensure the state dictionary contains layers with '.weight' parameters",
|
||||
)
|
||||
|
||||
fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")])
|
||||
|
||||
fig.update_layout(
|
||||
title="Overall Weight Distribution",
|
||||
xaxis_title="Weight Value",
|
||||
yaxis_title="Frequency",
|
||||
template="plotly_dark",
|
||||
)
|
||||
return fig
|
||||
432
src/ria_toolkit_oss/viz/radio_dataset.py
Normal file
432
src/ria_toolkit_oss/viz/radio_dataset.py
Normal file
|
|
@ -0,0 +1,432 @@
|
|||
"""
|
||||
Simple, clean visualization utilities for RadioDataset analysis.
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
from plotly.graph_objects import Figure
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
|
||||
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> Figure:
|
||||
"""Create a professional error figure with Qoherent dark theme styling."""
|
||||
fig = go.Figure()
|
||||
|
||||
# Create a clean, centered text display using Plotly's text formatting
|
||||
main_text = f"<b style='color:#f56565;font-size:18px'>⚠️ {title}</b><br><br>"
|
||||
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
|
||||
|
||||
if suggestion:
|
||||
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
|
||||
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
|
||||
|
||||
# Add the main text annotation
|
||||
fig.add_annotation(
|
||||
text=main_text,
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=0.5,
|
||||
y=0.5,
|
||||
xanchor="center",
|
||||
yanchor="middle",
|
||||
showarrow=False,
|
||||
align="center",
|
||||
borderwidth=2,
|
||||
bordercolor="#4a5568",
|
||||
bgcolor="#2d3748",
|
||||
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
|
||||
)
|
||||
|
||||
# Update layout with dark theme
|
||||
fig.update_layout(
|
||||
title="",
|
||||
height=400,
|
||||
template="plotly_dark",
|
||||
margin=dict(l=40, r=40, t=40, b=40),
|
||||
plot_bgcolor="#1a202c",
|
||||
paper_bgcolor="#1a202c",
|
||||
font=dict(color="#e2e8f0"),
|
||||
)
|
||||
|
||||
# Remove axes and grid
|
||||
fig.update_xaxes(visible=False)
|
||||
fig.update_yaxes(visible=False)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _check_dataset_compatibility(dataset, plot_type: str) -> tuple[bool, str]:
|
||||
"""Check if dataset is compatible with a specific plot type.
|
||||
Returns (is_compatible, error_message)
|
||||
"""
|
||||
try:
|
||||
metadata = dataset.metadata
|
||||
|
||||
if len(metadata) == 0:
|
||||
return False, "Dataset is empty"
|
||||
|
||||
if plot_type == "class_distribution":
|
||||
# Check if we have any categorical columns
|
||||
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
|
||||
alternatives = ["class", "label", "modulation", "impairment", "use_case", "category", "labels"]
|
||||
|
||||
has_class_col = any(alt in metadata.columns for alt in alternatives)
|
||||
has_categorical = len(categorical_cols) > 0
|
||||
|
||||
if not has_class_col and not has_categorical:
|
||||
return False, "No categorical columns found for class distribution"
|
||||
|
||||
elif plot_type == "sample_spectrogram":
|
||||
# Check if we can generate a valid spectrogram
|
||||
if len(metadata) < 1:
|
||||
return False, "No samples available for spectrogram"
|
||||
|
||||
# Check if we can access sample data (basic test)
|
||||
try:
|
||||
sample_data = dataset[0] if hasattr(dataset, "__getitem__") else None
|
||||
if sample_data is None or len(sample_data) < 32:
|
||||
return False, "Insufficient sample data for spectrogram (need at least 32 points)"
|
||||
except Exception:
|
||||
# If we can't access data, we'll rely on synthetic data generation
|
||||
pass
|
||||
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
return False, f"Dataset compatibility check failed: {str(e)}"
|
||||
|
||||
|
||||
def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
|
||||
"""Generate a bar plot showing the distribution of examples across classes."""
|
||||
try:
|
||||
# Check dataset compatibility first
|
||||
is_compatible, error_msg = _check_dataset_compatibility(dataset, "class_distribution")
|
||||
if not is_compatible:
|
||||
return create_styled_error_figure(
|
||||
"Dataset Not Compatible",
|
||||
"This dataset doesn't have categorical labels needed for class distribution analysis.",
|
||||
"Try using the Dataset Overview widget to explore the available data columns.",
|
||||
)
|
||||
|
||||
metadata = dataset.metadata
|
||||
|
||||
# Find the class column
|
||||
if class_key not in metadata.columns:
|
||||
# Try common alternatives
|
||||
alternatives = ["class", "label", "modulation", "impairment", "use_case", "category", "labels"]
|
||||
for alt in alternatives:
|
||||
if alt in metadata.columns:
|
||||
class_key = alt
|
||||
break
|
||||
else:
|
||||
# Use first categorical column
|
||||
for col in metadata.columns:
|
||||
if metadata[col].dtype == "object" or metadata[col].nunique() < 50:
|
||||
class_key = col
|
||||
break
|
||||
|
||||
if class_key not in metadata.columns:
|
||||
return create_styled_error_figure(
|
||||
"No Class Labels Found",
|
||||
"This dataset contains numerical data without categorical labels.",
|
||||
(
|
||||
"Try using the Dataset Overview widget for data analysis, "
|
||||
"or check if your dataset has hidden categorical columns."
|
||||
),
|
||||
)
|
||||
|
||||
# Count examples per class (limit to top 20 for performance)
|
||||
class_counts = metadata[class_key].value_counts()
|
||||
if len(class_counts) > 20:
|
||||
class_counts = class_counts.head(20)
|
||||
|
||||
class_counts = class_counts.sort_index()
|
||||
|
||||
# Create simple bar plot
|
||||
fig = px.bar(x=class_counts.index, y=class_counts.values, title=f"Class Distribution: {class_key.title()}")
|
||||
|
||||
fig.update_traces(texttemplate="%{y}", textposition="outside")
|
||||
fig.update_layout(
|
||||
xaxis_title=class_key.title(),
|
||||
yaxis_title="Number of Examples",
|
||||
showlegend=False,
|
||||
height=400,
|
||||
template="plotly_dark",
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
return create_styled_error_figure(
|
||||
"Class Distribution Error",
|
||||
"An error occurred while generating the class distribution plot.",
|
||||
f"Technical details: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
def dataset_overview_plot(dataset) -> Figure:
|
||||
"""Generate an overview plot with key dataset statistics."""
|
||||
try:
|
||||
metadata = dataset.metadata
|
||||
total_examples = len(metadata)
|
||||
|
||||
# Create subplot with multiple charts
|
||||
|
||||
# Determine subplot titles based on data type
|
||||
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
|
||||
numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ["int64", "float64"]]
|
||||
|
||||
dist_title = "Value Distribution" if categorical_cols else "Data Distribution"
|
||||
|
||||
fig = make_subplots(
|
||||
rows=2,
|
||||
cols=2,
|
||||
subplot_titles=("Dataset Size", "Data Types", dist_title, "Statistics Summary"),
|
||||
specs=[
|
||||
[{"type": "indicator"}, {"type": "bar"}],
|
||||
[{"type": "histogram" if not categorical_cols else "bar"}, {"type": "table"}],
|
||||
],
|
||||
)
|
||||
|
||||
# Top left: Dataset size indicator
|
||||
fig.add_trace(
|
||||
go.Indicator(
|
||||
mode="number", value=total_examples, title={"text": "Total Examples"}, number={"font": {"size": 40}}
|
||||
),
|
||||
row=1,
|
||||
col=1,
|
||||
)
|
||||
|
||||
# Top right: Data types distribution
|
||||
dtype_counts = metadata.dtypes.value_counts()
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=[str(dt) for dt in dtype_counts.index], y=dtype_counts.values, name="Data Types", showlegend=False
|
||||
),
|
||||
row=1,
|
||||
col=2,
|
||||
)
|
||||
|
||||
# Bottom left: Show distribution of numeric columns or categorical if available
|
||||
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
|
||||
numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ["int64", "float64"]]
|
||||
|
||||
if categorical_cols:
|
||||
col = categorical_cols[0] # Show first categorical column
|
||||
value_counts = metadata[col].value_counts().head(10)
|
||||
fig.add_trace(
|
||||
go.Bar(x=value_counts.index, y=value_counts.values, name=f"{col} Distribution", showlegend=False),
|
||||
row=2,
|
||||
col=1,
|
||||
)
|
||||
elif numeric_cols:
|
||||
# Show histogram of first numeric column
|
||||
col = numeric_cols[0]
|
||||
fig.add_trace(
|
||||
go.Histogram(x=metadata[col], name=f"{col} Distribution", showlegend=False, nbinsx=20), row=2, col=1
|
||||
)
|
||||
|
||||
# Bottom right: Basic statistics table
|
||||
stats_data = []
|
||||
display_cols = numeric_cols[:5] if len(numeric_cols) > 0 else metadata.columns[:5]
|
||||
|
||||
for col in display_cols:
|
||||
if metadata[col].dtype in ["int64", "float64"]:
|
||||
stats_data.append(
|
||||
[
|
||||
col[:15] + "..." if len(col) > 15 else col, # Truncate long column names
|
||||
f"{metadata[col].mean():.3f}",
|
||||
f"{metadata[col].std():.3f}",
|
||||
f"{metadata[col].min():.3f}",
|
||||
f"{metadata[col].max():.3f}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
unique_count = metadata[col].nunique()
|
||||
stats_data.append(
|
||||
[col[:15] + "..." if len(col) > 15 else col, "N/A", "N/A", f"{unique_count} unique", "N/A"]
|
||||
)
|
||||
|
||||
if stats_data:
|
||||
fig.add_trace(
|
||||
go.Table(
|
||||
header=dict(
|
||||
values=["Column", "Mean", "Std", "Min/Unique", "Max"],
|
||||
fill_color="rgba(30, 30, 30, 0.8)",
|
||||
align="center",
|
||||
font=dict(color="white", size=12),
|
||||
),
|
||||
cells=dict(
|
||||
values=list(zip(*stats_data)),
|
||||
fill_color="rgba(50, 50, 50, 0.6)",
|
||||
align="center",
|
||||
font=dict(color="white", size=11),
|
||||
),
|
||||
),
|
||||
row=2,
|
||||
col=2,
|
||||
)
|
||||
|
||||
# Create informative title
|
||||
total_cols = len(metadata.columns)
|
||||
title = f"Dataset Overview - {total_examples} samples, {total_cols} columns"
|
||||
if total_cols > 5:
|
||||
title += " (showing first 5)"
|
||||
|
||||
fig.update_layout(title=title, height=600, showlegend=False, template="plotly_dark")
|
||||
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
return create_styled_error_figure(
|
||||
"Dataset Overview Error",
|
||||
"An error occurred while generating the dataset overview.",
|
||||
f"Technical details: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
def _find_class_column(metadata, class_key: str) -> str:
|
||||
"""Find the appropriate class column in metadata."""
|
||||
if class_key in metadata.columns:
|
||||
return class_key
|
||||
|
||||
alternatives = ["class", "label", "modulation", "impairment", "use_case"]
|
||||
for alt in alternatives:
|
||||
if alt in metadata.columns:
|
||||
return alt
|
||||
return class_key
|
||||
|
||||
|
||||
def _get_sample_data(dataset, sample_idx: int):
|
||||
"""Get sample data from dataset, with synthetic fallback."""
|
||||
try:
|
||||
return dataset[sample_idx]
|
||||
except Exception:
|
||||
# Generate synthetic signal based on class
|
||||
n_samples = 1024
|
||||
t = np.linspace(0, 1, n_samples)
|
||||
freq = 0.1 + 0.05 * sample_idx % 5 # Vary frequency by sample
|
||||
sample_data = np.exp(1j * 2 * np.pi * freq * t)
|
||||
# Add some noise
|
||||
sample_data += 0.1 * (np.random.randn(n_samples) + 1j * np.random.randn(n_samples))
|
||||
return sample_data
|
||||
|
||||
|
||||
def _calculate_spectrogram_params(n_samples: int) -> tuple[int, int, int, int]:
|
||||
"""Calculate spectrogram parameters based on sample length."""
|
||||
if n_samples < 32:
|
||||
raise ValueError(f"Insufficient data: need at least 32 samples, got {n_samples}")
|
||||
|
||||
nperseg = min(256, max(32, n_samples // 4))
|
||||
hop_length = max(1, nperseg // 2)
|
||||
|
||||
# Adjust for very short signals
|
||||
if n_samples < nperseg:
|
||||
nperseg = n_samples
|
||||
hop_length = 1
|
||||
|
||||
n_frames = max(1, (n_samples - nperseg) // hop_length + 1)
|
||||
freq_bins = max(1, nperseg // 2)
|
||||
|
||||
return nperseg, hop_length, n_frames, freq_bins
|
||||
|
||||
|
||||
def _compute_spectrogram(sample_data, nperseg: int, hop_length: int, n_frames: int, freq_bins: int):
|
||||
"""Compute spectrogram using FFT."""
|
||||
n_samples = len(sample_data)
|
||||
Sxx = np.zeros((freq_bins, n_frames))
|
||||
|
||||
for i in range(n_frames):
|
||||
start_idx = i * hop_length
|
||||
end_idx = min(start_idx + nperseg, n_samples)
|
||||
|
||||
if end_idx > start_idx:
|
||||
windowed = sample_data[start_idx:end_idx]
|
||||
|
||||
# Pad if necessary to maintain nperseg size
|
||||
if len(windowed) < nperseg:
|
||||
windowed = np.pad(windowed, (0, nperseg - len(windowed)), mode="constant")
|
||||
|
||||
fft_result = np.fft.fft(windowed)
|
||||
Sxx[:, i] = np.abs(fft_result[:freq_bins]) ** 2
|
||||
|
||||
return Sxx
|
||||
|
||||
|
||||
def _create_spectrogram_figure(
|
||||
Sxx,
|
||||
n_frames: int,
|
||||
hop_length: int,
|
||||
n_samples: int,
|
||||
freq_bins: int,
|
||||
sample_idx: int,
|
||||
class_key: str,
|
||||
sample_metadata,
|
||||
) -> Figure:
|
||||
"""Create the plotly figure for the spectrogram."""
|
||||
# Convert to dB
|
||||
Sxx_db = 10 * np.log10(Sxx + 1e-10)
|
||||
|
||||
# Create time and frequency vectors
|
||||
t = np.arange(n_frames) * hop_length / max(1, n_samples)
|
||||
f = np.linspace(0, 0.5, freq_bins)
|
||||
|
||||
# Create plot
|
||||
fig = go.Figure(data=go.Heatmap(z=Sxx_db, x=t, y=f, colorscale="viridis", colorbar=dict(title="Power (dB)")))
|
||||
|
||||
# Add title with metadata
|
||||
title = f"Sample Spectrogram (Index: {sample_idx})"
|
||||
if class_key in sample_metadata:
|
||||
title += f" - {class_key}: {sample_metadata[class_key]}"
|
||||
|
||||
fig.update_layout(title=title, xaxis_title="Time", yaxis_title="Frequency", height=400, template="plotly_dark")
|
||||
return fig
|
||||
|
||||
|
||||
def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx: Optional[int] = None) -> Figure:
|
||||
"""Generate a spectrogram plot from a sample in the dataset."""
|
||||
try:
|
||||
# Check dataset compatibility first
|
||||
is_compatible, error_msg = _check_dataset_compatibility(dataset, "sample_spectrogram")
|
||||
if not is_compatible:
|
||||
return create_styled_error_figure(
|
||||
"Spectrogram Not Available",
|
||||
"This dataset doesn't have sufficient signal data for spectrogram visualization.",
|
||||
"Ensure your dataset contains complex-valued signal samples with at least 32 data points per sample.",
|
||||
)
|
||||
|
||||
metadata = dataset.metadata
|
||||
if len(metadata) == 0:
|
||||
raise ValueError("Dataset is empty")
|
||||
|
||||
# Find class column and select sample
|
||||
class_key = _find_class_column(metadata, class_key)
|
||||
if sample_idx is None:
|
||||
sample_idx = random.randint(0, len(metadata) - 1)
|
||||
sample_metadata = metadata.iloc[sample_idx]
|
||||
|
||||
# Get sample data and ensure it's complex
|
||||
sample_data = _get_sample_data(dataset, sample_idx)
|
||||
if not np.iscomplexobj(sample_data):
|
||||
sample_data = sample_data.astype(complex)
|
||||
|
||||
# Calculate spectrogram parameters and compute spectrogram
|
||||
n_samples = len(sample_data)
|
||||
nperseg, hop_length, n_frames, freq_bins = _calculate_spectrogram_params(n_samples)
|
||||
Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins)
|
||||
|
||||
# Create and return the figure
|
||||
return _create_spectrogram_figure(
|
||||
Sxx, n_frames, hop_length, n_samples, freq_bins, sample_idx, class_key, sample_metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return create_styled_error_figure(
|
||||
"Spectrogram Error",
|
||||
"An error occurred while generating the spectrogram plot.",
|
||||
f"Technical details: {str(e)}",
|
||||
)
|
||||
|
|
@ -28,7 +28,7 @@ def test_npy_save_1(tmp_path):
|
|||
|
||||
# Save to tmp_path
|
||||
filename = tmp_path / "test"
|
||||
to_npy(filename=filename.name, path=tmp_path, recording=recording1)
|
||||
to_npy(filename=filename.name, path=tmp_path, recording=recording1, overwrite=True)
|
||||
|
||||
# Reload
|
||||
recording2 = from_npy(filename)
|
||||
|
|
@ -44,7 +44,7 @@ def test_npy_save_2(tmp_path):
|
|||
|
||||
# Save to tmp_path
|
||||
filename = tmp_path / "test"
|
||||
to_npy(filename=filename.name, path=tmp_path, recording=recording1)
|
||||
to_npy(filename=filename.name, path=tmp_path, recording=recording1, overwrite=True)
|
||||
|
||||
# Reload
|
||||
recording2 = from_npy(filename)
|
||||
|
|
@ -63,7 +63,7 @@ def test_npy_save_3(tmp_path):
|
|||
|
||||
# Save to tmp_path
|
||||
filename = tmp_path / "test"
|
||||
to_npy(filename=filename.name, path=tmp_path, recording=recording1)
|
||||
to_npy(filename=filename.name, path=tmp_path, recording=recording1, overwrite=True)
|
||||
|
||||
# Reload
|
||||
recording2 = from_npy(filename)
|
||||
|
|
@ -73,6 +73,15 @@ def test_npy_save_3(tmp_path):
|
|||
assert recording1.metadata == recording2.metadata
|
||||
|
||||
|
||||
def test_npy_save_4(tmp_path):
|
||||
recording1 = Recording(data=nd_complex_data_1)
|
||||
try:
|
||||
filename = tmp_path / "test"
|
||||
to_npy(filename=filename.name, path=tmp_path, recording=recording1)
|
||||
except IOError as e:
|
||||
assert str(e) == "File already exists"
|
||||
|
||||
|
||||
def test_npy_annotations(tmp_path):
|
||||
# Create annotations
|
||||
annotation1 = Annotation(sample_start=0, sample_count=100, freq_lower_edge=0, freq_upper_edge=100)
|
||||
|
|
@ -84,7 +93,7 @@ def test_npy_annotations(tmp_path):
|
|||
|
||||
# Save to tmp_path
|
||||
filename = tmp_path / "test"
|
||||
to_npy(filename=filename.name, path=tmp_path, recording=recording1)
|
||||
to_npy(filename=filename.name, path=tmp_path, recording=recording1, overwrite=True)
|
||||
|
||||
# Reload
|
||||
recording2 = from_npy(filename)
|
||||
|
|
@ -104,7 +113,7 @@ def test_load_recording_npy(tmp_path):
|
|||
|
||||
# Save to tmp_path
|
||||
filename = tmp_path / "test.npy"
|
||||
recording1.to_npy(path=tmp_path, filename=filename.name)
|
||||
recording1.to_npy(path=tmp_path, filename=filename.name, overwrite=True)
|
||||
|
||||
# Load from tmp_path
|
||||
recording2 = load_rec(filename)
|
||||
|
|
@ -130,7 +139,7 @@ def test_sigmf_1(tmp_path):
|
|||
|
||||
# Save to tmp_path in SigMF format
|
||||
filename = tmp_path / "test"
|
||||
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name)
|
||||
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name, overwrite=True)
|
||||
|
||||
# Reload
|
||||
recording2 = from_sigmf(filename)
|
||||
|
|
@ -158,7 +167,7 @@ def test_sigmf_2(tmp_path):
|
|||
|
||||
# Save to tmp_path using the base name
|
||||
filename = tmp_path / "test"
|
||||
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name)
|
||||
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name, overwrite=True)
|
||||
|
||||
# Load from tmp_path; from_sigmf expects the base name
|
||||
recording2 = from_sigmf(filename)
|
||||
|
|
@ -171,3 +180,12 @@ def test_sigmf_2(tmp_path):
|
|||
)
|
||||
|
||||
assert np.array_equal(recording1.data, recording2.data)
|
||||
|
||||
|
||||
def test_sigmf_3(tmp_path):
|
||||
recording1 = Recording(data=complex_data_1, metadata=sample_metadata)
|
||||
try:
|
||||
filename = tmp_path / "test"
|
||||
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name)
|
||||
except IOError as e:
|
||||
assert str(e) == "File already exists"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user