import math import pickle import threading import warnings from abc import ABC, abstractmethod from typing import Optional import numpy as np import zmq from ria_toolkit_oss.datatypes.recording import Recording class SDR(ABC): """ This class defines a common interface (a template) for all SDR devices. Each specific SDR implementation should subclass SDR and provide concrete implementations for the abstract methods. To add support for a new radio, subclass this interface and implement all abstract methods. If you experience difficulties, please `contact us `_, we are happy to provide additional direction and/or help with the implementation details. """ def __init__(self): self._rx_initialized = False self._tx_initialized = False self._enable_rx = False self._enable_tx = False self._accumulated_buffer = None self._max_num_buffers = None self._num_buffers_processed = 0 self._accumulated_buffer = None self._last_buffer = None self._corrupted_buffer_count = 0 self.rx_sample_rate = None self.rx_center_frequency = None self.rx_gain = None self.tx_sample_rate = None self.tx_center_frequency = None self.tx_gain = None self._param_lock = threading.RLock() # Reentrant lock def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording: """ Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided. Note that ``init_rx()`` must be called before ``record()``. :param num_samples: The number of samples to record. :type num_samples: int, optional :param rx_time: The time to record. :type rx_time: int or float, optional :return: The Recording object :rtype: Recording """ if not self._rx_initialized: raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()") if num_samples is not None and rx_time is not None: raise ValueError("Only input one of num_samples or rx_time") elif num_samples is not None: self._num_samples_to_record = num_samples elif rx_time is not None: self._num_samples_to_record = int(rx_time * self.rx_sample_rate) else: raise ValueError("Must provide input of one of num_samples or rx_time") self.buffer_size = self.rx_buffer_size num_buffers = self._num_samples_to_record // self.buffer_size + 1 self._max_num_buffers = num_buffers self._num_buffers_processed = 0 self._last_buffer = None self._accumulated_buffer = None print("Starting stream") self._stream_rx( callback=self._accumulate_buffers_callback, ) print("Finished stream") metadata = { "source": self.__class__.__name__, "sample_rate": self.rx_sample_rate, "center_frequency": self.rx_center_frequency, "gain": self.rx_gain, } print("Creating recording") # build recording, truncate to self._num_samples_to_record recording = Recording(data=self._accumulated_buffer[:, : self._num_samples_to_record], metadata=metadata) # reset to record again self._accumulated_buffer = None self._num_buffers_processed = 0 return recording def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000): """ Stream iq samples as interleaved bytes via zmq. :param zmq_address: The zmq address. :type zmq_address: :param n_samples: The number of samples to stream. :type n_samples: int :param buffer_size: The buffer size during streaming. Defaults to 10000. :type buffer_size: int, optional :return: The trimmed Recording. :rtype: Recording """ try: self._previous_buffer = None self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size) self._num_buffers_processed = 0 self.zmq_address = _generate_full_zmq_address(str(zmq_address)) self.context = zmq.Context() self.socket = self.context.socket(zmq.PUB) self.socket.bind(self.zmq_address) self._stream_rx( self._zmq_bytestream_callback, ) finally: if hasattr(self, "socket"): self.socket.close() if hasattr(self, "context"): self.context.destroy() def _accumulate_buffers_callback(self, buffer, metadata=None): """ Receives a buffer and saves it to self.accumulated_buffer. """ # expected buffer is complex samples range -1 to 1 # save the buffer until max reached # return a recording # Validate buffer if not self._validate_buffer(buffer): print("Warning: Corrupted buffer detected, skipping") self._corrupted_buffer_count += 1 return # Skip this buffer if isinstance(buffer, np.ndarray): if buffer.ndim == 1: buffer = buffer[np.newaxis, :] # make shape (1, N) else: buffer = np.array(buffer) # make it 1d if len(buffer.shape) == 1: buffer = np.array([buffer]) # First call: pre-allocate if we know the final size if self._accumulated_buffer is None: # Check that _max_num_buffers is set if self._max_num_buffers is None: raise ValueError("Number of buffers for block capture not set.") if self._num_samples_to_record is None: raise ValueError("Number of samples not set before RX start.") if metadata is not None: self.received_metadata = metadata # Preallocate once (avoid np.zeros; use np.empty for speed) num_channels = buffer.shape[0] self._accumulated_buffer = np.empty((num_channels, self._num_samples_to_record), dtype=buffer.dtype) self._write_position = 0 print(f"Pre-allocated buffer for {self._num_samples_to_record:,} samples.") # Write new buffer into pre-allocated array n = buffer.shape[1] start = self._write_position end = min(start + n, self._num_samples_to_record) samples_to_write = end - start if samples_to_write > 0: self._accumulated_buffer[:, start:end] = buffer[:, : end - start] self._write_position = end # Check if we're done self._num_buffers_processed += 1 if self._num_buffers_processed >= self._max_num_buffers: self.stop() def _validate_buffer(self, buffer): """Check for obviously corrupt data.""" # Check for all zeros if np.all(buffer == 0): return False # Check for all same value if np.all(buffer == buffer[0]): return False return True def _zmq_bytestream_callback(self, buffer, metadata=None): # push to ZMQ port data = np.array(buffer).tobytes() # convert to bytes for transport self.socket.send(data) self._num_buffers_processed = self._num_buffers_processed + 1 if self._max_num_buffers is not None: if self._num_buffers_processed >= self._max_num_buffers: self.pause_rx() def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers): """ Stream samples to a zmq address, packaged in binary buffers using numpy.pickle. Useful for inference applications with a known input size. May reduce transfer rates, but individual buffers will not have discontinuities. :param zmq_address: The tcp address to stream to. :type zmq_address: str :param buffer_size: The number of iq samples in a buffer. :type buffer_size: int :param num_buffers: The number of buffers to stream before stopping. :type num_buffers: int """ self._max_num_buffers = num_buffers self.buffer_size = buffer_size self._num_buffers_processed = 0 self.zmq_address = _generate_full_zmq_address(str(zmq_address)) self.context = zmq.Context() self.socket = self.context.socket(zmq.PUB) self.socket.bind(self.zmq_address) self.set_rx_buffer_size(buffer_size) self._stream_rx(self._zmq_pickle_buffer_callback) def _zmq_pickle_buffer_callback(self, buffer, metadata=None): # push to ZMQ port # data = np.array(buffer).tobytes() # convert to bytes for transport # self.socket.send(data) self.socket.send(pickle.dumps(buffer)) # print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}") self._num_buffers_processed = self._num_buffers_processed + 1 if self._max_num_buffers is not None: if self._num_buffers_processed >= self._max_num_buffers: self.stop() if self._last_buffer is not None: if np.array_equal(buffer, self._last_buffer): print("\033[93mWarning: Buffer Overflow Detected\033[0m") self._last_buffer = buffer.copy() else: self._last_buffer = buffer.copy() def tx_recording( self, recording: Recording | np.ndarray, num_samples: Optional[int] = None, tx_time: Optional[int | float] = None, ): """ Transmit the given iq samples from the provided recording. init_tx() must be called before this function. :param recording: The recording to transmit. :type recording: Recording or np.ndarray :param num_samples: The number of samples to transmit, will repeat or truncate the recording to this length. Defaults to None. :type num_samples: int, optional :param tx_time: The time to transmit, will repeat or truncate the recording to this length. Defaults to None. :type tx_time: int or float, optional """ if not self._tx_initialized: raise RuntimeError( "TX was not initialized. init_tx() must be called before _stream_tx() or transmit_recording()" ) if num_samples is not None and tx_time is not None: raise ValueError("Only input one of num_samples or tx_time") elif num_samples is not None: self._num_samples_to_transmit = num_samples elif tx_time is not None: self._num_samples_to_transmit = tx_time * self.tx_sample_rate else: self._num_samples_to_transmit = len(recording) if isinstance(recording, np.ndarray): self._samples_to_transmit = recording elif isinstance(recording, Recording): if len(recording.data) > 1: warnings.warn("Recording object is multichannel, only channel 0 data was used for transmission") self._samples_to_transmit = recording.data[0] self._num_samples_transmitted = 0 self._stream_tx(self._loop_recording_callback) def _loop_recording_callback(self, num_samples): samples_left = self._num_samples_to_transmit - self._num_samples_transmitted # find where to start based on num_samples_transmitted start_index = self._num_samples_transmitted % len(self._samples_to_transmit) # generates an array of indices that wrap around as many times as necessary. indices = np.arange(start_index, start_index + num_samples) % len(self._samples_to_transmit) samples = self._samples_to_transmit[indices] # zero pad at the end so we are still giving the requested buffer size # while also giving the exact number of non zero samples if len(samples) > samples_left: samples[int(samples_left) :] = 0 self.pause_tx() self._num_samples_transmitted = self._num_samples_transmitted + num_samples return samples def supports_bias_tee(self) -> bool: """Return True when the radio supports bias-tee control.""" return False def set_bias_tee(self, enable: bool): """Enable or disable bias-tee power when supported by the radio.""" raise NotImplementedError(f"{self.__class__.__name__} does not support bias-tee control") def pause_rx(self): self._enable_rx = False def pause_tx(self): self._enable_tx = False def stop(self): self.pause_rx() self.pause_tx() def get_rx_sample_rate(self): """ Retrieve the current sample rate of the receiver. Returns: float: The receiver's sample rate in samples per second (Hz). """ return self.rx_sample_rate def get_rx_center_frequency(self): """ Retrieve the current center frequency of the receiver. Returns: float: The receiver's center frequency in Hertz (Hz). """ return self.rx_center_frequency def get_rx_gain(self): """ Retrieve the current gain setting of the receiver. Returns: float: The receiver's gain in decibels (dB). """ return self.rx_gain def get_tx_sample_rate(self): """ Retrieve the current sample rate of the transmitter. Returns: float: The transmitter's sample rate in samples per second (Hz). """ return self.tx_sample_rate def get_tx_center_frequency(self): """ Retrieve the current center frequency of the transmitter. Returns: float: The transmitter's center frequency in Hertz (Hz). """ return self.tx_center_frequency def get_tx_gain(self): """ Retrieve the current gain setting of the transmitter. Returns: float: The transmitter's gain in decibels (dB). """ return self.tx_gain def set_rx_sample_rate(self): """ Set the sample rate of the receiver. """ raise NotImplementedError def set_rx_center_frequency(self): """ Set the center frequency of the receiver. """ raise NotImplementedError def set_rx_gain(self): """ Set the gain setting of the receiver. """ raise NotImplementedError def set_tx_sample_rate(self): """ Set the sample rate of the transmitter. """ raise NotImplementedError def set_tx_center_frequency(self): """ Set the center frequency of the transmitter. """ raise NotImplementedError def set_tx_gain(self): """ Set the gain setting of the transmitter. """ raise NotImplementedError def supports_dynamic_updates(self) -> dict: """ Report which parameters can be updated during streaming. Returns: dict: {'center_frequency': bool, 'sample_rate': bool, 'gain': bool} """ return {"center_frequency": False, "sample_rate": False, "gain": False} def __del__(self): """Cleanup on garbage collection.""" try: self.close() except Exception: pass @abstractmethod def close(self): pass @abstractmethod def init_rx(self, sample_rate, center_frequency, gain, channel, gain_mode): pass @abstractmethod def init_tx(self, sample_rate, center_frequency, gain, channel, gain_mode): pass @abstractmethod def _stream_rx(self, callback): pass @abstractmethod def _stream_tx(self, callback): pass @abstractmethod def set_clock_source(self, source): """ Sets the clock source to external or internal. :param source: The clock source :type source: str """ pass def _generate_full_zmq_address(input_address): """ Helper function for zmq streaming. If given a port number like 5556, return tcp localhost address at that port. Otherwise, return the address untouched. """ if ("://" not in str(input_address)) and _is_valid_port(input_address): # If no transport protocol specified, assume TCP return "tcp://*:" + str(input_address) else: # Otherwise, return the input unchanged return input_address def _is_valid_port(port): """ Helper function for zmq address. """ try: port_num = int(port) return 0 <= port_num <= 65535 except ValueError: return False def _verify_sample_format(samples): """ Verify that the sample data is in the range -1 to 1. :param buffer: An array of samples. :Return: True if the buffer is in the correct format, false if not. :rtype: bool """ return np.max(np.abs(samples)) <= 1 class SDRError(Exception): """Base exception for SDR errors.""" pass class SDRParameterError(SDRError): """Invalid parameter (sample rate, freq, gain).""" pass class SDROverflowError(SDRError): """Buffer overflow detected.""" pass