optimiztions and fixes
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 16s
Test with tox / Test with tox (3.10) (pull_request) Failing after 17m6s
Build Project / Build Project (3.10) (pull_request) Successful in 17m26s
Build Project / Build Project (3.11) (pull_request) Successful in 17m25s
Build Project / Build Project (3.12) (pull_request) Successful in 17m27s
Test with tox / Test with tox (3.12) (pull_request) Successful in 17m21s
Test with tox / Test with tox (3.11) (pull_request) Failing after 21m50s
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 16s
Test with tox / Test with tox (3.10) (pull_request) Failing after 17m6s
Build Project / Build Project (3.10) (pull_request) Successful in 17m26s
Build Project / Build Project (3.11) (pull_request) Successful in 17m25s
Build Project / Build Project (3.12) (pull_request) Successful in 17m27s
Test with tox / Test with tox (3.12) (pull_request) Successful in 17m21s
Test with tox / Test with tox (3.11) (pull_request) Failing after 21m50s
This commit is contained in:
parent
9a960e2f29
commit
c36fdcf607
|
|
@ -21,7 +21,8 @@ class DatasetBuilder(ABC):
|
|||
"""
|
||||
|
||||
_url: str = abstract_attribute()
|
||||
_SHA256: str # SHA256 checksum.
|
||||
_SHA256: Optional[str] = None # SHA256 checksum.
|
||||
_MD5: Optional[str] = None # MD5 checksum.
|
||||
_name: str = abstract_attribute()
|
||||
_author: str = abstract_attribute()
|
||||
_license: DatasetLicense = abstract_attribute()
|
||||
|
|
|
|||
|
|
@ -169,8 +169,10 @@ def delete_example_inplace(source: str | os.PathLike, idx: int) -> None:
|
|||
with h5py.File(source, "a") as f:
|
||||
ds, md = f["data"], f["metadata/metadata"]
|
||||
m, c, n = ds.shape
|
||||
assert 0 <= idx <= m - 1
|
||||
assert len(ds) == len(md)
|
||||
if not (0 <= idx <= m - 1):
|
||||
raise IndexError(f"Index {idx} out of range [0, {m - 1}]")
|
||||
if len(ds) != len(md):
|
||||
raise ValueError("Data and metadata array lengths do not match")
|
||||
|
||||
new_ds = f.create_dataset(
|
||||
"data.temp",
|
||||
|
|
|
|||
|
|
@ -255,7 +255,9 @@ class RadioDataset(ABC):
|
|||
else:
|
||||
classes_to_augment = classes_to_augment.encode("utf-8")
|
||||
if classes_to_augment not in class_sizes:
|
||||
raise ValueError(f"class name of {i} does not belong to the class key of {class_key}")
|
||||
raise ValueError(
|
||||
f"class name of {classes_to_augment} does not belong to the class key of {class_key}"
|
||||
)
|
||||
|
||||
result_sizes = get_result_sizes(
|
||||
level=level, target_size=target_size, classes_to_augment=classes_to_augment, class_sizes=class_sizes
|
||||
|
|
@ -375,7 +377,7 @@ class RadioDataset(ABC):
|
|||
counters[key] = counters.get(key, 0)
|
||||
|
||||
idx = 0
|
||||
with h5py.File(self.source, "a") as f:
|
||||
with h5py.File(self.source, "r") as f:
|
||||
while idx < len(self):
|
||||
labels = f["metadata/metadata"][class_key]
|
||||
current_class = labels[idx]
|
||||
|
|
@ -514,7 +516,7 @@ class RadioDataset(ABC):
|
|||
|
||||
idx = 0
|
||||
|
||||
with h5py.File(self.source, "a") as f:
|
||||
with h5py.File(self.source, "r") as f:
|
||||
while idx < len(self):
|
||||
labels = f["metadata/metadata"][class_key]
|
||||
current_class = labels[idx]
|
||||
|
|
|
|||
|
|
@ -247,7 +247,7 @@ def _validate_sublists(list_of_lists: list[list[str]], ids: list[str]) -> None:
|
|||
"""Ensure that each ID is present in one and only one sublist."""
|
||||
all_elements = [item for sublist in list_of_lists for item in sublist]
|
||||
|
||||
assert len(all_elements) == len(set(all_elements)) and list(set(ids)).sort() == list(set(all_elements)).sort()
|
||||
assert len(all_elements) == len(set(all_elements)) and sorted(set(ids)) == sorted(set(all_elements))
|
||||
|
||||
|
||||
def _generate_split_source_filenames(
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class Recording:
|
|||
self._metadata["timestamp"] = time.time()
|
||||
else:
|
||||
if not isinstance(self._metadata["timestamp"], (int, float)):
|
||||
raise ValueError("timestamp must be int or float, not ", type(self._metadata["timestamp"]))
|
||||
raise ValueError(f"timestamp must be int or float, not {type(self._metadata['timestamp'])}")
|
||||
|
||||
if "rec_id" not in self.metadata:
|
||||
self._metadata["rec_id"] = generate_recording_id(data=self.data, timestamp=self._metadata["timestamp"])
|
||||
|
|
@ -445,7 +445,7 @@ class Recording:
|
|||
'rec_id': 'fda0f41...'} # Example value
|
||||
"""
|
||||
if key not in PROTECTED_KEYS:
|
||||
self._metadata.pop(key)
|
||||
self._metadata.pop(key, None)
|
||||
else:
|
||||
raise ValueError(f"Key {key} is protected and cannot be modified or removed.")
|
||||
|
||||
|
|
|
|||
|
|
@ -330,7 +330,7 @@ def to_sigmf(
|
|||
converted_metadata = {
|
||||
sigmf_key: metadata[metadata_key]
|
||||
for sigmf_key, metadata_key in SIGMF_KEY_CONVERSION.items()
|
||||
if metadata_key in metadata
|
||||
if metadata_key in metadata and sigmf_key != SigMFFile.HASH_KEY
|
||||
}
|
||||
|
||||
# Merge dictionaries, giving priority to sigmf_meta
|
||||
|
|
@ -387,9 +387,8 @@ def from_sigmf(file: os.PathLike | str) -> Recording:
|
|||
"""
|
||||
|
||||
file = str(file)
|
||||
if len(file) > 11:
|
||||
if file[-11:-5] != ".sigmf":
|
||||
file = file + ".sigmf-data"
|
||||
if not file.endswith((".sigmf-data", ".sigmf-meta", ".sigmf")):
|
||||
file = file + ".sigmf-data"
|
||||
|
||||
sigmf_file = sigmffile.fromfile(file)
|
||||
|
||||
|
|
|
|||
|
|
@ -213,7 +213,7 @@ class CaptureStep:
|
|||
bandwidth_mhz=parse_bandwidth_mhz(d.get("bandwidth")),
|
||||
traffic=d.get("traffic"),
|
||||
connection_interval_ms=d.get("connection_interval_ms"),
|
||||
power_dbm=float(d["power"].rstrip("dBm").strip()) if d.get("power") else None,
|
||||
power_dbm=float(d["power"].removesuffix("dBm").strip()) if d.get("power") else None,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -139,9 +139,9 @@ def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
|
|||
Returns:
|
||||
Script stdout as a string.
|
||||
"""
|
||||
script_path = Path(script).resolve()
|
||||
if not script_path.is_absolute():
|
||||
if not Path(script).is_absolute():
|
||||
raise RuntimeError(f"Script path must be absolute: {script}")
|
||||
script_path = Path(script).resolve()
|
||||
if not script_path.is_file():
|
||||
raise RuntimeError(f"Script not found or is not a regular file: {script}")
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ def estimate_snr_db(samples: np.ndarray, signal_fraction: float = 0.7) -> float:
|
|||
psd = np.abs(np.fft.fft(samples[:n_fft] * window)) ** 2
|
||||
|
||||
psd_sorted = np.sort(psd)[::-1]
|
||||
n_signal = max(1, int(n_fft * signal_fraction))
|
||||
n_signal = min(max(1, int(n_fft * signal_fraction)), n_fft - 1)
|
||||
signal_power = psd_sorted[:n_signal].mean()
|
||||
noise_power = psd_sorted[n_signal:].mean()
|
||||
|
||||
|
|
|
|||
|
|
@ -333,7 +333,12 @@ class Pluto(SDR):
|
|||
elif tx_time is not None:
|
||||
pass
|
||||
else:
|
||||
tx_time = len(recording) / self.tx_sample_rate
|
||||
if isinstance(recording, Recording):
|
||||
tx_time = recording.data.shape[-1] / self.tx_sample_rate
|
||||
elif isinstance(recording, np.ndarray):
|
||||
tx_time = recording.shape[-1] / self.tx_sample_rate
|
||||
else:
|
||||
tx_time = len(recording[0]) / self.tx_sample_rate
|
||||
|
||||
data = self._format_tx_data(recording=recording)
|
||||
|
||||
|
|
@ -437,7 +442,7 @@ class Pluto(SDR):
|
|||
abs_gain = gain
|
||||
|
||||
if abs_gain < rx_gain_min or abs_gain > rx_gain_max:
|
||||
abs_gain = min(max(gain, rx_gain_min), rx_gain_max)
|
||||
abs_gain = min(max(abs_gain, rx_gain_min), rx_gain_max)
|
||||
print(f"Gain {gain} out of range for Pluto.")
|
||||
print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB")
|
||||
|
||||
|
|
@ -591,6 +596,8 @@ class Pluto(SDR):
|
|||
self.tx_buffer_size = buffer_size
|
||||
|
||||
def close(self):
|
||||
if not hasattr(self, "radio"):
|
||||
return
|
||||
if self.radio.tx_cyclic_buffer:
|
||||
self.radio.tx_destroy_buffer()
|
||||
del self.radio
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ def _inference_loop(state: InferenceState, sdr) -> None:
|
|||
from ria_toolkit_oss.orchestration.qa import estimate_snr_db
|
||||
|
||||
state.sdr = sdr
|
||||
state.set_running(True)
|
||||
session = state.session
|
||||
input_name = session.get_inputs()[0].name
|
||||
expected_shape = tuple(
|
||||
|
|
@ -189,7 +190,7 @@ async def start_inference(request: StartInferenceRequest):
|
|||
|
||||
try:
|
||||
from ria_toolkit_oss.orchestration.executor import _DEVICE_ALIASES
|
||||
from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
|
||||
from ria_toolkit_oss.sdr import get_sdr_device
|
||||
except ImportError as e:
|
||||
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"SDR import failed: {e}")
|
||||
|
||||
|
|
@ -207,7 +208,6 @@ async def start_inference(request: StartInferenceRequest):
|
|||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"SDR initialisation failed: {e}")
|
||||
|
||||
state.stop_event.clear()
|
||||
state.set_running(True)
|
||||
state.thread = threading.Thread(target=_inference_loop, args=(state, sdr), daemon=True)
|
||||
state.thread.start()
|
||||
return StartInferenceResponse(running=True)
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ async def deploy(request: DeployRequest):
|
|||
except (ValueError, KeyError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
|
||||
|
||||
if any(t.script for t in cfg.transmitters):
|
||||
if cfg.transmitters and any(t.script for t in cfg.transmitters):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="External scripts are not permitted in server-deployed campaigns. "
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
.. todo:: Need to add some information here about signal generation and the signal generators in this module.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -227,7 +228,7 @@ def noise(
|
|||
|
||||
# TODO figure out a better way to make it conform to [-1,1]
|
||||
if not np.array_equal(magnitude, magnitude2):
|
||||
print("Warning: clipping in basic_signal_generator.noise")
|
||||
warnings.warn("basic_signal_generator.noise: magnitude clipped to [-1, 1]")
|
||||
|
||||
phase = np.random.uniform(low=0, high=2 * np.pi, size=length)
|
||||
complex_awgn = magnitude2 * np.exp(1j * phase)
|
||||
|
|
@ -268,6 +269,9 @@ def chirp(sample_rate: int, num_samples: int, center_frequency: Optional[float]
|
|||
.. todo:: Usage examples coming soon!
|
||||
"""
|
||||
# Ensure that the generated chirp signal remains within a safe frequency range to avoid aliasing.
|
||||
if num_samples < 2:
|
||||
raise ValueError("num_samples must be >= 2 for chirp generation")
|
||||
|
||||
chirp_start_frequency = center_frequency - sample_rate / 4
|
||||
chirp_end_frequency = center_frequency + sample_rate / 4
|
||||
|
||||
|
|
@ -307,6 +311,9 @@ def lfm_chirp_complex(
|
|||
down_part = np.flip(up_part)
|
||||
baseband_chirp = np.concatenate([up_part, down_part])
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown chirp_type '{chirp_type}'. Must be 'up', 'down', or 'up_down'.")
|
||||
|
||||
# Generate the full signal by tiling the windowed chirp
|
||||
num_chirps = round(total_time / chirp_period)
|
||||
full_signal = np.tile(baseband_chirp, num_chirps)
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ def phase_shift(signal: ArrayLike | Recording, phase: Optional[float] = np.pi) -
|
|||
>>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]])
|
||||
>>> new_rec = phase_shift(rec, np.pi/2)
|
||||
>>> new_rec.data
|
||||
array([[-1+1j, -2+2j -3+3j -4+4j]])
|
||||
array([[-1+1j, -2+2j, -3+3j, -4+4j]])
|
||||
"""
|
||||
# TODO: Additional info needs to be added to docstring description
|
||||
|
||||
|
|
|
|||
|
|
@ -315,7 +315,7 @@ def capture(
|
|||
ident = ident or config.get("ident") or config.get("serial") # Support legacy 'serial' in config
|
||||
sample_rate = sample_rate or config.get("sample_rate")
|
||||
center_frequency = center_frequency or config.get("center_frequency")
|
||||
gain = gain or config.get("gain")
|
||||
gain = gain if gain is not None else config.get("gain")
|
||||
bandwidth = bandwidth or config.get("bandwidth")
|
||||
num_samples = num_samples or config.get("num_samples")
|
||||
duration = duration or config.get("duration")
|
||||
|
|
|
|||
|
|
@ -214,7 +214,7 @@ def apply_post_processing(
|
|||
)
|
||||
|
||||
# 3. AWGN (Final stage usually)
|
||||
if add_noise == "awgn":
|
||||
if add_noise:
|
||||
npow = channel_params.get("noise_power", 0.1)
|
||||
echo_verbose(f"Applying AWGN (Power={npow})", verbose)
|
||||
|
||||
|
|
|
|||
|
|
@ -393,7 +393,7 @@ def transmit(
|
|||
ident = ident or config.get("ident") or config.get("serial") # Support legacy 'serial' in config
|
||||
sample_rate = sample_rate or config.get("sample_rate")
|
||||
center_frequency = center_frequency or config.get("center_frequency")
|
||||
gain = gain or config.get("gain")
|
||||
gain = gain if gain is not None else config.get("gain")
|
||||
bandwidth = bandwidth or config.get("bandwidth")
|
||||
input_file = input_file or config.get("input")
|
||||
generate = generate or config.get("generate")
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ def test_cut_out_avg_snr_1():
|
|||
transformed_data = iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="avg-snr")
|
||||
assert np.allclose(
|
||||
transformed_data,
|
||||
np.asarray([[-1.26516288 - 0.36655702j, -2.44693984 + 1.27294267j, 3 + 3j, 4.1583403 - 0.96625365j]]),
|
||||
np.asarray([[1.04504475 - 3.19650874j, 2.18835276 + 1.87922077j, 3 + 3j, 3.38706877 - 0.53958902j]]),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -334,7 +334,7 @@ def test_drop_samples_invalid_real_raises():
|
|||
|
||||
def test_quantize_tape_invalid_rounding_type_raises():
|
||||
# An unrecognised rounding_type must raise UserWarning.
|
||||
with pytest.raises(UserWarning):
|
||||
with pytest.warns(UserWarning):
|
||||
iq_augmentations.quantize_tape(TEST_DATA1, rounding_type="round")
|
||||
|
||||
|
||||
|
|
@ -347,7 +347,7 @@ def test_quantize_tape_invalid_real_raises():
|
|||
|
||||
|
||||
def test_quantize_parts_invalid_rounding_type_raises():
|
||||
with pytest.raises(UserWarning):
|
||||
with pytest.warns(UserWarning):
|
||||
iq_augmentations.quantize_parts(TEST_DATA1, rounding_type="round")
|
||||
|
||||
|
||||
|
|
@ -399,7 +399,7 @@ def test_cut_out_rec_input():
|
|||
|
||||
|
||||
def test_cut_out_invalid_fill_type_raises():
|
||||
with pytest.raises(UserWarning):
|
||||
with pytest.warns(UserWarning):
|
||||
iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="bad")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -130,23 +130,14 @@ def test_time_shift_invalid_real_input():
|
|||
|
||||
def test_time_shift_large_shift_warns():
|
||||
"""shift > n raises a UserWarning."""
|
||||
with pytest.raises(UserWarning):
|
||||
with pytest.warns(UserWarning):
|
||||
iq_impairments.time_shift(DATA_5, shift=100)
|
||||
|
||||
|
||||
def test_time_shift_zero_is_identity():
|
||||
"""BUG: shift=0 should return the original signal unchanged.
|
||||
|
||||
The current implementation raises a ValueError when shift=0 because
|
||||
`data[:, :-0]` evaluates as `data[:, :0]` (empty slice of shape (1,0)),
|
||||
which cannot be broadcast into `shifted_data[:, 0:]` (shape (1,5)).
|
||||
|
||||
This test documents the bug: callers cannot safely pass shift=0.
|
||||
Remove the `pytest.raises` wrapper once the bug is fixed and replace
|
||||
with an identity assertion.
|
||||
"""
|
||||
with pytest.raises((ValueError, AssertionError), match=".*"):
|
||||
iq_impairments.time_shift(DATA_5, shift=0)
|
||||
"""shift=0 returns the original signal unchanged."""
|
||||
result = iq_impairments.time_shift(DATA_5, shift=0)
|
||||
assert np.array_equal(result, DATA_5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -408,17 +399,8 @@ def test_resample_invalid_real_input():
|
|||
iq_impairments.resample(real_data)
|
||||
|
||||
|
||||
def test_resample_downsample_returns_shorter_array():
|
||||
"""BUG documentation: up=1, down=2 returns a shorter array instead of zero-padding.
|
||||
|
||||
The 'else' branch of resample() builds 'empty_array' but never returns it.
|
||||
The shorter resampled_iqdata is returned directly. This test documents the
|
||||
actual (potentially unintended) behaviour so any future fix is detectable.
|
||||
"""
|
||||
def test_resample_downsample_returns_same_length():
|
||||
"""Downsampling zero-pads output to match input length."""
|
||||
signal = np.array([[1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j, 5 + 5j, 6 + 6j]], dtype=np.complex128)
|
||||
result = iq_impairments.resample(signal, up=1, down=2)
|
||||
# Downsampling by 2 produces ~3 samples; the empty_array logic is dead code.
|
||||
assert result.shape[1] < signal.shape[1], (
|
||||
"resample with up<down should return fewer samples than the input "
|
||||
"(empty_array is built but discarded — dead code)."
|
||||
)
|
||||
assert result.shape[1] == signal.shape[1]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user