ria-toolkit-oss/src/ria_toolkit_oss/sdr/sdr.py
M madrigal 8a66860d33
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 15m51s
Build Project / Build Project (3.10) (pull_request) Successful in 16m14s
Build Project / Build Project (3.11) (pull_request) Successful in 17m9s
Build Project / Build Project (3.12) (pull_request) Successful in 2m29s
Test with tox / Test with tox (3.12) (pull_request) Successful in 21m28s
Test with tox / Test with tox (3.10) (pull_request) Successful in 22m50s
Test with tox / Test with tox (3.11) (pull_request) Successful in 23m18s
Moved all contents of to , refactored accordingly
2026-04-21 14:38:06 -04:00

612 lines
21 KiB
Python
Raw RIA Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import pickle
import threading
import warnings
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
import zmq
from ria_toolkit_oss.data.recording import Recording
class SDR(ABC):
"""
This class defines a common interface (a template) for all SDR devices.
Each specific SDR implementation should subclass SDR and provide concrete implementations
for the abstract methods.
To add support for a new radio, subclass this interface and implement all abstract methods.
If you experience difficulties, please `contact us <mailto:info@qoherent.ai>`_, we are happy to
provide additional direction and/or help with the implementation details.
"""
def __init__(self):
self._rx_initialized = False
self._tx_initialized = False
self._enable_rx = False
self._enable_tx = False
self._accumulated_buffer = None
self._max_num_buffers = None
self._num_buffers_processed = 0
self._last_buffer = None
self._corrupted_buffer_count = 0
self.rx_sample_rate = None
self.rx_center_frequency = None
self.rx_gain = None
self.tx_sample_rate = None
self.tx_center_frequency = None
self.tx_gain = None
self._param_lock = threading.RLock() # Reentrant lock
# Pending config consumed by rx() on first call and by _apply_sdr_config
# in the agent inference loop. Subclasses that need different defaults
# (e.g. MockSDR) can overwrite these in their own __init__.
self.center_freq: float = 2.4e9
self.sample_rate: float = 10e6
self.gain: float = 40.0
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
"""
Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided.
Note that ``init_rx()`` must be called before ``record()``.
:param num_samples: The number of samples to record.
:type num_samples: int, optional
:param rx_time: The time to record.
:type rx_time: int or float, optional
:return: The Recording object
:rtype: Recording
"""
if not self._rx_initialized:
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
if num_samples is not None and rx_time is not None:
raise ValueError("Only input one of num_samples or rx_time")
elif num_samples is not None:
self._num_samples_to_record = num_samples
elif rx_time is not None:
self._num_samples_to_record = int(rx_time * self.rx_sample_rate)
else:
raise ValueError("Must provide input of one of num_samples or rx_time")
self.buffer_size = self.rx_buffer_size
num_buffers = self._num_samples_to_record // self.buffer_size + 1
self._max_num_buffers = num_buffers
self._num_buffers_processed = 0
self._last_buffer = None
self._accumulated_buffer = None
print("Starting stream")
self._stream_rx(
callback=self._accumulate_buffers_callback,
)
print("Finished stream")
metadata = {
"source": self.__class__.__name__,
"sample_rate": self.rx_sample_rate,
"center_frequency": self.rx_center_frequency,
"gain": self.rx_gain,
}
print("Creating recording")
# build recording, truncate to self._num_samples_to_record
recording = Recording(data=self._accumulated_buffer[:, : self._num_samples_to_record], metadata=metadata)
# reset to record again
self._accumulated_buffer = None
self._num_buffers_processed = 0
return recording
def rx(self, num_samples: int) -> "np.ndarray":
"""Return *num_samples* complex IQ samples as a 1-D complex64 array.
This is the interface used by the agent inference loop. On first call,
``init_rx()`` is invoked automatically using the values stored in
``center_freq``, ``sample_rate``, and ``gain`` (set beforehand by
``_apply_sdr_config``). Subsequent calls stream directly.
Subclasses may override this for hardware-native capture APIs (e.g.
``MockSDR`` uses AWGN generation; ``PlutoSDR`` could use
``self.radio.rx()``).
"""
if not self._rx_initialized:
gain = self.gain if isinstance(self.gain, (int, float)) else 40.0
self.init_rx(
sample_rate=self.sample_rate,
center_frequency=self.center_freq,
gain=gain,
channel=0,
)
recording = self.record(num_samples=num_samples)
# Recording.data is either a list of 1-D arrays (one per channel) or a
# 2-D ndarray (channels × samples). Either way, index 0 is channel 0.
data = recording.data
return data[0] if hasattr(data, "__getitem__") else data
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
"""
Stream iq samples as interleaved bytes via zmq.
:param zmq_address: The zmq address.
:type zmq_address:
:param n_samples: The number of samples to stream.
:type n_samples: int
:param buffer_size: The buffer size during streaming. Defaults to 10000.
:type buffer_size: int, optional
:return: The trimmed Recording.
:rtype: Recording
"""
try:
self._previous_buffer = None
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size)
self._num_buffers_processed = 0
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(self.zmq_address)
self._stream_rx(
self._zmq_bytestream_callback,
)
finally:
if hasattr(self, "socket"):
self.socket.close()
if hasattr(self, "context"):
self.context.destroy()
def _accumulate_buffers_callback(self, buffer, metadata=None):
"""
Receives a buffer and saves it to self.accumulated_buffer.
"""
# expected buffer is complex samples range -1 to 1
# save the buffer until max reached
# return a recording
# Validate buffer
if not self._validate_buffer(buffer):
print("Warning: Corrupted buffer detected, skipping")
self._corrupted_buffer_count += 1
return # Skip this buffer
if isinstance(buffer, np.ndarray):
if buffer.ndim == 1:
buffer = buffer[np.newaxis, :] # make shape (1, N)
else:
buffer = np.array(buffer) # make it 1d
if len(buffer.shape) == 1:
buffer = np.array([buffer])
# First call: pre-allocate if we know the final size
if self._accumulated_buffer is None:
# Check that _max_num_buffers is set
if self._max_num_buffers is None:
raise ValueError("Number of buffers for block capture not set.")
if self._num_samples_to_record is None:
raise ValueError("Number of samples not set before RX start.")
if metadata is not None:
self.received_metadata = metadata
# Preallocate once (avoid np.zeros; use np.empty for speed)
num_channels = buffer.shape[0]
self._accumulated_buffer = np.empty((num_channels, self._num_samples_to_record), dtype=buffer.dtype)
self._write_position = 0
print(f"Pre-allocated buffer for {self._num_samples_to_record:,} samples.")
# Write new buffer into pre-allocated array
n = buffer.shape[1]
start = self._write_position
end = min(start + n, self._num_samples_to_record)
samples_to_write = end - start
if samples_to_write > 0:
self._accumulated_buffer[:, start:end] = buffer[:, : end - start]
self._write_position = end
# Check if we're done
self._num_buffers_processed += 1
if self._num_buffers_processed >= self._max_num_buffers:
self.stop()
def _validate_buffer(self, buffer):
"""Check for obviously corrupt data."""
# Check for all zeros
if np.all(buffer == 0):
return False
# Check for all same value
if np.all(buffer == buffer[0]):
return False
return True
def _zmq_bytestream_callback(self, buffer, metadata=None):
# push to ZMQ port
data = np.array(buffer).tobytes() # convert to bytes for transport
self.socket.send(data)
self._num_buffers_processed = self._num_buffers_processed + 1
if self._max_num_buffers is not None:
if self._num_buffers_processed >= self._max_num_buffers:
self.pause_rx()
def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers):
"""
Stream samples to a zmq address, packaged in binary buffers using numpy.pickle.
Useful for inference applications with a known input size.
May reduce transfer rates, but individual buffers will not have discontinuities.
:param zmq_address: The tcp address to stream to.
:type zmq_address: str
:param buffer_size: The number of iq samples in a buffer.
:type buffer_size: int
:param num_buffers: The number of buffers to stream before stopping.
:type num_buffers: int
"""
self._max_num_buffers = num_buffers
self.buffer_size = buffer_size
self._num_buffers_processed = 0
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(self.zmq_address)
self.set_rx_buffer_size(buffer_size)
self._stream_rx(self._zmq_pickle_buffer_callback)
def _zmq_pickle_buffer_callback(self, buffer, metadata=None):
# push to ZMQ port
# data = np.array(buffer).tobytes() # convert to bytes for transport
# self.socket.send(data)
self.socket.send(pickle.dumps(buffer))
# print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}")
self._num_buffers_processed = self._num_buffers_processed + 1
if self._max_num_buffers is not None:
if self._num_buffers_processed >= self._max_num_buffers:
self.stop()
if self._last_buffer is not None:
if np.array_equal(buffer, self._last_buffer):
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
self._last_buffer = buffer.copy()
else:
self._last_buffer = buffer.copy()
def tx_recording(
self,
recording: Recording | np.ndarray,
num_samples: Optional[int] = None,
tx_time: Optional[int | float] = None,
):
"""
Transmit the given iq samples from the provided recording.
init_tx() must be called before this function.
:param recording: The recording to transmit.
:type recording: Recording or np.ndarray
:param num_samples: The number of samples to transmit, will repeat or
truncate the recording to this length. Defaults to None.
:type num_samples: int, optional
:param tx_time: The time to transmit, will repeat or truncate the
recording to this length. Defaults to None.
:type tx_time: int or float, optional
"""
if not self._tx_initialized:
raise RuntimeError(
"TX was not initialized. init_tx() must be called before _stream_tx() or transmit_recording()"
)
if num_samples is not None and tx_time is not None:
raise ValueError("Only input one of num_samples or tx_time")
elif num_samples is not None:
self._num_samples_to_transmit = num_samples
elif tx_time is not None:
self._num_samples_to_transmit = int(tx_time * self.tx_sample_rate)
else:
self._num_samples_to_transmit = len(recording)
if isinstance(recording, np.ndarray):
self._samples_to_transmit = recording
elif isinstance(recording, Recording):
if len(recording.data) > 1:
warnings.warn("Recording object is multichannel, only channel 0 data was used for transmission")
self._samples_to_transmit = recording.data[0]
self._num_samples_transmitted = 0
self._stream_tx(self._loop_recording_callback)
def _loop_recording_callback(self, num_samples):
samples_left = self._num_samples_to_transmit - self._num_samples_transmitted
# find where to start based on num_samples_transmitted
start_index = self._num_samples_transmitted % len(self._samples_to_transmit)
# generates an array of indices that wrap around as many times as necessary.
indices = np.arange(start_index, start_index + num_samples) % len(self._samples_to_transmit)
samples = self._samples_to_transmit[indices]
# zero pad at the end so we are still giving the requested buffer size
# while also giving the exact number of non zero samples
if len(samples) > samples_left:
samples[int(samples_left) :] = 0
self.pause_tx()
self._num_samples_transmitted = self._num_samples_transmitted + num_samples
return samples
def supports_bias_tee(self) -> bool:
"""Return True when the radio supports bias-tee control."""
return False
def set_bias_tee(self, enable: bool):
"""Enable or disable bias-tee power when supported by the radio."""
raise NotImplementedError(f"{self.__class__.__name__} does not support bias-tee control")
def pause_rx(self):
self._enable_rx = False
def pause_tx(self):
self._enable_tx = False
def stop(self):
self.pause_rx()
self.pause_tx()
def get_rx_sample_rate(self):
"""
Retrieve the current sample rate of the receiver.
Returns:
float: The receiver's sample rate in samples per second (Hz).
"""
return self.rx_sample_rate
def get_rx_center_frequency(self):
"""
Retrieve the current center frequency of the receiver.
Returns:
float: The receiver's center frequency in Hertz (Hz).
"""
return self.rx_center_frequency
def get_rx_gain(self):
"""
Retrieve the current gain setting of the receiver.
Returns:
float: The receiver's gain in decibels (dB).
"""
return self.rx_gain
def get_tx_sample_rate(self):
"""
Retrieve the current sample rate of the transmitter.
Returns:
float: The transmitter's sample rate in samples per second (Hz).
"""
return self.tx_sample_rate
def get_tx_center_frequency(self):
"""
Retrieve the current center frequency of the transmitter.
Returns:
float: The transmitter's center frequency in Hertz (Hz).
"""
return self.tx_center_frequency
def get_tx_gain(self):
"""
Retrieve the current gain setting of the transmitter.
Returns:
float: The transmitter's gain in decibels (dB).
"""
return self.tx_gain
def set_rx_sample_rate(self):
"""
Set the sample rate of the receiver.
"""
raise NotImplementedError
def set_rx_center_frequency(self):
"""
Set the center frequency of the receiver.
"""
raise NotImplementedError
def set_rx_gain(self):
"""
Set the gain setting of the receiver.
"""
raise NotImplementedError
def set_tx_sample_rate(self):
"""
Set the sample rate of the transmitter.
"""
raise NotImplementedError
def set_tx_center_frequency(self):
"""
Set the center frequency of the transmitter.
"""
raise NotImplementedError
def set_tx_gain(self):
"""
Set the gain setting of the transmitter.
"""
raise NotImplementedError
def supports_dynamic_updates(self) -> dict:
"""
Report which parameters can be updated during streaming.
Returns:
dict: {'center_frequency': bool, 'sample_rate': bool, 'gain': bool}
"""
return {"center_frequency": False, "sample_rate": False, "gain": False}
def __del__(self):
"""Cleanup on garbage collection."""
try:
self.close()
except Exception:
pass
@abstractmethod
def close(self):
pass
@abstractmethod
def init_rx(self, sample_rate, center_frequency, gain, channel, gain_mode):
pass
@abstractmethod
def init_tx(self, sample_rate, center_frequency, gain, channel, gain_mode):
pass
@abstractmethod
def _stream_rx(self, callback):
pass
@abstractmethod
def _stream_tx(self, callback):
pass
@abstractmethod
def set_clock_source(self, source):
"""
Sets the clock source to external or internal.
:param source: The clock source
:type source: str
"""
pass
def _generate_full_zmq_address(input_address):
"""
Helper function for zmq streaming.
If given a port number like 5556,
return tcp localhost address at that port.
Otherwise, return the address untouched.
"""
if ("://" not in str(input_address)) and _is_valid_port(input_address):
# If no transport protocol specified, assume TCP
return "tcp://*:" + str(input_address)
else:
# Otherwise, return the input unchanged
return input_address
def _is_valid_port(port):
"""
Helper function for zmq address.
"""
try:
port_num = int(port)
return 0 <= port_num <= 65535
except ValueError:
return False
def _verify_sample_format(samples):
"""
Verify that the sample data is in the range -1 to 1.
:param buffer: An array of samples.
:Return: True if the buffer is in the correct format, false if not.
:rtype: bool
"""
return np.max(np.abs(samples)) <= 1
class SDRError(Exception):
"""Base exception for SDR errors."""
pass
class SDRParameterError(SDRError):
"""Invalid parameter (sample rate, freq, gain)."""
pass
class SDROverflowError(SDRError):
"""Buffer overflow detected."""
pass
class SdrDisconnectedError(SDRError):
"""Raised when the SDR device disappears mid-operation (USB unplug, network drop)."""
pass
# Substrings that strongly indicate a device has disappeared rather than a
# transient / recoverable error. Checked case-insensitively against str(exc).
_DISCONNECT_MARKERS = (
"no such device",
"device not found",
"not found",
"broken pipe",
"disconnected",
"no device",
"device unplugged",
"usb",
"i/o error",
"input/output error",
"errno 19", # ENODEV
"errno 5", # EIO
)
def translate_disconnect(exc: BaseException) -> BaseException:
"""Return ``SdrDisconnectedError`` if *exc* looks like a USB/device drop, else *exc*.
Drivers wrap their native-API calls with::
try:
return self.radio.rx()
except Exception as exc:
raise translate_disconnect(exc) from exc
The caller (e.g. the streamer) can then catch ``SdrDisconnectedError``
specifically and report it to the hub rather than crashing the loop.
"""
if isinstance(exc, SdrDisconnectedError):
return exc
msg = str(exc).lower()
if any(marker in msg for marker in _DISCONNECT_MARKERS):
return SdrDisconnectedError(str(exc))
# OSError subclass with ENODEV / EIO errno is also a disconnect signal.
if isinstance(exc, OSError) and getattr(exc, "errno", None) in (5, 19):
return SdrDisconnectedError(str(exc))
return exc