added tests for cli

This commit is contained in:
G gillian 2025-12-11 15:59:08 -05:00
parent 155b13928b
commit 806fcf8293
10 changed files with 4020 additions and 18 deletions

View File

@ -69,7 +69,8 @@ all-sdr = [
[tool.poetry]
packages = [
{ include = "ria_toolkit_oss", from = "src" }
{ include = "ria_toolkit_oss", from = "src" },
{ include = "ria_toolkit_oss_cli", from = "src/ria_toolkit_oss" }
]
include = [
"**/*.so", # Required for Nuitkaification

View File

@ -0,0 +1,963 @@
"""Tests for the combine command."""
import tempfile
from pathlib import Path
import numpy as np
import pytest
from click.testing import CliRunner
from ria_toolkit_oss.datatypes import Annotation, Recording
from ria_toolkit_oss.io import load_recording, to_npy, to_sigmf
from ria_toolkit_oss_cli.cli import cli
class TestCombineHelp:
"""Test help and basic command structure."""
def test_help(self):
"""Test combine help."""
runner = CliRunner()
result = runner.invoke(cli, ["combine", "--help"])
assert result.exit_code == 0
assert "Combine multiple recordings" in result.output
assert "--mode" in result.output
assert "--align-mode" in result.output
def test_no_inputs(self):
"""Test error with no inputs."""
runner = CliRunner()
result = runner.invoke(cli, ["combine"])
assert result.exit_code != 0
def test_single_input(self):
"""Test error with only one input."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as tmpdir:
# Create test file
signal = np.arange(1000, dtype=np.complex64)
recording = Recording(data=signal, metadata={"sample_rate": 2e6})
to_npy(recording, filename=str(Path(tmpdir) / "test.npy"), overwrite=True)
result = runner.invoke(cli, ["combine", str(Path(tmpdir) / "test.npy"), str(Path(tmpdir) / "output.npy")])
assert result.exit_code != 0
class TestCombineConcat:
"""Test concatenate mode."""
@pytest.fixture
def test_recordings(self):
"""Create multiple test recording files."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create 3 recordings with different data
for i in range(3):
signal = np.arange(i * 1000, (i + 1) * 1000, dtype=np.complex64)
recording = Recording(data=signal, metadata={"sample_rate": 2e6})
to_npy(recording, filename=str(Path(tmpdir) / f"chunk{i}.npy"), overwrite=True)
yield tmpdir
def test_concat_basic(self, test_recordings):
"""Test basic concatenation."""
runner = CliRunner()
tmpdir = test_recordings
output_path = str(Path(tmpdir) / "combined.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "chunk0.npy"),
str(Path(tmpdir) / "chunk1.npy"),
str(Path(tmpdir) / "chunk2.npy"),
output_path,
],
)
assert result.exit_code == 0
assert Path(output_path).exists()
# Verify result
combined = load_recording(output_path)
assert combined.data.shape[1] == 3000
assert combined._metadata["combine_mode"] == "concat"
assert combined._metadata["num_inputs"] == 3
# Check data is correctly concatenated
assert np.allclose(combined.data[0, :1000], np.arange(0, 1000))
assert np.allclose(combined.data[0, 1000:2000], np.arange(1000, 2000))
assert np.allclose(combined.data[0, 2000:3000], np.arange(2000, 3000))
def test_concat_verbose(self, test_recordings):
"""Test concatenation with verbose output."""
runner = CliRunner()
tmpdir = test_recordings
output_path = str(Path(tmpdir) / "combined.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "chunk0.npy"),
str(Path(tmpdir) / "chunk1.npy"),
str(Path(tmpdir) / "chunk2.npy"),
output_path,
"--verbose",
],
)
assert result.exit_code == 0
assert "Combining 3 recordings" in result.output
assert "concat mode" in result.output
assert "Concatenating..." in result.output
def test_concat_with_annotations(self):
"""Test that annotations are preserved and shifted in concat mode."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create recordings with annotations
rec1 = Recording(data=np.ones(1000, dtype=np.complex64), metadata={"sample_rate": 2e6})
rec1._annotations.append(
Annotation(
sample_start=100, sample_count=200, freq_lower_edge=900e6, freq_upper_edge=920e6, label="test1"
)
)
rec2 = Recording(data=np.ones(1000, dtype=np.complex64) * 2, metadata={"sample_rate": 2e6})
rec2._annotations.append(
Annotation(
sample_start=100, sample_count=200, freq_lower_edge=900e6, freq_upper_edge=920e6, label="test2"
)
)
to_sigmf(rec1, filename="rec1", path=tmpdir, overwrite=True)
to_sigmf(rec2, filename="rec2", path=tmpdir, overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "combined.sigmf-data")
result = runner.invoke(
cli,
["combine", str(Path(tmpdir) / "rec1.sigmf-data"), str(Path(tmpdir) / "rec2.sigmf-data"), output_path],
)
assert result.exit_code == 0
combined = load_recording(output_path)
assert len(combined._annotations) == 2
# First annotation unchanged
assert combined._annotations[0].sample_start == 100
assert combined._annotations[0].label == "test1"
# Second annotation shifted by 1000 samples
assert combined._annotations[1].sample_start == 1100
assert combined._annotations[1].label == "test2"
class TestCombineAddSameLength:
"""Test add mode with same-length recordings."""
def test_add_basic(self):
"""Test basic add with same-length recordings."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create two recordings with same length
sig1 = np.ones(1000, dtype=np.complex64)
sig2 = np.ones(1000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "added.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "sig1.npy"),
str(Path(tmpdir) / "sig2.npy"),
output_path,
"--mode",
"add",
],
)
assert result.exit_code == 0
# Verify result
combined = load_recording(output_path)
assert combined.data.shape[1] == 1000
assert np.allclose(combined.data, 3 + 0j)
assert combined._metadata["combine_mode"] == "add"
assert combined._metadata["align_mode"] == "error"
def test_add_three_recordings(self):
"""Test adding three same-length recordings."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create three recordings
for i in range(1, 4):
sig = np.ones(1000, dtype=np.complex64) * i
rec = Recording(data=sig, metadata={"sample_rate": 2e6})
to_npy(rec, filename=str(Path(tmpdir) / f"sig{i}.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "added.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "sig1.npy"),
str(Path(tmpdir) / "sig2.npy"),
str(Path(tmpdir) / "sig3.npy"),
output_path,
"--mode",
"add",
],
)
assert result.exit_code == 0
combined = load_recording(output_path)
# 1 + 2 + 3 = 6
assert np.allclose(combined.data, 6 + 0j)
class TestCombineAddAlignError:
"""Test add mode with error alignment (default)."""
def test_different_length_error(self):
"""Test that different lengths cause error by default."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create recordings with different lengths
sig1 = np.ones(10000, dtype=np.complex64)
sig2 = np.ones(5000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "output.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "long.npy"),
str(Path(tmpdir) / "short.npy"),
output_path,
"--mode",
"add",
],
)
assert result.exit_code != 0
assert "different lengths" in result.output
assert "--align-mode" in result.output
class TestCombineAddAlignTruncate:
"""Test add mode with truncate alignment."""
def test_truncate(self):
"""Test truncate to shortest recording."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(10000, dtype=np.complex64)
sig2 = np.ones(5000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "truncated.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "long.npy"),
str(Path(tmpdir) / "short.npy"),
output_path,
"--mode",
"add",
"--align-mode",
"truncate",
],
)
assert result.exit_code == 0
combined = load_recording(output_path)
assert combined.data.shape[1] == 5000
assert np.allclose(combined.data, 3 + 0j)
class TestCombineAddAlignPad:
"""Test add mode with pad alignment."""
def test_pad(self):
"""Test zero-padding to longest recording."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(10000, dtype=np.complex64)
sig2 = np.ones(5000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "padded.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "long.npy"),
str(Path(tmpdir) / "short.npy"),
output_path,
"--mode",
"add",
"--align-mode",
"pad",
],
)
assert result.exit_code == 0
combined = load_recording(output_path)
assert combined.data.shape[1] == 10000
# First 5000: 1 + 2 = 3
assert np.allclose(combined.data[0, :5000], 3 + 0j)
# Last 5000: 1 + 0 = 1
assert np.allclose(combined.data[0, 5000:], 1 + 0j)
class TestCombineAddAlignPadStart:
"""Test add mode with pad-start alignment."""
def test_pad_start(self):
"""Test pad-start at specific sample."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(10000, dtype=np.complex64)
sig2 = np.ones(5000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "pad_start.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "long.npy"),
str(Path(tmpdir) / "short.npy"),
output_path,
"--mode",
"add",
"--align-mode",
"pad-start",
"--pad-start-sample",
"3000",
],
)
assert result.exit_code == 0
combined = load_recording(output_path)
assert combined.data.shape[1] == 10000
# Before 3000: 1 + 0 = 1
assert np.allclose(combined.data[0, :3000], 1 + 0j)
# 3000-8000: 1 + 2 = 3
assert np.allclose(combined.data[0, 3000:8000], 3 + 0j)
# After 8000: 1 + 0 = 1
assert np.allclose(combined.data[0, 8000:], 1 + 0j)
def test_pad_start_invalid(self):
"""Test invalid pad-start-sample."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(10000, dtype=np.complex64)
sig2 = np.ones(5000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "output.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "long.npy"),
str(Path(tmpdir) / "short.npy"),
output_path,
"--mode",
"add",
"--align-mode",
"pad-start",
"--pad-start-sample",
"7000", # Too large
],
)
assert result.exit_code != 0
assert "exceeds max length" in result.output
class TestCombineAddAlignPadCenter:
"""Test add mode with pad-center alignment."""
def test_pad_center(self):
"""Test centering shorter recording."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(10000, dtype=np.complex64)
sig2 = np.ones(5000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "pad_center.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "long.npy"),
str(Path(tmpdir) / "short.npy"),
output_path,
"--mode",
"add",
"--align-mode",
"pad-center",
],
)
assert result.exit_code == 0
combined = load_recording(output_path)
assert combined.data.shape[1] == 10000
# Before 2500: 1 + 0 = 1
assert np.allclose(combined.data[0, :2500], 1 + 0j)
# 2500-7500: 1 + 2 = 3
assert np.allclose(combined.data[0, 2500:7500], 3 + 0j)
# After 7500: 1 + 0 = 1
assert np.allclose(combined.data[0, 7500:], 1 + 0j)
class TestCombineAddAlignPadEnd:
"""Test add mode with pad-end alignment."""
def test_pad_end(self):
"""Test aligning end of recordings."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(10000, dtype=np.complex64)
sig2 = np.ones(5000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "pad_end.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "long.npy"),
str(Path(tmpdir) / "short.npy"),
output_path,
"--mode",
"add",
"--align-mode",
"pad-end",
],
)
assert result.exit_code == 0
combined = load_recording(output_path)
assert combined.data.shape[1] == 10000
# First 5000: 1 + 0 = 1
assert np.allclose(combined.data[0, :5000], 1 + 0j)
# Last 5000: 1 + 2 = 3
assert np.allclose(combined.data[0, 5000:], 3 + 0j)
class TestCombineAddAlignRepeat:
"""Test add mode with repeat alignment."""
def test_repeat(self):
"""Test repeating shorter recording."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(10000, dtype=np.complex64)
sig2 = np.ones(5000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "repeated.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "long.npy"),
str(Path(tmpdir) / "short.npy"),
output_path,
"--mode",
"add",
"--align-mode",
"repeat",
],
)
assert result.exit_code == 0
combined = load_recording(output_path)
assert combined.data.shape[1] == 10000
# Entire recording: 1 + 2 = 3 (pattern repeated)
assert np.allclose(combined.data, 3 + 0j)
def test_repeat_partial(self):
"""Test repeat with non-exact multiple."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(10000, dtype=np.complex64)
sig2 = np.arange(3000, dtype=np.complex64)
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "repeated.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "long.npy"),
str(Path(tmpdir) / "short.npy"),
output_path,
"--mode",
"add",
"--align-mode",
"repeat",
],
)
assert result.exit_code == 0
combined = load_recording(output_path)
# Check pattern repeats correctly
# First 3000: 1 + [0,1,2,...,2999]
assert np.allclose(combined.data[0, :3000], 1 + np.arange(3000))
# Next 3000: 1 + [0,1,2,...,2999]
assert np.allclose(combined.data[0, 3000:6000], 1 + np.arange(3000))
# Next 3000: 1 + [0,1,2,...,2999]
assert np.allclose(combined.data[0, 6000:9000], 1 + np.arange(3000))
# Last 1000: 1 + [0,1,2,...,999]
assert np.allclose(combined.data[0, 9000:10000], 1 + np.arange(1000))
class TestCombineAddAlignRepeatSpaced:
"""Test add mode with repeat-spaced alignment."""
def test_repeat_spaced(self):
"""Test repeating with spacing."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(10000, dtype=np.complex64)
sig2 = np.ones(2000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "repeat_spaced.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "long.npy"),
str(Path(tmpdir) / "short.npy"),
output_path,
"--mode",
"add",
"--align-mode",
"repeat-spaced",
"--repeat-spacing",
"1000",
],
)
assert result.exit_code == 0
combined = load_recording(output_path)
assert combined.data.shape[1] == 10000
# First 2000: 1 + 2 = 3
assert np.allclose(combined.data[0, :2000], 3 + 0j)
# Next 1000 (gap): 1 + 0 = 1
assert np.allclose(combined.data[0, 2000:3000], 1 + 0j)
# Next 2000: 1 + 2 = 3
assert np.allclose(combined.data[0, 3000:5000], 3 + 0j)
# Next 1000 (gap): 1 + 0 = 1
assert np.allclose(combined.data[0, 5000:6000], 1 + 0j)
def test_repeat_spaced_missing_spacing(self):
"""Test error when spacing not provided."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(10000, dtype=np.complex64)
sig2 = np.ones(5000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "long.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "short.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "output.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "long.npy"),
str(Path(tmpdir) / "short.npy"),
output_path,
"--mode",
"add",
"--align-mode",
"repeat-spaced",
# Missing --repeat-spacing
],
)
assert result.exit_code != 0
assert "requires --repeat-spacing" in result.output
class TestCombineValidation:
"""Test validation and error handling."""
def test_sample_rate_mismatch(self):
"""Test error on sample rate mismatch in add mode."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(1000, dtype=np.complex64)
sig2 = np.ones(1000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 1e6})
to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "output.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "sig1.npy"),
str(Path(tmpdir) / "sig2.npy"),
output_path,
"--mode",
"add",
],
)
assert result.exit_code != 0
assert "different sample rates" in result.output
def test_channel_count_mismatch(self):
"""Test error on channel count mismatch."""
with tempfile.TemporaryDirectory() as tmpdir:
# Single channel
sig1 = np.ones((1, 1000), dtype=np.complex64)
# Two channels
sig2 = np.ones((2, 1000), dtype=np.complex64)
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "output.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "sig1.npy"),
str(Path(tmpdir) / "sig2.npy"),
output_path,
"--mode",
"add",
],
)
assert result.exit_code != 0
assert "different channel counts" in result.output
def test_overwrite_protection(self):
"""Test overwrite protection."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create test recordings
sig1 = np.ones(1000, dtype=np.complex64)
sig2 = np.ones(1000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True)
# Create existing output file
existing = Recording(data=np.zeros(100, dtype=np.complex64), metadata={})
output_path = str(Path(tmpdir) / "output.npy")
to_npy(existing, filename=output_path, overwrite=True)
runner = CliRunner()
# Should fail without --overwrite
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "sig1.npy"),
str(Path(tmpdir) / "sig2.npy"),
output_path,
"--mode",
"add",
],
)
assert result.exit_code != 0
assert "already exists" in result.output
# Should succeed with --overwrite
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "sig1.npy"),
str(Path(tmpdir) / "sig2.npy"),
output_path,
"--mode",
"add",
"--overwrite",
],
)
assert result.exit_code == 0
class TestCombineOutputOptions:
"""Test output format and options."""
def test_output_formats(self):
"""Test different output formats."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create test recordings
sig1 = np.ones(1000, dtype=np.complex64)
sig2 = np.ones(1000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True)
runner = CliRunner()
# Test SigMF output
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "sig1.npy"),
str(Path(tmpdir) / "sig2.npy"),
str(Path(tmpdir) / "output.sigmf-data"),
"--mode",
"add",
],
)
assert result.exit_code == 0
assert Path(tmpdir, "output.sigmf-data").exists()
assert Path(tmpdir, "output.sigmf-meta").exists()
def test_normalize(self):
"""Test normalize option."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(1000, dtype=np.complex64) * 10
sig2 = np.ones(1000, dtype=np.complex64) * 20
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "normalized.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "sig1.npy"),
str(Path(tmpdir) / "sig2.npy"),
output_path,
"--mode",
"add",
"--normalize",
],
)
assert result.exit_code == 0
combined = load_recording(output_path)
# Should be normalized to max magnitude 1
assert np.allclose(np.max(np.abs(combined.data)), 1.0)
assert combined._metadata.get("normalized") is True
def test_custom_metadata(self):
"""Test adding custom metadata."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(1000, dtype=np.complex64)
sig2 = np.ones(1000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "output.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "sig1.npy"),
str(Path(tmpdir) / "sig2.npy"),
output_path,
"--mode",
"add",
"--metadata",
"test_id=test123",
"--metadata",
"author=tester",
],
)
assert result.exit_code == 0
combined = load_recording(output_path)
assert combined._metadata["test_id"] == "test123"
assert combined._metadata["author"] == "tester"
class TestCombineVerboseQuiet:
"""Test verbose and quiet modes."""
def test_verbose(self):
"""Test verbose output."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(1000, dtype=np.complex64)
sig2 = np.ones(1000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "output.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "sig1.npy"),
str(Path(tmpdir) / "sig2.npy"),
output_path,
"--mode",
"add",
"--verbose",
],
)
assert result.exit_code == 0
assert "Loading" in result.output
assert "Done" in result.output
def test_quiet(self):
"""Test quiet output."""
with tempfile.TemporaryDirectory() as tmpdir:
sig1 = np.ones(1000, dtype=np.complex64)
sig2 = np.ones(1000, dtype=np.complex64) * 2
rec1 = Recording(data=sig1, metadata={"sample_rate": 2e6})
rec2 = Recording(data=sig2, metadata={"sample_rate": 2e6})
to_npy(rec1, filename=str(Path(tmpdir) / "sig1.npy"), overwrite=True)
to_npy(rec2, filename=str(Path(tmpdir) / "sig2.npy"), overwrite=True)
runner = CliRunner()
output_path = str(Path(tmpdir) / "output.npy")
result = runner.invoke(
cli,
[
"combine",
str(Path(tmpdir) / "sig1.npy"),
str(Path(tmpdir) / "sig2.npy"),
output_path,
"--mode",
"add",
"--quiet",
],
)
assert result.exit_code == 0
assert result.output == ""

