ria-toolkit-oss/src/ria_toolkit_oss/sdr/sdr.py

612 lines
21 KiB
Python
Raw Normal View History

2025-09-12 11:32:49 -04:00
import math
import pickle
import threading
2025-09-12 11:32:49 -04:00
import warnings
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
import zmq
from ria_toolkit_oss.data.recording import Recording
2025-09-12 11:32:49 -04:00
class SDR(ABC):
"""
2025-09-12 15:49:47 -04:00
This class defines a common interface (a template) for all SDR devices.
2025-09-12 14:51:45 -04:00
Each specific SDR implementation should subclass SDR and provide concrete implementations
for the abstract methods.
To add support for a new radio, subclass this interface and implement all abstract methods.
If you experience difficulties, please `contact us <mailto:info@qoherent.ai>`_, we are happy to
provide additional direction and/or help with the implementation details.
2025-09-12 11:32:49 -04:00
"""
def __init__(self):
2025-09-12 15:49:47 -04:00
2025-09-12 11:32:49 -04:00
self._rx_initialized = False
self._tx_initialized = False
self._enable_rx = False
self._enable_tx = False
2025-09-12 11:32:49 -04:00
self._accumulated_buffer = None
self._max_num_buffers = None
self._num_buffers_processed = 0
self._last_buffer = None
self._corrupted_buffer_count = 0
M
2025-10-23 16:44:43 -04:00
self.rx_sample_rate = None
self.rx_center_frequency = None
self.rx_gain = None
self.tx_sample_rate = None
self.tx_center_frequency = None
self.tx_gain = None
self._param_lock = threading.RLock() # Reentrant lock
2025-09-12 11:32:49 -04:00
2026-04-14 10:45:54 -04:00
# Pending config consumed by rx() on first call and by _apply_sdr_config
# in the agent inference loop. Subclasses that need different defaults
# (e.g. MockSDR) can overwrite these in their own __init__.
self.center_freq: float = 2.4e9
self.sample_rate: float = 10e6
self.gain: float = 40.0
2025-09-12 11:32:49 -04:00
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
"""
Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided.
2025-09-12 15:49:47 -04:00
2025-09-12 11:32:49 -04:00
Note that ``init_rx()`` must be called before ``record()``.
:param num_samples: The number of samples to record.
:type num_samples: int, optional
:param rx_time: The time to record.
:type rx_time: int or float, optional
:return: The Recording object
:rtype: Recording
"""
if not self._rx_initialized:
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
if num_samples is not None and rx_time is not None:
raise ValueError("Only input one of num_samples or rx_time")
elif num_samples is not None:
self._num_samples_to_record = num_samples
elif rx_time is not None:
self._num_samples_to_record = int(rx_time * self.rx_sample_rate)
else:
raise ValueError("Must provide input of one of num_samples or rx_time")
self.buffer_size = self.rx_buffer_size
num_buffers = self._num_samples_to_record // self.buffer_size + 1
self._max_num_buffers = num_buffers
self._num_buffers_processed = 0
self._last_buffer = None
self._accumulated_buffer = None
print("Starting stream")
self._stream_rx(
callback=self._accumulate_buffers_callback,
)
print("Finished stream")
metadata = {
"source": self.__class__.__name__,
"sample_rate": self.rx_sample_rate,
"center_frequency": self.rx_center_frequency,
"gain": self.rx_gain,
}
print("Creating recording")
# build recording, truncate to self._num_samples_to_record
recording = Recording(data=self._accumulated_buffer[:, : self._num_samples_to_record], metadata=metadata)
# reset to record again
self._accumulated_buffer = None
self._num_buffers_processed = 0
2025-09-12 11:32:49 -04:00
return recording
2026-04-14 10:45:54 -04:00
def rx(self, num_samples: int) -> "np.ndarray":
"""Return *num_samples* complex IQ samples as a 1-D complex64 array.
This is the interface used by the agent inference loop. On first call,
``init_rx()`` is invoked automatically using the values stored in
``center_freq``, ``sample_rate``, and ``gain`` (set beforehand by
``_apply_sdr_config``). Subsequent calls stream directly.
Subclasses may override this for hardware-native capture APIs (e.g.
``MockSDR`` uses AWGN generation; ``PlutoSDR`` could use
``self.radio.rx()``).
"""
if not self._rx_initialized:
gain = self.gain if isinstance(self.gain, (int, float)) else 40.0
self.init_rx(
sample_rate=self.sample_rate,
center_frequency=self.center_freq,
gain=gain,
channel=0,
)
recording = self.record(num_samples=num_samples)
# Recording.data is either a list of 1-D arrays (one per channel) or a
# 2-D ndarray (channels × samples). Either way, index 0 is channel 0.
data = recording.data
return data[0] if hasattr(data, "__getitem__") else data
2025-09-12 11:32:49 -04:00
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
"""
Stream iq samples as interleaved bytes via zmq.
:param zmq_address: The zmq address.
:type zmq_address:
:param n_samples: The number of samples to stream.
:type n_samples: int
:param buffer_size: The buffer size during streaming. Defaults to 10000.
:type buffer_size: int, optional
:return: The trimmed Recording.
:rtype: Recording
"""
try:
self._previous_buffer = None
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size)
self._num_buffers_processed = 0
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(self.zmq_address)
self._stream_rx(
self._zmq_bytestream_callback,
)
finally:
if hasattr(self, "socket"):
self.socket.close()
if hasattr(self, "context"):
self.context.destroy()
2025-09-12 11:32:49 -04:00
def _accumulate_buffers_callback(self, buffer, metadata=None):
"""
Receives a buffer and saves it to self.accumulated_buffer.
"""
# expected buffer is complex samples range -1 to 1
# save the buffer until max reached
# return a recording
# Validate buffer
if not self._validate_buffer(buffer):
print("Warning: Corrupted buffer detected, skipping")
self._corrupted_buffer_count += 1
return # Skip this buffer
2025-09-12 11:32:49 -04:00
if isinstance(buffer, np.ndarray):
if buffer.ndim == 1:
buffer = buffer[np.newaxis, :] # make shape (1, N)
2025-09-12 11:32:49 -04:00
else:
buffer = np.array(buffer) # make it 1d
if len(buffer.shape) == 1:
buffer = np.array([buffer])
# First call: pre-allocate if we know the final size
if self._accumulated_buffer is None:
# Check that _max_num_buffers is set
if self._max_num_buffers is None:
raise ValueError("Number of buffers for block capture not set.")
if self._num_samples_to_record is None:
raise ValueError("Number of samples not set before RX start.")
if metadata is not None:
self.received_metadata = metadata
# Preallocate once (avoid np.zeros; use np.empty for speed)
num_channels = buffer.shape[0]
self._accumulated_buffer = np.empty((num_channels, self._num_samples_to_record), dtype=buffer.dtype)
self._write_position = 0
print(f"Pre-allocated buffer for {self._num_samples_to_record:,} samples.")
# Write new buffer into pre-allocated array
n = buffer.shape[1]
start = self._write_position
end = min(start + n, self._num_samples_to_record)
samples_to_write = end - start
if samples_to_write > 0:
self._accumulated_buffer[:, start:end] = buffer[:, : end - start]
self._write_position = end
# Check if we're done
self._num_buffers_processed += 1
2025-09-12 11:32:49 -04:00
if self._num_buffers_processed >= self._max_num_buffers:
self.stop()
def _validate_buffer(self, buffer):
"""Check for obviously corrupt data."""
# Check for all zeros
if np.all(buffer == 0):
return False
# Check for all same value
if np.all(buffer == buffer[0]):
return False
return True
2025-09-12 11:32:49 -04:00
def _zmq_bytestream_callback(self, buffer, metadata=None):
# push to ZMQ port
data = np.array(buffer).tobytes() # convert to bytes for transport
self.socket.send(data)
self._num_buffers_processed = self._num_buffers_processed + 1
if self._max_num_buffers is not None:
if self._num_buffers_processed >= self._max_num_buffers:
self.pause_rx()
def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers):
"""
Stream samples to a zmq address, packaged in binary buffers using numpy.pickle.
Useful for inference applications with a known input size.
May reduce transfer rates, but individual buffers will not have discontinuities.
:param zmq_address: The tcp address to stream to.
:type zmq_address: str
:param buffer_size: The number of iq samples in a buffer.
:type buffer_size: int
:param num_buffers: The number of buffers to stream before stopping.
:type num_buffers: int
"""
self._max_num_buffers = num_buffers
self.buffer_size = buffer_size
self._num_buffers_processed = 0
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(self.zmq_address)
self.set_rx_buffer_size(buffer_size)
self._stream_rx(self._zmq_pickle_buffer_callback)
def _zmq_pickle_buffer_callback(self, buffer, metadata=None):
# push to ZMQ port
# data = np.array(buffer).tobytes() # convert to bytes for transport
# self.socket.send(data)
self.socket.send(pickle.dumps(buffer))
# print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}")
self._num_buffers_processed = self._num_buffers_processed + 1
if self._max_num_buffers is not None:
if self._num_buffers_processed >= self._max_num_buffers:
self.stop()
if self._last_buffer is not None:
if np.array_equal(buffer, self._last_buffer):
2025-09-12 11:32:49 -04:00
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
self._last_buffer = buffer.copy()
else:
self._last_buffer = buffer.copy()
def tx_recording(
self,
recording: Recording | np.ndarray,
num_samples: Optional[int] = None,
tx_time: Optional[int | float] = None,
):
"""
Transmit the given iq samples from the provided recording.
init_tx() must be called before this function.
:param recording: The recording to transmit.
:type recording: Recording or np.ndarray
:param num_samples: The number of samples to transmit, will repeat or
truncate the recording to this length. Defaults to None.
:type num_samples: int, optional
:param tx_time: The time to transmit, will repeat or truncate the
recording to this length. Defaults to None.
:type tx_time: int or float, optional
"""
if not self._tx_initialized:
raise RuntimeError(
"TX was not initialized. init_tx() must be called before _stream_tx() or transmit_recording()"
)
if num_samples is not None and tx_time is not None:
raise ValueError("Only input one of num_samples or tx_time")
elif num_samples is not None:
self._num_samples_to_transmit = num_samples
elif tx_time is not None:
2026-03-31 13:51:10 -04:00
self._num_samples_to_transmit = int(tx_time * self.tx_sample_rate)
2025-09-12 11:32:49 -04:00
else:
self._num_samples_to_transmit = len(recording)
if isinstance(recording, np.ndarray):
self._samples_to_transmit = recording
elif isinstance(recording, Recording):
if len(recording.data) > 1:
warnings.warn("Recording object is multichannel, only channel 0 data was used for transmission")
self._samples_to_transmit = recording.data[0]
self._num_samples_transmitted = 0
self._stream_tx(self._loop_recording_callback)
def _loop_recording_callback(self, num_samples):
samples_left = self._num_samples_to_transmit - self._num_samples_transmitted
# find where to start based on num_samples_transmitted
start_index = self._num_samples_transmitted % len(self._samples_to_transmit)
# generates an array of indices that wrap around as many times as necessary.
indices = np.arange(start_index, start_index + num_samples) % len(self._samples_to_transmit)
samples = self._samples_to_transmit[indices]
# zero pad at the end so we are still giving the requested buffer size
# while also giving the exact number of non zero samples
if len(samples) > samples_left:
samples[int(samples_left) :] = 0
self.pause_tx()
self._num_samples_transmitted = self._num_samples_transmitted + num_samples
return samples
M
2025-10-16 15:22:07 -04:00
def supports_bias_tee(self) -> bool:
"""Return True when the radio supports bias-tee control."""
return False
def set_bias_tee(self, enable: bool):
"""Enable or disable bias-tee power when supported by the radio."""
raise NotImplementedError(f"{self.__class__.__name__} does not support bias-tee control")
2025-09-12 11:32:49 -04:00
def pause_rx(self):
self._enable_rx = False
def pause_tx(self):
self._enable_tx = False
M
2025-10-16 15:22:07 -04:00
def stop(self):
self.pause_rx()
self.pause_tx()
M
2025-10-23 16:44:43 -04:00
def get_rx_sample_rate(self):
"""
Retrieve the current sample rate of the receiver.
Returns:
float: The receiver's sample rate in samples per second (Hz).
"""
return self.rx_sample_rate
def get_rx_center_frequency(self):
"""
Retrieve the current center frequency of the receiver.
Returns:
float: The receiver's center frequency in Hertz (Hz).
"""
return self.rx_center_frequency
def get_rx_gain(self):
"""
Retrieve the current gain setting of the receiver.
Returns:
float: The receiver's gain in decibels (dB).
"""
return self.rx_gain
def get_tx_sample_rate(self):
"""
Retrieve the current sample rate of the transmitter.
Returns:
float: The transmitter's sample rate in samples per second (Hz).
"""
return self.tx_sample_rate
def get_tx_center_frequency(self):
"""
Retrieve the current center frequency of the transmitter.
Returns:
float: The transmitter's center frequency in Hertz (Hz).
"""
return self.tx_center_frequency
def get_tx_gain(self):
"""
Retrieve the current gain setting of the transmitter.
Returns:
float: The transmitter's gain in decibels (dB).
"""
return self.tx_gain
def set_rx_sample_rate(self):
"""
Set the sample rate of the receiver.
"""
raise NotImplementedError
def set_rx_center_frequency(self):
"""
Set the center frequency of the receiver.
"""
raise NotImplementedError
def set_rx_gain(self):
"""
Set the gain setting of the receiver.
"""
raise NotImplementedError
def set_tx_sample_rate(self):
"""
Set the sample rate of the transmitter.
"""
raise NotImplementedError
def set_tx_center_frequency(self):
"""
Set the center frequency of the transmitter.
"""
raise NotImplementedError
def set_tx_gain(self):
"""
Set the gain setting of the transmitter.
"""
raise NotImplementedError
def supports_dynamic_updates(self) -> dict:
"""
Report which parameters can be updated during streaming.
Returns:
dict: {'center_frequency': bool, 'sample_rate': bool, 'gain': bool}
"""
return {"center_frequency": False, "sample_rate": False, "gain": False}
def __del__(self):
"""Cleanup on garbage collection."""
try:
self.close()
except Exception:
pass
M
2025-10-16 15:22:07 -04:00
@abstractmethod
def close(self):
pass
2025-09-12 11:32:49 -04:00
@abstractmethod
def init_rx(self, sample_rate, center_frequency, gain, channel, gain_mode):
pass
@abstractmethod
def init_tx(self, sample_rate, center_frequency, gain, channel, gain_mode):
pass
@abstractmethod
def _stream_rx(self, callback):
pass
@abstractmethod
def _stream_tx(self, callback):
pass
@abstractmethod
def set_clock_source(self, source):
"""
Sets the clock source to external or internal.
:param source: The clock source
:type source: str
"""
pass
def _generate_full_zmq_address(input_address):
"""
Helper function for zmq streaming.
If given a port number like 5556,
return tcp localhost address at that port.
Otherwise, return the address untouched.
"""
if ("://" not in str(input_address)) and _is_valid_port(input_address):
# If no transport protocol specified, assume TCP
return "tcp://*:" + str(input_address)
else:
# Otherwise, return the input unchanged
return input_address
def _is_valid_port(port):
"""
Helper function for zmq address.
"""
try:
port_num = int(port)
return 0 <= port_num <= 65535
except ValueError:
return False
def _verify_sample_format(samples):
"""
Verify that the sample data is in the range -1 to 1.
:param buffer: An array of samples.
:Return: True if the buffer is in the correct format, false if not.
:rtype: bool
"""
return np.max(np.abs(samples)) <= 1
class SDRError(Exception):
"""Base exception for SDR errors."""
pass
class SDRParameterError(SDRError):
"""Invalid parameter (sample rate, freq, gain)."""
pass
class SDROverflowError(SDRError):
"""Buffer overflow detected."""
pass
class SdrDisconnectedError(SDRError):
"""Raised when the SDR device disappears mid-operation (USB unplug, network drop)."""
pass
# Substrings that strongly indicate a device has disappeared rather than a
# transient / recoverable error. Checked case-insensitively against str(exc).
_DISCONNECT_MARKERS = (
"no such device",
"device not found",
"not found",
"broken pipe",
"disconnected",
"no device",
"device unplugged",
"usb",
"i/o error",
"input/output error",
"errno 19", # ENODEV
2026-04-20 13:51:15 -04:00
"errno 5", # EIO
)
def translate_disconnect(exc: BaseException) -> BaseException:
"""Return ``SdrDisconnectedError`` if *exc* looks like a USB/device drop, else *exc*.
Drivers wrap their native-API calls with::
try:
return self.radio.rx()
except Exception as exc:
raise translate_disconnect(exc) from exc
The caller (e.g. the streamer) can then catch ``SdrDisconnectedError``
specifically and report it to the hub rather than crashing the loop.
"""
if isinstance(exc, SdrDisconnectedError):
return exc
msg = str(exc).lower()
if any(marker in msg for marker in _DISCONNECT_MARKERS):
return SdrDisconnectedError(str(exc))
# OSError subclass with ENODEV / EIO errno is also a disconnect signal.
if isinstance(exc, OSError) and getattr(exc, "errno", None) in (5, 19):
return SdrDisconnectedError(str(exc))
return exc