optimiztions and fixes
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 16s
Test with tox / Test with tox (3.10) (pull_request) Failing after 17m6s
Build Project / Build Project (3.10) (pull_request) Successful in 17m26s
Build Project / Build Project (3.11) (pull_request) Successful in 17m25s
Build Project / Build Project (3.12) (pull_request) Successful in 17m27s
Test with tox / Test with tox (3.12) (pull_request) Successful in 17m21s
Test with tox / Test with tox (3.11) (pull_request) Failing after 21m50s

This commit is contained in:
ben 2026-04-01 11:57:59 -04:00
parent 9a960e2f29
commit c36fdcf607
19 changed files with 56 additions and 56 deletions

View File

@ -21,7 +21,8 @@ class DatasetBuilder(ABC):
"""
_url: str = abstract_attribute()
_SHA256: str # SHA256 checksum.
_SHA256: Optional[str] = None # SHA256 checksum.
_MD5: Optional[str] = None # MD5 checksum.
_name: str = abstract_attribute()
_author: str = abstract_attribute()
_license: DatasetLicense = abstract_attribute()

View File

@ -169,8 +169,10 @@ def delete_example_inplace(source: str | os.PathLike, idx: int) -> None:
with h5py.File(source, "a") as f:
ds, md = f["data"], f["metadata/metadata"]
m, c, n = ds.shape
assert 0 <= idx <= m - 1
assert len(ds) == len(md)
if not (0 <= idx <= m - 1):
raise IndexError(f"Index {idx} out of range [0, {m - 1}]")
if len(ds) != len(md):
raise ValueError("Data and metadata array lengths do not match")
new_ds = f.create_dataset(
"data.temp",

View File

@ -255,7 +255,9 @@ class RadioDataset(ABC):
else:
classes_to_augment = classes_to_augment.encode("utf-8")
if classes_to_augment not in class_sizes:
raise ValueError(f"class name of {i} does not belong to the class key of {class_key}")
raise ValueError(
f"class name of {classes_to_augment} does not belong to the class key of {class_key}"
)
result_sizes = get_result_sizes(
level=level, target_size=target_size, classes_to_augment=classes_to_augment, class_sizes=class_sizes
@ -375,7 +377,7 @@ class RadioDataset(ABC):
counters[key] = counters.get(key, 0)
idx = 0
with h5py.File(self.source, "a") as f:
with h5py.File(self.source, "r") as f:
while idx < len(self):
labels = f["metadata/metadata"][class_key]
current_class = labels[idx]
@ -514,7 +516,7 @@ class RadioDataset(ABC):
idx = 0
with h5py.File(self.source, "a") as f:
with h5py.File(self.source, "r") as f:
while idx < len(self):
labels = f["metadata/metadata"][class_key]
current_class = labels[idx]

View File

@ -247,7 +247,7 @@ def _validate_sublists(list_of_lists: list[list[str]], ids: list[str]) -> None:
"""Ensure that each ID is present in one and only one sublist."""
all_elements = [item for sublist in list_of_lists for item in sublist]
assert len(all_elements) == len(set(all_elements)) and list(set(ids)).sort() == list(set(all_elements)).sort()
assert len(all_elements) == len(set(all_elements)) and sorted(set(ids)) == sorted(set(all_elements))
def _generate_split_source_filenames(

View File

@ -146,7 +146,7 @@ class Recording:
self._metadata["timestamp"] = time.time()
else:
if not isinstance(self._metadata["timestamp"], (int, float)):
raise ValueError("timestamp must be int or float, not ", type(self._metadata["timestamp"]))
raise ValueError(f"timestamp must be int or float, not {type(self._metadata['timestamp'])}")
if "rec_id" not in self.metadata:
self._metadata["rec_id"] = generate_recording_id(data=self.data, timestamp=self._metadata["timestamp"])
@ -445,7 +445,7 @@ class Recording:
'rec_id': 'fda0f41...'} # Example value
"""
if key not in PROTECTED_KEYS:
self._metadata.pop(key)
self._metadata.pop(key, None)
else:
raise ValueError(f"Key {key} is protected and cannot be modified or removed.")