View File

@ -0,0 +1,165 @@
# flake8: noqa
"""Tests for capture command."""
import os
import tempfile
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import yaml
from click.testing import CliRunner
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture import (
auto_select_device,
capture,
get_sdr_device,
save_visualization,
)
class TestGetSdrDevice:
"""Tests for get_sdr_device function."""
def test_get_pluto_device(self):
"""Test getting PlutoSDR device."""
mock_sdr_class = MagicMock()
mock_sdr_instance = MagicMock()
mock_sdr_class.return_value = mock_sdr_instance
with patch.dict("sys.modules", {"src.ria_toolkit_oss.sdr.pluto": MagicMock(Pluto=mock_sdr_class)}):
device = get_sdr_device("pluto")
assert device is mock_sdr_instance
def test_get_hackrf_device(self):
"""Test getting HackRF device."""
mock_sdr_class = MagicMock()
mock_sdr_instance = MagicMock()
mock_sdr_class.return_value = mock_sdr_instance
with patch.dict("sys.modules", {"src.ria_toolkit_oss.sdr.hackrf": MagicMock(HackRF=mock_sdr_class)}):
device = get_sdr_device("hackrf")
assert device is mock_sdr_instance
def test_get_unknown_device(self):
"""Test getting unknown device type."""
from click.exceptions import ClickException
with pytest.raises(ClickException) as exc_info:
get_sdr_device("unknown_device")
assert "Unknown device type" in str(exc_info.value)
class TestAutoSelectDevice:
"""Tests for auto_select_device function."""
def test_auto_select_no_devices(self):
"""Test auto-select with no devices found."""
from click.exceptions import ClickException
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.list_all_devices") as mock_discover:
mock_discover.return_value = []
with pytest.raises(ClickException) as exc_info:
auto_select_device()
assert "No SDR devices found" in str(exc_info.value)
def test_auto_select_single_device(self):
"""Test auto-select with single device."""
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.list_all_devices") as mock_discover:
mock_discover.return_value = [{"type": "HackRF", "serial": "123456"}]
device_type = auto_select_device(quiet=True)
assert device_type == "hackrf"
def test_auto_select_single_device_with_warning(self):
"""Test auto-select shows warning when not quiet."""
with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.list_all_devices") as mock_discover,
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo,
):
mock_discover.return_value = [{"type": "PlutoSDR", "uri": "ip:pluto.local"}]
device_type = auto_select_device(quiet=False)
assert device_type == "pluto"
# Should have called echo twice (warning + hint)
assert mock_echo.call_count == 2
def test_auto_select_multiple_devices(self):
"""Test auto-select with multiple devices raises error."""
from click.exceptions import ClickException
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.list_all_devices") as mock_discover:
mock_discover.return_value = [
{"type": "HackRF", "serial": "123456"},
{"type": "PlutoSDR", "uri": "ip:pluto.local"},
]
with pytest.raises(ClickException) as exc_info:
auto_select_device()
assert "Multiple devices found" in str(exc_info.value)
def test_auto_select_device_name_mapping(self):
"""Test device name mapping."""
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.list_all_devices") as mock_discover:
# Test various device name formats
test_cases = [
("PlutoSDR", "pluto"),
("HackRF", "hackrf"),
("BladeRF", "bladerf"),
("RTL-SDR", "rtlsdr"),
]
for device_name, expected_type in test_cases:
mock_discover.return_value = [{"type": device_name}]
device_type = auto_select_device(quiet=True)
assert device_type == expected_type
class TestSaveVisualization:
"""Tests for save_visualization function."""
def test_save_visualization_success(self):
"""Test successful visualization save."""
mock_recording = MagicMock()
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig") as mock_view:
save_visualization(mock_recording, "test.png", quiet=True)
mock_view.assert_called_once_with(
mock_recording, output_path="test.png", saveplot=True, fast_mode=False, labels_mode=True
)
def test_save_visualization_import_error(self):
"""Test visualization save with import error."""
mock_recording = MagicMock()
with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig", side_effect=ImportError("Module not found")),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo,
):
save_visualization(mock_recording, "test.png", quiet=True)
# Should catch error and echo warning
mock_echo.assert_called_once()
assert "Warning" in str(mock_echo.call_args)
def test_save_visualization_general_error(self):
"""Test visualization save with general error."""
mock_recording = MagicMock()
with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig", side_effect=Exception("Failed to plot")),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo,
):
save_visualization(mock_recording, "test.png", quiet=True)
mock_echo.assert_called_once()
assert "Failed to save visualization" in str(mock_echo.call_args)

