ria-toolkit-oss/tests/ria_toolkit_oss_cli/test_transmit.py
2025-12-11 15:59:08 -05:00

346 lines
13 KiB
Python

"""Tests for transmit command."""
import os
import tempfile
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import yaml
from click.testing import CliRunner
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit import (
auto_select_tx_device,
check_sample_rate_mismatch,
load_input_file,
transmit,
validate_tx_gain,
)
class TestGetTxDevice:
"""Tests for get_sdr_device function."""
def test_get_pluto_device(self):
"""Test getting PlutoSDR device."""
mock_sdr_class = MagicMock()
mock_sdr_instance = MagicMock()
mock_sdr_class.return_value = mock_sdr_instance
with patch.dict("sys.modules", {"src.ria_toolkit_oss.sdr.pluto": MagicMock(Pluto=mock_sdr_class)}):
device = get_sdr_device("pluto")
assert device is mock_sdr_instance
def test_get_hackrf_device(self):
"""Test getting HackRF device."""
mock_sdr_class = MagicMock()
mock_sdr_instance = MagicMock()
mock_sdr_class.return_value = mock_sdr_instance
with patch.dict("sys.modules", {"src.ria_toolkit_oss.sdr.hackrf": MagicMock(HackRF=mock_sdr_class)}):
device = get_sdr_device("hackrf")
assert device is mock_sdr_instance
def test_get_unknown_device(self):
"""Test getting unknown device type."""
from click.exceptions import ClickException
with pytest.raises(ClickException) as exc_info:
get_sdr_device("unknown_device")
assert "Unknown device type" in str(exc_info.value)
class TestAutoSelectTxDevice:
"""Tests for auto_select_tx_device function."""
def test_auto_select_no_devices(self):
"""Test auto-select with no TX devices found."""
from click.exceptions import ClickException
with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]),
):
with pytest.raises(ClickException) as exc_info:
auto_select_tx_device()
assert "No TX-capable SDR devices found" in str(exc_info.value)
def test_auto_select_single_device(self):
"""Test auto-select with single TX device."""
with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]),
patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices",
return_value=[{"type": "HackRF One", "serial": "123456"}],
),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]),
):
device_type = auto_select_tx_device(quiet=True)
assert device_type == "hackrf"
def test_auto_select_multiple_devices(self):
"""Test auto-select with multiple TX devices raises error."""
from click.exceptions import ClickException
with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]),
patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices",
return_value=[{"type": "PlutoSDR", "uri": "ip:pluto.local"}],
),
patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices",
return_value=[{"type": "HackRF One", "serial": "123456"}],
),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]),
):
with pytest.raises(ClickException) as exc_info:
auto_select_tx_device()
assert "Multiple TX-capable devices found" in str(exc_info.value)
def test_auto_select_device_mapping(self):
"""Test device type name mapping."""
test_cases = [
("PlutoSDR", "pluto"),
("HackRF One", "hackrf"),
("BladeRF", "bladerf"),
("b200", "usrp"),
("B210", "usrp"),
]
for device_name, expected_type in test_cases:
with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[{"type": device_name}]),
):
device_type = auto_select_tx_device(quiet=True)
assert device_type == expected_type
class TestLoadInputFile:
"""Tests for load_input_file function."""
def test_load_file_not_found(self):
"""Test loading non-existent file."""
from click.exceptions import ClickException
with pytest.raises(ClickException) as exc_info:
load_input_file("nonexistent.sigmf")
assert "Input file not found" in str(exc_info.value)
def test_load_sigmf_file(self):
"""Test loading SigMF file."""
with tempfile.NamedTemporaryFile(suffix=".sigmf-data", delete=False) as f:
test_file = f.name
try:
mock_recording = MagicMock()
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording", return_value=mock_recording):
recording = load_input_file(test_file, legacy=False)
assert recording == mock_recording
finally:
os.unlink(test_file)
def test_load_legacy_npy_file(self):
"""Test loading legacy NPY file."""
with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as f:
test_file = f.name
try:
mock_recording = MagicMock()
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.from_npy_legacy", return_value=mock_recording):
recording = load_input_file(test_file, legacy=True)
assert recording == mock_recording
finally:
os.unlink(test_file)
def test_load_unsupported_format(self):
"""Test loading unsupported file format."""
from click.exceptions import ClickException
with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as f:
test_file = f.name
try:
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording", side_effect=Exception("Unsupported format")):
with pytest.raises(ClickException) as exc_info:
load_input_file(test_file)
assert "Could not load" in str(exc_info.value)
assert "Supported formats" in str(exc_info.value)
finally:
os.unlink(test_file)
class TestValidateTxGain:
"""Tests for validate_tx_gain function."""
def test_valid_pluto_gain(self):
"""Test valid PlutoSDR gain."""
validate_tx_gain("pluto", -30)
validate_tx_gain("pluto", 0)
validate_tx_gain("pluto", -89)
def test_invalid_pluto_gain_too_high(self):
"""Test PlutoSDR gain too high."""
from click.exceptions import ClickException
with pytest.raises(ClickException) as exc_info:
validate_tx_gain("pluto", 10)
assert "out of range" in str(exc_info.value)
def test_invalid_pluto_gain_too_low(self):
"""Test PlutoSDR gain too low."""
from click.exceptions import ClickException
with pytest.raises(ClickException) as exc_info:
validate_tx_gain("pluto", -100)
assert "out of range" in str(exc_info.value)
def test_valid_hackrf_gain(self):
"""Test valid HackRF gain."""
validate_tx_gain("hackrf", 0)
validate_tx_gain("hackrf", 20)
validate_tx_gain("hackrf", 47)
def test_invalid_hackrf_gain(self):
"""Test invalid HackRF gain."""
from click.exceptions import ClickException
with pytest.raises(ClickException):
validate_tx_gain("hackrf", -10)
with pytest.raises(ClickException):
validate_tx_gain("hackrf", 50)
def test_high_gain_warning(self):
"""Test warning for high gain levels."""
import click
with patch.object(click, "echo") as mock_echo:
validate_tx_gain("hackrf", 45)
mock_echo.assert_called()
args = str(mock_echo.call_args)
assert "WARNING" in args
assert "high gain" in args.lower()
class TestCheckSampleRateMismatch:
"""Tests for check_sample_rate_mismatch function."""
def test_no_mismatch(self):
"""Test when sample rates match."""
mock_recording = MagicMock()
mock_recording.metadata = {"sample_rate": 2e6}
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo:
check_sample_rate_mismatch(mock_recording, 2e6, quiet=False)
mock_echo.assert_not_called()
def test_mismatch_warning(self):
"""Test warning when sample rates differ."""
mock_recording = MagicMock()
mock_recording.metadata = {"sample_rate": 1e6}
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo:
check_sample_rate_mismatch(mock_recording, 2e6, quiet=False)
mock_echo.assert_called_once()
args = str(mock_echo.call_args)
assert "Warning" in args
assert "differs" in args
def test_mismatch_quiet_mode(self):
"""Test no warning in quiet mode."""
mock_recording = MagicMock()
mock_recording.metadata = {"sample_rate": 1e6}
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo:
check_sample_rate_mismatch(mock_recording, 2e6, quiet=True)
mock_echo.assert_not_called()
def test_no_metadata(self):
"""Test when recording has no metadata."""
mock_recording = MagicMock()
mock_recording.metadata = None
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo:
check_sample_rate_mismatch(mock_recording, 2e6, quiet=False)
mock_echo.assert_not_called()
class TestTransmitCommand:
"""Tests for transmit CLI command."""
def setup_method(self):
"""Set up test fixtures."""
self.runner = CliRunner()
self.temp_dir = tempfile.mkdtemp()
def teardown_method(self):
"""Clean up test fixtures."""
import shutil
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)
def test_transmit_basic(self):
"""Test basic transmit command."""
test_file = os.path.join(self.temp_dir, "test.npy")
open(test_file, "w").close()
mock_sdr = MagicMock()
mock_recording = MagicMock()
mock_recording.data = np.array([[0.1 + 0.1j] * 1000])
mock_recording.metadata = {}
with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.get_sdr_device", return_value=mock_sdr),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_input_file", return_value=mock_recording),
):
result = self.runner.invoke(
transmit,
[
"--device",
"hackrf",
"--sample-rate",
"2e6",
"--center-frequency",
"915M",
"--gain",
"10",
"--input",
test_file,
"--quiet",
],
)
assert result.exit_code == 0
mock_sdr.tx_recording.assert_called_once()
mock_sdr.close.assert_called_once()