updates_and_fixes #12

Merged
madrigal merged 9 commits from updates_and_fixes into main 2025-11-18 15:01:25 -05:00
Showing only changes of commit c673967a90 - Show all commits

View File

@ -1,5 +1,6 @@
import math
import pickle
import threading
import warnings
from abc import ABC, abstractmethod
from typing import Optional
@ -27,17 +28,21 @@ class SDR(ABC):
self._tx_initialized = False
self._enable_rx = False
self._enable_tx = False
self._accumulated_buffer = None
self._max_num_buffers = None
self._num_buffers_processed = 0
self._accumulated_buffer = None
self._last_buffer = None
self._corrupted_buffer_count = 0
self.rx_sample_rate = None
self.rx_center_frequency = None
self.rx_gain = None
self.tx_sample_rate = None
self.tx_center_frequency = None
self.tx_gain = None
self._param_lock = threading.RLock() # Reentrant lock
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
"""
@ -71,7 +76,6 @@ class SDR(ABC):
self._max_num_buffers = num_buffers
self._num_buffers_processed = 0
self._num_buffers_processed = 0
self._last_buffer = None
self._accumulated_buffer = None
print("Starting stream")
@ -94,6 +98,7 @@ class SDR(ABC):
# reset to record again
self._accumulated_buffer = None
self._num_buffers_processed = 0
return recording
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
@ -110,21 +115,23 @@ class SDR(ABC):
:return: The trimmed Recording.
:rtype: Recording
"""
try:
self._previous_buffer = None
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size)
self._num_buffers_processed = 0
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(self.zmq_address)
self._previous_buffer = None
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size)
self._num_buffers_processed = 0
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(self.zmq_address)
self._stream_rx(
self._zmq_bytestream_callback,
)
self.context.destroy()
self.socket.close()
self._stream_rx(
self._zmq_bytestream_callback,
)
finally:
if hasattr(self, "socket"):
self.socket.close()
if hasattr(self, "context"):
self.context.destroy()
def _accumulate_buffers_callback(self, buffer, metadata=None):
"""
@ -134,62 +141,72 @@ class SDR(ABC):
# save the buffer until max reached
# return a recording
buffer = np.array(buffer) # make it 1d
if len(buffer.shape) == 1:
buffer = np.array([buffer])
# Validate buffer
if not self._validate_buffer(buffer):
print("Warning: Corrupted buffer detected, skipping")
self._corrupted_buffer_count += 1
return # Skip this buffer
# it runs these checks each time, is that an efficiency issue?
if self._max_num_buffers is None:
# default then
# this should probably print, but that would happen every buffer...
raise ValueError("Number of buffers for block capture not set.")
# add the given buffer to the pre-allocated buffer
if metadata is not None:
self.received_metadata = metadata
# TODO optimize, pre-allocate
if self._accumulated_buffer is not None:
self._accumulated_buffer = np.concatenate((self._accumulated_buffer, buffer), axis=1)
if isinstance(buffer, np.ndarray):
if buffer.ndim == 1:
buffer = buffer[np.newaxis, :] # make shape (1, N)
else:
# the first time
self._accumulated_buffer = buffer.copy()
buffer = np.array(buffer) # make it 1d
if len(buffer.shape) == 1:
buffer = np.array([buffer])
self._num_buffers_processed = self._num_buffers_processed + 1
# First call: pre-allocate if we know the final size
if self._accumulated_buffer is None:
# Check that _max_num_buffers is set
if self._max_num_buffers is None:
raise ValueError("Number of buffers for block capture not set.")
if self._num_samples_to_record is None:
raise ValueError("Number of samples not set before RX start.")
if metadata is not None:
self.received_metadata = metadata
# Preallocate once (avoid np.zeros; use np.empty for speed)
num_channels = buffer.shape[0]
self._accumulated_buffer = np.empty((num_channels, self._num_samples_to_record), dtype=buffer.dtype)
self._write_position = 0
print(f"Pre-allocated buffer for {self._num_samples_to_record:,} samples.")
# Write new buffer into pre-allocated array
n = buffer.shape[1]
start = self._write_position
end = min(start + n, self._num_samples_to_record)
samples_to_write = end - start
if samples_to_write > 0:
self._accumulated_buffer[:, start:end] = buffer[:, : end - start]
self._write_position = end
# Check if we're done
self._num_buffers_processed += 1
if self._num_buffers_processed >= self._max_num_buffers:
self.stop()
if self._last_buffer is not None:
if (buffer == self._last_buffer).all():
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
self._last_buffer = buffer.copy()
else:
self._last_buffer = buffer.copy()
# print("Number of buffers received: " + str(self._num_buffers_processed))
def _validate_buffer(self, buffer):
"""Check for obviously corrupt data."""
# Check for all zeros
if np.all(buffer == 0):
return False
# Check for all same value
if np.all(buffer == buffer[0]):
return False
return True
def _zmq_bytestream_callback(self, buffer, metadata=None):
# push to ZMQ port
data = np.array(buffer).tobytes() # convert to bytes for transport
self.socket.send(data)
# print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}")
self._num_buffers_processed = self._num_buffers_processed + 1
if self._max_num_buffers is not None:
if self._num_buffers_processed >= self._max_num_buffers:
self.pause_rx()
if self._previous_buffer is not None:
if (buffer == self._previous_buffer).all():
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
# TODO: I suggest we think about moving this part to the top of this function
# and skip the rest of the function in case of overflow.
# like, it's not necessary to stream repeated IQ data anyways!
self._previous_buffer = buffer.copy()
def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers):
"""
Stream samples to a zmq address, packaged in binary buffers using numpy.pickle.
@ -229,7 +246,7 @@ class SDR(ABC):
self.stop()
if self._last_buffer is not None:
if (buffer == self._last_buffer).all():
if np.array_equal(buffer, self._last_buffer):
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
self._last_buffer = buffer.copy()
else:
@ -373,6 +390,58 @@ class SDR(ABC):
"""
return self.tx_gain
def set_rx_sample_rate(self):
"""
Set the sample rate of the receiver.
"""
raise NotImplementedError
def set_rx_center_frequency(self):
"""
Set the center frequency of the receiver.
"""
raise NotImplementedError
def set_rx_gain(self):
"""
Set the gain setting of the receiver.
"""
raise NotImplementedError
def set_tx_sample_rate(self):
"""
Set the sample rate of the transmitter.
"""
raise NotImplementedError
def set_tx_center_frequency(self):
"""
Set the center frequency of the transmitter.
"""
raise NotImplementedError
def set_tx_gain(self):
"""
Set the gain setting of the transmitter.
"""
raise NotImplementedError
def supports_dynamic_updates(self) -> dict:
"""
Report which parameters can be updated during streaming.
Returns:
dict: {'center_frequency': bool, 'sample_rate': bool, 'gain': bool}
"""
return {"center_frequency": False, "sample_rate": False, "gain": False}
def __del__(self):
"""Cleanup on garbage collection."""
try:
self.close()
except Exception:
pass
@abstractmethod
def close(self):
pass
@ -442,3 +511,21 @@ def _verify_sample_format(samples):
"""
return np.max(np.abs(samples)) <= 1
class SDRError(Exception):
"""Base exception for SDR errors."""
pass
class SDRParameterError(SDRError):
"""Invalid parameter (sample rate, freq, gain)."""
pass
class SDROverflowError(SDRError):
"""Buffer overflow detected."""
pass