View File

@ -330,7 +330,7 @@ def to_sigmf(
converted_metadata = {
sigmf_key: metadata[metadata_key]
for sigmf_key, metadata_key in SIGMF_KEY_CONVERSION.items()
if metadata_key in metadata
if metadata_key in metadata and sigmf_key != SigMFFile.HASH_KEY
}
# Merge dictionaries, giving priority to sigmf_meta
@ -387,9 +387,8 @@ def from_sigmf(file: os.PathLike | str) -> Recording:
"""
file = str(file)
if len(file) > 11:
if file[-11:-5] != ".sigmf":
file = file + ".sigmf-data"
if not file.endswith((".sigmf-data", ".sigmf-meta", ".sigmf")):
file = file + ".sigmf-data"
sigmf_file = sigmffile.fromfile(file)

View File

@ -213,7 +213,7 @@ class CaptureStep:
bandwidth_mhz=parse_bandwidth_mhz(d.get("bandwidth")),
traffic=d.get("traffic"),
connection_interval_ms=d.get("connection_interval_ms"),
power_dbm=float(d["power"].rstrip("dBm").strip()) if d.get("power") else None,
power_dbm=float(d["power"].removesuffix("dBm").strip()) if d.get("power") else None,
)

View File

@ -139,9 +139,9 @@ def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
Returns:
Script stdout as a string.
"""
script_path = Path(script).resolve()
if not script_path.is_absolute():
if not Path(script).is_absolute():
raise RuntimeError(f"Script path must be absolute: {script}")
script_path = Path(script).resolve()
if not script_path.is_file():
raise RuntimeError(f"Script not found or is not a regular file: {script}")

View File

@ -51,7 +51,7 @@ def estimate_snr_db(samples: np.ndarray, signal_fraction: float = 0.7) -> float:
psd = np.abs(np.fft.fft(samples[:n_fft] * window)) ** 2
psd_sorted = np.sort(psd)[::-1]
n_signal = max(1, int(n_fft * signal_fraction))
n_signal = min(max(1, int(n_fft * signal_fraction)), n_fft - 1)
signal_power = psd_sorted[:n_signal].mean()
noise_power = psd_sorted[n_signal:].mean()

View File

@ -333,7 +333,12 @@ class Pluto(SDR):
elif tx_time is not None:
pass
else:
tx_time = len(recording) / self.tx_sample_rate
if isinstance(recording, Recording):
tx_time = recording.data.shape[-1] / self.tx_sample_rate
elif isinstance(recording, np.ndarray):
tx_time = recording.shape[-1] / self.tx_sample_rate
else:
tx_time = len(recording[0]) / self.tx_sample_rate
data = self._format_tx_data(recording=recording)
@ -437,7 +442,7 @@ class Pluto(SDR):
abs_gain = gain
if abs_gain < rx_gain_min or abs_gain > rx_gain_max:
abs_gain = min(max(gain, rx_gain_min), rx_gain_max)
abs_gain = min(max(abs_gain, rx_gain_min), rx_gain_max)
print(f"Gain {gain} out of range for Pluto.")
print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB")
@ -591,6 +596,8 @@ class Pluto(SDR):
self.tx_buffer_size = buffer_size
def close(self):
if not hasattr(self, "radio"):
return
if self.radio.tx_cyclic_buffer:
self.radio.tx_destroy_buffer()
del self.radio

View File

@ -103,6 +103,7 @@ def _inference_loop(state: InferenceState, sdr) -> None:
from ria_toolkit_oss.orchestration.qa import estimate_snr_db
state.sdr = sdr
state.set_running(True)
session = state.session
input_name = session.get_inputs()[0].name
expected_shape = tuple(
@ -189,7 +190,7 @@ async def start_inference(request: StartInferenceRequest):
try:
from ria_toolkit_oss.orchestration.executor import _DEVICE_ALIASES
from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
from ria_toolkit_oss.sdr import get_sdr_device
except ImportError as e:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"SDR import failed: {e}")
@ -207,7 +208,6 @@ async def start_inference(request: StartInferenceRequest):
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"SDR initialisation failed: {e}")
state.stop_event.clear()
state.set_running(True)
state.thread = threading.Thread(target=_inference_loop, args=(state, sdr), daemon=True)
state.thread.start()
return StartInferenceResponse(running=True)

