zfp-oss tools #18

Merged
benchinnery merged 5 commits from zfp-oss into main 2026-04-01 13:52:10 -04:00
19 changed files with 56 additions and 56 deletions
Showing only changes of commit c36fdcf607 - Show all commits

View File

@ -21,7 +21,8 @@ class DatasetBuilder(ABC):
""" """
_url: str = abstract_attribute() _url: str = abstract_attribute()
_SHA256: str # SHA256 checksum. _SHA256: Optional[str] = None # SHA256 checksum.
_MD5: Optional[str] = None # MD5 checksum.
_name: str = abstract_attribute() _name: str = abstract_attribute()
_author: str = abstract_attribute() _author: str = abstract_attribute()
_license: DatasetLicense = abstract_attribute() _license: DatasetLicense = abstract_attribute()

View File

@ -169,8 +169,10 @@ def delete_example_inplace(source: str | os.PathLike, idx: int) -> None:
with h5py.File(source, "a") as f: with h5py.File(source, "a") as f:
ds, md = f["data"], f["metadata/metadata"] ds, md = f["data"], f["metadata/metadata"]
m, c, n = ds.shape m, c, n = ds.shape
assert 0 <= idx <= m - 1 if not (0 <= idx <= m - 1):
assert len(ds) == len(md) raise IndexError(f"Index {idx} out of range [0, {m - 1}]")
if len(ds) != len(md):
raise ValueError("Data and metadata array lengths do not match")
new_ds = f.create_dataset( new_ds = f.create_dataset(
"data.temp", "data.temp",

View File

@ -255,7 +255,9 @@ class RadioDataset(ABC):
else: else:
classes_to_augment = classes_to_augment.encode("utf-8") classes_to_augment = classes_to_augment.encode("utf-8")
if classes_to_augment not in class_sizes: if classes_to_augment not in class_sizes:
raise ValueError(f"class name of {i} does not belong to the class key of {class_key}") raise ValueError(
f"class name of {classes_to_augment} does not belong to the class key of {class_key}"
)
result_sizes = get_result_sizes( result_sizes = get_result_sizes(
level=level, target_size=target_size, classes_to_augment=classes_to_augment, class_sizes=class_sizes level=level, target_size=target_size, classes_to_augment=classes_to_augment, class_sizes=class_sizes
@ -375,7 +377,7 @@ class RadioDataset(ABC):
counters[key] = counters.get(key, 0) counters[key] = counters.get(key, 0)
idx = 0 idx = 0
with h5py.File(self.source, "a") as f: with h5py.File(self.source, "r") as f:
while idx < len(self): while idx < len(self):
labels = f["metadata/metadata"][class_key] labels = f["metadata/metadata"][class_key]
current_class = labels[idx] current_class = labels[idx]
@ -514,7 +516,7 @@ class RadioDataset(ABC):
idx = 0 idx = 0
with h5py.File(self.source, "a") as f: with h5py.File(self.source, "r") as f:
while idx < len(self): while idx < len(self):
labels = f["metadata/metadata"][class_key] labels = f["metadata/metadata"][class_key]
current_class = labels[idx] current_class = labels[idx]

View File

@ -247,7 +247,7 @@ def _validate_sublists(list_of_lists: list[list[str]], ids: list[str]) -> None:
"""Ensure that each ID is present in one and only one sublist.""" """Ensure that each ID is present in one and only one sublist."""
all_elements = [item for sublist in list_of_lists for item in sublist] all_elements = [item for sublist in list_of_lists for item in sublist]
assert len(all_elements) == len(set(all_elements)) and list(set(ids)).sort() == list(set(all_elements)).sort() assert len(all_elements) == len(set(all_elements)) and sorted(set(ids)) == sorted(set(all_elements))
def _generate_split_source_filenames( def _generate_split_source_filenames(

View File

@ -146,7 +146,7 @@ class Recording:
self._metadata["timestamp"] = time.time() self._metadata["timestamp"] = time.time()
else: else:
if not isinstance(self._metadata["timestamp"], (int, float)): if not isinstance(self._metadata["timestamp"], (int, float)):
raise ValueError("timestamp must be int or float, not ", type(self._metadata["timestamp"])) raise ValueError(f"timestamp must be int or float, not {type(self._metadata['timestamp'])}")
if "rec_id" not in self.metadata: if "rec_id" not in self.metadata:
self._metadata["rec_id"] = generate_recording_id(data=self.data, timestamp=self._metadata["timestamp"]) self._metadata["rec_id"] = generate_recording_id(data=self.data, timestamp=self._metadata["timestamp"])
@ -445,7 +445,7 @@ class Recording:
'rec_id': 'fda0f41...'} # Example value 'rec_id': 'fda0f41...'} # Example value
""" """
if key not in PROTECTED_KEYS: if key not in PROTECTED_KEYS:
self._metadata.pop(key) self._metadata.pop(key, None)
else: else:
raise ValueError(f"Key {key} is protected and cannot be modified or removed.") raise ValueError(f"Key {key} is protected and cannot be modified or removed.")

View File

@ -330,7 +330,7 @@ def to_sigmf(
converted_metadata = { converted_metadata = {
sigmf_key: metadata[metadata_key] sigmf_key: metadata[metadata_key]
for sigmf_key, metadata_key in SIGMF_KEY_CONVERSION.items() for sigmf_key, metadata_key in SIGMF_KEY_CONVERSION.items()
if metadata_key in metadata if metadata_key in metadata and sigmf_key != SigMFFile.HASH_KEY
} }
# Merge dictionaries, giving priority to sigmf_meta # Merge dictionaries, giving priority to sigmf_meta
@ -387,9 +387,8 @@ def from_sigmf(file: os.PathLike | str) -> Recording:
""" """
file = str(file) file = str(file)
if len(file) > 11: if not file.endswith((".sigmf-data", ".sigmf-meta", ".sigmf")):
if file[-11:-5] != ".sigmf": file = file + ".sigmf-data"
file = file + ".sigmf-data"
sigmf_file = sigmffile.fromfile(file) sigmf_file = sigmffile.fromfile(file)

View File

@ -213,7 +213,7 @@ class CaptureStep:
bandwidth_mhz=parse_bandwidth_mhz(d.get("bandwidth")), bandwidth_mhz=parse_bandwidth_mhz(d.get("bandwidth")),
traffic=d.get("traffic"), traffic=d.get("traffic"),
connection_interval_ms=d.get("connection_interval_ms"), connection_interval_ms=d.get("connection_interval_ms"),
power_dbm=float(d["power"].rstrip("dBm").strip()) if d.get("power") else None, power_dbm=float(d["power"].removesuffix("dBm").strip()) if d.get("power") else None,
) )

View File

@ -139,9 +139,9 @@ def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
Returns: Returns:
Script stdout as a string. Script stdout as a string.
""" """
script_path = Path(script).resolve() if not Path(script).is_absolute():
if not script_path.is_absolute():
raise RuntimeError(f"Script path must be absolute: {script}") raise RuntimeError(f"Script path must be absolute: {script}")
script_path = Path(script).resolve()
if not script_path.is_file(): if not script_path.is_file():
raise RuntimeError(f"Script not found or is not a regular file: {script}") raise RuntimeError(f"Script not found or is not a regular file: {script}")

View File

@ -51,7 +51,7 @@ def estimate_snr_db(samples: np.ndarray, signal_fraction: float = 0.7) -> float:
psd = np.abs(np.fft.fft(samples[:n_fft] * window)) ** 2 psd = np.abs(np.fft.fft(samples[:n_fft] * window)) ** 2
psd_sorted = np.sort(psd)[::-1] psd_sorted = np.sort(psd)[::-1]
n_signal = max(1, int(n_fft * signal_fraction)) n_signal = min(max(1, int(n_fft * signal_fraction)), n_fft - 1)
signal_power = psd_sorted[:n_signal].mean() signal_power = psd_sorted[:n_signal].mean()
noise_power = psd_sorted[n_signal:].mean() noise_power = psd_sorted[n_signal:].mean()

View File

@ -333,7 +333,12 @@ class Pluto(SDR):
elif tx_time is not None: elif tx_time is not None:
pass pass
else: else:
tx_time = len(recording) / self.tx_sample_rate if isinstance(recording, Recording):
tx_time = recording.data.shape[-1] / self.tx_sample_rate
elif isinstance(recording, np.ndarray):
tx_time = recording.shape[-1] / self.tx_sample_rate
else:
tx_time = len(recording[0]) / self.tx_sample_rate
data = self._format_tx_data(recording=recording) data = self._format_tx_data(recording=recording)
@ -437,7 +442,7 @@ class Pluto(SDR):
abs_gain = gain abs_gain = gain
if abs_gain < rx_gain_min or abs_gain > rx_gain_max: if abs_gain < rx_gain_min or abs_gain > rx_gain_max:
abs_gain = min(max(gain, rx_gain_min), rx_gain_max) abs_gain = min(max(abs_gain, rx_gain_min), rx_gain_max)
print(f"Gain {gain} out of range for Pluto.") print(f"Gain {gain} out of range for Pluto.")
print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB") print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB")
@ -591,6 +596,8 @@ class Pluto(SDR):
self.tx_buffer_size = buffer_size self.tx_buffer_size = buffer_size
def close(self): def close(self):
if not hasattr(self, "radio"):
return
if self.radio.tx_cyclic_buffer: if self.radio.tx_cyclic_buffer:
self.radio.tx_destroy_buffer() self.radio.tx_destroy_buffer()
del self.radio del self.radio

View File

@ -103,6 +103,7 @@ def _inference_loop(state: InferenceState, sdr) -> None:
from ria_toolkit_oss.orchestration.qa import estimate_snr_db from ria_toolkit_oss.orchestration.qa import estimate_snr_db
state.sdr = sdr state.sdr = sdr
state.set_running(True)
session = state.session session = state.session
input_name = session.get_inputs()[0].name input_name = session.get_inputs()[0].name
expected_shape = tuple( expected_shape = tuple(
@ -189,7 +190,7 @@ async def start_inference(request: StartInferenceRequest):
try: try:
from ria_toolkit_oss.orchestration.executor import _DEVICE_ALIASES from ria_toolkit_oss.orchestration.executor import _DEVICE_ALIASES
from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device from ria_toolkit_oss.sdr import get_sdr_device
except ImportError as e: except ImportError as e:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"SDR import failed: {e}") raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"SDR import failed: {e}")
@ -207,7 +208,6 @@ async def start_inference(request: StartInferenceRequest):
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"SDR initialisation failed: {e}") raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"SDR initialisation failed: {e}")
state.stop_event.clear() state.stop_event.clear()
state.set_running(True)
state.thread = threading.Thread(target=_inference_loop, args=(state, sdr), daemon=True) state.thread = threading.Thread(target=_inference_loop, args=(state, sdr), daemon=True)
state.thread.start() state.thread.start()
return StartInferenceResponse(running=True) return StartInferenceResponse(running=True)

View File

@ -67,7 +67,7 @@ async def deploy(request: DeployRequest):
except (ValueError, KeyError) as e: except (ValueError, KeyError) as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
if any(t.script for t in cfg.transmitters): if cfg.transmitters and any(t.script for t in cfg.transmitters):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="External scripts are not permitted in server-deployed campaigns. " detail="External scripts are not permitted in server-deployed campaigns. "

View File

@ -2,6 +2,7 @@
.. todo:: Need to add some information here about signal generation and the signal generators in this module. .. todo:: Need to add some information here about signal generation and the signal generators in this module.
""" """
import warnings
from typing import Optional from typing import Optional
import numpy as np import numpy as np
@ -227,7 +228,7 @@ def noise(
# TODO figure out a better way to make it conform to [-1,1] # TODO figure out a better way to make it conform to [-1,1]
if not np.array_equal(magnitude, magnitude2): if not np.array_equal(magnitude, magnitude2):
print("Warning: clipping in basic_signal_generator.noise") warnings.warn("basic_signal_generator.noise: magnitude clipped to [-1, 1]")
phase = np.random.uniform(low=0, high=2 * np.pi, size=length) phase = np.random.uniform(low=0, high=2 * np.pi, size=length)
complex_awgn = magnitude2 * np.exp(1j * phase) complex_awgn = magnitude2 * np.exp(1j * phase)
@ -268,6 +269,9 @@ def chirp(sample_rate: int, num_samples: int, center_frequency: Optional[float]
.. todo:: Usage examples coming soon! .. todo:: Usage examples coming soon!
""" """
# Ensure that the generated chirp signal remains within a safe frequency range to avoid aliasing. # Ensure that the generated chirp signal remains within a safe frequency range to avoid aliasing.
if num_samples < 2:
raise ValueError("num_samples must be >= 2 for chirp generation")
chirp_start_frequency = center_frequency - sample_rate / 4 chirp_start_frequency = center_frequency - sample_rate / 4
chirp_end_frequency = center_frequency + sample_rate / 4 chirp_end_frequency = center_frequency + sample_rate / 4
@ -307,6 +311,9 @@ def lfm_chirp_complex(
down_part = np.flip(up_part) down_part = np.flip(up_part)
baseband_chirp = np.concatenate([up_part, down_part]) baseband_chirp = np.concatenate([up_part, down_part])
else:
raise ValueError(f"Unknown chirp_type '{chirp_type}'. Must be 'up', 'down', or 'up_down'.")
# Generate the full signal by tiling the windowed chirp # Generate the full signal by tiling the windowed chirp
num_chirps = round(total_time / chirp_period) num_chirps = round(total_time / chirp_period)
full_signal = np.tile(baseband_chirp, num_chirps) full_signal = np.tile(baseband_chirp, num_chirps)

View File

@ -204,7 +204,7 @@ def phase_shift(signal: ArrayLike | Recording, phase: Optional[float] = np.pi) -
>>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]]) >>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]])
>>> new_rec = phase_shift(rec, np.pi/2) >>> new_rec = phase_shift(rec, np.pi/2)
>>> new_rec.data >>> new_rec.data
array([[-1+1j, -2+2j -3+3j -4+4j]]) array([[-1+1j, -2+2j, -3+3j, -4+4j]])
""" """
# TODO: Additional info needs to be added to docstring description # TODO: Additional info needs to be added to docstring description

