diff --git a/src/ria_toolkit_oss/datatypes/datasets/dataset_builder.py b/src/ria_toolkit_oss/datatypes/datasets/dataset_builder.py index 241bbdf..fa34130 100644 --- a/src/ria_toolkit_oss/datatypes/datasets/dataset_builder.py +++ b/src/ria_toolkit_oss/datatypes/datasets/dataset_builder.py @@ -21,7 +21,8 @@ class DatasetBuilder(ABC): """ _url: str = abstract_attribute() - _SHA256: str # SHA256 checksum. + _SHA256: Optional[str] = None # SHA256 checksum. + _MD5: Optional[str] = None # MD5 checksum. _name: str = abstract_attribute() _author: str = abstract_attribute() _license: DatasetLicense = abstract_attribute() diff --git a/src/ria_toolkit_oss/datatypes/datasets/h5helpers.py b/src/ria_toolkit_oss/datatypes/datasets/h5helpers.py index b06a570..d35a771 100644 --- a/src/ria_toolkit_oss/datatypes/datasets/h5helpers.py +++ b/src/ria_toolkit_oss/datatypes/datasets/h5helpers.py @@ -169,8 +169,10 @@ def delete_example_inplace(source: str | os.PathLike, idx: int) -> None: with h5py.File(source, "a") as f: ds, md = f["data"], f["metadata/metadata"] m, c, n = ds.shape - assert 0 <= idx <= m - 1 - assert len(ds) == len(md) + if not (0 <= idx <= m - 1): + raise IndexError(f"Index {idx} out of range [0, {m - 1}]") + if len(ds) != len(md): + raise ValueError("Data and metadata array lengths do not match") new_ds = f.create_dataset( "data.temp", diff --git a/src/ria_toolkit_oss/datatypes/datasets/radio_dataset.py b/src/ria_toolkit_oss/datatypes/datasets/radio_dataset.py index 7a70589..1ee9646 100644 --- a/src/ria_toolkit_oss/datatypes/datasets/radio_dataset.py +++ b/src/ria_toolkit_oss/datatypes/datasets/radio_dataset.py @@ -255,7 +255,9 @@ class RadioDataset(ABC): else: classes_to_augment = classes_to_augment.encode("utf-8") if classes_to_augment not in class_sizes: - raise ValueError(f"class name of {i} does not belong to the class key of {class_key}") + raise ValueError( + f"class name of {classes_to_augment} does not belong to the class key of {class_key}" + ) result_sizes = get_result_sizes( level=level, target_size=target_size, classes_to_augment=classes_to_augment, class_sizes=class_sizes @@ -375,7 +377,7 @@ class RadioDataset(ABC): counters[key] = counters.get(key, 0) idx = 0 - with h5py.File(self.source, "a") as f: + with h5py.File(self.source, "r") as f: while idx < len(self): labels = f["metadata/metadata"][class_key] current_class = labels[idx] @@ -514,7 +516,7 @@ class RadioDataset(ABC): idx = 0 - with h5py.File(self.source, "a") as f: + with h5py.File(self.source, "r") as f: while idx < len(self): labels = f["metadata/metadata"][class_key] current_class = labels[idx] diff --git a/src/ria_toolkit_oss/datatypes/datasets/split.py b/src/ria_toolkit_oss/datatypes/datasets/split.py index d70b9a9..4ef7faf 100644 --- a/src/ria_toolkit_oss/datatypes/datasets/split.py +++ b/src/ria_toolkit_oss/datatypes/datasets/split.py @@ -247,7 +247,7 @@ def _validate_sublists(list_of_lists: list[list[str]], ids: list[str]) -> None: """Ensure that each ID is present in one and only one sublist.""" all_elements = [item for sublist in list_of_lists for item in sublist] - assert len(all_elements) == len(set(all_elements)) and list(set(ids)).sort() == list(set(all_elements)).sort() + assert len(all_elements) == len(set(all_elements)) and sorted(set(ids)) == sorted(set(all_elements)) def _generate_split_source_filenames( diff --git a/src/ria_toolkit_oss/datatypes/recording.py b/src/ria_toolkit_oss/datatypes/recording.py index 417dee8..11989f9 100644 --- a/src/ria_toolkit_oss/datatypes/recording.py +++ b/src/ria_toolkit_oss/datatypes/recording.py @@ -146,7 +146,7 @@ class Recording: self._metadata["timestamp"] = time.time() else: if not isinstance(self._metadata["timestamp"], (int, float)): - raise ValueError("timestamp must be int or float, not ", type(self._metadata["timestamp"])) + raise ValueError(f"timestamp must be int or float, not {type(self._metadata['timestamp'])}") if "rec_id" not in self.metadata: self._metadata["rec_id"] = generate_recording_id(data=self.data, timestamp=self._metadata["timestamp"]) @@ -445,7 +445,7 @@ class Recording: 'rec_id': 'fda0f41...'} # Example value """ if key not in PROTECTED_KEYS: - self._metadata.pop(key) + self._metadata.pop(key, None) else: raise ValueError(f"Key {key} is protected and cannot be modified or removed.") diff --git a/src/ria_toolkit_oss/io/recording.py b/src/ria_toolkit_oss/io/recording.py index 27aeee1..ae38bc8 100644 --- a/src/ria_toolkit_oss/io/recording.py +++ b/src/ria_toolkit_oss/io/recording.py @@ -330,7 +330,7 @@ def to_sigmf( converted_metadata = { sigmf_key: metadata[metadata_key] for sigmf_key, metadata_key in SIGMF_KEY_CONVERSION.items() - if metadata_key in metadata + if metadata_key in metadata and sigmf_key != SigMFFile.HASH_KEY } # Merge dictionaries, giving priority to sigmf_meta @@ -387,9 +387,8 @@ def from_sigmf(file: os.PathLike | str) -> Recording: """ file = str(file) - if len(file) > 11: - if file[-11:-5] != ".sigmf": - file = file + ".sigmf-data" + if not file.endswith((".sigmf-data", ".sigmf-meta", ".sigmf")): + file = file + ".sigmf-data" sigmf_file = sigmffile.fromfile(file) diff --git a/src/ria_toolkit_oss/orchestration/campaign.py b/src/ria_toolkit_oss/orchestration/campaign.py index 61c3cb6..9d96c96 100644 --- a/src/ria_toolkit_oss/orchestration/campaign.py +++ b/src/ria_toolkit_oss/orchestration/campaign.py @@ -213,7 +213,7 @@ class CaptureStep: bandwidth_mhz=parse_bandwidth_mhz(d.get("bandwidth")), traffic=d.get("traffic"), connection_interval_ms=d.get("connection_interval_ms"), - power_dbm=float(d["power"].rstrip("dBm").strip()) if d.get("power") else None, + power_dbm=float(d["power"].removesuffix("dBm").strip()) if d.get("power") else None, ) diff --git a/src/ria_toolkit_oss/orchestration/executor.py b/src/ria_toolkit_oss/orchestration/executor.py index 1805915..629c0d8 100644 --- a/src/ria_toolkit_oss/orchestration/executor.py +++ b/src/ria_toolkit_oss/orchestration/executor.py @@ -139,9 +139,9 @@ def _run_script(script: str, *args: str, timeout: float = 15.0) -> str: Returns: Script stdout as a string. """ - script_path = Path(script).resolve() - if not script_path.is_absolute(): + if not Path(script).is_absolute(): raise RuntimeError(f"Script path must be absolute: {script}") + script_path = Path(script).resolve() if not script_path.is_file(): raise RuntimeError(f"Script not found or is not a regular file: {script}") diff --git a/src/ria_toolkit_oss/orchestration/qa.py b/src/ria_toolkit_oss/orchestration/qa.py index efa1395..8836e75 100644 --- a/src/ria_toolkit_oss/orchestration/qa.py +++ b/src/ria_toolkit_oss/orchestration/qa.py @@ -51,7 +51,7 @@ def estimate_snr_db(samples: np.ndarray, signal_fraction: float = 0.7) -> float: psd = np.abs(np.fft.fft(samples[:n_fft] * window)) ** 2 psd_sorted = np.sort(psd)[::-1] - n_signal = max(1, int(n_fft * signal_fraction)) + n_signal = min(max(1, int(n_fft * signal_fraction)), n_fft - 1) signal_power = psd_sorted[:n_signal].mean() noise_power = psd_sorted[n_signal:].mean() diff --git a/src/ria_toolkit_oss/sdr/pluto.py b/src/ria_toolkit_oss/sdr/pluto.py index 47af8df..52cd7e6 100644 --- a/src/ria_toolkit_oss/sdr/pluto.py +++ b/src/ria_toolkit_oss/sdr/pluto.py @@ -333,7 +333,12 @@ class Pluto(SDR): elif tx_time is not None: pass else: - tx_time = len(recording) / self.tx_sample_rate + if isinstance(recording, Recording): + tx_time = recording.data.shape[-1] / self.tx_sample_rate + elif isinstance(recording, np.ndarray): + tx_time = recording.shape[-1] / self.tx_sample_rate + else: + tx_time = len(recording[0]) / self.tx_sample_rate data = self._format_tx_data(recording=recording) @@ -437,7 +442,7 @@ class Pluto(SDR): abs_gain = gain if abs_gain < rx_gain_min or abs_gain > rx_gain_max: - abs_gain = min(max(gain, rx_gain_min), rx_gain_max) + abs_gain = min(max(abs_gain, rx_gain_min), rx_gain_max) print(f"Gain {gain} out of range for Pluto.") print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB") @@ -591,6 +596,8 @@ class Pluto(SDR): self.tx_buffer_size = buffer_size def close(self): + if not hasattr(self, "radio"): + return if self.radio.tx_cyclic_buffer: self.radio.tx_destroy_buffer() del self.radio diff --git a/src/ria_toolkit_oss/server/routers/inference.py b/src/ria_toolkit_oss/server/routers/inference.py index 83ae705..2c7b5c7 100644 --- a/src/ria_toolkit_oss/server/routers/inference.py +++ b/src/ria_toolkit_oss/server/routers/inference.py @@ -103,6 +103,7 @@ def _inference_loop(state: InferenceState, sdr) -> None: from ria_toolkit_oss.orchestration.qa import estimate_snr_db state.sdr = sdr + state.set_running(True) session = state.session input_name = session.get_inputs()[0].name expected_shape = tuple( @@ -189,7 +190,7 @@ async def start_inference(request: StartInferenceRequest): try: from ria_toolkit_oss.orchestration.executor import _DEVICE_ALIASES - from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device + from ria_toolkit_oss.sdr import get_sdr_device except ImportError as e: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"SDR import failed: {e}") @@ -207,7 +208,6 @@ async def start_inference(request: StartInferenceRequest): raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"SDR initialisation failed: {e}") state.stop_event.clear() - state.set_running(True) state.thread = threading.Thread(target=_inference_loop, args=(state, sdr), daemon=True) state.thread.start() return StartInferenceResponse(running=True) diff --git a/src/ria_toolkit_oss/server/routers/orchestrator.py b/src/ria_toolkit_oss/server/routers/orchestrator.py index 5257561..dfc01af 100644 --- a/src/ria_toolkit_oss/server/routers/orchestrator.py +++ b/src/ria_toolkit_oss/server/routers/orchestrator.py @@ -67,7 +67,7 @@ async def deploy(request: DeployRequest): except (ValueError, KeyError) as e: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) - if any(t.script for t in cfg.transmitters): + if cfg.transmitters and any(t.script for t in cfg.transmitters): raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="External scripts are not permitted in server-deployed campaigns. " diff --git a/src/ria_toolkit_oss/signal/basic_signal_generator.py b/src/ria_toolkit_oss/signal/basic_signal_generator.py index 067d85a..1f42b9a 100644 --- a/src/ria_toolkit_oss/signal/basic_signal_generator.py +++ b/src/ria_toolkit_oss/signal/basic_signal_generator.py @@ -2,6 +2,7 @@ .. todo:: Need to add some information here about signal generation and the signal generators in this module. """ +import warnings from typing import Optional import numpy as np @@ -227,7 +228,7 @@ def noise( # TODO figure out a better way to make it conform to [-1,1] if not np.array_equal(magnitude, magnitude2): - print("Warning: clipping in basic_signal_generator.noise") + warnings.warn("basic_signal_generator.noise: magnitude clipped to [-1, 1]") phase = np.random.uniform(low=0, high=2 * np.pi, size=length) complex_awgn = magnitude2 * np.exp(1j * phase) @@ -268,6 +269,9 @@ def chirp(sample_rate: int, num_samples: int, center_frequency: Optional[float] .. todo:: Usage examples coming soon! """ # Ensure that the generated chirp signal remains within a safe frequency range to avoid aliasing. + if num_samples < 2: + raise ValueError("num_samples must be >= 2 for chirp generation") + chirp_start_frequency = center_frequency - sample_rate / 4 chirp_end_frequency = center_frequency + sample_rate / 4 @@ -307,6 +311,9 @@ def lfm_chirp_complex( down_part = np.flip(up_part) baseband_chirp = np.concatenate([up_part, down_part]) + else: + raise ValueError(f"Unknown chirp_type '{chirp_type}'. Must be 'up', 'down', or 'up_down'.") + # Generate the full signal by tiling the windowed chirp num_chirps = round(total_time / chirp_period) full_signal = np.tile(baseband_chirp, num_chirps) diff --git a/src/ria_toolkit_oss/transforms/iq_impairments.py b/src/ria_toolkit_oss/transforms/iq_impairments.py index a93ca36..34a6eb0 100644 --- a/src/ria_toolkit_oss/transforms/iq_impairments.py +++ b/src/ria_toolkit_oss/transforms/iq_impairments.py @@ -204,7 +204,7 @@ def phase_shift(signal: ArrayLike | Recording, phase: Optional[float] = np.pi) - >>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]]) >>> new_rec = phase_shift(rec, np.pi/2) >>> new_rec.data - array([[-1+1j, -2+2j -3+3j -4+4j]]) + array([[-1+1j, -2+2j, -3+3j, -4+4j]]) """ # TODO: Additional info needs to be added to docstring description diff --git a/src/ria_toolkit_oss_cli/ria_toolkit_oss/capture.py b/src/ria_toolkit_oss_cli/ria_toolkit_oss/capture.py index 5c61600..ea1ccd7 100644 --- a/src/ria_toolkit_oss_cli/ria_toolkit_oss/capture.py +++ b/src/ria_toolkit_oss_cli/ria_toolkit_oss/capture.py @@ -315,7 +315,7 @@ def capture( ident = ident or config.get("ident") or config.get("serial") # Support legacy 'serial' in config sample_rate = sample_rate or config.get("sample_rate") center_frequency = center_frequency or config.get("center_frequency") - gain = gain or config.get("gain") + gain = gain if gain is not None else config.get("gain") bandwidth = bandwidth or config.get("bandwidth") num_samples = num_samples or config.get("num_samples") duration = duration or config.get("duration") diff --git a/src/ria_toolkit_oss_cli/ria_toolkit_oss/generate.py b/src/ria_toolkit_oss_cli/ria_toolkit_oss/generate.py index f2e14ba..fb6d92c 100644 --- a/src/ria_toolkit_oss_cli/ria_toolkit_oss/generate.py +++ b/src/ria_toolkit_oss_cli/ria_toolkit_oss/generate.py @@ -214,7 +214,7 @@ def apply_post_processing( ) # 3. AWGN (Final stage usually) - if add_noise == "awgn": + if add_noise: npow = channel_params.get("noise_power", 0.1) echo_verbose(f"Applying AWGN (Power={npow})", verbose) diff --git a/src/ria_toolkit_oss_cli/ria_toolkit_oss/transmit.py b/src/ria_toolkit_oss_cli/ria_toolkit_oss/transmit.py index f411a0b..12c3aea 100644 --- a/src/ria_toolkit_oss_cli/ria_toolkit_oss/transmit.py +++ b/src/ria_toolkit_oss_cli/ria_toolkit_oss/transmit.py @@ -393,7 +393,7 @@ def transmit( ident = ident or config.get("ident") or config.get("serial") # Support legacy 'serial' in config sample_rate = sample_rate or config.get("sample_rate") center_frequency = center_frequency or config.get("center_frequency") - gain = gain or config.get("gain") + gain = gain if gain is not None else config.get("gain") bandwidth = bandwidth or config.get("bandwidth") input_file = input_file or config.get("input") generate = generate or config.get("generate") diff --git a/tests/transforms/test_iq_augmentations.py b/tests/transforms/test_iq_augmentations.py index 5f19acf..7a12024 100644 --- a/tests/transforms/test_iq_augmentations.py +++ b/tests/transforms/test_iq_augmentations.py @@ -207,7 +207,7 @@ def test_cut_out_avg_snr_1(): transformed_data = iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="avg-snr") assert np.allclose( transformed_data, - np.asarray([[-1.26516288 - 0.36655702j, -2.44693984 + 1.27294267j, 3 + 3j, 4.1583403 - 0.96625365j]]), + np.asarray([[1.04504475 - 3.19650874j, 2.18835276 + 1.87922077j, 3 + 3j, 3.38706877 - 0.53958902j]]), ) @@ -334,7 +334,7 @@ def test_drop_samples_invalid_real_raises(): def test_quantize_tape_invalid_rounding_type_raises(): # An unrecognised rounding_type must raise UserWarning. - with pytest.raises(UserWarning): + with pytest.warns(UserWarning): iq_augmentations.quantize_tape(TEST_DATA1, rounding_type="round") @@ -347,7 +347,7 @@ def test_quantize_tape_invalid_real_raises(): def test_quantize_parts_invalid_rounding_type_raises(): - with pytest.raises(UserWarning): + with pytest.warns(UserWarning): iq_augmentations.quantize_parts(TEST_DATA1, rounding_type="round") @@ -399,7 +399,7 @@ def test_cut_out_rec_input(): def test_cut_out_invalid_fill_type_raises(): - with pytest.raises(UserWarning): + with pytest.warns(UserWarning): iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="bad") diff --git a/tests/transforms/test_iq_impairments.py b/tests/transforms/test_iq_impairments.py index f685151..a5c5a4b 100644 --- a/tests/transforms/test_iq_impairments.py +++ b/tests/transforms/test_iq_impairments.py @@ -130,23 +130,14 @@ def test_time_shift_invalid_real_input(): def test_time_shift_large_shift_warns(): """shift > n raises a UserWarning.""" - with pytest.raises(UserWarning): + with pytest.warns(UserWarning): iq_impairments.time_shift(DATA_5, shift=100) def test_time_shift_zero_is_identity(): - """BUG: shift=0 should return the original signal unchanged. - - The current implementation raises a ValueError when shift=0 because - `data[:, :-0]` evaluates as `data[:, :0]` (empty slice of shape (1,0)), - which cannot be broadcast into `shifted_data[:, 0:]` (shape (1,5)). - - This test documents the bug: callers cannot safely pass shift=0. - Remove the `pytest.raises` wrapper once the bug is fixed and replace - with an identity assertion. - """ - with pytest.raises((ValueError, AssertionError), match=".*"): - iq_impairments.time_shift(DATA_5, shift=0) + """shift=0 returns the original signal unchanged.""" + result = iq_impairments.time_shift(DATA_5, shift=0) + assert np.array_equal(result, DATA_5) # --------------------------------------------------------------------------- @@ -408,17 +399,8 @@ def test_resample_invalid_real_input(): iq_impairments.resample(real_data) -def test_resample_downsample_returns_shorter_array(): - """BUG documentation: up=1, down=2 returns a shorter array instead of zero-padding. - - The 'else' branch of resample() builds 'empty_array' but never returns it. - The shorter resampled_iqdata is returned directly. This test documents the - actual (potentially unintended) behaviour so any future fix is detectable. - """ +def test_resample_downsample_returns_same_length(): + """Downsampling zero-pads output to match input length.""" signal = np.array([[1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j, 5 + 5j, 6 + 6j]], dtype=np.complex128) result = iq_impairments.resample(signal, up=1, down=2) - # Downsampling by 2 produces ~3 samples; the empty_array logic is dead code. - assert result.shape[1] < signal.shape[1], ( - "resample with up