View File

@ -0,0 +1,118 @@
"""Tests for common CLI utilities."""
import os
import tempfile
import pytest
import yaml
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import (
format_frequency,
format_sample_rate,
load_yaml_config,
parse_frequency,
parse_metadata_args,
)
def test_load_yaml_config():
"""Test loading YAML configuration files."""
config_data = {
"device": "pluto",
"sample_rate": 2e6,
"center_frequency": "915e6",
"gain": 30,
"metadata": {"location": "test_lab", "experiment": "test_001"},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config_data, f)
config_file = f.name
try:
loaded_config = load_yaml_config(config_file)
assert loaded_config == config_data
assert loaded_config["device"] == "pluto"
assert loaded_config["sample_rate"] == 2e6
assert loaded_config["metadata"]["location"] == "test_lab"
finally:
os.unlink(config_file)
def test_load_yaml_config_empty():
"""Test loading empty YAML file."""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write("")
config_file = f.name
try:
loaded_config = load_yaml_config(config_file)
assert loaded_config == {}
finally:
os.unlink(config_file)
def test_parse_metadata_args():
"""Test parsing metadata KEY=VALUE arguments."""
metadata_args = ["location=test_lab", "experiment=001", "power=30", "frequency=2.4e9", "description=Test Signal"]
result = parse_metadata_args(metadata_args)
assert result["location"] == "test_lab"
assert result["experiment"] == "001" # String because doesn't parse as number
assert result["power"] == 30 # Integer
assert result["frequency"] == 2.4e9 # Float
assert result["description"] == "Test Signal"
def test_parse_metadata_args_invalid():
"""Test invalid metadata format raises error."""
from click.exceptions import ClickException
with pytest.raises(ClickException):
parse_metadata_args(["invalid_format"])
with pytest.raises(ClickException):
parse_metadata_args(["key1=value1", "invalid", "key2=value2"])
def test_parse_frequency():
"""Test frequency parsing with different formats."""
# Scientific notation
assert parse_frequency("915e6") == 915e6
assert parse_frequency("2.4e9") == 2.4e9
assert parse_frequency("433e6") == 433e6
# With suffixes
assert parse_frequency("915M") == 915e6
assert parse_frequency("2.4G") == 2.4e9
assert parse_frequency("433M") == 433e6
assert parse_frequency("100k") == 100e3
assert parse_frequency("100K") == 100e3
# Plain numbers
assert parse_frequency("915000000") == 915e6
assert parse_frequency("2400000000") == 2.4e9
# Edge cases
assert parse_frequency("0.915G") == 915e6
assert parse_frequency("915.0M") == 915e6
def test_format_frequency():
"""Test frequency formatting."""
assert format_frequency(915e6) == "915.00 MHz"
assert format_frequency(2.4e9) == "2.40 GHz"
assert format_frequency(433e6) == "433.00 MHz"
assert format_frequency(100e3) == "100.00 kHz"
assert format_frequency(1e3) == "1.00 kHz"
assert format_frequency(500) == "500.00 Hz"
def test_format_sample_rate():
"""Test sample rate formatting."""
assert format_sample_rate(20e6) == "20.00 MS/s"
assert format_sample_rate(2e6) == "2.00 MS/s"
assert format_sample_rate(100e3) == "100.00 kS/s"
assert format_sample_rate(1e3) == "1.00 kS/s"
assert format_sample_rate(500) == "500.00 S/s"

