diff --git a/pyproject.toml b/pyproject.toml index 86eb791..465a899 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,8 @@ all-sdr = [ [tool.poetry] packages = [ - { include = "ria_toolkit_oss", from = "src" } + { include = "ria_toolkit_oss", from = "src" }, + { include = "ria_toolkit_oss_cli", from = "src/ria_toolkit_oss" } ] include = [ "**/*.so", # Required for Nuitkaification diff --git a/tests/ria_toolkit_oss_cli/test.combine.py b/tests/ria_toolkit_oss_cli/test.combine.py new file mode 100644 index 0000000..b6f7d8b --- /dev/null +++ b/tests/ria_toolkit_oss_cli/test.combine.py @@ -0,0 +1,963 @@ +"""Tests for the combine command.""" + +import tempfile +from pathlib import Path + +import numpy as np +import pytest +from click.testing import CliRunner + +from ria_toolkit_oss.datatypes import Annotation, Recording +from ria_toolkit_oss.io import load_recording, to_npy, to_sigmf +from ria_toolkit_oss_cli.cli import cli + + +class TestCombineHelp: + """Test help and basic command structure.""" + + def test_help(self): + """Test combine help.""" + runner = CliRunner() + result = runner.invoke(cli, ["combine", "--help"]) + assert result.exit_code == 0 + assert "Combine multiple recordings" in result.output + assert "--mode" in result.output + assert "--align-mode" in result.output + + def test_no_inputs(self): + """Test error with no inputs.""" + runner = CliRunner() + result = runner.invoke(cli, ["combine"]) + assert result.exit_code != 0 + + def test_single_input(self): + """Test error with only one input.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + # Create test file + signal = np.arange(1000, dtype=np.complex64) + recording = Recording(data=signal, metadata={"sample_rate": 2e6}) + to_npy(recording, filename=str(Path(tmpdir) / "test.npy"), overwrite=True) + + result = runner.invoke(cli, ["combine", str(Path(tmpdir) / "test.npy"), str(Path(tmpdir) / "output.npy")]) + assert result.exit_code != 0 + + +class TestCombineConcat: + """Test concatenate mode.""" + + @pytest.fixture + def test_recordings(self): + """Create multiple test recording files.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create 3 recordings with different data + for i in range(3): + signal = np.arange(i * 1000, (i + 1) * 1000, dtype=np.complex64) + recording = Recording(data=signal, metadata={"sample_rate": 2e6}) + to_npy(recording, filename=str(Path(tmpdir) / f"chunk{i}.npy"), overwrite=True) + yield tmpdir + + def test_concat_basic(self, test_recordings): + """Test basic concatenation.""" + runner = CliRunner() + tmpdir = test_recordings + output_path = str(Path(tmpdir) / "combined.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "chunk0.npy"), + str(Path(tmpdir) / "chunk1.npy"), + str(Path(tmpdir) / "chunk2.npy"), + output_path, + ], + ) + + assert result.exit_code == 0 + assert Path(output_path).exists() + + # Verify result + combined = load_recording(output_path) + assert combined.data.shape[1] == 3000 + assert combined._metadata["combine_mode"] == "concat" + assert combined._metadata["num_inputs"] == 3 + + # Check data is correctly concatenated + assert np.allclose(combined.data[0, :1000], np.arange(0, 1000)) + assert np.allclose(combined.data[0, 1000:2000], np.arange(1000, 2000)) + assert np.allclose(combined.data[0, 2000:3000], np.arange(2000, 3000)) + + def test_concat_verbose(self, test_recordings): + """Test concatenation with verbose output.""" + runner = CliRunner() + tmpdir = test_recordings + output_path = str(Path(tmpdir) / "combined.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "chunk0.npy"), + str(Path(tmpdir) / "chunk1.npy"), + str(Path(tmpdir) / "chunk2.npy"), + output_path, + "--verbose", + ], + ) + + assert result.exit_code == 0 + assert "Combining 3 recordings" in result.output + assert "concat mode" in result.output + assert "Concatenating..." in result.output + + def test_concat_with_annotations(self): + """Test that annotations are preserved and shifted in concat mode.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create recordings with annotations + rec1 = Recording(data=np.ones(1000, dtype=np.complex64), metadata={"sample_rate": 2e6}) + rec1._annotations.append( + Annotation( + sample_start=100, sample_count=200, freq_lower_edge=900e6, freq_upper_edge=920e6, label="test1" + ) + ) + + rec2 = Recording(data=np.ones(1000, dtype=np.complex64) * 2, metadata={"sample_rate": 2e6}) + rec2._annotations.append( + Annotation( + sample_start=100, sample_count=200, freq_lower_edge=900e6, freq_upper_edge=920e6, label="test2" + ) + ) + + to_sigmf(rec1, filename="rec1", path=tmpdir, overwrite=True) + to_sigmf(rec2, filename="rec2", path=tmpdir, overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "combined.sigmf-data") + + result = runner.invoke( + cli, + ["combine", str(Path(tmpdir) / "rec1.sigmf-data"), str(Path(tmpdir) / "rec2.sigmf-data"), output_path], + ) + + assert result.exit_code == 0 + + combined = load_recording(output_path) + assert len(combined._annotations) == 2 + # First annotation unchanged + assert combined._annotations[0].sample_start == 100 + assert combined._annotations[0].label == "test1" + # Second annotation shifted by 1000 samples + assert combined._annotations[1].sample_start == 1100 + assert combined._annotations[1].label == "test2" + + +class TestCombineAddSameLength: + """Test add mode with same-length recordings.""" + + def test_add_basic(self): + """Test basic add with same-length recordings.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create two recordings with same length + sig1 = np.ones(1000, dtype=np.complex64) + sig2 = np.ones(1000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "added.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "sig1.npy"), + str(Path(tmpdir) / "sig2.npy"), + output_path, + "--mode", + "add", + ], + ) + + assert result.exit_code == 0 + + # Verify result + combined = load_recording(output_path) + assert combined.data.shape[1] == 1000 + assert np.allclose(combined.data, 3 + 0j) + assert combined._metadata["combine_mode"] == "add" + assert combined._metadata["align_mode"] == "error" + + def test_add_three_recordings(self): + """Test adding three same-length recordings.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create three recordings + for i in range(1, 4): + sig = np.ones(1000, dtype=np.complex64) * i + rec = Recording(data=sig, metadata={"sample_rate": 2e6}) + to_npy(rec, filename=str(Path(tmpdir) / f"sig{i}.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "added.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "sig1.npy"), + str(Path(tmpdir) / "sig2.npy"), + str(Path(tmpdir) / "sig3.npy"), + output_path, + "--mode", + "add", + ], + ) + + assert result.exit_code == 0 + + combined = load_recording(output_path) + # 1 + 2 + 3 = 6 + assert np.allclose(combined.data, 6 + 0j) + + +class TestCombineAddAlignError: + """Test add mode with error alignment (default).""" + + def test_different_length_error(self): + """Test that different lengths cause error by default.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create recordings with different lengths + sig1 = np.ones(10000, dtype=np.complex64) + sig2 = np.ones(5000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "output.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "long.npy"), + str(Path(tmpdir) / "short.npy"), + output_path, + "--mode", + "add", + ], + ) + + assert result.exit_code != 0 + assert "different lengths" in result.output + assert "--align-mode" in result.output + + +class TestCombineAddAlignTruncate: + """Test add mode with truncate alignment.""" + + def test_truncate(self): + """Test truncate to shortest recording.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(10000, dtype=np.complex64) + sig2 = np.ones(5000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "truncated.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "long.npy"), + str(Path(tmpdir) / "short.npy"), + output_path, + "--mode", + "add", + "--align-mode", + "truncate", + ], + ) + + assert result.exit_code == 0 + + combined = load_recording(output_path) + assert combined.data.shape[1] == 5000 + assert np.allclose(combined.data, 3 + 0j) + + +class TestCombineAddAlignPad: + """Test add mode with pad alignment.""" + + def test_pad(self): + """Test zero-padding to longest recording.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(10000, dtype=np.complex64) + sig2 = np.ones(5000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "padded.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "long.npy"), + str(Path(tmpdir) / "short.npy"), + output_path, + "--mode", + "add", + "--align-mode", + "pad", + ], + ) + + assert result.exit_code == 0 + + combined = load_recording(output_path) + assert combined.data.shape[1] == 10000 + # First 5000: 1 + 2 = 3 + assert np.allclose(combined.data[0, :5000], 3 + 0j) + # Last 5000: 1 + 0 = 1 + assert np.allclose(combined.data[0, 5000:], 1 + 0j) + + +class TestCombineAddAlignPadStart: + """Test add mode with pad-start alignment.""" + + def test_pad_start(self): + """Test pad-start at specific sample.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(10000, dtype=np.complex64) + sig2 = np.ones(5000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "pad_start.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "long.npy"), + str(Path(tmpdir) / "short.npy"), + output_path, + "--mode", + "add", + "--align-mode", + "pad-start", + "--pad-start-sample", + "3000", + ], + ) + + assert result.exit_code == 0 + + combined = load_recording(output_path) + assert combined.data.shape[1] == 10000 + # Before 3000: 1 + 0 = 1 + assert np.allclose(combined.data[0, :3000], 1 + 0j) + # 3000-8000: 1 + 2 = 3 + assert np.allclose(combined.data[0, 3000:8000], 3 + 0j) + # After 8000: 1 + 0 = 1 + assert np.allclose(combined.data[0, 8000:], 1 + 0j) + + def test_pad_start_invalid(self): + """Test invalid pad-start-sample.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(10000, dtype=np.complex64) + sig2 = np.ones(5000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "output.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "long.npy"), + str(Path(tmpdir) / "short.npy"), + output_path, + "--mode", + "add", + "--align-mode", + "pad-start", + "--pad-start-sample", + "7000", # Too large + ], + ) + + assert result.exit_code != 0 + assert "exceeds max length" in result.output + + +class TestCombineAddAlignPadCenter: + """Test add mode with pad-center alignment.""" + + def test_pad_center(self): + """Test centering shorter recording.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(10000, dtype=np.complex64) + sig2 = np.ones(5000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "pad_center.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "long.npy"), + str(Path(tmpdir) / "short.npy"), + output_path, + "--mode", + "add", + "--align-mode", + "pad-center", + ], + ) + + assert result.exit_code == 0 + + combined = load_recording(output_path) + assert combined.data.shape[1] == 10000 + # Before 2500: 1 + 0 = 1 + assert np.allclose(combined.data[0, :2500], 1 + 0j) + # 2500-7500: 1 + 2 = 3 + assert np.allclose(combined.data[0, 2500:7500], 3 + 0j) + # After 7500: 1 + 0 = 1 + assert np.allclose(combined.data[0, 7500:], 1 + 0j) + + +class TestCombineAddAlignPadEnd: + """Test add mode with pad-end alignment.""" + + def test_pad_end(self): + """Test aligning end of recordings.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(10000, dtype=np.complex64) + sig2 = np.ones(5000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "pad_end.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "long.npy"), + str(Path(tmpdir) / "short.npy"), + output_path, + "--mode", + "add", + "--align-mode", + "pad-end", + ], + ) + + assert result.exit_code == 0 + + combined = load_recording(output_path) + assert combined.data.shape[1] == 10000 + # First 5000: 1 + 0 = 1 + assert np.allclose(combined.data[0, :5000], 1 + 0j) + # Last 5000: 1 + 2 = 3 + assert np.allclose(combined.data[0, 5000:], 3 + 0j) + + +class TestCombineAddAlignRepeat: + """Test add mode with repeat alignment.""" + + def test_repeat(self): + """Test repeating shorter recording.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(10000, dtype=np.complex64) + sig2 = np.ones(5000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "repeated.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "long.npy"), + str(Path(tmpdir) / "short.npy"), + output_path, + "--mode", + "add", + "--align-mode", + "repeat", + ], + ) + + assert result.exit_code == 0 + + combined = load_recording(output_path) + assert combined.data.shape[1] == 10000 + # Entire recording: 1 + 2 = 3 (pattern repeated) + assert np.allclose(combined.data, 3 + 0j) + + def test_repeat_partial(self): + """Test repeat with non-exact multiple.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(10000, dtype=np.complex64) + sig2 = np.arange(3000, dtype=np.complex64) + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "repeated.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "long.npy"), + str(Path(tmpdir) / "short.npy"), + output_path, + "--mode", + "add", + "--align-mode", + "repeat", + ], + ) + + assert result.exit_code == 0 + + combined = load_recording(output_path) + # Check pattern repeats correctly + # First 3000: 1 + [0,1,2,...,2999] + assert np.allclose(combined.data[0, :3000], 1 + np.arange(3000)) + # Next 3000: 1 + [0,1,2,...,2999] + assert np.allclose(combined.data[0, 3000:6000], 1 + np.arange(3000)) + # Next 3000: 1 + [0,1,2,...,2999] + assert np.allclose(combined.data[0, 6000:9000], 1 + np.arange(3000)) + # Last 1000: 1 + [0,1,2,...,999] + assert np.allclose(combined.data[0, 9000:10000], 1 + np.arange(1000)) + + +class TestCombineAddAlignRepeatSpaced: + """Test add mode with repeat-spaced alignment.""" + + def test_repeat_spaced(self): + """Test repeating with spacing.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(10000, dtype=np.complex64) + sig2 = np.ones(2000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "repeat_spaced.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "long.npy"), + str(Path(tmpdir) / "short.npy"), + output_path, + "--mode", + "add", + "--align-mode", + "repeat-spaced", + "--repeat-spacing", + "1000", + ], + ) + + assert result.exit_code == 0 + + combined = load_recording(output_path) + assert combined.data.shape[1] == 10000 + + # First 2000: 1 + 2 = 3 + assert np.allclose(combined.data[0, :2000], 3 + 0j) + # Next 1000 (gap): 1 + 0 = 1 + assert np.allclose(combined.data[0, 2000:3000], 1 + 0j) + # Next 2000: 1 + 2 = 3 + assert np.allclose(combined.data[0, 3000:5000], 3 + 0j) + # Next 1000 (gap): 1 + 0 = 1 + assert np.allclose(combined.data[0, 5000:6000], 1 + 0j) + + def test_repeat_spaced_missing_spacing(self): + """Test error when spacing not provided.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(10000, dtype=np.complex64) + sig2 = np.ones(5000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "output.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "long.npy"), + str(Path(tmpdir) / "short.npy"), + output_path, + "--mode", + "add", + "--align-mode", + "repeat-spaced", + # Missing --repeat-spacing + ], + ) + + assert result.exit_code != 0 + assert "requires --repeat-spacing" in result.output + + +class TestCombineValidation: + """Test validation and error handling.""" + + def test_sample_rate_mismatch(self): + """Test error on sample rate mismatch in add mode.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(1000, dtype=np.complex64) + sig2 = np.ones(1000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 1e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "output.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "sig1.npy"), + str(Path(tmpdir) / "sig2.npy"), + output_path, + "--mode", + "add", + ], + ) + + assert result.exit_code != 0 + assert "different sample rates" in result.output + + def test_channel_count_mismatch(self): + """Test error on channel count mismatch.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Single channel + sig1 = np.ones((1, 1000), dtype=np.complex64) + # Two channels + sig2 = np.ones((2, 1000), dtype=np.complex64) + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "output.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "sig1.npy"), + str(Path(tmpdir) / "sig2.npy"), + output_path, + "--mode", + "add", + ], + ) + + assert result.exit_code != 0 + assert "different channel counts" in result.output + + def test_overwrite_protection(self): + """Test overwrite protection.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test recordings + sig1 = np.ones(1000, dtype=np.complex64) + sig2 = np.ones(1000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True) + + # Create existing output file + existing = Recording(data=np.zeros(100, dtype=np.complex64), metadata={}) + output_path = str(Path(tmpdir) / "output.npy") + to_npy(existing, filename=output_path, overwrite=True) + + runner = CliRunner() + + # Should fail without --overwrite + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "sig1.npy"), + str(Path(tmpdir) / "sig2.npy"), + output_path, + "--mode", + "add", + ], + ) + + assert result.exit_code != 0 + assert "already exists" in result.output + + # Should succeed with --overwrite + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "sig1.npy"), + str(Path(tmpdir) / "sig2.npy"), + output_path, + "--mode", + "add", + "--overwrite", + ], + ) + + assert result.exit_code == 0 + + +class TestCombineOutputOptions: + """Test output format and options.""" + + def test_output_formats(self): + """Test different output formats.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test recordings + sig1 = np.ones(1000, dtype=np.complex64) + sig2 = np.ones(1000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True) + + runner = CliRunner() + + # Test SigMF output + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "sig1.npy"), + str(Path(tmpdir) / "sig2.npy"), + str(Path(tmpdir) / "output.sigmf-data"), + "--mode", + "add", + ], + ) + assert result.exit_code == 0 + assert Path(tmpdir, "output.sigmf-data").exists() + assert Path(tmpdir, "output.sigmf-meta").exists() + + def test_normalize(self): + """Test normalize option.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(1000, dtype=np.complex64) * 10 + sig2 = np.ones(1000, dtype=np.complex64) * 20 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "normalized.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "sig1.npy"), + str(Path(tmpdir) / "sig2.npy"), + output_path, + "--mode", + "add", + "--normalize", + ], + ) + + assert result.exit_code == 0 + + combined = load_recording(output_path) + # Should be normalized to max magnitude 1 + assert np.allclose(np.max(np.abs(combined.data)), 1.0) + assert combined._metadata.get("normalized") is True + + def test_custom_metadata(self): + """Test adding custom metadata.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(1000, dtype=np.complex64) + sig2 = np.ones(1000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "output.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "sig1.npy"), + str(Path(tmpdir) / "sig2.npy"), + output_path, + "--mode", + "add", + "--metadata", + "test_id=test123", + "--metadata", + "author=tester", + ], + ) + + assert result.exit_code == 0 + + combined = load_recording(output_path) + assert combined._metadata["test_id"] == "test123" + assert combined._metadata["author"] == "tester" + + +class TestCombineVerboseQuiet: + """Test verbose and quiet modes.""" + + def test_verbose(self): + """Test verbose output.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(1000, dtype=np.complex64) + sig2 = np.ones(1000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "output.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "sig1.npy"), + str(Path(tmpdir) / "sig2.npy"), + output_path, + "--mode", + "add", + "--verbose", + ], + ) + + assert result.exit_code == 0 + assert "Loading" in result.output + assert "Done" in result.output + + def test_quiet(self): + """Test quiet output.""" + with tempfile.TemporaryDirectory() as tmpdir: + sig1 = np.ones(1000, dtype=np.complex64) + sig2 = np.ones(1000, dtype=np.complex64) * 2 + + rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6}) + rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6}) + + to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True) + to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True) + + runner = CliRunner() + output_path = str(Path(tmpdir) / "output.npy") + + result = runner.invoke( + cli, + [ + "combine", + str(Path(tmpdir) / "sig1.npy"), + str(Path(tmpdir) / "sig2.npy"), + output_path, + "--mode", + "add", + "--quiet", + ], + ) + + assert result.exit_code == 0 + assert result.output == "" diff --git a/tests/ria_toolkit_oss_cli/test_capture.py b/tests/ria_toolkit_oss_cli/test_capture.py new file mode 100644 index 0000000..81749d6 --- /dev/null +++ b/tests/ria_toolkit_oss_cli/test_capture.py @@ -0,0 +1,165 @@ +# flake8: noqa +"""Tests for capture command.""" + +import os +import tempfile +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import yaml +from click.testing import CliRunner + +from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture import ( + auto_select_device, + capture, + get_sdr_device, + save_visualization, +) + + +class TestGetSdrDevice: + """Tests for get_sdr_device function.""" + + def test_get_pluto_device(self): + """Test getting PlutoSDR device.""" + mock_sdr_class = MagicMock() + mock_sdr_instance = MagicMock() + mock_sdr_class.return_value = mock_sdr_instance + + with patch.dict("sys.modules", {"src.ria_toolkit_oss.sdr.pluto": MagicMock(Pluto=mock_sdr_class)}): + device = get_sdr_device("pluto") + assert device is mock_sdr_instance + + def test_get_hackrf_device(self): + """Test getting HackRF device.""" + mock_sdr_class = MagicMock() + mock_sdr_instance = MagicMock() + mock_sdr_class.return_value = mock_sdr_instance + + with patch.dict("sys.modules", {"src.ria_toolkit_oss.sdr.hackrf": MagicMock(HackRF=mock_sdr_class)}): + device = get_sdr_device("hackrf") + assert device is mock_sdr_instance + + def test_get_unknown_device(self): + """Test getting unknown device type.""" + from click.exceptions import ClickException + + with pytest.raises(ClickException) as exc_info: + get_sdr_device("unknown_device") + + assert "Unknown device type" in str(exc_info.value) + + +class TestAutoSelectDevice: + """Tests for auto_select_device function.""" + + def test_auto_select_no_devices(self): + """Test auto-select with no devices found.""" + from click.exceptions import ClickException + + with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.list_all_devices") as mock_discover: + mock_discover.return_value = [] + + with pytest.raises(ClickException) as exc_info: + auto_select_device() + + assert "No SDR devices found" in str(exc_info.value) + + def test_auto_select_single_device(self): + """Test auto-select with single device.""" + with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.list_all_devices") as mock_discover: + mock_discover.return_value = [{"type": "HackRF", "serial": "123456"}] + + device_type = auto_select_device(quiet=True) + assert device_type == "hackrf" + + def test_auto_select_single_device_with_warning(self): + """Test auto-select shows warning when not quiet.""" + with ( + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.list_all_devices") as mock_discover, + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo, + ): + + mock_discover.return_value = [{"type": "PlutoSDR", "uri": "ip:pluto.local"}] + + device_type = auto_select_device(quiet=False) + + assert device_type == "pluto" + # Should have called echo twice (warning + hint) + assert mock_echo.call_count == 2 + + def test_auto_select_multiple_devices(self): + """Test auto-select with multiple devices raises error.""" + from click.exceptions import ClickException + + with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.list_all_devices") as mock_discover: + mock_discover.return_value = [ + {"type": "HackRF", "serial": "123456"}, + {"type": "PlutoSDR", "uri": "ip:pluto.local"}, + ] + + with pytest.raises(ClickException) as exc_info: + auto_select_device() + + assert "Multiple devices found" in str(exc_info.value) + + def test_auto_select_device_name_mapping(self): + """Test device name mapping.""" + with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.list_all_devices") as mock_discover: + # Test various device name formats + test_cases = [ + ("PlutoSDR", "pluto"), + ("HackRF", "hackrf"), + ("BladeRF", "bladerf"), + ("RTL-SDR", "rtlsdr"), + ] + + for device_name, expected_type in test_cases: + mock_discover.return_value = [{"type": device_name}] + device_type = auto_select_device(quiet=True) + assert device_type == expected_type + + +class TestSaveVisualization: + """Tests for save_visualization function.""" + + def test_save_visualization_success(self): + """Test successful visualization save.""" + mock_recording = MagicMock() + + with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig") as mock_view: + save_visualization(mock_recording, "test.png", quiet=True) + + mock_view.assert_called_once_with( + mock_recording, output_path="test.png", saveplot=True, fast_mode=False, labels_mode=True + ) + + def test_save_visualization_import_error(self): + """Test visualization save with import error.""" + mock_recording = MagicMock() + + with ( + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig", side_effect=ImportError("Module not found")), + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo, + ): + + save_visualization(mock_recording, "test.png", quiet=True) + + # Should catch error and echo warning + mock_echo.assert_called_once() + assert "Warning" in str(mock_echo.call_args) + + def test_save_visualization_general_error(self): + """Test visualization save with general error.""" + mock_recording = MagicMock() + + with ( + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig", side_effect=Exception("Failed to plot")), + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo, + ): + + save_visualization(mock_recording, "test.png", quiet=True) + + mock_echo.assert_called_once() + assert "Failed to save visualization" in str(mock_echo.call_args) diff --git a/tests/ria_toolkit_oss_cli/test_common.py b/tests/ria_toolkit_oss_cli/test_common.py new file mode 100644 index 0000000..cc58e88 --- /dev/null +++ b/tests/ria_toolkit_oss_cli/test_common.py @@ -0,0 +1,118 @@ +"""Tests for common CLI utilities.""" + +import os +import tempfile + +import pytest +import yaml + +from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import ( + format_frequency, + format_sample_rate, + load_yaml_config, + parse_frequency, + parse_metadata_args, +) + + +def test_load_yaml_config(): + """Test loading YAML configuration files.""" + config_data = { + "device": "pluto", + "sample_rate": 2e6, + "center_frequency": "915e6", + "gain": 30, + "metadata": {"location": "test_lab", "experiment": "test_001"}, + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config_data, f) + config_file = f.name + + try: + loaded_config = load_yaml_config(config_file) + assert loaded_config == config_data + assert loaded_config["device"] == "pluto" + assert loaded_config["sample_rate"] == 2e6 + assert loaded_config["metadata"]["location"] == "test_lab" + finally: + os.unlink(config_file) + + +def test_load_yaml_config_empty(): + """Test loading empty YAML file.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write("") + config_file = f.name + + try: + loaded_config = load_yaml_config(config_file) + assert loaded_config == {} + finally: + os.unlink(config_file) + + +def test_parse_metadata_args(): + """Test parsing metadata KEY=VALUE arguments.""" + metadata_args = ["location=test_lab", "experiment=001", "power=30", "frequency=2.4e9", "description=Test Signal"] + + result = parse_metadata_args(metadata_args) + + assert result["location"] == "test_lab" + assert result["experiment"] == "001" # String because doesn't parse as number + assert result["power"] == 30 # Integer + assert result["frequency"] == 2.4e9 # Float + assert result["description"] == "Test Signal" + + +def test_parse_metadata_args_invalid(): + """Test invalid metadata format raises error.""" + from click.exceptions import ClickException + + with pytest.raises(ClickException): + parse_metadata_args(["invalid_format"]) + + with pytest.raises(ClickException): + parse_metadata_args(["key1=value1", "invalid", "key2=value2"]) + + +def test_parse_frequency(): + """Test frequency parsing with different formats.""" + # Scientific notation + assert parse_frequency("915e6") == 915e6 + assert parse_frequency("2.4e9") == 2.4e9 + assert parse_frequency("433e6") == 433e6 + + # With suffixes + assert parse_frequency("915M") == 915e6 + assert parse_frequency("2.4G") == 2.4e9 + assert parse_frequency("433M") == 433e6 + assert parse_frequency("100k") == 100e3 + assert parse_frequency("100K") == 100e3 + + # Plain numbers + assert parse_frequency("915000000") == 915e6 + assert parse_frequency("2400000000") == 2.4e9 + + # Edge cases + assert parse_frequency("0.915G") == 915e6 + assert parse_frequency("915.0M") == 915e6 + + +def test_format_frequency(): + """Test frequency formatting.""" + assert format_frequency(915e6) == "915.00 MHz" + assert format_frequency(2.4e9) == "2.40 GHz" + assert format_frequency(433e6) == "433.00 MHz" + assert format_frequency(100e3) == "100.00 kHz" + assert format_frequency(1e3) == "1.00 kHz" + assert format_frequency(500) == "500.00 Hz" + + +def test_format_sample_rate(): + """Test sample rate formatting.""" + assert format_sample_rate(20e6) == "20.00 MS/s" + assert format_sample_rate(2e6) == "2.00 MS/s" + assert format_sample_rate(100e3) == "100.00 kS/s" + assert format_sample_rate(1e3) == "1.00 kS/s" + assert format_sample_rate(500) == "500.00 S/s" diff --git a/tests/ria_toolkit_oss_cli/test_convert.py b/tests/ria_toolkit_oss_cli/test_convert.py new file mode 100644 index 0000000..f5528cc --- /dev/null +++ b/tests/ria_toolkit_oss_cli/test_convert.py @@ -0,0 +1,190 @@ +"""Tests for convert command.""" + +import os +import tempfile +from pathlib import Path + +import pytest +from click.testing import CliRunner + +from ria_toolkit_oss.ria_toolkit_oss_cli.cli import cli + + +class TestConvert: + """Test convert command functionality.""" + + def test_convert_help(self): + """Test convert command help.""" + runner = CliRunner() + result = runner.invoke(cli, ["convert", "--help"]) + assert result.exit_code == 0 + assert "Convert recordings between file formats" in result.output + assert "--format" in result.output + assert "--legacy" in result.output + assert "--wav-sample-rate" in result.output + assert "--blue-format" in result.output + assert "--overwrite" in result.output + assert "--metadata" in result.output + + def test_missing_arguments(self): + """Test that missing arguments show error.""" + runner = CliRunner() + result = runner.invoke(cli, ["convert"]) + assert result.exit_code != 0 + assert "Missing argument" in result.output or "Error" in result.output + + def test_invalid_input_format(self): + """Test handling of invalid input format.""" + runner = CliRunner() + with tempfile.NamedTemporaryFile(suffix=".xyz", delete=False) as f: + try: + result = runner.invoke(cli, ["convert", f.name, "output.npy"]) + assert result.exit_code != 0 + assert "Unknown format" in result.output or "Supported" in result.output + finally: + os.unlink(f.name) + + def test_overwrite_protection(self): + """Test that overwrite protection works.""" + runner = CliRunner() + + # Create a dummy input file (will use actual test data if available) + test_input = "/home/qrf/workarea/ash/signal-testbed/recordings/iq2440MHz234233.npy" + if not os.path.exists(test_input): + pytest.skip("Test recording file not found") + + with tempfile.TemporaryDirectory() as tmpdir: + output_file = os.path.join(tmpdir, "test.sigmf") + + # First conversion should succeed + result = runner.invoke(cli, ["convert", test_input, output_file, "--legacy", "-q"]) + assert result.exit_code == 0 + + # Second conversion without --overwrite should fail + result = runner.invoke(cli, ["convert", test_input, output_file, "--legacy"]) + assert result.exit_code != 0 + assert "exist" in result.output.lower() + assert "--overwrite" in result.output + + # Third conversion with --overwrite should succeed + result = runner.invoke(cli, ["convert", test_input, output_file, "--legacy", "--overwrite", "-q"]) + assert result.exit_code == 0 + + def test_metadata_override(self): + """Test metadata override functionality.""" + runner = CliRunner() + + test_input = "/home/qrf/workarea/ash/signal-testbed/recordings/iq2440MHz234233.npy" + if not os.path.exists(test_input): + pytest.skip("Test recording file not found") + + with tempfile.TemporaryDirectory() as tmpdir: + output_file = os.path.join(tmpdir, "test.sigmf") + + result = runner.invoke( + cli, + [ + "convert", + test_input, + output_file, + "--legacy", + "--metadata", + "test_key=test_value", + "--metadata", + "number=42", + "--metadata", + "float_val=3.14", + "-v", + ], + ) + + assert result.exit_code == 0 + assert "test_key" in result.output + assert "number" in result.output + assert "float_val" in result.output + + def test_format_detection(self): + """Test that format detection works for different extensions.""" + runner = CliRunner() + + test_input = "/home/qrf/workarea/ash/signal-testbed/recordings/iq2440MHz234233.npy" + if not os.path.exists(test_input): + pytest.skip("Test recording file not found") + + with tempfile.TemporaryDirectory() as tmpdir: + # Test NPY to SigMF + sigmf_out = os.path.join(tmpdir, "test.sigmf") + result = runner.invoke(cli, ["convert", test_input, sigmf_out, "--legacy", "-q"]) + assert result.exit_code == 0 + assert Path(sigmf_out).with_suffix(".sigmf-data").exists() + assert Path(sigmf_out).with_suffix(".sigmf-meta").exists() + + # Test NPY to NPY + npy_out = os.path.join(tmpdir, "test.npy") + result = runner.invoke(cli, ["convert", test_input, npy_out, "--legacy", "-q"]) + assert result.exit_code == 0 + assert Path(npy_out).exists() + + def test_wav_conversion_with_decimation(self): + """Test WAV conversion with sample rate decimation.""" + runner = CliRunner() + + test_input = "/home/qrf/workarea/ash/signal-testbed/recordings/iq2440MHz234233.npy" + if not os.path.exists(test_input): + pytest.skip("Test recording file not found") + + with tempfile.TemporaryDirectory() as tmpdir: + wav_out = os.path.join(tmpdir, "test.wav") + result = runner.invoke( + cli, ["convert", test_input, wav_out, "--legacy", "--wav-sample-rate", "48000", "--wav-bits", "16"] + ) + + assert result.exit_code == 0 + assert "Decimation factor" in result.output + assert Path(wav_out).exists() + # Check file is non-empty + assert os.path.getsize(wav_out) > 0 + + def test_blue_format_conversion(self): + """Test MIDAS Blue format conversion.""" + runner = CliRunner() + + test_input = "/home/qrf/workarea/ash/signal-testbed/recordings/iq2440MHz234233.npy" + if not os.path.exists(test_input): + pytest.skip("Test recording file not found") + + with tempfile.TemporaryDirectory() as tmpdir: + # Test each Blue format + for blue_fmt in ["CI", "CF", "CD"]: + blue_out = os.path.join(tmpdir, f"test_{blue_fmt}.blue") + result = runner.invoke( + cli, ["convert", test_input, blue_out, "--legacy", "--blue-format", blue_fmt, "-q"] + ) + + assert result.exit_code == 0 + assert Path(blue_out).exists() + # Check file is non-empty + assert os.path.getsize(blue_out) > 0 + + def test_quiet_and_verbose_modes(self): + """Test quiet and verbose output modes.""" + runner = CliRunner() + + test_input = "/home/qrf/workarea/ash/signal-testbed/recordings/iq2440MHz234233.npy" + if not os.path.exists(test_input): + pytest.skip("Test recording file not found") + + with tempfile.TemporaryDirectory() as tmpdir: + # Test verbose mode + output_file = os.path.join(tmpdir, "test_verbose.sigmf") + result = runner.invoke(cli, ["convert", test_input, output_file, "--legacy", "-v"]) + assert result.exit_code == 0 + assert "Reading input" in result.output + assert "Metadata preserved" in result.output + + # Test quiet mode + output_file = os.path.join(tmpdir, "test_quiet.npy") + result = runner.invoke(cli, ["convert", test_input, output_file, "--legacy", "-q"]) + assert result.exit_code == 0 + # Should have minimal output + assert len(result.output) < 100 or result.output.strip() == "" diff --git a/tests/ria_toolkit_oss_cli/test_discover b/tests/ria_toolkit_oss_cli/test_discover new file mode 100644 index 0000000..fb9f869 --- /dev/null +++ b/tests/ria_toolkit_oss_cli/test_discover @@ -0,0 +1,287 @@ +"""Tests for discover command.""" + +import json +import re +from unittest.mock import MagicMock, patch + +from click.testing import CliRunner + +from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover import ( # find_bladerf_devices,; find_thinkrf_devices,; find_uhd_devices, + discover, + discover_all_devices, + find_hackrf_devices, + find_pluto_devices, + find_rtlsdr_devices, + load_sdr_drivers, +) + + +def test_discover_pluto_no_devices(): + """Test PlutoSDR discovery with no devices.""" + with ( + patch.dict("sys.modules", {"iio": MagicMock()}) as mock_modules, + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.get_usb_devices") as mock_usb, + ): + mock_iio = mock_modules["iio"] + mock_iio.scan_contexts.return_value = {} + mock_usb.return_value = [] + + devices = find_pluto_devices() + assert devices == [] + + +def test_discover_pluto_with_device(): + """Test PlutoSDR discovery with device present.""" + with patch.dict("sys.modules", {"iio": MagicMock()}) as mock_modules: + mock_iio = mock_modules["iio"] + mock_ctx = MagicMock() + mock_ctx.attrs = {"hw_serial": "123456", "fw_version": "1.0"} + mock_ctx._destroy = MagicMock() + + mock_iio.scan_contexts.return_value = {"ip:pluto.local": "PlutoSDR (ADALM-PLUTO)"} + mock_iio.Context.return_value = mock_ctx + + devices = find_pluto_devices() + + assert len(devices) == 1 + assert devices[0]["type"] == "PlutoSDR" + assert devices[0]["serial"] == "123456" + assert devices[0]["firmware"] == "1.0" + assert devices[0]["uri"] == "ip:pluto.local" + + +def test_discover_hackrf_no_devices(): + """Test HackRF discovery with no devices.""" + with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.subprocess") as mock_subprocess: + mock_subprocess.check_output.return_value = "" + devices = find_hackrf_devices() + assert devices == [] + + +def test_discover_hackrf_with_devices(): + """Test HackRF discovery with devices present.""" + with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.subprocess") as mock_subprocess: + mock_subprocess.check_output.return_value = """ + hackrf_info version: 2023.01.1 + libhackrf version: 2023.01.1 (0.8) + Found HackRF + Index: 0 + Serial number: serial123 + Board ID Number: 2 (HackRF One) + Firmware Version: v2.1.0 (API:1.08) + Part ID Number: 0xa000cb3c 0x005d4761 + Index: 1 + Serial number: serial456 + Board ID Number: 2 (HackRF One) + Firmware Version: v2.1.0 (API:1.08) + Part ID Number: 0xa000cb3c 0x005d4761 + """ + + devices = find_hackrf_devices() + + assert len(devices) == 2 + assert devices[0]["type"] == "HackRF One" + assert devices[0]["serial"] == "serial123" + assert devices[0]["device_index"] == 0 or devices[0]["device_index"] == "0" + assert devices[1]["serial"] == "serial456" + assert devices[1]["device_index"] == 1 or devices[1]["device_index"] == "1" + + +def test_discover_rtlsdr_no_devices(): + """Test RTL-SDR discovery with no devices.""" + with patch("ria_toolkit_oss.ria_toolkit_oss.ria_toolkit_oss.discover.subprocess") as mock_subprocess: + mock_subprocess.check_output.return_value = "" + devices = find_rtlsdr_devices() + assert devices == [] + + +def test_discover_rtlsdr_with_devices(): + """Test RTL-SDR discovery with devices present.""" + with patch("ria_toolkit_oss.ria_toolkit_oss.ria_toolkit_oss.discover.subprocess") as mock_subprocess: + mock_subprocess.check_output.return_value = """ + Found 2 device(s): + 0: RTLSDRBlog, Blog V4, SN: 00000001 + 1: RTLSDRBlog, Blog V4, SN: 00000002 + + Using device 0: Generic RTL2832U OEM + Found Rafael Micro R828D tuner + RTL-SDR Blog V4 Detected + """ + + devices = find_rtlsdr_devices() + + assert len(devices) == 2 + assert devices[0]["type"] == "RTL-SDR" + assert devices[0]["serial"] == "00000001" + assert devices[0]["device_index"] == 0 or devices[0]["device_index"] == "0" + + +def test_discover_all_devices_filter(): + """Test discovering devices with type filter.""" + with ( + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_pluto_devices") as mock_pluto, + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_hackrf_devices") as mock_hackrf, + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_bladerf_devices") as mock_bladerf, + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_uhd_devices") as mock_usrp, + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_rtlsdr_devices") as mock_rtlsdr, + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_thinkrf_devices") as mock_thinkrf, + ): + + mock_pluto.return_value = [{"type": "PlutoSDR", "uri": "ip:pluto.local"}] + mock_hackrf.return_value = [] + mock_bladerf.return_value = [] + mock_usrp.return_value = [] + mock_rtlsdr.return_value = [] + mock_thinkrf.return_value = [] + + # Test filtering by pluto + load_sdr_drivers(verbose=False) + devices = discover_all_devices() + mock_pluto.assert_called_once() + mock_hackrf.assert_called_once() + mock_bladerf.assert_called_once() + assert len(devices["devices"]) == 1 + assert len(devices["pluto_devices"]) == 1 + + +def test_discover_all_devices_no_filter(): + """Test discovering all device types.""" + with ( + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_pluto_devices") as mock_pluto, + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_hackrf_devices") as mock_hackrf, + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_bladerf_devices") as mock_bladerf, + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_uhd_devices") as mock_usrp, + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_rtlsdr_devices") as mock_rtlsdr, + patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_thinkrf_devices") as mock_thinkrf, + ): + + mock_pluto.return_value = [{"type": "PlutoSDR", "uri": "ip:pluto.local"}] + mock_hackrf.return_value = [{"type": "HackRF"}] + mock_bladerf.return_value = [] + mock_usrp.return_value = [] + mock_rtlsdr.return_value = [] + mock_thinkrf.return_value = [] + + load_sdr_drivers(verbose=False) + devices = discover_all_devices() + mock_pluto.assert_called_once() + mock_hackrf.assert_called_once() + mock_bladerf.assert_called_once() + mock_usrp.assert_called_once() + mock_rtlsdr.assert_called_once() + assert len(devices["devices"]) == 2 + assert len(devices["pluto_devices"]) == 1 + assert len(devices["hackrf_devices"]) == 1 + + +def test_discover_command_no_devices(): + """Test discover CLI command with no devices.""" + runner = CliRunner() + + with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.discover_all_devices") as mock_discover: + mock_discover.return_value = { + "loaded_drivers": [], + "failed_drivers": [], + "devices": [], + "total_devices": 0, + "uhd_devices": [], + "pluto_devices": [], + "rtlsdr_devices": [], + "bladerf_devices": [], + "hackrf_devices": [], + } + + result = runner.invoke(discover) + + assert result.exit_code == 0 + assert "No devices detected" in result.output + + +def test_discover_command(): + """Test discover CLI command.""" + runner = CliRunner() + result = runner.invoke(discover) + + radios = ["USRP/UHD", "PlutoSDR", "RTL-SDR", "BladeRF", "HackRF", "ThinkRF"] + match = re.search(r"Detected devices: (\d+)", result.output) + if match: + total_devices = int(match.group(1)) + else: + total_devices = 0 + + if result.exit_code == 0: + assert "Attached Devices" in result.output + assert "Discovery Summary" in result.output + if total_devices > 0: + assert any(radio in result.output for radio in radios) + else: + assert not any(radio in result.output for radio in radios) + else: + assert result.exit_code == 1 + assert isinstance(result.exception, AttributeError) + assert "undefined symbol: iio_get_backends_count" in str(result.exception) + + +def test_discover_command_json_output(): + """Test discover CLI command with JSON output.""" + runner = CliRunner() + + with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.discover_all_devices") as mock_discover: + mock_discover.return_value = { + "loaded_drivers": [], + "failed_drivers": [], + "devices": [{"type": "HackRF", "serial": "123456", "status": "available"}], + "total_devices": 1, + "uhd_devices": [], + "pluto_devices": [], + "rtlsdr_devices": [], + "bladerf_devices": [], + "hackrf_devices": [{"type": "HackRF", "serial": "123456", "status": "available"}], + } + + result = runner.invoke(discover, ["--json-output"]) + output_data = json.loads(result.output) + + assert result.exit_code == 0 + assert output_data["total_devices"] == 1 + assert len(output_data["devices"]) == 1 + assert output_data["devices"][0]["type"] == "HackRF" + + +def test_discover_command_verbose(): + """Test discover CLI command with verbose output.""" + runner = CliRunner() + + with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.discover_all_devices") as mock_discover: + mock_discover.return_value = { + "loaded_drivers": [], + "failed_drivers": [], + "devices": [ + { + "type": "PlutoSDR", + "serial": "123456", + "firmware": "1.0", + "uri": "ip:pluto.local", + "status": "available", + } + ], + "total_devices": 1, + "uhd_devices": [], + "pluto_devices": [], + "rtlsdr_devices": [], + "bladerf_devices": [], + "hackrf_devices": [ + { + "type": "PlutoSDR", + "serial": "123456", + "firmware": "1.0", + "uri": "ip:pluto.local", + "status": "available", + } + ], + } + + result = runner.invoke(discover, ["--verbose"]) + + assert result.exit_code == 0 + assert "RTL-SDR devices: None found" in result.output or "\n rtlsdr:" in result.output diff --git a/tests/ria_toolkit_oss_cli/test_generate.py b/tests/ria_toolkit_oss_cli/test_generate.py new file mode 100644 index 0000000..65742a6 --- /dev/null +++ b/tests/ria_toolkit_oss_cli/test_generate.py @@ -0,0 +1,1502 @@ +"""Tests for generate/synth command. + +This test suite covers the `ria generate` and `ria synth` (alias) commands for +generating synthetic RF signals. Tests are designed to work with the current +implementation status of the generate command. + +Note: Some impairment parameters that appear in common_options are not yet +fully implemented in individual command function signatures. Tests have been +designed to work with parameters that are currently supported. +""" + +import os +import tempfile +from pathlib import Path + +import pytest +from click.testing import CliRunner + +from ria_toolkit_oss.ria_toolkit_oss_cli.cli import cli + + +class TestGenerateCommandBasics: + """Test basic generate command functionality.""" + + def test_generate_help(self): + """Test generate command help.""" + runner = CliRunner() + result = runner.invoke(cli, ["generate", "--help"]) + assert result.exit_code == 0 + assert "generate" in result.output.lower() or "Generate signal" in result.output + # Check for some key subcommands + subcommands = ["chirp", "fsk", "gmsk", "noise", "psk", "qam", "tone"] + for cmd in subcommands: + assert cmd in result.output + + def test_synth_alias_help(self): + """Test synth alias for generate command.""" + runner = CliRunner() + result = runner.invoke(cli, ["synth", "--help"]) + assert result.exit_code == 0 + assert "synth" in result.output.lower() or "Generate signal" in result.output + + def test_missing_sample_rate(self): + """Test that sample rate is required.""" + runner = CliRunner() + result = runner.invoke(cli, ["generate", "tone", "-n", "1000", "-o", "/tmp/test.sigmf"]) + assert result.exit_code != 0 + # Should fail due to missing sample-rate + + +class TestToneCommand: + """Test tone (CW) signal generation.""" + + def test_tone_basic(self): + """Test basic tone generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone.sigmf") + result = runner.invoke( + cli, ["generate", "tone", "--sample-rate", "1e6", "--num-samples", "10000", "--output", output, "-q"] + ) + assert result.exit_code == 0 + # Check that output files were created + assert ( + Path(output.replace(".sigmf", ".sigmf-data")).exists() + or Path(output.replace(".sigmf", "") + ".sigmf-data").exists() + or Path(output).exists() + ) + + def test_tone_with_frequency(self): + """Test tone with custom frequency.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_freq.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "10000", + "--frequency", + "100000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_tone_with_amplitude(self): + """Test tone with custom amplitude.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_amp.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "10000", + "--amplitude", + "0.5", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_tone_duration_instead_of_samples(self): + """Test tone using duration instead of num-samples.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_duration.sigmf") + result = runner.invoke( + cli, ["generate", "tone", "--sample-rate", "1e6", "--duration", "0.01", "--output", output, "-q"] + ) + assert result.exit_code == 0 + + def test_tone_with_phase(self): + """Test tone with phase offset.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_phase.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "10000", + "--phase", + "1.57", # pi/2 + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_tone_with_center_frequency(self): + """Test tone with center frequency metadata.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_cf.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "10000", + "--center-frequency", + "915e6", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + +class TestNoiseCommand: + """Test noise signal generation.""" + + def test_noise_gaussian(self): + """Test Gaussian noise generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "noise_gauss.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "noise", + "--sample-rate", + "1e6", + "--num-samples", + "10000", + "--noise-type", + "gaussian", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_noise_uniform(self): + """Test uniform noise generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "noise_uniform.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "noise", + "--sample-rate", + "1e6", + "--num-samples", + "10000", + "--noise-type", + "uniform", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_noise_with_power(self): + """Test noise with custom power.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "noise_power.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "noise", + "--sample-rate", + "1e6", + "--num-samples", + "10000", + "--power", + "0.5", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + +class TestChirpCommand: + """Test chirp/LFM signal generation.""" + + def test_chirp_up(self): + """Test upward chirp generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "chirp_up.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "chirp", + "--sample-rate", + "1e6", + "--num-samples", + "10000", + "--bandwidth", + "100000", + "--period", + "0.01", + "--type", + "up", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_chirp_down(self): + """Test downward chirp generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "chirp_down.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "chirp", + "--sample-rate", + "1e6", + "--num-samples", + "10000", + "--bandwidth", + "100000", + "--period", + "0.01", + "--type", + "down", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_chirp_up_down(self): + """Test up-down chirp generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "chirp_updown.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "chirp", + "--sample-rate", + "1e6", + "--num-samples", + "10000", + "--bandwidth", + "100000", + "--period", + "0.01", + "--type", + "up_down", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + +class TestWaveformCommands: + """Test square and sawtooth waveforms.""" + + def test_square_basic(self): + """Test square wave generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "square.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "square", + "--sample-rate", + "1e6", + "--num-samples", + "10000", + "--frequency", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_square_duty_cycle(self): + """Test square wave with custom duty cycle.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "square_duty.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "square", + "--sample-rate", + "1e6", + "--num-samples", + "10000", + "--frequency", + "10000", + "--duty-cycle", + "0.25", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_sawtooth_basic(self): + """Test sawtooth wave generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "sawtooth.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "sawtooth", + "--sample-rate", + "1e6", + "--num-samples", + "10000", + "--frequency", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + +class TestQAMCommand: + """Test QAM (Quadrature Amplitude Modulation) generation.""" + + def test_qam16(self): + """Test 16-QAM generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "qam16.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "qam", + "--sample-rate", + "1e6", + "--order", + "16", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_qam64(self): + """Test 64-QAM generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "qam64.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "qam", + "--sample-rate", + "1e6", + "--order", + "64", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_qam256(self): + """Test 256-QAM generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "qam256.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "qam", + "--sample-rate", + "1e6", + "--order", + "256", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_qam_with_filter(self): + """Test QAM with pulse shaping filter.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "qam_rrc.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "qam", + "--sample-rate", + "1e6", + "--order", + "16", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--filter", + "rrc", + "--filter-beta", + "0.35", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_qam_symbols_instead_of_samples(self): + """Test QAM using symbol count instead of sample count.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "qam_symbols.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "qam", + "--sample-rate", + "1e6", + "--order", + "16", + "--symbol-rate", + "1e5", + "--symbols", + "100", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_qam_with_gaussian_filter(self): + """Test QAM with Gaussian filter.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "qam_gauss.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "qam", + "--sample-rate", + "1e6", + "--order", + "16", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--filter", + "gaussian", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 1 + assert isinstance(result.exception, SystemExit) + + +class TestAPSKCommand: + """Test APSK (Amplitude Phase Shift Keying) generation.""" + + def test_apsk16(self): + """Test 16-APSK generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "apsk16.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "apsk", + "--sample-rate", + "1e6", + "--order", + "16", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_apsk32(self): + """Test 32-APSK generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "apsk32.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "apsk", + "--sample-rate", + "1e6", + "--order", + "32", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_apsk_with_rrc_filter(self): + """Test APSK with RRC filter.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "apsk_rrc.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "apsk", + "--sample-rate", + "1e6", + "--order", + "32", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--filter", + "rrc", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + +class TestPAMCommand: + """Test PAM (Pulse Amplitude Modulation) generation.""" + + def test_pam4(self): + """Test 4-PAM generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "pam4.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "pam", + "--sample-rate", + "1e6", + "--order", + "4", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_pam16(self): + """Test 16-PAM generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "pam16.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "pam", + "--sample-rate", + "1e6", + "--order", + "16", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + +class TestPSKCommand: + """Test PSK (Phase Shift Keying) generation.""" + + def test_bpsk(self): + """Test BPSK (2-PSK) generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "bpsk.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "psk", + "--sample-rate", + "1e6", + "--order", + "2", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_qpsk(self): + """Test QPSK (4-PSK) generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "qpsk.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "psk", + "--sample-rate", + "1e6", + "--order", + "4", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_8psk(self): + """Test 8-PSK generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "8psk.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "psk", + "--sample-rate", + "1e6", + "--order", + "8", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + +class TestFSKCommand: + """Test FSK (Frequency Shift Keying) generation.""" + + def test_fsk2(self): + """Test 2-FSK (binary FSK) generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "fsk2.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "fsk", + "--sample-rate", + "1e6", + "--order", + "2", + "--symbol-rate", + "1e5", + "--freq-spacing", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_fsk4(self): + """Test 4-FSK generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "fsk4.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "fsk", + "--sample-rate", + "1e6", + "--order", + "4", + "--symbol-rate", + "1e5", + "--freq-spacing", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_fsk_with_modulation_index(self): + """Test FSK with modulation index instead of frequency spacing.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "fsk_mi.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "fsk", + "--sample-rate", + "1e6", + "--order", + "2", + "--symbol-rate", + "1e5", + "--modulation-index", + "5.0", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + +class TestGMSKCommand: + """Test GMSK (Gaussian Minimum Shift Keying) generation.""" + + def test_gmsk_basic(self): + """Test basic GMSK generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "gmsk.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "gmsk", + "--sample-rate", + "1e6", + "--symbol-rate", + "270833", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_gmsk_custom_bt(self): + """Test GMSK with custom BT product.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "gmsk_bt.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "gmsk", + "--sample-rate", + "1e6", + "--symbol-rate", + "270833", + "--bt", + "0.5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + +class TestOOKCommand: + """Test OOK (On-Off Keying) generation.""" + + def test_ook_basic(self): + """Test OOK generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "ook.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "ook", + "--sample-rate", + "1e6", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + +class TestOQPSKCommand: + """Test OQPSK (Offset QPSK) generation.""" + + def test_oqpsk_basic(self): + """Test OQPSK generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "oqpsk.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "oqpsk", + "--sample-rate", + "1e6", + "--symbol-rate", + "1e5", + "--num-samples", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + +class TestNR5GCommand: + """Test NR 5G frame generation.""" + + @pytest.mark.skip(reason="NR5G generation may not be available in all configurations") + def test_nr5g_basic(self): + """Test basic NR 5G frame generation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "nr5g.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "nr5g", + "--sample-rate", + "30.72e6", + "--bandwidth", + "20", + "--mu", + "1", + "--num-samples", + "30720", + "--output", + output, + "-q", + ], + ) + # NR5G may not be available, check accordingly + if result.exit_code != 0 and "not available" in result.output.lower(): + pytest.skip("NR5G not available") + assert result.exit_code == 0 + + +class TestOutputFormats: + """Test different output formats.""" + + def test_output_npy(self): + """Test saving as NPY format.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "signal.npy") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--format", + "npy", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_output_wav(self): + """Test saving as WAV format.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "signal.wav") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--format", + "wav", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_output_blue(self): + """Test saving as BLUE (Midas) format.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "signal.blue") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--format", + "blue", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_format_detection_from_extension(self): + """Test that format is detected from file extension.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + # .npy extension should use NPY format + output = os.path.join(tmpdir, "signal.npy") + result = runner.invoke( + cli, ["generate", "tone", "--sample-rate", "1e6", "--num-samples", "1000", "--output", output, "-q"] + ) + assert result.exit_code == 0 + + +class TestChannelModels: + """Test channel models that are currently implemented.""" + + def test_frequency_shift(self): + """Test frequency shift application.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_shifted.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--frequency-shift", + "10000", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_awgn_channel(self): + """Test AWGN channel.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_awgn.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--channel-type", + "awgn", + "--noise-power", + "0.1", + "--output", + output, + "-q", + ], + ) + # May not be fully implemented yet + if result.exit_code == 0: + pass # Test passes + else: + # Document if AWGN not implemented + pytest.skip("AWGN channel not yet implemented") + + def test_rayleigh_channel(self): + """Test Rayleigh fading channel.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_rayleigh.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--channel-type", + "rayleigh", + "--output", + output, + "-q", + ], + ) + # May not be fully implemented yet + if result.exit_code == 0: + pass # Test passes + else: + pytest.skip("Rayleigh channel not yet implemented") + + def test_rician_channel(self): + """Test Rician fading channel.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_rician.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--channel-type", + "rician", + "--output", + output, + "-q", + ], + ) + # May not be fully implemented yet + if result.exit_code == 0: + pass # Test passes + else: + pytest.skip("Rician channel not yet implemented") + + +class TestMetadataAndConfig: + """Test metadata and configuration options.""" + + def test_custom_metadata(self): + """Test adding custom metadata.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_meta.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--metadata", + "test_key=test_value", + "--metadata", + "experiment=001", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_center_frequency_metadata(self): + """Test that center frequency is included in metadata.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_cf.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--center-frequency", + "915e6", + "--output", + output, + "-v", + ], + ) + assert result.exit_code == 0 + # In verbose mode, should show frequency + assert "915" in result.output + + def test_verbose_output(self): + """Test verbose output.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_verbose.sigmf") + result = runner.invoke( + cli, ["generate", "tone", "--sample-rate", "1e6", "--num-samples", "1000", "--output", output, "-v"] + ) + assert result.exit_code == 0 + # Verbose output should contain more info + assert len(result.output) > 0 + + +class TestMessageSources: + """Test different message sources for modulation.""" + + def test_qam_random_bits(self): + """Test QAM with random bits (default).""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "qam_random.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "qam", + "--sample-rate", + "1e6", + "--order", + "16", + "--symbol-rate", + "1e5", + "--num-samples", + "1000", + "--message-source", + "random", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_qam_string_message(self): + """Test QAM with string message source.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "qam_string.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "qam", + "--sample-rate", + "1e6", + "--order", + "16", + "--symbol-rate", + "1e5", + "--num-samples", + "1000", + "--message-source", + "string", + "--message-content", + "Hello World", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_qam_file_message(self): + """Test QAM with file message source.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + # Create a test file + message_file = os.path.join(tmpdir, "message.bin") + with open(message_file, "wb") as f: + f.write(b"Test message content") + + output = os.path.join(tmpdir, "qam_file.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "qam", + "--sample-rate", + "1e6", + "--order", + "16", + "--symbol-rate", + "1e5", + "--num-samples", + "1000", + "--message-source", + "file", + "--message-content", + message_file, + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + +class TestOverwriteProtection: + """Test overwrite protection and file handling.""" + + def test_overwrite_protection_sigmf(self): + """Test that overwrite protection works for sigmf files.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone.sigmf") + + # First generation should succeed + result = runner.invoke( + cli, ["generate", "tone", "--sample-rate", "1e6", "--num-samples", "1000", "--output", output, "-q"] + ) + assert result.exit_code == 0 + + # Second generation without --overwrite should fail + result = runner.invoke( + cli, ["generate", "tone", "--sample-rate", "1e6", "--num-samples", "1000", "--output", output] + ) + assert result.exit_code != 0 + assert "exist" in result.output.lower() + + # Third generation with --overwrite should succeed + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--output", + output, + "--overwrite", + "-q", + ], + ) + assert result.exit_code == 0 + + def test_overwrite_protection_npy(self): + """Test that overwrite protection works for NPY files.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone.npy") + + # First generation should succeed + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--format", + "npy", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + # Second generation without --overwrite should fail + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--format", + "npy", + "--output", + output, + ], + ) + assert result.exit_code != 0 + assert "exist" in result.output.lower() + + +class TestParameterValidation: + """Test parameter validation and error handling.""" + + def test_invalid_sample_rate_type(self): + """Test that invalid sample rate type is rejected.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone.sigmf") + result = runner.invoke( + cli, ["generate", "tone", "--sample-rate", "invalid", "--num-samples", "1000", "--output", output] + ) + assert result.exit_code != 0 + + def test_frequency_shift_formatting(self): + """Test that frequency shift accepts scientific notation.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_shift.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--frequency-shift", + "1e5", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 + + def test_both_num_samples_and_duration(self): + """Test that num-samples takes precedence when both provided.""" + runner = CliRunner() + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tone_both.sigmf") + result = runner.invoke( + cli, + [ + "generate", + "tone", + "--sample-rate", + "1e6", + "--num-samples", + "1000", + "--duration", + "0.01", + "--output", + output, + "-q", + ], + ) + assert result.exit_code == 0 diff --git a/tests/ria_toolkit_oss_cli/test_split.py b/tests/ria_toolkit_oss_cli/test_split.py new file mode 100644 index 0000000..102afcf --- /dev/null +++ b/tests/ria_toolkit_oss_cli/test_split.py @@ -0,0 +1,670 @@ +"""Tests for split CLI command.""" + +import tempfile +from pathlib import Path + +import numpy as np +import pytest +from click.testing import CliRunner + +from ria_toolkit_oss.datatypes import Annotation, Recording +from ria_toolkit_oss.io import load_recording, to_sigmf +from ria_toolkit_oss.ria_toolkit_oss_cli.cli import cli + + +class TestSplitHelp: + """Test split command help and basic functionality.""" + + def test_split_help(self): + """Test split command help.""" + runner = CliRunner() + result = runner.invoke(cli, ["split", "--help"]) + assert result.exit_code == 0 + assert "Split, trim, and extract portions of recordings" in result.output + assert "--split-at" in result.output + assert "--split-every" in result.output + assert "--split-duration" in result.output + assert "--trim" in result.output + assert "--extract-annotations" in result.output + + def test_missing_arguments(self): + """Test that missing arguments show error.""" + runner = CliRunner() + result = runner.invoke(cli, ["split"]) + assert result.exit_code != 0 + assert "Missing argument" in result.output or "Error" in result.output + + def test_no_operation_specified(self): + """Test error when no operation is specified.""" + runner = CliRunner() + + # Create a test file + with tempfile.TemporaryDirectory() as tmpdir: + signal = np.ones(1000, dtype=np.complex64) + recording = Recording(data=signal, metadata={"sample_rate": 1e6}) + to_sigmf(recording, filename="test", path=tmpdir, overwrite=True) + + test_file = str(Path(tmpdir) / "test.sigmf-data") + result = runner.invoke(cli, ["split", test_file]) + + assert result.exit_code != 0 + assert "No operation specified" in result.output + + +class TestSplitTrim: + """Test trim operations.""" + + @pytest.fixture + def test_recording(self): + """Create a test recording file.""" + with tempfile.TemporaryDirectory() as tmpdir: + signal = np.arange(10000, dtype=np.complex64) + recording = Recording(data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6}) + to_sigmf(recording, filename="test", path=tmpdir, overwrite=True) + yield str(Path(tmpdir) / "test.sigmf-data") + + def test_trim_with_length(self, test_recording): + """Test trim with --start and --length.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + result = runner.invoke( + cli, + [ + "split", + test_recording, + "--trim", + "--start", + "1000", + "--length", + "5000", + "--output-dir", + outdir, + "-q", + ], + ) + + assert result.exit_code == 0 + + # Verify output file exists + output_files = list(Path(outdir).glob("*.sigmf-data")) + assert len(output_files) == 1 + + # Verify output has correct length + output_rec = load_recording(str(output_files[0])) + assert output_rec.data.shape[1] == 5000 + assert output_rec.metadata["original_start_sample"] == 1000 + assert output_rec.metadata["original_end_sample"] == 6000 + assert output_rec.metadata["split_operation"] == "trim" + + def test_trim_with_end(self, test_recording): + """Test trim with --start and --end.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + result = runner.invoke( + cli, + ["split", test_recording, "--trim", "--start", "2000", "--end", "7000", "--output-dir", outdir, "-q"], + ) + + assert result.exit_code == 0 + + output_files = list(Path(outdir).glob("*.sigmf-data")) + assert len(output_files) == 1 + + output_rec = load_recording(str(output_files[0])) + assert output_rec.data.shape[1] == 5000 + + def test_trim_without_length_or_end(self, test_recording): + """Test that trim requires --length or --end.""" + runner = CliRunner() + + result = runner.invoke(cli, ["split", test_recording, "--trim", "--start", "1000"]) + + assert result.exit_code != 0 + assert "requires either --length or --end" in result.output + + def test_trim_with_both_length_and_end(self, test_recording): + """Test that trim rejects both --length and --end.""" + runner = CliRunner() + + result = runner.invoke( + cli, ["split", test_recording, "--trim", "--start", "1000", "--length", "5000", "--end", "6000"] + ) + + assert result.exit_code != 0 + assert "Cannot specify both --length and --end" in result.output + + def test_trim_invalid_range(self, test_recording): + """Test trim with invalid range.""" + runner = CliRunner() + + result = runner.invoke( + cli, + ["split", test_recording, "--trim", "--start", "1000", "--length", "50000"], # Exceeds recording length + ) + + assert result.exit_code != 0 + assert "Invalid trim range" in result.output + + def test_trim_end_before_start(self, test_recording): + """Test trim with end < start.""" + runner = CliRunner() + + result = runner.invoke(cli, ["split", test_recording, "--trim", "--start", "5000", "--end", "1000"]) + + assert result.exit_code != 0 + assert "Invalid range" in result.output + + +class TestSplitAt: + """Test split-at operations.""" + + @pytest.fixture + def test_recording(self): + """Create a test recording file.""" + with tempfile.TemporaryDirectory() as tmpdir: + signal = np.arange(10000, dtype=np.complex64) + recording = Recording(data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6}) + to_sigmf(recording, filename="test", path=tmpdir, overwrite=True) + yield str(Path(tmpdir) / "test.sigmf-data") + + def test_split_at_middle(self, test_recording): + """Test splitting at middle of recording.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + result = runner.invoke(cli, ["split", test_recording, "--split-at", "5000", "--output-dir", outdir, "-q"]) + + assert result.exit_code == 0 + + # Verify two output files exist + output_files = sorted(Path(outdir).glob("*.sigmf-data")) + assert len(output_files) == 2 + + # Verify part1 + part1 = load_recording(str(output_files[0])) + assert part1.data.shape[1] == 5000 + assert part1.metadata["original_start_sample"] == 0 + assert part1.metadata["original_end_sample"] == 5000 + + # Verify part2 + part2 = load_recording(str(output_files[1])) + assert part2.data.shape[1] == 5000 + assert part2.metadata["original_start_sample"] == 5000 + assert part2.metadata["original_end_sample"] == 10000 + + def test_split_at_invalid_point(self, test_recording): + """Test split-at with invalid sample point.""" + runner = CliRunner() + + result = runner.invoke(cli, ["split", test_recording, "--split-at", "50000"]) # Exceeds recording length + + assert result.exit_code != 0 + assert "Invalid split point" in result.output + + +class TestSplitEvery: + """Test split-every operations.""" + + @pytest.fixture + def test_recording(self): + """Create a test recording file.""" + with tempfile.TemporaryDirectory() as tmpdir: + signal = np.arange(10000, dtype=np.complex64) + recording = Recording(data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6}) + to_sigmf(recording, filename="test", path=tmpdir, overwrite=True) + yield str(Path(tmpdir) / "test.sigmf-data") + + def test_split_every_equal_chunks(self, test_recording): + """Test splitting into equal chunks.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + result = runner.invoke( + cli, ["split", test_recording, "--split-every", "2500", "--output-dir", outdir, "-q"] + ) + + assert result.exit_code == 0 + + # Verify 4 chunks created + output_files = sorted(Path(outdir).glob("*.sigmf-data")) + assert len(output_files) == 4 + + # Verify all chunks have correct size + for i, file in enumerate(output_files): + chunk = load_recording(str(file)) + assert chunk.data.shape[1] == 2500 + assert chunk.metadata["chunk_index"] == i + 1 + assert chunk.metadata["total_chunks"] == 4 + + def test_split_every_unequal_chunks(self, test_recording): + """Test splitting with remainder chunk.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + result = runner.invoke( + cli, ["split", test_recording, "--split-every", "3000", "--output-dir", outdir, "-q"] + ) + + assert result.exit_code == 0 + + # Verify 4 chunks created (3x3000 + 1x1000) + output_files = sorted(Path(outdir).glob("*.sigmf-data")) + assert len(output_files) == 4 + + # Last chunk should be smaller + last_chunk = load_recording(str(output_files[-1])) + assert last_chunk.data.shape[1] == 1000 + + +class TestSplitDuration: + """Test split-duration operations.""" + + @pytest.fixture + def test_recording(self): + """Create a test recording file with known sample rate.""" + with tempfile.TemporaryDirectory() as tmpdir: + signal = np.arange(10000, dtype=np.complex64) + recording = Recording( + data=signal, metadata={"sample_rate": 10000, "center_frequency": 915e6} # 10kHz for easy math + ) + to_sigmf(recording, filename="test", path=tmpdir, overwrite=True) + yield str(Path(tmpdir) / "test.sigmf-data") + + def test_split_duration_basic(self, test_recording): + """Test splitting by duration.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + result = runner.invoke( + cli, + [ + "split", + test_recording, + "--split-duration", + "0.25", # 0.25s = 2500 samples at 10kHz + "--output-dir", + outdir, + "-q", + ], + ) + + assert result.exit_code == 0 + + # Verify chunks created + output_files = sorted(Path(outdir).glob("*.sigmf-data")) + assert len(output_files) == 4 + + # Verify chunk sizes + for file in output_files[:-1]: + chunk = load_recording(str(file)) + assert chunk.data.shape[1] == 2500 + + def test_split_duration_no_sample_rate(self): + """Test that split-duration requires sample_rate in metadata.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as tmpdir: + # Create recording without sample_rate + signal = np.arange(1000, dtype=np.complex64) + recording = Recording(data=signal, metadata={}) + to_sigmf(recording, filename="test", path=tmpdir, overwrite=True) + + test_file = str(Path(tmpdir) / "test.sigmf-data") + result = runner.invoke(cli, ["split", test_file, "--split-duration", "1.0"]) + + assert result.exit_code != 0 + assert "Cannot split by duration" in result.output + assert "no sample_rate" in result.output + + +class TestExtractAnnotations: + """Test extract-annotations operations.""" + + @pytest.fixture + def annotated_recording(self): + """Create a test recording with annotations.""" + with tempfile.TemporaryDirectory() as tmpdir: + signal = np.arange(100000, dtype=np.complex64) + + annotations = [ + Annotation( + sample_start=0, sample_count=10000, freq_lower_edge=914e6, freq_upper_edge=916e6, label="preamble" + ), + Annotation( + sample_start=10000, + sample_count=50000, + freq_lower_edge=914e6, + freq_upper_edge=916e6, + label="payload", + ), + Annotation( + sample_start=60000, sample_count=5000, freq_lower_edge=914e6, freq_upper_edge=916e6, label="crc" + ), + ] + + recording = Recording( + data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6}, annotations=annotations + ) + + to_sigmf(recording, filename="annotated", path=tmpdir, overwrite=True) + yield str(Path(tmpdir) / "annotated.sigmf-data") + + def test_extract_all_annotations(self, annotated_recording): + """Test extracting all annotations.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + result = runner.invoke( + cli, ["split", annotated_recording, "--extract-annotations", "--output-dir", outdir, "-q"] + ) + + assert result.exit_code == 0 + + # Verify 3 files created + output_files = sorted(Path(outdir).glob("*.sigmf-data")) + assert len(output_files) == 3 + + # Verify each annotation was extracted + preamble = [f for f in output_files if "preamble" in str(f)][0] + payload = [f for f in output_files if "payload" in str(f)][0] + crc = [f for f in output_files if "crc" in str(f)][0] + + preamble_rec = load_recording(str(preamble)) + assert preamble_rec.data.shape[1] == 10000 + assert preamble_rec.metadata["annotation_label"] == "preamble" + assert len(preamble_rec.annotations) == 0 # Annotations cleared + + payload_rec = load_recording(str(payload)) + assert payload_rec.data.shape[1] == 50000 + assert payload_rec.metadata["annotation_label"] == "payload" + + crc_rec = load_recording(str(crc)) + assert crc_rec.data.shape[1] == 5000 + assert crc_rec.metadata["annotation_label"] == "crc" + + def test_extract_annotation_by_label(self, annotated_recording): + """Test extracting annotations by label.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + result = runner.invoke( + cli, + [ + "split", + annotated_recording, + "--extract-annotations", + "--annotation-label", + "payload", + "--output-dir", + outdir, + "-q", + ], + ) + + assert result.exit_code == 0 + + # Verify only 1 file created + output_files = list(Path(outdir).glob("*.sigmf-data")) + assert len(output_files) == 1 + assert "payload" in str(output_files[0]) + + def test_extract_annotation_by_index(self, annotated_recording): + """Test extracting annotation by index.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + result = runner.invoke( + cli, + [ + "split", + annotated_recording, + "--extract-annotations", + "--annotation-index", + "1", + "--output-dir", + outdir, + "-q", + ], + ) + + assert result.exit_code == 0 + + # Verify only 1 file created (payload at index 1) + output_files = list(Path(outdir).glob("*.sigmf-data")) + assert len(output_files) == 1 + assert "payload" in str(output_files[0]) + + def test_extract_annotations_invalid_label(self, annotated_recording): + """Test error with non-existent label.""" + runner = CliRunner() + + result = runner.invoke( + cli, ["split", annotated_recording, "--extract-annotations", "--annotation-label", "nonexistent"] + ) + + assert result.exit_code != 0 + assert "No annotations with label" in result.output + + def test_extract_annotations_invalid_index(self, annotated_recording): + """Test error with invalid index.""" + runner = CliRunner() + + result = runner.invoke( + cli, ["split", annotated_recording, "--extract-annotations", "--annotation-index", "99"] + ) + + assert result.exit_code != 0 + assert "Invalid annotation index" in result.output + + def test_extract_annotations_no_annotations(self): + """Test error when recording has no annotations.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as tmpdir: + signal = np.arange(1000, dtype=np.complex64) + recording = Recording(data=signal, metadata={"sample_rate": 1e6}) + to_sigmf(recording, filename="test", path=tmpdir, overwrite=True) + + test_file = str(Path(tmpdir) / "test.sigmf-data") + result = runner.invoke(cli, ["split", test_file, "--extract-annotations"]) + + assert result.exit_code != 0 + assert "No annotations found" in result.output + + +class TestOutputOptions: + """Test output-related options.""" + + @pytest.fixture + def test_recording(self): + """Create a test recording file.""" + with tempfile.TemporaryDirectory() as tmpdir: + signal = np.arange(10000, dtype=np.complex64) + recording = Recording(data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6}) + to_sigmf(recording, filename="test", path=tmpdir, overwrite=True) + yield str(Path(tmpdir) / "test.sigmf-data") + + def test_output_prefix(self, test_recording): + """Test custom output prefix.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + result = runner.invoke( + cli, + [ + "split", + test_recording, + "--split-every", + "3000", + "--output-prefix", + "custom", + "--output-dir", + outdir, + "-q", + ], + ) + + assert result.exit_code == 0 + + output_files = list(Path(outdir).glob("*.sigmf-data")) + assert all("custom" in str(f) for f in output_files) + + def test_output_format_conversion(self, test_recording): + """Test format conversion during split.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + result = runner.invoke( + cli, + [ + "split", + test_recording, + "--split-every", + "5000", + "--output-format", + "npy", + "--output-dir", + outdir, + "-q", + ], + ) + + assert result.exit_code == 0 + + # Verify NPY files created + output_files = list(Path(outdir).glob("*.npy")) + assert len(output_files) == 2 + + def test_overwrite_protection(self, test_recording): + """Test overwrite protection.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + # First split should succeed + result = runner.invoke( + cli, + ["split", test_recording, "--trim", "--start", "0", "--length", "1000", "--output-dir", outdir, "-q"], + ) + assert result.exit_code == 0 + + # Second split without --overwrite should fail + result = runner.invoke( + cli, ["split", test_recording, "--trim", "--start", "0", "--length", "1000", "--output-dir", outdir] + ) + assert result.exit_code != 0 + assert "exist" in result.output.lower() + + # Third split with --overwrite should succeed + result = runner.invoke( + cli, + [ + "split", + test_recording, + "--trim", + "--start", + "0", + "--length", + "1000", + "--output-dir", + outdir, + "--overwrite", + "-q", + ], + ) + assert result.exit_code == 0 + + +class TestMultipleOperations: + """Test that multiple operations are rejected.""" + + @pytest.fixture + def test_recording(self): + """Create a test recording file.""" + with tempfile.TemporaryDirectory() as tmpdir: + signal = np.arange(10000, dtype=np.complex64) + recording = Recording(data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6}) + to_sigmf(recording, filename="test", path=tmpdir, overwrite=True) + yield str(Path(tmpdir) / "test.sigmf-data") + + def test_trim_and_split_at(self, test_recording): + """Test that trim and split-at cannot be used together.""" + runner = CliRunner() + + result = runner.invoke(cli, ["split", test_recording, "--trim", "--split-at", "5000"]) + + assert result.exit_code != 0 + assert "Multiple operations specified" in result.output + + def test_split_every_and_extract(self, test_recording): + """Test that split-every and extract-annotations cannot be used together.""" + runner = CliRunner() + + result = runner.invoke(cli, ["split", test_recording, "--split-every", "1000", "--extract-annotations"]) + + assert result.exit_code != 0 + assert "Multiple operations specified" in result.output + + +class TestVerboseQuiet: + """Test verbose and quiet modes.""" + + @pytest.fixture + def test_recording(self): + """Create a test recording file.""" + with tempfile.TemporaryDirectory() as tmpdir: + signal = np.arange(10000, dtype=np.complex64) + recording = Recording(data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6}) + to_sigmf(recording, filename="test", path=tmpdir, overwrite=True) + yield str(Path(tmpdir) / "test.sigmf-data") + + def test_verbose_mode(self, test_recording): + """Test verbose output.""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + result = runner.invoke( + cli, + [ + "split", + test_recording, + "--trim", + "--start", + "0", + "--length", + "1000", + "--output-dir", + outdir, + "--verbose", + ], + ) + + assert result.exit_code == 0 + assert "Input format: SIGMF" in result.output + assert "Output format: SIGMF" in result.output + + def test_quiet_mode(self, test_recording): + """Test quiet output (minimal output).""" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as outdir: + result = runner.invoke( + cli, + [ + "split", + test_recording, + "--trim", + "--start", + "0", + "--length", + "1000", + "--output-dir", + outdir, + "--quiet", + ], + ) + + assert result.exit_code == 0 + # Output should be minimal in quiet mode + assert len(result.output.strip()) < 100 or result.output.strip() == "" diff --git a/tests/ria_toolkit_oss_cli/test_transmit.py b/tests/ria_toolkit_oss_cli/test_transmit.py index 4eacbe5..d2eaa71 100644 --- a/tests/ria_toolkit_oss_cli/test_transmit.py +++ b/tests/ria_toolkit_oss_cli/test_transmit.py @@ -21,23 +21,35 @@ from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit import ( class TestGetTxDevice: """Tests for get_sdr_device function.""" -from click.exceptions import ClickException -def get_sdr_device(name: str, tx: bool = False): - """Return an SDR device. If not connected, return a MagicMock instead of failing.""" - try: - if name == "pluto": - from ria_toolkit_oss.sdr.pluto import Pluto - return Pluto(tx=tx) - elif name == "hackrf": - from ria_toolkit_oss.sdr.hackrf import HackRF - return HackRF(tx=tx) - # other devices... - else: - raise ClickException(f"Unknown device {name}") - except Exception: - # If initialization fails, return a dummy/mock device - from unittest.mock import MagicMock - return MagicMock() + + def test_get_pluto_device(self): + """Test getting PlutoSDR device.""" + mock_sdr_class = MagicMock() + mock_sdr_instance = MagicMock() + mock_sdr_class.return_value = mock_sdr_instance + + with patch.dict("sys.modules", {"src.ria_toolkit_oss.sdr.pluto": MagicMock(Pluto=mock_sdr_class)}): + device = get_sdr_device("pluto") + assert device is mock_sdr_instance + + def test_get_hackrf_device(self): + """Test getting HackRF device.""" + mock_sdr_class = MagicMock() + mock_sdr_instance = MagicMock() + mock_sdr_class.return_value = mock_sdr_instance + + with patch.dict("sys.modules", {"src.ria_toolkit_oss.sdr.hackrf": MagicMock(HackRF=mock_sdr_class)}): + device = get_sdr_device("hackrf") + assert device is mock_sdr_instance + + def test_get_unknown_device(self): + """Test getting unknown device type.""" + from click.exceptions import ClickException + + with pytest.raises(ClickException) as exc_info: + get_sdr_device("unknown_device") + + assert "Unknown device type" in str(exc_info.value) class TestAutoSelectTxDevice: diff --git a/tests/ria_toolkit_oss_cli/test_transmit_generate.py b/tests/ria_toolkit_oss_cli/test_transmit_generate.py new file mode 100644 index 0000000..493d37e --- /dev/null +++ b/tests/ria_toolkit_oss_cli/test_transmit_generate.py @@ -0,0 +1,94 @@ +"""Tests for transmit command signal generation.""" + +from click.testing import CliRunner + +from ria_toolkit_oss.ria_toolkit_oss_cli.cli import cli + + +class TestTransmitGenerate: + """Test signal generation in transmit command.""" + + def test_transmit_help(self): + """Test transmit command help.""" + runner = CliRunner() + result = runner.invoke(cli, ["transmit", "--help"]) + assert result.exit_code == 0 + assert "Generate signal instead of loading from file" in result.output + assert "lfm" in result.output + assert "chirp" in result.output + assert "sine" in result.output + assert "pulse" in result.output + + def test_generate_lfm_chirp(self): + """Test LFM chirp generation (should fail without device).""" + runner = CliRunner() + result = runner.invoke(cli, ["transmit", "--generate", "lfm", "--device", "pluto", "-v"]) + # Should fail because no device is connected, but should show it's generating LFM + # Error will be about device initialization, not about missing input file + assert "Generating LFM chirp signal" in result.output or "Failed to initialize" in result.output + + def test_generate_sine_wave(self): + """Test sine wave generation (should fail without device).""" + runner = CliRunner() + result = runner.invoke(cli, ["transmit", "--generate", "sine", "--device", "pluto", "-v"]) + # Should fail because no device is connected, but should show it's generating sine + assert "Generating sine wave signal" in result.output or "Failed to initialize" in result.output + + def test_generate_chirp(self): + """Test simple chirp generation (should fail without device).""" + runner = CliRunner() + result = runner.invoke(cli, ["transmit", "--generate", "chirp", "--device", "pluto", "-v"]) + # Should fail because no device is connected, but should show it's generating chirp + assert "Generating chirp signal" in result.output or "Failed to initialize" in result.output + + def test_generate_pulse(self): + """Test pulse generation (should fail without device).""" + runner = CliRunner() + result = runner.invoke(cli, ["transmit", "--generate", "pulse", "--device", "pluto", "-v"]) + # Should fail because no device is connected, but should show it's generating pulse + assert "Generating pulse signal" in result.output or "Failed to initialize" in result.output + + def test_default_generates_lfm_when_no_input(self): + """Test that default generates LFM chirp when no input file specified.""" + runner = CliRunner() + result = runner.invoke(cli, ["transmit", "--device", "pluto", "-v"]) + # Should default to LFM chirp when no input file or --generate specified + assert "Generating LFM chirp signal" in result.output or "Failed to initialize" in result.output + + def test_generate_overrides_input_file(self): + """Test that --generate overrides --input file.""" + runner = CliRunner() + result = runner.invoke( + cli, ["transmit", "--device", "pluto", "--input", "nonexistent.sigmf", "--generate", "lfm", "-v"] + ) + # Should generate LFM, not try to load nonexistent.sigmf + assert "Generating LFM chirp signal" in result.output or "Failed to initialize" in result.output + # Should NOT say "Input file not found" + assert "Input file not found" not in result.output + + def test_signal_generation_parameters(self): + """Test that signal generation uses correct parameters from CLI.""" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "transmit", + "--device", + "pluto", + "--generate", + "lfm", + "--sample-rate", + "10e6", + "--center-frequency", + "915M", + "--gain", + "-20", + "-v", + ], + ) + # Check that parameters are shown in output + if "Failed to initialize" in result.output: + # Device initialization failed (expected without real device) + assert "10.00 MHz" in result.output or "10.000 MHz" in result.output or "10.00 MS/s" in result.output + assert "915" in result.output + assert "-20 dB" in result.output or "-20.0 dB" in result.output