From c673967a905a3c1d30a99259ceb4856b4cda6e25 Mon Sep 17 00:00:00 2001 From: madrigal Date: Mon, 17 Nov 2025 11:20:38 -0500 Subject: [PATCH] Updated methods, added setters, and created standardized SDRError classes --- src/ria_toolkit_oss/sdr/sdr.py | 197 ++++++++++++++++++++++++--------- 1 file changed, 142 insertions(+), 55 deletions(-) diff --git a/src/ria_toolkit_oss/sdr/sdr.py b/src/ria_toolkit_oss/sdr/sdr.py index c2464bf..489fd48 100644 --- a/src/ria_toolkit_oss/sdr/sdr.py +++ b/src/ria_toolkit_oss/sdr/sdr.py @@ -1,5 +1,6 @@ import math import pickle +import threading import warnings from abc import ABC, abstractmethod from typing import Optional @@ -27,17 +28,21 @@ class SDR(ABC): self._tx_initialized = False self._enable_rx = False self._enable_tx = False + self._accumulated_buffer = None self._max_num_buffers = None self._num_buffers_processed = 0 self._accumulated_buffer = None self._last_buffer = None + self._corrupted_buffer_count = 0 + self.rx_sample_rate = None self.rx_center_frequency = None self.rx_gain = None self.tx_sample_rate = None self.tx_center_frequency = None self.tx_gain = None + self._param_lock = threading.RLock() # Reentrant lock def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording: """ @@ -71,7 +76,6 @@ class SDR(ABC): self._max_num_buffers = num_buffers self._num_buffers_processed = 0 - self._num_buffers_processed = 0 self._last_buffer = None self._accumulated_buffer = None print("Starting stream") @@ -94,6 +98,7 @@ class SDR(ABC): # reset to record again self._accumulated_buffer = None + self._num_buffers_processed = 0 return recording def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000): @@ -110,21 +115,23 @@ class SDR(ABC): :return: The trimmed Recording. :rtype: Recording """ + try: + self._previous_buffer = None + self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size) + self._num_buffers_processed = 0 + self.zmq_address = _generate_full_zmq_address(str(zmq_address)) + self.context = zmq.Context() + self.socket = self.context.socket(zmq.PUB) + self.socket.bind(self.zmq_address) - self._previous_buffer = None - self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size) - self._num_buffers_processed = 0 - self.zmq_address = _generate_full_zmq_address(str(zmq_address)) - self.context = zmq.Context() - self.socket = self.context.socket(zmq.PUB) - self.socket.bind(self.zmq_address) - - self._stream_rx( - self._zmq_bytestream_callback, - ) - - self.context.destroy() - self.socket.close() + self._stream_rx( + self._zmq_bytestream_callback, + ) + finally: + if hasattr(self, "socket"): + self.socket.close() + if hasattr(self, "context"): + self.context.destroy() def _accumulate_buffers_callback(self, buffer, metadata=None): """ @@ -134,62 +141,72 @@ class SDR(ABC): # save the buffer until max reached # return a recording - buffer = np.array(buffer) # make it 1d - if len(buffer.shape) == 1: - buffer = np.array([buffer]) + # Validate buffer + if not self._validate_buffer(buffer): + print("Warning: Corrupted buffer detected, skipping") + self._corrupted_buffer_count += 1 + return # Skip this buffer - # it runs these checks each time, is that an efficiency issue? - - if self._max_num_buffers is None: - # default then - # this should probably print, but that would happen every buffer... - raise ValueError("Number of buffers for block capture not set.") - - # add the given buffer to the pre-allocated buffer - - if metadata is not None: - self.received_metadata = metadata - - # TODO optimize, pre-allocate - if self._accumulated_buffer is not None: - self._accumulated_buffer = np.concatenate((self._accumulated_buffer, buffer), axis=1) + if isinstance(buffer, np.ndarray): + if buffer.ndim == 1: + buffer = buffer[np.newaxis, :] # make shape (1, N) else: - # the first time - self._accumulated_buffer = buffer.copy() + buffer = np.array(buffer) # make it 1d + if len(buffer.shape) == 1: + buffer = np.array([buffer]) - self._num_buffers_processed = self._num_buffers_processed + 1 + # First call: pre-allocate if we know the final size + if self._accumulated_buffer is None: + # Check that _max_num_buffers is set + if self._max_num_buffers is None: + raise ValueError("Number of buffers for block capture not set.") + if self._num_samples_to_record is None: + raise ValueError("Number of samples not set before RX start.") + + if metadata is not None: + self.received_metadata = metadata + + # Preallocate once (avoid np.zeros; use np.empty for speed) + num_channels = buffer.shape[0] + self._accumulated_buffer = np.empty((num_channels, self._num_samples_to_record), dtype=buffer.dtype) + self._write_position = 0 + print(f"Pre-allocated buffer for {self._num_samples_to_record:,} samples.") + + # Write new buffer into pre-allocated array + n = buffer.shape[1] + start = self._write_position + end = min(start + n, self._num_samples_to_record) + samples_to_write = end - start + + if samples_to_write > 0: + self._accumulated_buffer[:, start:end] = buffer[:, : end - start] + self._write_position = end + + # Check if we're done + self._num_buffers_processed += 1 if self._num_buffers_processed >= self._max_num_buffers: self.stop() - if self._last_buffer is not None: - if (buffer == self._last_buffer).all(): - print("\033[93mWarning: Buffer Overflow Detected\033[0m") - self._last_buffer = buffer.copy() - else: - self._last_buffer = buffer.copy() - - # print("Number of buffers received: " + str(self._num_buffers_processed)) + def _validate_buffer(self, buffer): + """Check for obviously corrupt data.""" + # Check for all zeros + if np.all(buffer == 0): + return False + # Check for all same value + if np.all(buffer == buffer[0]): + return False + return True def _zmq_bytestream_callback(self, buffer, metadata=None): # push to ZMQ port data = np.array(buffer).tobytes() # convert to bytes for transport self.socket.send(data) - # print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}") - self._num_buffers_processed = self._num_buffers_processed + 1 if self._max_num_buffers is not None: if self._num_buffers_processed >= self._max_num_buffers: self.pause_rx() - if self._previous_buffer is not None: - if (buffer == self._previous_buffer).all(): - print("\033[93mWarning: Buffer Overflow Detected\033[0m") - # TODO: I suggest we think about moving this part to the top of this function - # and skip the rest of the function in case of overflow. - # like, it's not necessary to stream repeated IQ data anyways! - self._previous_buffer = buffer.copy() - def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers): """ Stream samples to a zmq address, packaged in binary buffers using numpy.pickle. @@ -229,7 +246,7 @@ class SDR(ABC): self.stop() if self._last_buffer is not None: - if (buffer == self._last_buffer).all(): + if np.array_equal(buffer, self._last_buffer): print("\033[93mWarning: Buffer Overflow Detected\033[0m") self._last_buffer = buffer.copy() else: @@ -373,6 +390,58 @@ class SDR(ABC): """ return self.tx_gain + def set_rx_sample_rate(self): + """ + Set the sample rate of the receiver. + """ + raise NotImplementedError + + def set_rx_center_frequency(self): + """ + Set the center frequency of the receiver. + """ + raise NotImplementedError + + def set_rx_gain(self): + """ + Set the gain setting of the receiver. + """ + raise NotImplementedError + + def set_tx_sample_rate(self): + """ + Set the sample rate of the transmitter. + """ + raise NotImplementedError + + def set_tx_center_frequency(self): + """ + Set the center frequency of the transmitter. + """ + raise NotImplementedError + + def set_tx_gain(self): + """ + Set the gain setting of the transmitter. + """ + raise NotImplementedError + + def supports_dynamic_updates(self) -> dict: + """ + Report which parameters can be updated during streaming. + + Returns: + dict: {'center_frequency': bool, 'sample_rate': bool, 'gain': bool} + """ + return {"center_frequency": False, "sample_rate": False, "gain": False} + + def __del__(self): + """Cleanup on garbage collection.""" + try: + self.close() + except Exception: + pass + @abstractmethod def close(self): pass @@ -442,3 +511,21 @@ def _verify_sample_format(samples): """ return np.max(np.abs(samples)) <= 1 + + +class SDRError(Exception): + """Base exception for SDR errors.""" + + pass + + +class SDRParameterError(SDRError): + """Invalid parameter (sample rate, freq, gain).""" + + pass + + +class SDROverflowError(SDRError): + """Buffer overflow detected.""" + + pass