View File

@ -0,0 +1,190 @@
"""Tests for convert command."""
import os
import tempfile
from pathlib import Path
import pytest
from click.testing import CliRunner
from ria_toolkit_oss.ria_toolkit_oss_cli.cli import cli
class TestConvert:
"""Test convert command functionality."""
def test_convert_help(self):
"""Test convert command help."""
runner = CliRunner()
result = runner.invoke(cli, ["convert", "--help"])
assert result.exit_code == 0
assert "Convert recordings between file formats" in result.output
assert "--format" in result.output
assert "--legacy" in result.output
assert "--wav-sample-rate" in result.output
assert "--blue-format" in result.output
assert "--overwrite" in result.output
assert "--metadata" in result.output
def test_missing_arguments(self):
"""Test that missing arguments show error."""
runner = CliRunner()
result = runner.invoke(cli, ["convert"])
assert result.exit_code != 0
assert "Missing argument" in result.output or "Error" in result.output
def test_invalid_input_format(self):
"""Test handling of invalid input format."""
runner = CliRunner()
with tempfile.NamedTemporaryFile(suffix=".xyz", delete=False) as f:
try:
result = runner.invoke(cli, ["convert", f.name, "output.npy"])
assert result.exit_code != 0
assert "Unknown format" in result.output or "Supported" in result.output
finally:
os.unlink(f.name)
def test_overwrite_protection(self):
"""Test that overwrite protection works."""
runner = CliRunner()
# Create a dummy input file (will use actual test data if available)
test_input = "/home/qrf/workarea/ash/signal-testbed/recordings/iq2440MHz234233.npy"
if not os.path.exists(test_input):
pytest.skip("Test recording file not found")
with tempfile.TemporaryDirectory() as tmpdir:
output_file = os.path.join(tmpdir, "test.sigmf")
# First conversion should succeed
result = runner.invoke(cli, ["convert", test_input, output_file, "--legacy", "-q"])
assert result.exit_code == 0
# Second conversion without --overwrite should fail
result = runner.invoke(cli, ["convert", test_input, output_file, "--legacy"])
assert result.exit_code != 0
assert "exist" in result.output.lower()
assert "--overwrite" in result.output
# Third conversion with --overwrite should succeed
result = runner.invoke(cli, ["convert", test_input, output_file, "--legacy", "--overwrite", "-q"])
assert result.exit_code == 0
def test_metadata_override(self):
"""Test metadata override functionality."""
runner = CliRunner()
test_input = "/home/qrf/workarea/ash/signal-testbed/recordings/iq2440MHz234233.npy"
if not os.path.exists(test_input):
pytest.skip("Test recording file not found")
with tempfile.TemporaryDirectory() as tmpdir:
output_file = os.path.join(tmpdir, "test.sigmf")
result = runner.invoke(
cli,
[
"convert",
test_input,
output_file,
"--legacy",
"--metadata",
"test_key=test_value",
"--metadata",
"number=42",
"--metadata",
"float_val=3.14",
"-v",
],
)
assert result.exit_code == 0
assert "test_key" in result.output
assert "number" in result.output
assert "float_val" in result.output
def test_format_detection(self):
"""Test that format detection works for different extensions."""
runner = CliRunner()
test_input = "/home/qrf/workarea/ash/signal-testbed/recordings/iq2440MHz234233.npy"
if not os.path.exists(test_input):
pytest.skip("Test recording file not found")
with tempfile.TemporaryDirectory() as tmpdir:
# Test NPY to SigMF
sigmf_out = os.path.join(tmpdir, "test.sigmf")
result = runner.invoke(cli, ["convert", test_input, sigmf_out, "--legacy", "-q"])
assert result.exit_code == 0
assert Path(sigmf_out).with_suffix(".sigmf-data").exists()
assert Path(sigmf_out).with_suffix(".sigmf-meta").exists()
# Test NPY to NPY
npy_out = os.path.join(tmpdir, "test.npy")
result = runner.invoke(cli, ["convert", test_input, npy_out, "--legacy", "-q"])
assert result.exit_code == 0
assert Path(npy_out).exists()
def test_wav_conversion_with_decimation(self):
"""Test WAV conversion with sample rate decimation."""
runner = CliRunner()
test_input = "/home/qrf/workarea/ash/signal-testbed/recordings/iq2440MHz234233.npy"
if not os.path.exists(test_input):
pytest.skip("Test recording file not found")
with tempfile.TemporaryDirectory() as tmpdir:
wav_out = os.path.join(tmpdir, "test.wav")
result = runner.invoke(
cli, ["convert", test_input, wav_out, "--legacy", "--wav-sample-rate", "48000", "--wav-bits", "16"]
)
assert result.exit_code == 0
assert "Decimation factor" in result.output
assert Path(wav_out).exists()
# Check file is non-empty
assert os.path.getsize(wav_out) > 0
def test_blue_format_conversion(self):
"""Test MIDAS Blue format conversion."""
runner = CliRunner()
test_input = "/home/qrf/workarea/ash/signal-testbed/recordings/iq2440MHz234233.npy"
if not os.path.exists(test_input):
pytest.skip("Test recording file not found")
with tempfile.TemporaryDirectory() as tmpdir:
# Test each Blue format
for blue_fmt in ["CI", "CF", "CD"]:
blue_out = os.path.join(tmpdir, f"test_{blue_fmt}.blue")
result = runner.invoke(
cli, ["convert", test_input, blue_out, "--legacy", "--blue-format", blue_fmt, "-q"]
)
assert result.exit_code == 0
assert Path(blue_out).exists()
# Check file is non-empty
assert os.path.getsize(blue_out) > 0
def test_quiet_and_verbose_modes(self):
"""Test quiet and verbose output modes."""
runner = CliRunner()
test_input = "/home/qrf/workarea/ash/signal-testbed/recordings/iq2440MHz234233.npy"
if not os.path.exists(test_input):
pytest.skip("Test recording file not found")
with tempfile.TemporaryDirectory() as tmpdir:
# Test verbose mode
output_file = os.path.join(tmpdir, "test_verbose.sigmf")
result = runner.invoke(cli, ["convert", test_input, output_file, "--legacy", "-v"])
assert result.exit_code == 0
assert "Reading input" in result.output
assert "Metadata preserved" in result.output
# Test quiet mode
output_file = os.path.join(tmpdir, "test_quiet.npy")
result = runner.invoke(cli, ["convert", test_input, output_file, "--legacy", "-q"])
assert result.exit_code == 0
# Should have minimal output
assert len(result.output) < 100 or result.output.strip() == ""

