ria-toolkit-oss/tests/orchestration/test_qa.py

193 lines
7.2 KiB
Python
Raw Normal View History

2026-03-11 10:27:18 -04:00
"""Tests for orchestration QA metrics."""
import numpy as np
import pytest
from ria_toolkit_oss.data.recording import Recording
2026-03-11 10:27:18 -04:00
from ria_toolkit_oss.orchestration.campaign import QAConfig
from ria_toolkit_oss.orchestration.qa import QAResult, check_recording, estimate_snr_db
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_recording(n_samples: int, sample_rate: float, signal: np.ndarray) -> Recording:
return Recording(
signal.astype(np.complex64),
metadata={"sample_rate": sample_rate, "center_frequency": 2.45e9},
)
def _tone(n: int, sr: float, freq_hz: float = 100e3, amplitude: float = 0.5) -> np.ndarray:
t = np.arange(n) / sr
return (np.exp(1j * 2 * np.pi * freq_hz * t) * amplitude).astype(np.complex64)
def _noise(n: int, amplitude: float = 0.001) -> np.ndarray:
rng = np.random.default_rng(42)
return ((rng.standard_normal(n) + 1j * rng.standard_normal(n)) * amplitude).astype(np.complex64)
DEFAULT_QA = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
# ---------------------------------------------------------------------------
# estimate_snr_db
# ---------------------------------------------------------------------------
class TestEstimateSnrDb:
def test_high_snr_tone(self):
sr = 1e6
samples = _tone(int(sr * 1), sr)
snr = estimate_snr_db(samples)
assert snr > 20.0, f"Expected high SNR for clean tone, got {snr:.1f} dB"
def test_pure_noise_low_snr(self):
sr = 1e6
rng = np.random.default_rng(0)
samples = (rng.standard_normal(int(sr)) + 1j * rng.standard_normal(int(sr))).astype(np.complex64)
snr = estimate_snr_db(samples)
# Pure noise should yield a low (possibly negative) SNR
assert snr < 15.0, f"Expected low SNR for noise, got {snr:.1f} dB"
def test_snr_increases_with_amplitude(self):
sr = 1e6
n = int(sr)
rng = np.random.default_rng(1)
noise = (rng.standard_normal(n) + 1j * rng.standard_normal(n)).astype(np.complex64) * 0.01
t = np.arange(n) / sr
tone = np.exp(1j * 2 * np.pi * 100e3 * t).astype(np.complex64)
low_snr = estimate_snr_db(noise + tone * 0.1)
high_snr = estimate_snr_db(noise + tone * 1.0)
assert high_snr > low_snr
def test_short_input_still_works(self):
# Input shorter than n_fft=4096 should not raise
samples = _tone(512, 1e6)
snr = estimate_snr_db(samples)
assert np.isfinite(snr)
# ---------------------------------------------------------------------------
# check_recording — pass cases
# ---------------------------------------------------------------------------
class TestCheckRecordingPass:
def test_clean_tone_passes(self):
sr = 1e6
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
result = check_recording(rec, DEFAULT_QA)
assert result.passed is True
assert result.flagged is False
assert result.snr_db > 10.0
assert abs(result.duration_s - 30.0) < 0.1
def test_duration_exactly_at_threshold(self):
sr = 1e6
n = int(sr * 25) # exactly at min_duration_s
rec = _make_recording(n, sr, _tone(n, sr))
result = check_recording(rec, DEFAULT_QA)
assert result.flagged is False
def test_issues_empty_when_passing(self):
sr = 1e6
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
result = check_recording(rec, DEFAULT_QA)
assert result.issues == []
# ---------------------------------------------------------------------------
# check_recording — flag cases
# ---------------------------------------------------------------------------
class TestCheckRecordingFlag:
def test_short_recording_flagged(self):
sr = 1e6
n = int(sr * 10) # shorter than 25s min
rec = _make_recording(n, sr, _tone(n, sr))
result = check_recording(rec, DEFAULT_QA)
assert result.flagged is True
assert any("Duration" in issue for issue in result.issues)
def test_low_snr_flagged(self):
sr = 1e6
n = int(sr * 30)
rec = _make_recording(n, sr, _noise(n, amplitude=0.001))
result = check_recording(rec, DEFAULT_QA)
assert result.flagged is True
assert any("SNR" in issue for issue in result.issues)
def test_flag_for_review_still_passes(self):
"""With flag_for_review=True, flagged recordings are still marked passed."""
sr = 1e6
n = int(sr * 10) # short → will be flagged
rec = _make_recording(n, sr, _tone(n, sr))
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
result = check_recording(rec, qa)
assert result.flagged is True
assert result.passed is True # human review, not auto-reject
def test_flag_for_review_false_fails(self):
"""With flag_for_review=False, a flagged recording is also marked failed."""
sr = 1e6
n = int(sr * 10)
rec = _make_recording(n, sr, _tone(n, sr))
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=False)
result = check_recording(rec, qa)
assert result.flagged is True
assert result.passed is False
def test_multiple_issues_reported(self):
"""Both short duration AND low SNR should both appear in issues list."""
sr = 1e6
n = int(sr * 5) # very short
rec = _make_recording(n, sr, _noise(n, amplitude=0.0001))
result = check_recording(rec, DEFAULT_QA)
assert result.flagged is True
assert len(result.issues) >= 2
# ---------------------------------------------------------------------------
# check_recording — multichannel input
# ---------------------------------------------------------------------------
class TestCheckRecordingMultichannel:
def test_multichannel_recording(self):
"""2-channel recording should evaluate channel 0 without error."""
sr = 1e6
n = int(sr * 30)
ch0 = _tone(n, sr)
ch1 = _tone(n, sr, freq_hz=200e3)
data = np.stack([ch0, ch1]) # shape (2, N)
rec = Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9})
result = check_recording(rec, DEFAULT_QA)
assert result.passed is True
assert result.flagged is False
# ---------------------------------------------------------------------------
# QAResult.to_dict
# ---------------------------------------------------------------------------
class TestQAResultToDict:
def test_to_dict_keys(self):
r = QAResult(passed=True, flagged=False, snr_db=18.3, duration_s=30.0)
d = r.to_dict()
assert set(d.keys()) == {"passed", "flagged", "snr_db", "duration_s", "issues"}
def test_to_dict_values(self):
r = QAResult(passed=False, flagged=True, snr_db=7.5, duration_s=10.2, issues=["SNR below threshold"])
d = r.to_dict()
assert d["passed"] is False
assert d["flagged"] is True
assert d["snr_db"] == pytest.approx(7.5, abs=0.01)
assert d["duration_s"] == pytest.approx(10.2, abs=0.01)
assert d["issues"] == ["SNR below threshold"]