View File

@ -315,7 +315,7 @@ def capture(
ident = ident or config.get("ident") or config.get("serial") # Support legacy 'serial' in config ident = ident or config.get("ident") or config.get("serial") # Support legacy 'serial' in config
sample_rate = sample_rate or config.get("sample_rate") sample_rate = sample_rate or config.get("sample_rate")
center_frequency = center_frequency or config.get("center_frequency") center_frequency = center_frequency or config.get("center_frequency")
gain = gain or config.get("gain") gain = gain if gain is not None else config.get("gain")
bandwidth = bandwidth or config.get("bandwidth") bandwidth = bandwidth or config.get("bandwidth")
num_samples = num_samples or config.get("num_samples") num_samples = num_samples or config.get("num_samples")
duration = duration or config.get("duration") duration = duration or config.get("duration")

View File

@ -214,7 +214,7 @@ def apply_post_processing(
) )
# 3. AWGN (Final stage usually) # 3. AWGN (Final stage usually)
if add_noise == "awgn": if add_noise:
npow = channel_params.get("noise_power", 0.1) npow = channel_params.get("noise_power", 0.1)
echo_verbose(f"Applying AWGN (Power={npow})", verbose) echo_verbose(f"Applying AWGN (Power={npow})", verbose)

View File

@ -393,7 +393,7 @@ def transmit(
ident = ident or config.get("ident") or config.get("serial") # Support legacy 'serial' in config ident = ident or config.get("ident") or config.get("serial") # Support legacy 'serial' in config
sample_rate = sample_rate or config.get("sample_rate") sample_rate = sample_rate or config.get("sample_rate")
center_frequency = center_frequency or config.get("center_frequency") center_frequency = center_frequency or config.get("center_frequency")
gain = gain or config.get("gain") gain = gain if gain is not None else config.get("gain")
bandwidth = bandwidth or config.get("bandwidth") bandwidth = bandwidth or config.get("bandwidth")
input_file = input_file or config.get("input") input_file = input_file or config.get("input")
generate = generate or config.get("generate") generate = generate or config.get("generate")

View File

@ -207,7 +207,7 @@ def test_cut_out_avg_snr_1():
transformed_data = iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="avg-snr") transformed_data = iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="avg-snr")
assert np.allclose( assert np.allclose(
transformed_data, transformed_data,
np.asarray([[-1.26516288 - 0.36655702j, -2.44693984 + 1.27294267j, 3 + 3j, 4.1583403 - 0.96625365j]]), np.asarray([[1.04504475 - 3.19650874j, 2.18835276 + 1.87922077j, 3 + 3j, 3.38706877 - 0.53958902j]]),
) )
@ -334,7 +334,7 @@ def test_drop_samples_invalid_real_raises():
def test_quantize_tape_invalid_rounding_type_raises(): def test_quantize_tape_invalid_rounding_type_raises():
# An unrecognised rounding_type must raise UserWarning. # An unrecognised rounding_type must raise UserWarning.
with pytest.raises(UserWarning): with pytest.warns(UserWarning):
iq_augmentations.quantize_tape(TEST_DATA1, rounding_type="round") iq_augmentations.quantize_tape(TEST_DATA1, rounding_type="round")
@ -347,7 +347,7 @@ def test_quantize_tape_invalid_real_raises():
def test_quantize_parts_invalid_rounding_type_raises(): def test_quantize_parts_invalid_rounding_type_raises():
with pytest.raises(UserWarning): with pytest.warns(UserWarning):
iq_augmentations.quantize_parts(TEST_DATA1, rounding_type="round") iq_augmentations.quantize_parts(TEST_DATA1, rounding_type="round")
@ -399,7 +399,7 @@ def test_cut_out_rec_input():
def test_cut_out_invalid_fill_type_raises(): def test_cut_out_invalid_fill_type_raises():
with pytest.raises(UserWarning): with pytest.warns(UserWarning):
iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="bad") iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="bad")