View File

@ -0,0 +1,287 @@
"""Tests for discover command."""
import json
import re
from unittest.mock import MagicMock, patch
from click.testing import CliRunner
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover import ( # find_bladerf_devices,; find_thinkrf_devices,; find_uhd_devices,
discover,
discover_all_devices,
find_hackrf_devices,
find_pluto_devices,
find_rtlsdr_devices,
load_sdr_drivers,
)
def test_discover_pluto_no_devices():
"""Test PlutoSDR discovery with no devices."""
with (
patch.dict("sys.modules", {"iio": MagicMock()}) as mock_modules,
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.get_usb_devices") as mock_usb,
):
mock_iio = mock_modules["iio"]
mock_iio.scan_contexts.return_value = {}
mock_usb.return_value = []
devices = find_pluto_devices()
assert devices == []
def test_discover_pluto_with_device():
"""Test PlutoSDR discovery with device present."""
with patch.dict("sys.modules", {"iio": MagicMock()}) as mock_modules:
mock_iio = mock_modules["iio"]
mock_ctx = MagicMock()
mock_ctx.attrs = {"hw_serial": "123456", "fw_version": "1.0"}
mock_ctx._destroy = MagicMock()
mock_iio.scan_contexts.return_value = {"ip:pluto.local": "PlutoSDR (ADALM-PLUTO)"}
mock_iio.Context.return_value = mock_ctx
devices = find_pluto_devices()
assert len(devices) == 1
assert devices[0]["type"] == "PlutoSDR"
assert devices[0]["serial"] == "123456"
assert devices[0]["firmware"] == "1.0"
assert devices[0]["uri"] == "ip:pluto.local"
def test_discover_hackrf_no_devices():
"""Test HackRF discovery with no devices."""
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.subprocess") as mock_subprocess:
mock_subprocess.check_output.return_value = ""
devices = find_hackrf_devices()
assert devices == []
def test_discover_hackrf_with_devices():
"""Test HackRF discovery with devices present."""
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.subprocess") as mock_subprocess:
mock_subprocess.check_output.return_value = """
hackrf_info version: 2023.01.1
libhackrf version: 2023.01.1 (0.8)
Found HackRF
Index: 0
Serial number: serial123
Board ID Number: 2 (HackRF One)
Firmware Version: v2.1.0 (API:1.08)
Part ID Number: 0xa000cb3c 0x005d4761
Index: 1
Serial number: serial456
Board ID Number: 2 (HackRF One)
Firmware Version: v2.1.0 (API:1.08)
Part ID Number: 0xa000cb3c 0x005d4761
"""
devices = find_hackrf_devices()
assert len(devices) == 2
assert devices[0]["type"] == "HackRF One"
assert devices[0]["serial"] == "serial123"
assert devices[0]["device_index"] == 0 or devices[0]["device_index"] == "0"
assert devices[1]["serial"] == "serial456"
assert devices[1]["device_index"] == 1 or devices[1]["device_index"] == "1"
def test_discover_rtlsdr_no_devices():
"""Test RTL-SDR discovery with no devices."""
with patch("ria_toolkit_oss.ria_toolkit_oss.ria_toolkit_oss.discover.subprocess") as mock_subprocess:
mock_subprocess.check_output.return_value = ""
devices = find_rtlsdr_devices()
assert devices == []
def test_discover_rtlsdr_with_devices():
"""Test RTL-SDR discovery with devices present."""
with patch("ria_toolkit_oss.ria_toolkit_oss.ria_toolkit_oss.discover.subprocess") as mock_subprocess:
mock_subprocess.check_output.return_value = """
Found 2 device(s):
0: RTLSDRBlog, Blog V4, SN: 00000001
1: RTLSDRBlog, Blog V4, SN: 00000002
Using device 0: Generic RTL2832U OEM
Found Rafael Micro R828D tuner
RTL-SDR Blog V4 Detected
"""
devices = find_rtlsdr_devices()
assert len(devices) == 2
assert devices[0]["type"] == "RTL-SDR"
assert devices[0]["serial"] == "00000001"
assert devices[0]["device_index"] == 0 or devices[0]["device_index"] == "0"
def test_discover_all_devices_filter():
"""Test discovering devices with type filter."""
with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_pluto_devices") as mock_pluto,
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_hackrf_devices") as mock_hackrf,
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_bladerf_devices") as mock_bladerf,
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_uhd_devices") as mock_usrp,
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_rtlsdr_devices") as mock_rtlsdr,
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_thinkrf_devices") as mock_thinkrf,
):
mock_pluto.return_value = [{"type": "PlutoSDR", "uri": "ip:pluto.local"}]
mock_hackrf.return_value = []
mock_bladerf.return_value = []
mock_usrp.return_value = []
mock_rtlsdr.return_value = []
mock_thinkrf.return_value = []
# Test filtering by pluto
load_sdr_drivers(verbose=False)
devices = discover_all_devices()
mock_pluto.assert_called_once()
mock_hackrf.assert_called_once()
mock_bladerf.assert_called_once()
assert len(devices["devices"]) == 1
assert len(devices["pluto_devices"]) == 1
def test_discover_all_devices_no_filter():
"""Test discovering all device types."""
with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_pluto_devices") as mock_pluto,
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_hackrf_devices") as mock_hackrf,
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_bladerf_devices") as mock_bladerf,
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_uhd_devices") as mock_usrp,
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_rtlsdr_devices") as mock_rtlsdr,
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.find_thinkrf_devices") as mock_thinkrf,
):
mock_pluto.return_value = [{"type": "PlutoSDR", "uri": "ip:pluto.local"}]
mock_hackrf.return_value = [{"type": "HackRF"}]
mock_bladerf.return_value = []
mock_usrp.return_value = []
mock_rtlsdr.return_value = []
mock_thinkrf.return_value = []
load_sdr_drivers(verbose=False)
devices = discover_all_devices()
mock_pluto.assert_called_once()
mock_hackrf.assert_called_once()
mock_bladerf.assert_called_once()
mock_usrp.assert_called_once()
mock_rtlsdr.assert_called_once()
assert len(devices["devices"]) == 2
assert len(devices["pluto_devices"]) == 1
assert len(devices["hackrf_devices"]) == 1
def test_discover_command_no_devices():
"""Test discover CLI command with no devices."""
runner = CliRunner()
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.discover_all_devices") as mock_discover:
mock_discover.return_value = {
"loaded_drivers": [],
"failed_drivers": [],
"devices": [],
"total_devices": 0,
"uhd_devices": [],
"pluto_devices": [],
"rtlsdr_devices": [],
"bladerf_devices": [],
"hackrf_devices": [],
}
result = runner.invoke(discover)
assert result.exit_code == 0
assert "No devices detected" in result.output
def test_discover_command():
"""Test discover CLI command."""
runner = CliRunner()
result = runner.invoke(discover)
radios = ["USRP/UHD", "PlutoSDR", "RTL-SDR", "BladeRF", "HackRF", "ThinkRF"]
match = re.search(r"Detected devices: (\d+)", result.output)
if match:
total_devices = int(match.group(1))
else:
total_devices = 0
if result.exit_code == 0:
assert "Attached Devices" in result.output
assert "Discovery Summary" in result.output
if total_devices > 0:
assert any(radio in result.output for radio in radios)
else:
assert not any(radio in result.output for radio in radios)
else:
assert result.exit_code == 1
assert isinstance(result.exception, AttributeError)
assert "undefined symbol: iio_get_backends_count" in str(result.exception)
def test_discover_command_json_output():
"""Test discover CLI command with JSON output."""
runner = CliRunner()
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.discover_all_devices") as mock_discover:
mock_discover.return_value = {
"loaded_drivers": [],
"failed_drivers": [],
"devices": [{"type": "HackRF", "serial": "123456", "status": "available"}],
"total_devices": 1,
"uhd_devices": [],
"pluto_devices": [],
"rtlsdr_devices": [],
"bladerf_devices": [],
"hackrf_devices": [{"type": "HackRF", "serial": "123456", "status": "available"}],
}
result = runner.invoke(discover, ["--json-output"])
output_data = json.loads(result.output)
assert result.exit_code == 0
assert output_data["total_devices"] == 1
assert len(output_data["devices"]) == 1
assert output_data["devices"][0]["type"] == "HackRF"
def test_discover_command_verbose():
"""Test discover CLI command with verbose output."""
runner = CliRunner()
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.discover.discover_all_devices") as mock_discover:
mock_discover.return_value = {
"loaded_drivers": [],
"failed_drivers": [],
"devices": [
{
"type": "PlutoSDR",
"serial": "123456",
"firmware": "1.0",
"uri": "ip:pluto.local",
"status": "available",
}
],
"total_devices": 1,
"uhd_devices": [],
"pluto_devices": [],
"rtlsdr_devices": [],
"bladerf_devices": [],
"hackrf_devices": [
{
"type": "PlutoSDR",
"serial": "123456",
"firmware": "1.0",
"uri": "ip:pluto.local",
"status": "available",
}
],
}
result = runner.invoke(discover, ["--verbose"])
assert result.exit_code == 0
assert "RTL-SDR devices: None found" in result.output or "\n rtlsdr:" in result.output

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,670 @@
"""Tests for split CLI command."""
import tempfile
from pathlib import Path
import numpy as np
import pytest
from click.testing import CliRunner
from ria_toolkit_oss.datatypes import Annotation, Recording
from ria_toolkit_oss.io import load_recording, to_sigmf
from ria_toolkit_oss.ria_toolkit_oss_cli.cli import cli
class TestSplitHelp:
"""Test split command help and basic functionality."""
def test_split_help(self):
"""Test split command help."""
runner = CliRunner()
result = runner.invoke(cli, ["split", "--help"])
assert result.exit_code == 0
assert "Split, trim, and extract portions of recordings" in result.output
assert "--split-at" in result.output
assert "--split-every" in result.output
assert "--split-duration" in result.output
assert "--trim" in result.output
assert "--extract-annotations" in result.output
def test_missing_arguments(self):
"""Test that missing arguments show error."""
runner = CliRunner()
result = runner.invoke(cli, ["split"])
assert result.exit_code != 0
assert "Missing argument" in result.output or "Error" in result.output
def test_no_operation_specified(self):
"""Test error when no operation is specified."""
runner = CliRunner()
# Create a test file
with tempfile.TemporaryDirectory() as tmpdir:
signal = np.ones(1000, dtype=np.complex64)
recording = Recording(data=signal, metadata={"sample_rate": 1e6})
to_sigmf(recording, filename="test", path=tmpdir, overwrite=True)
test_file = str(Path(tmpdir) / "test.sigmf-data")
result = runner.invoke(cli, ["split", test_file])
assert result.exit_code != 0
assert "No operation specified" in result.output
class TestSplitTrim:
"""Test trim operations."""
@pytest.fixture
def test_recording(self):
"""Create a test recording file."""
with tempfile.TemporaryDirectory() as tmpdir:
signal = np.arange(10000, dtype=np.complex64)
recording = Recording(data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6})
to_sigmf(recording, filename="test", path=tmpdir, overwrite=True)
yield str(Path(tmpdir) / "test.sigmf-data")
def test_trim_with_length(self, test_recording):
"""Test trim with --start and --length."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
result = runner.invoke(
cli,
[
"split",
test_recording,
"--trim",
"--start",
"1000",
"--length",
"5000",
"--output-dir",
outdir,
"-q",
],
)
assert result.exit_code == 0
# Verify output file exists
output_files = list(Path(outdir).glob("*.sigmf-data"))
assert len(output_files) == 1
# Verify output has correct length
output_rec = load_recording(str(output_files[0]))
assert output_rec.data.shape[1] == 5000
assert output_rec.metadata["original_start_sample"] == 1000
assert output_rec.metadata["original_end_sample"] == 6000
assert output_rec.metadata["split_operation"] == "trim"
def test_trim_with_end(self, test_recording):
"""Test trim with --start and --end."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
result = runner.invoke(
cli,
["split", test_recording, "--trim", "--start", "2000", "--end", "7000", "--output-dir", outdir, "-q"],
)
assert result.exit_code == 0
output_files = list(Path(outdir).glob("*.sigmf-data"))
assert len(output_files) == 1
output_rec = load_recording(str(output_files[0]))
assert output_rec.data.shape[1] == 5000
def test_trim_without_length_or_end(self, test_recording):
"""Test that trim requires --length or --end."""
runner = CliRunner()
result = runner.invoke(cli, ["split", test_recording, "--trim", "--start", "1000"])
assert result.exit_code != 0
assert "requires either --length or --end" in result.output
def test_trim_with_both_length_and_end(self, test_recording):
"""Test that trim rejects both --length and --end."""
runner = CliRunner()
result = runner.invoke(
cli, ["split", test_recording, "--trim", "--start", "1000", "--length", "5000", "--end", "6000"]
)
assert result.exit_code != 0
assert "Cannot specify both --length and --end" in result.output
def test_trim_invalid_range(self, test_recording):
"""Test trim with invalid range."""
runner = CliRunner()
result = runner.invoke(
cli,
["split", test_recording, "--trim", "--start", "1000", "--length", "50000"], # Exceeds recording length
)
assert result.exit_code != 0
assert "Invalid trim range" in result.output
def test_trim_end_before_start(self, test_recording):
"""Test trim with end < start."""
runner = CliRunner()
result = runner.invoke(cli, ["split", test_recording, "--trim", "--start", "5000", "--end", "1000"])
assert result.exit_code != 0
assert "Invalid range" in result.output
class TestSplitAt:
"""Test split-at operations."""
@pytest.fixture
def test_recording(self):
"""Create a test recording file."""
with tempfile.TemporaryDirectory() as tmpdir:
signal = np.arange(10000, dtype=np.complex64)
recording = Recording(data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6})
to_sigmf(recording, filename="test", path=tmpdir, overwrite=True)
yield str(Path(tmpdir) / "test.sigmf-data")
def test_split_at_middle(self, test_recording):
"""Test splitting at middle of recording."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
result = runner.invoke(cli, ["split", test_recording, "--split-at", "5000", "--output-dir", outdir, "-q"])
assert result.exit_code == 0
# Verify two output files exist
output_files = sorted(Path(outdir).glob("*.sigmf-data"))
assert len(output_files) == 2
# Verify part1
part1 = load_recording(str(output_files[0]))
assert part1.data.shape[1] == 5000
assert part1.metadata["original_start_sample"] == 0
assert part1.metadata["original_end_sample"] == 5000
# Verify part2
part2 = load_recording(str(output_files[1]))
assert part2.data.shape[1] == 5000
assert part2.metadata["original_start_sample"] == 5000
assert part2.metadata["original_end_sample"] == 10000
def test_split_at_invalid_point(self, test_recording):
"""Test split-at with invalid sample point."""
runner = CliRunner()
result = runner.invoke(cli, ["split", test_recording, "--split-at", "50000"]) # Exceeds recording length
assert result.exit_code != 0
assert "Invalid split point" in result.output
class TestSplitEvery:
"""Test split-every operations."""
@pytest.fixture
def test_recording(self):
"""Create a test recording file."""
with tempfile.TemporaryDirectory() as tmpdir:
signal = np.arange(10000, dtype=np.complex64)
recording = Recording(data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6})
to_sigmf(recording, filename="test", path=tmpdir, overwrite=True)
yield str(Path(tmpdir) / "test.sigmf-data")
def test_split_every_equal_chunks(self, test_recording):
"""Test splitting into equal chunks."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
result = runner.invoke(
cli, ["split", test_recording, "--split-every", "2500", "--output-dir", outdir, "-q"]
)
assert result.exit_code == 0
# Verify 4 chunks created
output_files = sorted(Path(outdir).glob("*.sigmf-data"))
assert len(output_files) == 4
# Verify all chunks have correct size
for i, file in enumerate(output_files):
chunk = load_recording(str(file))
assert chunk.data.shape[1] == 2500
assert chunk.metadata["chunk_index"] == i + 1
assert chunk.metadata["total_chunks"] == 4
def test_split_every_unequal_chunks(self, test_recording):
"""Test splitting with remainder chunk."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
result = runner.invoke(
cli, ["split", test_recording, "--split-every", "3000", "--output-dir", outdir, "-q"]
)
assert result.exit_code == 0
# Verify 4 chunks created (3x3000 + 1x1000)
output_files = sorted(Path(outdir).glob("*.sigmf-data"))
assert len(output_files) == 4
# Last chunk should be smaller
last_chunk = load_recording(str(output_files[-1]))
assert last_chunk.data.shape[1] == 1000
class TestSplitDuration:
"""Test split-duration operations."""
@pytest.fixture
def test_recording(self):
"""Create a test recording file with known sample rate."""
with tempfile.TemporaryDirectory() as tmpdir:
signal = np.arange(10000, dtype=np.complex64)
recording = Recording(
data=signal, metadata={"sample_rate": 10000, "center_frequency": 915e6} # 10kHz for easy math
)
to_sigmf(recording, filename="test", path=tmpdir, overwrite=True)
yield str(Path(tmpdir) / "test.sigmf-data")
def test_split_duration_basic(self, test_recording):
"""Test splitting by duration."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
result = runner.invoke(
cli,
[
"split",
test_recording,
"--split-duration",
"0.25", # 0.25s = 2500 samples at 10kHz
"--output-dir",
outdir,
"-q",
],
)
assert result.exit_code == 0
# Verify chunks created
output_files = sorted(Path(outdir).glob("*.sigmf-data"))
assert len(output_files) == 4
# Verify chunk sizes
for file in output_files[:-1]:
chunk = load_recording(str(file))
assert chunk.data.shape[1] == 2500
def test_split_duration_no_sample_rate(self):
"""Test that split-duration requires sample_rate in metadata."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as tmpdir:
# Create recording without sample_rate
signal = np.arange(1000, dtype=np.complex64)
recording = Recording(data=signal, metadata={})
to_sigmf(recording, filename="test", path=tmpdir, overwrite=True)
test_file = str(Path(tmpdir) / "test.sigmf-data")
result = runner.invoke(cli, ["split", test_file, "--split-duration", "1.0"])
assert result.exit_code != 0
assert "Cannot split by duration" in result.output
assert "no sample_rate" in result.output
class TestExtractAnnotations:
"""Test extract-annotations operations."""
@pytest.fixture
def annotated_recording(self):
"""Create a test recording with annotations."""
with tempfile.TemporaryDirectory() as tmpdir:
signal = np.arange(100000, dtype=np.complex64)
annotations = [
Annotation(
sample_start=0, sample_count=10000, freq_lower_edge=914e6, freq_upper_edge=916e6, label="preamble"
),
Annotation(
sample_start=10000,
sample_count=50000,
freq_lower_edge=914e6,
freq_upper_edge=916e6,
label="payload",
),
Annotation(
sample_start=60000, sample_count=5000, freq_lower_edge=914e6, freq_upper_edge=916e6, label="crc"
),
]
recording = Recording(
data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6}, annotations=annotations
)
to_sigmf(recording, filename="annotated", path=tmpdir, overwrite=True)
yield str(Path(tmpdir) / "annotated.sigmf-data")
def test_extract_all_annotations(self, annotated_recording):
"""Test extracting all annotations."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
result = runner.invoke(
cli, ["split", annotated_recording, "--extract-annotations", "--output-dir", outdir, "-q"]
)
assert result.exit_code == 0
# Verify 3 files created
output_files = sorted(Path(outdir).glob("*.sigmf-data"))
assert len(output_files) == 3
# Verify each annotation was extracted
preamble = [f for f in output_files if "preamble" in str(f)][0]
payload = [f for f in output_files if "payload" in str(f)][0]
crc = [f for f in output_files if "crc" in str(f)][0]
preamble_rec = load_recording(str(preamble))
assert preamble_rec.data.shape[1] == 10000
assert preamble_rec.metadata["annotation_label"] == "preamble"
assert len(preamble_rec.annotations) == 0 # Annotations cleared
payload_rec = load_recording(str(payload))
assert payload_rec.data.shape[1] == 50000
assert payload_rec.metadata["annotation_label"] == "payload"
crc_rec = load_recording(str(crc))
assert crc_rec.data.shape[1] == 5000
assert crc_rec.metadata["annotation_label"] == "crc"
def test_extract_annotation_by_label(self, annotated_recording):
"""Test extracting annotations by label."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
result = runner.invoke(
cli,
[
"split",
annotated_recording,
"--extract-annotations",
"--annotation-label",
"payload",
"--output-dir",
outdir,
"-q",
],
)
assert result.exit_code == 0
# Verify only 1 file created
output_files = list(Path(outdir).glob("*.sigmf-data"))
assert len(output_files) == 1
assert "payload" in str(output_files[0])
def test_extract_annotation_by_index(self, annotated_recording):
"""Test extracting annotation by index."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
result = runner.invoke(
cli,
[
"split",
annotated_recording,
"--extract-annotations",
"--annotation-index",
"1",
"--output-dir",
outdir,
"-q",
],
)
assert result.exit_code == 0
# Verify only 1 file created (payload at index 1)
output_files = list(Path(outdir).glob("*.sigmf-data"))
assert len(output_files) == 1
assert "payload" in str(output_files[0])
def test_extract_annotations_invalid_label(self, annotated_recording):
"""Test error with non-existent label."""
runner = CliRunner()
result = runner.invoke(
cli, ["split", annotated_recording, "--extract-annotations", "--annotation-label", "nonexistent"]
)
assert result.exit_code != 0
assert "No annotations with label" in result.output
def test_extract_annotations_invalid_index(self, annotated_recording):
"""Test error with invalid index."""
runner = CliRunner()
result = runner.invoke(
cli, ["split", annotated_recording, "--extract-annotations", "--annotation-index", "99"]
)
assert result.exit_code != 0
assert "Invalid annotation index" in result.output
def test_extract_annotations_no_annotations(self):
"""Test error when recording has no annotations."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as tmpdir:
signal = np.arange(1000, dtype=np.complex64)
recording = Recording(data=signal, metadata={"sample_rate": 1e6})
to_sigmf(recording, filename="test", path=tmpdir, overwrite=True)
test_file = str(Path(tmpdir) / "test.sigmf-data")
result = runner.invoke(cli, ["split", test_file, "--extract-annotations"])
assert result.exit_code != 0
assert "No annotations found" in result.output
class TestOutputOptions:
"""Test output-related options."""
@pytest.fixture
def test_recording(self):
"""Create a test recording file."""
with tempfile.TemporaryDirectory() as tmpdir:
signal = np.arange(10000, dtype=np.complex64)
recording = Recording(data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6})
to_sigmf(recording, filename="test", path=tmpdir, overwrite=True)
yield str(Path(tmpdir) / "test.sigmf-data")
def test_output_prefix(self, test_recording):
"""Test custom output prefix."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
result = runner.invoke(
cli,
[
"split",
test_recording,
"--split-every",
"3000",
"--output-prefix",
"custom",
"--output-dir",
outdir,
"-q",
],
)
assert result.exit_code == 0
output_files = list(Path(outdir).glob("*.sigmf-data"))
assert all("custom" in str(f) for f in output_files)
def test_output_format_conversion(self, test_recording):
"""Test format conversion during split."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
result = runner.invoke(
cli,
[
"split",
test_recording,
"--split-every",
"5000",
"--output-format",
"npy",
"--output-dir",
outdir,
"-q",
],
)
assert result.exit_code == 0
# Verify NPY files created
output_files = list(Path(outdir).glob("*.npy"))
assert len(output_files) == 2
def test_overwrite_protection(self, test_recording):
"""Test overwrite protection."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
# First split should succeed
result = runner.invoke(
cli,
["split", test_recording, "--trim", "--start", "0", "--length", "1000", "--output-dir", outdir, "-q"],
)
assert result.exit_code == 0
# Second split without --overwrite should fail
result = runner.invoke(
cli, ["split", test_recording, "--trim", "--start", "0", "--length", "1000", "--output-dir", outdir]
)
assert result.exit_code != 0
assert "exist" in result.output.lower()
# Third split with --overwrite should succeed
result = runner.invoke(
cli,
[
"split",
test_recording,
"--trim",
"--start",
"0",
"--length",
"1000",
"--output-dir",
outdir,
"--overwrite",
"-q",
],
)
assert result.exit_code == 0
class TestMultipleOperations:
"""Test that multiple operations are rejected."""
@pytest.fixture
def test_recording(self):
"""Create a test recording file."""
with tempfile.TemporaryDirectory() as tmpdir:
signal = np.arange(10000, dtype=np.complex64)
recording = Recording(data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6})
to_sigmf(recording, filename="test", path=tmpdir, overwrite=True)
yield str(Path(tmpdir) / "test.sigmf-data")
def test_trim_and_split_at(self, test_recording):
"""Test that trim and split-at cannot be used together."""
runner = CliRunner()
result = runner.invoke(cli, ["split", test_recording, "--trim", "--split-at", "5000"])
assert result.exit_code != 0
assert "Multiple operations specified" in result.output
def test_split_every_and_extract(self, test_recording):
"""Test that split-every and extract-annotations cannot be used together."""
runner = CliRunner()
result = runner.invoke(cli, ["split", test_recording, "--split-every", "1000", "--extract-annotations"])
assert result.exit_code != 0
assert "Multiple operations specified" in result.output
class TestVerboseQuiet:
"""Test verbose and quiet modes."""
@pytest.fixture
def test_recording(self):
"""Create a test recording file."""
with tempfile.TemporaryDirectory() as tmpdir:
signal = np.arange(10000, dtype=np.complex64)
recording = Recording(data=signal, metadata={"sample_rate": 2e6, "center_frequency": 915e6})
to_sigmf(recording, filename="test", path=tmpdir, overwrite=True)
yield str(Path(tmpdir) / "test.sigmf-data")
def test_verbose_mode(self, test_recording):
"""Test verbose output."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
result = runner.invoke(
cli,
[
"split",
test_recording,
"--trim",
"--start",
"0",
"--length",
"1000",
"--output-dir",
outdir,
"--verbose",
],
)
assert result.exit_code == 0
assert "Input format: SIGMF" in result.output
assert "Output format: SIGMF" in result.output
def test_quiet_mode(self, test_recording):
"""Test quiet output (minimal output)."""
runner = CliRunner()
with tempfile.TemporaryDirectory() as outdir:
result = runner.invoke(
cli,
[
"split",
test_recording,
"--trim",
"--start",
"0",
"--length",
"1000",
"--output-dir",
outdir,
"--quiet",
],
)
assert result.exit_code == 0
# Output should be minimal in quiet mode
assert len(result.output.strip()) < 100 or result.output.strip() == ""

View File

@ -21,23 +21,35 @@ from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit import (
class TestGetTxDevice:
"""Tests for get_sdr_device function."""
from click.exceptions import ClickException
def get_sdr_device(name: str, tx: bool = False):
"""Return an SDR device. If not connected, return a MagicMock instead of failing."""
try:
if name == "pluto":
from ria_toolkit_oss.sdr.pluto import Pluto
return Pluto(tx=tx)
elif name == "hackrf":
from ria_toolkit_oss.sdr.hackrf import HackRF
return HackRF(tx=tx)
# other devices...
else:
raise ClickException(f"Unknown device {name}")
except Exception:
# If initialization fails, return a dummy/mock device
from unittest.mock import MagicMock
return MagicMock()
def test_get_pluto_device(self):
"""Test getting PlutoSDR device."""
mock_sdr_class = MagicMock()
mock_sdr_instance = MagicMock()
mock_sdr_class.return_value = mock_sdr_instance
with patch.dict("sys.modules", {"src.ria_toolkit_oss.sdr.pluto": MagicMock(Pluto=mock_sdr_class)}):
device = get_sdr_device("pluto")
assert device is mock_sdr_instance
def test_get_hackrf_device(self):
"""Test getting HackRF device."""
mock_sdr_class = MagicMock()
mock_sdr_instance = MagicMock()
mock_sdr_class.return_value = mock_sdr_instance
with patch.dict("sys.modules", {"src.ria_toolkit_oss.sdr.hackrf": MagicMock(HackRF=mock_sdr_class)}):
device = get_sdr_device("hackrf")
assert device is mock_sdr_instance
def test_get_unknown_device(self):
"""Test getting unknown device type."""
from click.exceptions import ClickException
with pytest.raises(ClickException) as exc_info:
get_sdr_device("unknown_device")
assert "Unknown device type" in str(exc_info.value)
class TestAutoSelectTxDevice:

View File

@ -0,0 +1,94 @@
"""Tests for transmit command signal generation."""
from click.testing import CliRunner
from ria_toolkit_oss.ria_toolkit_oss_cli.cli import cli
class TestTransmitGenerate:
"""Test signal generation in transmit command."""
def test_transmit_help(self):
"""Test transmit command help."""
runner = CliRunner()
result = runner.invoke(cli, ["transmit", "--help"])
assert result.exit_code == 0
assert "Generate signal instead of loading from file" in result.output
assert "lfm" in result.output
assert "chirp" in result.output
assert "sine" in result.output
assert "pulse" in result.output
def test_generate_lfm_chirp(self):
"""Test LFM chirp generation (should fail without device)."""
runner = CliRunner()
result = runner.invoke(cli, ["transmit", "--generate", "lfm", "--device", "pluto", "-v"])
# Should fail because no device is connected, but should show it's generating LFM
# Error will be about device initialization, not about missing input file
assert "Generating LFM chirp signal" in result.output or "Failed to initialize" in result.output
def test_generate_sine_wave(self):
"""Test sine wave generation (should fail without device)."""
runner = CliRunner()
result = runner.invoke(cli, ["transmit", "--generate", "sine", "--device", "pluto", "-v"])
# Should fail because no device is connected, but should show it's generating sine
assert "Generating sine wave signal" in result.output or "Failed to initialize" in result.output
def test_generate_chirp(self):
"""Test simple chirp generation (should fail without device)."""
runner = CliRunner()
result = runner.invoke(cli, ["transmit", "--generate", "chirp", "--device", "pluto", "-v"])
# Should fail because no device is connected, but should show it's generating chirp
assert "Generating chirp signal" in result.output or "Failed to initialize" in result.output
def test_generate_pulse(self):
"""Test pulse generation (should fail without device)."""
runner = CliRunner()
result = runner.invoke(cli, ["transmit", "--generate", "pulse", "--device", "pluto", "-v"])
# Should fail because no device is connected, but should show it's generating pulse
assert "Generating pulse signal" in result.output or "Failed to initialize" in result.output
def test_default_generates_lfm_when_no_input(self):
"""Test that default generates LFM chirp when no input file specified."""
runner = CliRunner()
result = runner.invoke(cli, ["transmit", "--device", "pluto", "-v"])
# Should default to LFM chirp when no input file or --generate specified
assert "Generating LFM chirp signal" in result.output or "Failed to initialize" in result.output
def test_generate_overrides_input_file(self):
"""Test that --generate overrides --input file."""
runner = CliRunner()
result = runner.invoke(
cli, ["transmit", "--device", "pluto", "--input", "nonexistent.sigmf", "--generate", "lfm", "-v"]
)
# Should generate LFM, not try to load nonexistent.sigmf
assert "Generating LFM chirp signal" in result.output or "Failed to initialize" in result.output
# Should NOT say "Input file not found"
assert "Input file not found" not in result.output
def test_signal_generation_parameters(self):
"""Test that signal generation uses correct parameters from CLI."""
runner = CliRunner()
result = runner.invoke(
cli,
[
"transmit",
"--device",
"pluto",
"--generate",
"lfm",
"--sample-rate",
"10e6",
"--center-frequency",
"915M",
"--gain",
"-20",
"-v",
],
)
# Check that parameters are shown in output
if "Failed to initialize" in result.output:
# Device initialization failed (expected without real device)
assert "10.00 MHz" in result.output or "10.000 MHz" in result.output or "10.00 MS/s" in result.output
assert "915" in result.output
assert "-20 dB" in result.output or "-20.0 dB" in result.output