updates_and_fixes #12
|
|
@ -1,5 +1,6 @@
|
|||
import math
|
||||
import pickle
|
||||
import threading
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
|
@ -27,17 +28,21 @@ class SDR(ABC):
|
|||
self._tx_initialized = False
|
||||
self._enable_rx = False
|
||||
self._enable_tx = False
|
||||
|
||||
self._accumulated_buffer = None
|
||||
self._max_num_buffers = None
|
||||
self._num_buffers_processed = 0
|
||||
self._accumulated_buffer = None
|
||||
self._last_buffer = None
|
||||
self._corrupted_buffer_count = 0
|
||||
|
||||
self.rx_sample_rate = None
|
||||
self.rx_center_frequency = None
|
||||
self.rx_gain = None
|
||||
self.tx_sample_rate = None
|
||||
self.tx_center_frequency = None
|
||||
self.tx_gain = None
|
||||
self._param_lock = threading.RLock() # Reentrant lock
|
||||
|
||||
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
|
||||
"""
|
||||
|
|
@ -71,7 +76,6 @@ class SDR(ABC):
|
|||
|
||||
self._max_num_buffers = num_buffers
|
||||
self._num_buffers_processed = 0
|
||||
self._num_buffers_processed = 0
|
||||
self._last_buffer = None
|
||||
self._accumulated_buffer = None
|
||||
print("Starting stream")
|
||||
|
|
@ -94,6 +98,7 @@ class SDR(ABC):
|
|||
|
||||
# reset to record again
|
||||
self._accumulated_buffer = None
|
||||
self._num_buffers_processed = 0
|
||||
return recording
|
||||
|
||||
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
|
||||
|
|
@ -110,21 +115,23 @@ class SDR(ABC):
|
|||
:return: The trimmed Recording.
|
||||
:rtype: Recording
|
||||
"""
|
||||
try:
|
||||
self._previous_buffer = None
|
||||
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size)
|
||||
self._num_buffers_processed = 0
|
||||
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PUB)
|
||||
self.socket.bind(self.zmq_address)
|
||||
|
||||
self._previous_buffer = None
|
||||
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size)
|
||||
self._num_buffers_processed = 0
|
||||
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PUB)
|
||||
self.socket.bind(self.zmq_address)
|
||||
|
||||
self._stream_rx(
|
||||
self._zmq_bytestream_callback,
|
||||
)
|
||||
|
||||
self.context.destroy()
|
||||
self.socket.close()
|
||||
self._stream_rx(
|
||||
self._zmq_bytestream_callback,
|
||||
)
|
||||
finally:
|
||||
if hasattr(self, "socket"):
|
||||
self.socket.close()
|
||||
if hasattr(self, "context"):
|
||||
self.context.destroy()
|
||||
|
||||
def _accumulate_buffers_callback(self, buffer, metadata=None):
|
||||
"""
|
||||
|
|
@ -134,62 +141,72 @@ class SDR(ABC):
|
|||
# save the buffer until max reached
|
||||
# return a recording
|
||||
|
||||
buffer = np.array(buffer) # make it 1d
|
||||
if len(buffer.shape) == 1:
|
||||
buffer = np.array([buffer])
|
||||
# Validate buffer
|
||||
if not self._validate_buffer(buffer):
|
||||
print("Warning: Corrupted buffer detected, skipping")
|
||||
self._corrupted_buffer_count += 1
|
||||
return # Skip this buffer
|
||||
|
||||
# it runs these checks each time, is that an efficiency issue?
|
||||
|
||||
if self._max_num_buffers is None:
|
||||
# default then
|
||||
# this should probably print, but that would happen every buffer...
|
||||
raise ValueError("Number of buffers for block capture not set.")
|
||||
|
||||
# add the given buffer to the pre-allocated buffer
|
||||
|
||||
if metadata is not None:
|
||||
self.received_metadata = metadata
|
||||
|
||||
# TODO optimize, pre-allocate
|
||||
if self._accumulated_buffer is not None:
|
||||
self._accumulated_buffer = np.concatenate((self._accumulated_buffer, buffer), axis=1)
|
||||
if isinstance(buffer, np.ndarray):
|
||||
if buffer.ndim == 1:
|
||||
buffer = buffer[np.newaxis, :] # make shape (1, N)
|
||||
else:
|
||||
# the first time
|
||||
self._accumulated_buffer = buffer.copy()
|
||||
buffer = np.array(buffer) # make it 1d
|
||||
if len(buffer.shape) == 1:
|
||||
buffer = np.array([buffer])
|
||||
|
||||
self._num_buffers_processed = self._num_buffers_processed + 1
|
||||
# First call: pre-allocate if we know the final size
|
||||
if self._accumulated_buffer is None:
|
||||
# Check that _max_num_buffers is set
|
||||
if self._max_num_buffers is None:
|
||||
raise ValueError("Number of buffers for block capture not set.")
|
||||
if self._num_samples_to_record is None:
|
||||
raise ValueError("Number of samples not set before RX start.")
|
||||
|
||||
if metadata is not None:
|
||||
self.received_metadata = metadata
|
||||
|
||||
# Preallocate once (avoid np.zeros; use np.empty for speed)
|
||||
num_channels = buffer.shape[0]
|
||||
self._accumulated_buffer = np.empty((num_channels, self._num_samples_to_record), dtype=buffer.dtype)
|
||||
self._write_position = 0
|
||||
print(f"Pre-allocated buffer for {self._num_samples_to_record:,} samples.")
|
||||
|
||||
# Write new buffer into pre-allocated array
|
||||
n = buffer.shape[1]
|
||||
start = self._write_position
|
||||
end = min(start + n, self._num_samples_to_record)
|
||||
samples_to_write = end - start
|
||||
|
||||
if samples_to_write > 0:
|
||||
self._accumulated_buffer[:, start:end] = buffer[:, : end - start]
|
||||
self._write_position = end
|
||||
|
||||
# Check if we're done
|
||||
self._num_buffers_processed += 1
|
||||
if self._num_buffers_processed >= self._max_num_buffers:
|
||||
self.stop()
|
||||
|
||||
if self._last_buffer is not None:
|
||||
if (buffer == self._last_buffer).all():
|
||||
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
|
||||
self._last_buffer = buffer.copy()
|
||||
else:
|
||||
self._last_buffer = buffer.copy()
|
||||
|
||||
# print("Number of buffers received: " + str(self._num_buffers_processed))
|
||||
def _validate_buffer(self, buffer):
|
||||
"""Check for obviously corrupt data."""
|
||||
# Check for all zeros
|
||||
if np.all(buffer == 0):
|
||||
return False
|
||||
# Check for all same value
|
||||
if np.all(buffer == buffer[0]):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _zmq_bytestream_callback(self, buffer, metadata=None):
|
||||
# push to ZMQ port
|
||||
data = np.array(buffer).tobytes() # convert to bytes for transport
|
||||
self.socket.send(data)
|
||||
|
||||
# print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}")
|
||||
|
||||
self._num_buffers_processed = self._num_buffers_processed + 1
|
||||
if self._max_num_buffers is not None:
|
||||
if self._num_buffers_processed >= self._max_num_buffers:
|
||||
self.pause_rx()
|
||||
|
||||
if self._previous_buffer is not None:
|
||||
if (buffer == self._previous_buffer).all():
|
||||
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
|
||||
# TODO: I suggest we think about moving this part to the top of this function
|
||||
# and skip the rest of the function in case of overflow.
|
||||
# like, it's not necessary to stream repeated IQ data anyways!
|
||||
self._previous_buffer = buffer.copy()
|
||||
|
||||
def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers):
|
||||
"""
|
||||
Stream samples to a zmq address, packaged in binary buffers using numpy.pickle.
|
||||
|
|
@ -229,7 +246,7 @@ class SDR(ABC):
|
|||
self.stop()
|
||||
|
||||
if self._last_buffer is not None:
|
||||
if (buffer == self._last_buffer).all():
|
||||
if np.array_equal(buffer, self._last_buffer):
|
||||
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
|
||||
self._last_buffer = buffer.copy()
|
||||
else:
|
||||
|
|
@ -373,6 +390,58 @@ class SDR(ABC):
|
|||
"""
|
||||
return self.tx_gain
|
||||
|
||||
def set_rx_sample_rate(self):
|
||||
"""
|
||||
Set the sample rate of the receiver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_rx_center_frequency(self):
|
||||
"""
|
||||
Set the center frequency of the receiver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_rx_gain(self):
|
||||
"""
|
||||
Set the gain setting of the receiver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_tx_sample_rate(self):
|
||||
"""
|
||||
Set the sample rate of the transmitter.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_tx_center_frequency(self):
|
||||
"""
|
||||
Set the center frequency of the transmitter.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_tx_gain(self):
|
||||
"""
|
||||
Set the gain setting of the transmitter.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def supports_dynamic_updates(self) -> dict:
|
||||
"""
|
||||
Report which parameters can be updated during streaming.
|
||||
|
||||
Returns:
|
||||
dict: {'center_frequency': bool, 'sample_rate': bool, 'gain': bool}
|
||||
"""
|
||||
return {"center_frequency": False, "sample_rate": False, "gain": False}
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup on garbage collection."""
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
pass
|
||||
|
|
@ -442,3 +511,21 @@ def _verify_sample_format(samples):
|
|||
"""
|
||||
|
||||
return np.max(np.abs(samples)) <= 1
|
||||
|
||||
|
||||
class SDRError(Exception):
|
||||
"""Base exception for SDR errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SDRParameterError(SDRError):
|
||||
"""Invalid parameter (sample rate, freq, gain)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SDROverflowError(SDRError):
|
||||
"""Buffer overflow detected."""
|
||||
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user