View File

@ -130,23 +130,14 @@ def test_time_shift_invalid_real_input():
def test_time_shift_large_shift_warns(): def test_time_shift_large_shift_warns():
"""shift > n raises a UserWarning.""" """shift > n raises a UserWarning."""
with pytest.raises(UserWarning): with pytest.warns(UserWarning):
iq_impairments.time_shift(DATA_5, shift=100) iq_impairments.time_shift(DATA_5, shift=100)
def test_time_shift_zero_is_identity(): def test_time_shift_zero_is_identity():
"""BUG: shift=0 should return the original signal unchanged. """shift=0 returns the original signal unchanged."""
result = iq_impairments.time_shift(DATA_5, shift=0)
The current implementation raises a ValueError when shift=0 because assert np.array_equal(result, DATA_5)
`data[:, :-0]` evaluates as `data[:, :0]` (empty slice of shape (1,0)),
which cannot be broadcast into `shifted_data[:, 0:]` (shape (1,5)).
This test documents the bug: callers cannot safely pass shift=0.
Remove the `pytest.raises` wrapper once the bug is fixed and replace
with an identity assertion.
"""
with pytest.raises((ValueError, AssertionError), match=".*"):
iq_impairments.time_shift(DATA_5, shift=0)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -408,17 +399,8 @@ def test_resample_invalid_real_input():
iq_impairments.resample(real_data) iq_impairments.resample(real_data)
def test_resample_downsample_returns_shorter_array(): def test_resample_downsample_returns_same_length():
"""BUG documentation: up=1, down=2 returns a shorter array instead of zero-padding. """Downsampling zero-pads output to match input length."""
The 'else' branch of resample() builds 'empty_array' but never returns it.
The shorter resampled_iqdata is returned directly. This test documents the
actual (potentially unintended) behaviour so any future fix is detectable.
"""
signal = np.array([[1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j, 5 + 5j, 6 + 6j]], dtype=np.complex128) signal = np.array([[1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j, 5 + 5j, 6 + 6j]], dtype=np.complex128)
result = iq_impairments.resample(signal, up=1, down=2) result = iq_impairments.resample(signal, up=1, down=2)
# Downsampling by 2 produces ~3 samples; the empty_array logic is dead code. assert result.shape[1] == signal.shape[1]
assert result.shape[1] < signal.shape[1], (
"resample with up<down should return fewer samples than the input "
"(empty_array is built but discarded — dead code)."
)