ria-toolkit-oss/tests/remote_control/test_remote_transmitter_controller.py

289 lines
10 KiB
Python
Raw Normal View History

2026-04-17 09:43:59 -04:00
"""Tests for RemoteTransmitterController — mocks paramiko and ZMQ entirely.
paramiko and zmq are optional runtime deps; these tests inject fakes into
sys.modules so they run regardless of whether the packages are installed.
"""
from __future__ import annotations
import json
import time
from types import ModuleType
from unittest.mock import MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Fake modules injected into sys.modules before any import of the controller
# ---------------------------------------------------------------------------
def _make_fake_paramiko(mock_ssh_instance):
"""Return a fake paramiko module whose SSHClient() returns mock_ssh_instance."""
mod = MagicMock(spec=ModuleType)
mod.SSHClient = MagicMock(return_value=mock_ssh_instance)
mod.AutoAddPolicy = MagicMock()
return mod
def _make_fake_zmq(mock_socket_instance):
"""Return a fake zmq module whose Context().socket() returns mock_socket_instance."""
mock_context = MagicMock()
mock_context.socket.return_value = mock_socket_instance
mod = MagicMock(spec=ModuleType)
mod.Context = MagicMock(return_value=mock_context)
mod.REQ = "REQ"
return mod, mock_context
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _ok_response(fn="set_radio") -> bytes:
return json.dumps({"status": True, "message": "", "error_message": ""}).encode()
def _err_response(fn="set_radio", msg="boom") -> bytes:
return json.dumps({"status": False, "message": "", "error_message": msg}).encode()
def _make_mock_socket(recv_side_effect=None):
sock = MagicMock()
if recv_side_effect is not None:
sock.recv.side_effect = recv_side_effect
else:
sock.recv.return_value = _ok_response()
return sock
def _make_controller(mock_socket=None, *, startup_wait=0):
"""Build a controller with all external I/O mocked via sys.modules injection."""
mock_sock = mock_socket or _make_mock_socket()
mock_ssh = MagicMock()
mock_stdout = MagicMock()
mock_stdout.channel = MagicMock()
mock_ssh.exec_command.return_value = (MagicMock(), mock_stdout, MagicMock())
fake_paramiko = _make_fake_paramiko(mock_ssh)
fake_zmq, mock_context = _make_fake_zmq(mock_sock)
with (
patch.dict("sys.modules", {"paramiko": fake_paramiko, "zmq": fake_zmq}),
patch(
"ria_toolkit_oss.remote_control.remote_transmitter_controller._STARTUP_WAIT_S",
startup_wait,
),
):
from ria_toolkit_oss.remote_control.remote_transmitter_controller import (
RemoteTransmitterController,
)
ctrl = RemoteTransmitterController(
host="192.168.1.10",
ssh_user="ubuntu",
ssh_key_path="/home/user/.ssh/id_rsa",
zmq_port=5556,
)
ctrl._mock_ssh = mock_ssh
ctrl._mock_socket = mock_sock
ctrl._mock_context = mock_context
ctrl._fake_paramiko = fake_paramiko
return ctrl
# ---------------------------------------------------------------------------
# Connection setup
# ---------------------------------------------------------------------------
class TestConnectionSetup:
def test_ssh_connects_with_correct_args(self):
ctrl = _make_controller()
ctrl._mock_ssh.connect.assert_called_once_with(
hostname="192.168.1.10",
username="ubuntu",
key_filename="/home/user/.ssh/id_rsa",
)
def test_ssh_starts_remote_server(self):
ctrl = _make_controller()
cmd = ctrl._mock_ssh.exec_command.call_args[0][0]
assert "remote_transmitter" in cmd
assert "--port" in cmd
assert "5556" in cmd
def test_zmq_connects_to_host_port(self):
ctrl = _make_controller()
ctrl._mock_socket.connect.assert_called_once_with("tcp://192.168.1.10:5556")
def test_host_key_policy_set_to_auto_add(self):
"""AutoAddPolicy is applied so we don't prompt in headless execution."""
ctrl = _make_controller()
ctrl._mock_ssh.set_missing_host_key_policy.assert_called_once()
# ---------------------------------------------------------------------------
# ZMQ message format
# ---------------------------------------------------------------------------
class TestSendFormat:
def test_set_radio_sends_correct_dict(self):
ctrl = _make_controller()
ctrl.set_radio("pluto", "ip:192.168.2.1")
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["function_name"] == "set_radio"
assert sent["radio_str"] == "pluto"
assert sent["identifier"] == "ip:192.168.2.1"
def test_set_radio_default_identifier(self):
ctrl = _make_controller()
ctrl.set_radio("hackrf")
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["identifier"] == ""
def test_init_tx_sends_correct_dict(self):
ctrl = _make_controller()
ctrl.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=30, channel=1)
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["function_name"] == "init_tx"
assert sent["center_frequency"] == pytest.approx(2.4e9)
assert sent["sample_rate"] == pytest.approx(20e6)
assert sent["gain"] == pytest.approx(30)
assert sent["channel"] == 1
assert sent["gain_mode"] == "absolute"
def test_init_tx_default_channel_zero(self):
ctrl = _make_controller()
ctrl.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=0)
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["channel"] == 0
def test_stop_sends_correct_dict(self):
ctrl = _make_controller()
ctrl.stop()
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["function_name"] == "stop"
# ---------------------------------------------------------------------------
# Error handling
# ---------------------------------------------------------------------------
class TestErrorHandling:
def test_error_response_raises_runtime_error(self):
sock = _make_mock_socket()
sock.recv.return_value = _err_response(msg="radio not found")
ctrl = _make_controller(mock_socket=sock)
with pytest.raises(RuntimeError, match="radio not found"):
ctrl.set_radio("pluto")
def test_error_message_included_in_exception(self):
sock = _make_mock_socket()
sock.recv.return_value = _err_response(msg="gain out of range")
ctrl = _make_controller(mock_socket=sock)
with pytest.raises(RuntimeError, match="gain out of range"):
ctrl.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=999)
def test_send_on_closed_controller_raises(self):
ctrl = _make_controller()
ctrl.close()
with pytest.raises(RuntimeError, match="closed"):
ctrl._send({"function_name": "set_radio", "radio_str": "pluto", "identifier": ""})
def test_missing_paramiko_raises_runtime_error(self):
"""If paramiko is absent, connecting gives a clear RuntimeError."""
import ria_toolkit_oss.remote_control.remote_transmitter_controller as mod
with patch.dict("sys.modules", {"paramiko": None}):
with pytest.raises((RuntimeError, ImportError)):
M
2026-04-20 11:43:03 -04:00
mod.RemoteTransmitterController(host="h", ssh_user="u", ssh_key_path="/k")
2026-04-17 09:43:59 -04:00
# ---------------------------------------------------------------------------
# transmit_async / wait_transmit
# ---------------------------------------------------------------------------
class TestTransmitAsync:
def test_transmit_async_returns_immediately(self):
"""transmit_async must not block — the ZMQ recv may take duration_s seconds."""
def slow_recv():
time.sleep(0.1)
return _ok_response("transmit")
sock = _make_mock_socket()
sock.recv.side_effect = slow_recv
ctrl = _make_controller(mock_socket=sock)
t0 = time.monotonic()
ctrl.transmit_async(duration_s=5.0)
elapsed = time.monotonic() - t0
assert elapsed < 0.05, "transmit_async must not block"
ctrl.wait_transmit(timeout=2.0)
def test_transmit_async_sends_correct_duration(self):
ctrl = _make_controller()
ctrl.transmit_async(duration_s=12.5)
ctrl.wait_transmit(timeout=1.0)
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["function_name"] == "transmit"
assert sent["duration_s"] == pytest.approx(12.5)
def test_wait_transmit_joins_thread(self):
ctrl = _make_controller()
ctrl.transmit_async(duration_s=0.01)
ctrl.wait_transmit(timeout=2.0)
assert ctrl._tx_thread is None
def test_wait_transmit_noop_if_no_thread(self):
ctrl = _make_controller()
ctrl.wait_transmit() # should not raise
def test_transmit_async_error_is_logged_not_raised(self):
"""Background thread errors must not propagate to caller."""
sock = _make_mock_socket()
sock.recv.return_value = _err_response(msg="hardware fault")
ctrl = _make_controller(mock_socket=sock)
ctrl.transmit_async(duration_s=0.01)
ctrl.wait_transmit(timeout=2.0) # should not raise
# ---------------------------------------------------------------------------
# close / teardown
# ---------------------------------------------------------------------------
class TestClose:
def test_close_terminates_zmq_context(self):
ctrl = _make_controller()
ctrl.close()
ctrl._mock_context.term.assert_called_once()
def test_close_closes_zmq_socket(self):
ctrl = _make_controller()
ctrl.close()
ctrl._mock_socket.close.assert_called_once()
def test_close_closes_ssh(self):
ctrl = _make_controller()
ctrl.close()
ctrl._mock_ssh.close.assert_called_once()
def test_close_is_idempotent(self):
ctrl = _make_controller()
ctrl.close()
ctrl.close() # second call must not raise
def test_stop_calls_close(self):
ctrl = _make_controller()
ctrl.stop()
assert ctrl._socket is None
assert ctrl._ssh is None