diff --git a/src/ria_toolkit_oss/sdr/usrp.py b/src/ria_toolkit_oss/sdr/usrp.py index abf4e3d..3ae7569 100644 --- a/src/ria_toolkit_oss/sdr/usrp.py +++ b/src/ria_toolkit_oss/sdr/usrp.py @@ -7,7 +7,7 @@ import numpy as np import uhd from ria_toolkit_oss.data.recording import Recording -from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError +from ria_toolkit_oss.sdr.sdr import SDR, SdrDisconnectedError, SDRError, SDRParameterError class USRP(SDR): @@ -32,6 +32,13 @@ class USRP(SDR): self._rx_initialized = False self._tx_initialized = False + # True once a continuous RX stream has been started (see rx()). Kept + # running across rx() calls so the agent streamer gets gapless capture + # instead of a start/stop per buffer. + self._rx_streaming = False + # Samples received past the end of one rx() request, carried into the + # next call so nothing is dropped between buffers. + self._rx_residual = np.empty(0, dtype=np.complex64) def init_rx( self, @@ -96,6 +103,8 @@ class USRP(SDR): # flag to prevent user from calling certain functions before this one. self._rx_initialized = True self._tx_initialized = False + self._rx_streaming = False # (re)started lazily on the first rx() call + self._rx_residual = np.empty(0, dtype=np.complex64) return {"sample_rate": self.rx_sample_rate, "center_frequency": self.rx_center_frequency, "gain": self.rx_gain} @@ -265,6 +274,97 @@ class USRP(SDR): return Recording(data=store_array[:, :num_samples], metadata=metadata) + def rx(self, num_samples: int) -> "np.ndarray": + """Return *num_samples* complex64 IQ samples from a continuous RX stream. + + This is the interface the agent streamer's capture loop calls every + buffer. Unlike ``record()`` (a one-shot that issues ``start_cont`` / + ``stop_cont`` and sleeps each call), this keeps a single continuous + stream running across calls, so capture is gapless — no per-buffer + start/stop churn, transients, or zero-filled gaps that show up as black + bands in the spectrogram. + + On the first call it auto-initializes RX (from ``sample_rate`` / + ``center_freq`` / ``gain`` set by the caller) and issues ``start_cont`` + once. ``close()`` (or ``stop()``) stops the stream. + """ + if not self._rx_initialized: + gain = self.gain if isinstance(self.gain, (int, float)) else 40.0 + self.init_rx( + sample_rate=self.sample_rate, + center_frequency=self.center_freq, + gain=gain, + channel=0, + ) + + if not self._rx_streaming: + stream_command = uhd.types.StreamCMD(uhd.types.StreamMode.start_cont) + stream_command.stream_now = True + self.rx_stream.issue_stream_cmd(stream_command) + self._enable_rx = True + self._rx_streaming = True + print("USRP Starting RX (continuous)...") + + out = np.empty(num_samples, dtype=np.complex64) + filled = 0 + + # Drain any samples carried over from the previous call first. + if self._rx_residual.size: + take = min(self._rx_residual.size, num_samples) + out[:take] = self._rx_residual[:take] + self._rx_residual = self._rx_residual[take:] + filled = take + + recv_buffer = np.zeros((1, self.rx_buffer_size), dtype=np.complex64) + consecutive_timeouts = 0 + error_codes = uhd.types.RXMetadataErrorCode + + while filled < num_samples: + n = self.rx_stream.recv(recv_buffer, self.metadata, self.timeout) + err = self.metadata.error_code + + if err == error_codes.timeout: + consecutive_timeouts += 1 + # A stalled stream is a disconnect, not a transient hiccup. + if consecutive_timeouts >= 5: + self._rx_streaming = False + raise SdrDisconnectedError("USRP RX timed out repeatedly — device may be disconnected") + continue + consecutive_timeouts = 0 + + # Overflow ("O") means the host fell behind and UHD dropped samples + # upstream; the samples we did get are still valid, so keep going. + if err not in (error_codes.none, error_codes.overflow): + self._rx_streaming = False + raise SDRError(f"USRP RX error: {err}") + + if n <= 0: + continue + take = min(n, num_samples - filled) + out[filled : filled + take] = recv_buffer[0, :take] + filled += take + # Keep anything received past this request for the next call so the + # stream stays gapless across rx() boundaries. + if take < n: + self._rx_residual = recv_buffer[0, take:n].copy() + + return out + + def _stop_rx_stream(self) -> None: + """Issue stop_cont for the continuous RX stream, if running.""" + if not self._rx_streaming: + return + self._enable_rx = False + try: + stop_cmd = uhd.types.StreamCMD(uhd.types.StreamMode.stop_cont) + stop_cmd.stream_now = True + self.rx_stream.issue_stream_cmd(stop_cmd) + except Exception: + pass + self._rx_streaming = False + self._rx_residual = np.empty(0, dtype=np.complex64) + print("USRP RX stopped.") + def init_tx( self, sample_rate: int | float, @@ -371,6 +471,7 @@ class USRP(SDR): print(f"USRP TX Gain = {self.tx_gain}") def close(self): + self._stop_rx_stream() self._tx_initialized = False self._rx_initialized = False if hasattr(self, "rx_stream"): diff --git a/tests/agent/test_usrp_rx.py b/tests/agent/test_usrp_rx.py new file mode 100644 index 0000000..aad9672 --- /dev/null +++ b/tests/agent/test_usrp_rx.py @@ -0,0 +1,142 @@ +"""Hardware-free tests for the USRP continuous-streaming rx(). + +`uhd` isn't importable without the UHD install, so we stub the bits USRP.rx() +touches and drive it with a scripted fake rx_stream. The point is to prove the +capture is gapless across rx() calls — the property that fixes the choppy / +black-banded spectrogram caused by the old start/stop-per-buffer record(). +""" + +from __future__ import annotations + +import sys +import types + +import numpy as np +import pytest + + +def _install_fake_uhd(): + uhd = types.ModuleType("uhd") + + class StreamCMD: + def __init__(self, mode): + self.mode = mode + self.stream_now = False + self.time_spec = None + + uhd.types = types.SimpleNamespace( + StreamCMD=StreamCMD, + StreamMode=types.SimpleNamespace(start_cont="start_cont", stop_cont="stop_cont"), + RXMetadataErrorCode=types.SimpleNamespace(none="none", overflow="overflow", timeout="timeout"), + ) + uhd.usrp = types.SimpleNamespace() + sys.modules["uhd"] = uhd + return uhd + + +@pytest.fixture +def USRP(): + # Snapshot so the fake uhd / freshly-imported usrp don't leak into other + # tests (e.g. detect_available() would otherwise think usrp is importable). + saved_uhd = sys.modules.get("uhd") + saved_usrp = sys.modules.get("ria_toolkit_oss.sdr.usrp") + + _install_fake_uhd() + sys.modules.pop("ria_toolkit_oss.sdr.usrp", None) + from ria_toolkit_oss.sdr.usrp import USRP as _USRP + + yield _USRP + + for name, mod in (("uhd", saved_uhd), ("ria_toolkit_oss.sdr.usrp", saved_usrp)): + if mod is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = mod + + +class _FakeStream: + """Delivers a contiguous ramp of samples; ``real`` part is the sample index. + + ``script`` is a list of (count, error_code) the recv loop walks through. + """ + + def __init__(self, script, metadata): + self._script = list(script) + self._metadata = metadata + self._counter = 0 + self.issued = [] + + def issue_stream_cmd(self, cmd): + self.issued.append(cmd.mode) + + def recv(self, buffer, metadata, timeout): + count, err = self._script.pop(0) + metadata.error_code = err + if count > 0: + idx = np.arange(self._counter, self._counter + count, dtype=np.float32) + buffer[0, :count] = idx.astype(np.complex64) + self._counter += count + return count + + +def _make_usrp(USRP, script, rx_buffer_size=4): + u = USRP.__new__(USRP) + u._rx_initialized = True + u._rx_streaming = False + u._rx_residual = np.empty(0, dtype=np.complex64) + u.rx_buffer_size = rx_buffer_size + u.timeout = 0.1 + u._enable_rx = False + u.metadata = types.SimpleNamespace(error_code="none") + u.rx_stream = _FakeStream(script, u.metadata) + return u + + +def test_rx_is_gapless_across_calls(USRP): + # rx_buffer_size=4; each recv yields 4 fresh samples. Two rx(6) calls must + # return a contiguous 0..11 ramp — the over-read remainder is carried over. + script = [(4, "none")] * 4 + u = _make_usrp(USRP, script) + + first = u.rx(6) + second = u.rx(6) + + assert first.dtype == np.complex64 and len(first) == 6 + combined = np.concatenate([first, second]).real + assert np.array_equal(combined, np.arange(12, dtype=np.float32)) # no drops, no zeros + assert "start_cont" in u.rx_stream.issued # stream started exactly via start_cont + assert u.rx_stream.issued.count("start_cont") == 1 # ...and only once + + +def test_rx_starts_stream_only_once(USRP): + u = _make_usrp(USRP, [(4, "none")] * 6) + u.rx(4) + u.rx(4) + assert u.rx_stream.issued.count("start_cont") == 1 + + +def test_rx_keeps_going_on_overflow(USRP): + # Overflow samples are still valid — they must be used, not dropped. + script = [(2, "none"), (2, "overflow"), (2, "none")] + u = _make_usrp(USRP, script) + out = u.rx(6).real + assert np.array_equal(out, np.arange(6, dtype=np.float32)) + + +def test_rx_raises_on_persistent_timeout(USRP): + from ria_toolkit_oss.sdr.sdr import SdrDisconnectedError + + u = _make_usrp(USRP, [(0, "timeout")] * 10) + with pytest.raises(SdrDisconnectedError): + u.rx(4) + + +def test_stop_rx_stream_resets_state(USRP): + u = _make_usrp(USRP, [(4, "none")] * 4) + u.rx(6) # leaves a 2-sample residual, stream running + assert u._rx_streaming is True + assert u._rx_residual.size == 2 + u._stop_rx_stream() + assert u._rx_streaming is False + assert u._rx_residual.size == 0 + assert "stop_cont" in u.rx_stream.issued