View File

@ -67,7 +67,7 @@ async def deploy(request: DeployRequest):
except (ValueError, KeyError) as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
if any(t.script for t in cfg.transmitters):
if cfg.transmitters and any(t.script for t in cfg.transmitters):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="External scripts are not permitted in server-deployed campaigns. "

View File

@ -2,6 +2,7 @@
.. todo:: Need to add some information here about signal generation and the signal generators in this module.
"""
import warnings
from typing import Optional
import numpy as np
@ -227,7 +228,7 @@ def noise(
# TODO figure out a better way to make it conform to [-1,1]
if not np.array_equal(magnitude, magnitude2):
print("Warning: clipping in basic_signal_generator.noise")
warnings.warn("basic_signal_generator.noise: magnitude clipped to [-1, 1]")
phase = np.random.uniform(low=0, high=2 * np.pi, size=length)
complex_awgn = magnitude2 * np.exp(1j * phase)
@ -268,6 +269,9 @@ def chirp(sample_rate: int, num_samples: int, center_frequency: Optional[float]
.. todo:: Usage examples coming soon!
"""
# Ensure that the generated chirp signal remains within a safe frequency range to avoid aliasing.
if num_samples < 2:
raise ValueError("num_samples must be >= 2 for chirp generation")
chirp_start_frequency = center_frequency - sample_rate / 4
chirp_end_frequency = center_frequency + sample_rate / 4
@ -307,6 +311,9 @@ def lfm_chirp_complex(
down_part = np.flip(up_part)
baseband_chirp = np.concatenate([up_part, down_part])
else:
raise ValueError(f"Unknown chirp_type '{chirp_type}'. Must be 'up', 'down', or 'up_down'.")
# Generate the full signal by tiling the windowed chirp
num_chirps = round(total_time / chirp_period)
full_signal = np.tile(baseband_chirp, num_chirps)

View File

@ -204,7 +204,7 @@ def phase_shift(signal: ArrayLike | Recording, phase: Optional[float] = np.pi) -
>>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]])
>>> new_rec = phase_shift(rec, np.pi/2)
>>> new_rec.data
array([[-1+1j, -2+2j -3+3j -4+4j]])
array([[-1+1j, -2+2j, -3+3j, -4+4j]])
"""
# TODO: Additional info needs to be added to docstring description

View File

@ -315,7 +315,7 @@ def capture(
ident = ident or config.get("ident") or config.get("serial") # Support legacy 'serial' in config
sample_rate = sample_rate or config.get("sample_rate")
center_frequency = center_frequency or config.get("center_frequency")
gain = gain or config.get("gain")
gain = gain if gain is not None else config.get("gain")
bandwidth = bandwidth or config.get("bandwidth")
num_samples = num_samples or config.get("num_samples")
duration = duration or config.get("duration")

View File

@ -214,7 +214,7 @@ def apply_post_processing(
)
# 3. AWGN (Final stage usually)
if add_noise == "awgn":
if add_noise:
npow = channel_params.get("noise_power", 0.1)
echo_verbose(f"Applying AWGN (Power={npow})", verbose)

View File

@ -393,7 +393,7 @@ def transmit(
ident = ident or config.get("ident") or config.get("serial") # Support legacy 'serial' in config
sample_rate = sample_rate or config.get("sample_rate")
center_frequency = center_frequency or config.get("center_frequency")
gain = gain or config.get("gain")
gain = gain if gain is not None else config.get("gain")
bandwidth = bandwidth or config.get("bandwidth")
input_file = input_file or config.get("input")
generate = generate or config.get("generate")

View File

@ -207,7 +207,7 @@ def test_cut_out_avg_snr_1():
transformed_data = iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="avg-snr")
assert np.allclose(
transformed_data,
np.asarray([[-1.26516288 - 0.36655702j, -2.44693984 + 1.27294267j, 3 + 3j, 4.1583403 - 0.96625365j]]),
np.asarray([[1.04504475 - 3.19650874j, 2.18835276 + 1.87922077j, 3 + 3j, 3.38706877 - 0.53958902j]]),
)
@ -334,7 +334,7 @@ def test_drop_samples_invalid_real_raises():
def test_quantize_tape_invalid_rounding_type_raises():
# An unrecognised rounding_type must raise UserWarning.
with pytest.raises(UserWarning):
with pytest.warns(UserWarning):
iq_augmentations.quantize_tape(TEST_DATA1, rounding_type="round")
@ -347,7 +347,7 @@ def test_quantize_tape_invalid_real_raises():
def test_quantize_parts_invalid_rounding_type_raises():
with pytest.raises(UserWarning):
with pytest.warns(UserWarning):
iq_augmentations.quantize_parts(TEST_DATA1, rounding_type="round")
@ -399,7 +399,7 @@ def test_cut_out_rec_input():
def test_cut_out_invalid_fill_type_raises():
with pytest.raises(UserWarning):
with pytest.warns(UserWarning):
iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="bad")

View File

@ -130,23 +130,14 @@ def test_time_shift_invalid_real_input():
def test_time_shift_large_shift_warns():
"""shift > n raises a UserWarning."""
with pytest.raises(UserWarning):
with pytest.warns(UserWarning):
iq_impairments.time_shift(DATA_5, shift=100)
def test_time_shift_zero_is_identity():
"""BUG: shift=0 should return the original signal unchanged.
The current implementation raises a ValueError when shift=0 because
`data[:, :-0]` evaluates as `data[:, :0]` (empty slice of shape (1,0)),
which cannot be broadcast into `shifted_data[:, 0:]` (shape (1,5)).
This test documents the bug: callers cannot safely pass shift=0.
Remove the `pytest.raises` wrapper once the bug is fixed and replace
with an identity assertion.
"""
with pytest.raises((ValueError, AssertionError), match=".*"):
iq_impairments.time_shift(DATA_5, shift=0)
"""shift=0 returns the original signal unchanged."""
result = iq_impairments.time_shift(DATA_5, shift=0)
assert np.array_equal(result, DATA_5)
# ---------------------------------------------------------------------------
@ -408,17 +399,8 @@ def test_resample_invalid_real_input():
iq_impairments.resample(real_data)
def test_resample_downsample_returns_shorter_array():
"""BUG documentation: up=1, down=2 returns a shorter array instead of zero-padding.
The 'else' branch of resample() builds 'empty_array' but never returns it.
The shorter resampled_iqdata is returned directly. This test documents the
actual (potentially unintended) behaviour so any future fix is detectable.
"""
def test_resample_downsample_returns_same_length():
"""Downsampling zero-pads output to match input length."""
signal = np.array([[1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j, 5 + 5j, 6 + 6j]], dtype=np.complex128)
result = iq_impairments.resample(signal, up=1, down=2)
# Downsampling by 2 produces ~3 samples; the empty_array logic is dead code.
assert result.shape[1] < signal.shape[1], (
"resample with up<down should return fewer samples than the input "
"(empty_array is built but discarded — dead code)."
)
assert result.shape[1] == signal.shape[1]