diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fec3c2..1bc13c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,21 @@ # Changelog +## [0.1.0] - 2026-02-20 + +### Added +- **Dual-Threshold Detection:** Logic to capture the start and end of signals, not just the peak. +- **Signal Smoothing & Noise Filters:** Prevents detections from breaking into fragments and ignores short interference spikes. +- **Auto-Frequency Calculation:** Automatically adjusts bounding boxes to fit signal frequency ranges tightly. + +### Changed +- **Signal Power Detection:** Switched from raw signal strength to power for improved accuracy. +- **CLI Workflow:** `Clear` and `Remove` commands now modify files directly (in-place) to avoid redundant copies. +- **Metadata Logic:** Updated labels to show detection percentages and overhauled internal metadata cleaning. +- **Viewer UI:** Moved legend outside the plot, added a black background, and adjusted transparency for better spectrogram visibility. + +### Fixed +- Prevented redundant `_annotated` suffixes in file naming patterns. +- Simplified internal math to increase processing speed and precision. All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) and [Semantic Versioning](https://semver.org/spec/v2.0.0.html). diff --git a/docs/source/ria_toolkit_oss/datatypes/radio_datasets.rst b/docs/source/ria_toolkit_oss/datatypes/radio_datasets.rst index 149fbaf..95d47e2 100644 --- a/docs/source/ria_toolkit_oss/datatypes/radio_datasets.rst +++ b/docs/source/ria_toolkit_oss/datatypes/radio_datasets.rst @@ -11,15 +11,15 @@ The Radio Dataset Framework provides a software interface to access and manipula the need for users to interface with the source files directly. Instead, users initialize and interact with a Python object, while the complexities of efficient data retrieval and source file manipulation are managed behind the scenes. -Utils includes an abstract class called :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset`, which defines common properties and +Ria Toolkit OSS includes an abstract class called :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset`, which defines common properties and behaviors for all radio datasets. :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset` can be considered a blueprint for all other radio dataset classes. This class is then subclassed to define more specific blueprints for different types of radio datasets. For example, :py:obj:`ria_toolkit_oss.datatypes.datasets.IQDataset`, which is tailored for machine learning tasks involving the processing of signals represented as IQ (In-phase and Quadrature) samples. -Then, in the various project backends, there are concrete dataset classes, which inherit from both Utils and the base +Then, in the various project backends, there are concrete dataset classes, which inherit from both Ria Toolkit OSS and the base dataset class from the respective backend. For example, the :py:obj:`TorchIQDataset` class extends both -:py:obj:`ria_toolkit_oss.datatypes.datasets.IQDataset` from Utils and :py:obj:`torch.ria_toolkit_oss.datatypes.IterableDataset` from +:py:obj:`ria_toolkit_oss.datatypes.datasets.IQDataset` from Ria Toolkit OSS and :py:obj:`torch.ria_toolkit_oss.datatypes.IterableDataset` from PyTorch, providing a concrete dataset class tailored for IQ datasets and optimized for the PyTorch backend. Dataset initialization @@ -130,7 +130,7 @@ Dataset processing and manipulation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ All radio datasets support methods tailored specifically for radio processing. These methods are backend-independent, -inherited from the blueprints in Utils like :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset`. +inherited from the blueprints in Ria Toolkit OSS like :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset`. For example, we can trim down the length of the examples from 1,024 to 512 samples, and then augment the dataset: diff --git a/poetry.lock b/poetry.lock index 8614127..92569ed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.3 and should not be changed by hand. [[package]] name = "alabaster" @@ -1271,7 +1271,7 @@ files = [ [package.dependencies] attrs = ">=22.2.0" -jsonschema-specifications = ">=2023.03.6" +jsonschema-specifications = ">=2023.3.6" referencing = ">=0.28.4" rpds-py = ">=0.25.0" diff --git a/src/ria_toolkit_oss/annotations/__init__.py b/src/ria_toolkit_oss/annotations/__init__.py new file mode 100644 index 0000000..1542dcf --- /dev/null +++ b/src/ria_toolkit_oss/annotations/__init__.py @@ -0,0 +1,54 @@ +""" +The annotations package contains tools and utilities for creating, managing, and processing annotations. + +Provides automatic annotation generation using various signal detection algorithms: +- Energy-based detection (detect_signals_energy) +- CUSUM-based segmentation (annotate_with_cusum) +- Threshold-based qualification (threshold_qualifier) +- Signal isolation and extraction (isolate_signal) +- Occupied bandwidth analysis (calculate_occupied_bandwidth, calculate_nominal_bandwidth) + +All detection functions return Recording objects with added annotations. +""" + +__all__ = [ + # Energy-based detection + "detect_signals_energy", + "calculate_occupied_bandwidth", + "calculate_nominal_bandwidth", + "calculate_full_detected_bandwidth", + "annotate_with_obw", + # CUSUM detection + "annotate_with_cusum", + # Threshold detection + "threshold_qualifier", + # Parallel signal separation (Phase 2) + "find_spectral_components", + "split_annotation_by_components", + "split_recording_annotations", + # Signal isolation + "isolate_signal", + # Annotation transforms + "remove_contained_boxes", + "is_annotation_contained", + # Dataset creation + "qualify_slice_from_annotations", +] + +from .annotation_transforms import is_annotation_contained, remove_contained_boxes +from .cusum_annotator import annotate_with_cusum +from .energy_detector import ( + annotate_with_obw, + calculate_full_detected_bandwidth, + calculate_nominal_bandwidth, + calculate_occupied_bandwidth, + detect_signals_energy, +) +from .parallel_signal_separator import ( + find_spectral_components, + split_annotation_by_components, + split_recording_annotations, +) +from .qualify_slice import qualify_slice_from_annotations +from .signal_isolation import isolate_signal +from .threshold_qualifier import threshold_qualifier diff --git a/src/ria_toolkit_oss/annotations/annotation_transforms.py b/src/ria_toolkit_oss/annotations/annotation_transforms.py new file mode 100644 index 0000000..47300c1 --- /dev/null +++ b/src/ria_toolkit_oss/annotations/annotation_transforms.py @@ -0,0 +1,55 @@ +from ria_toolkit_oss.datatypes.annotation import Annotation + +# TODO figure out how to transfer labels in the merge case + + +def remove_contained_boxes(annotations: list[Annotation]): + """ + Remove all annotations (bounding boxes) that are entirely contained within other boxes in the list. + + :param annotations: A list of Annotation objects. + :type annotations: list[Annotation] + + :returns: A new list of Annotation objects. + :rtype: list[Annotation]""" + + output_boxes = [] + + for i in range(len(annotations)): + contained = False + for j in range(len(annotations)): + if i != j and is_annotation_contained(annotations[i], annotations[j]): + contained = True + break + + if not contained: + output_boxes.append(annotations[i]) + + return output_boxes + + +def is_annotation_contained(inner: Annotation, outer: Annotation) -> bool: + """ + Check if an annotation box is entirely contained within another annotation bounding box. + + :param inner: The inner box. + :type inner: Annotation. + :param outer: The outer box. + :type outer: Annotation. + + :returns: True if inner is within outer, false otherwise. + :rtype: bool + """ + + inner_sample_stop = inner.sample_start + inner.sample_count + outer_sample_stop = outer.sample_start + outer.sample_count + + if inner.sample_start > outer.sample_start and inner_sample_stop < outer_sample_stop: + if inner.freq_lower_edge > outer.freq_lower_edge and inner.freq_upper_edge < outer.freq_upper_edge: + return True + + return False + + +def merge_annotations(annotations: list[Annotation], overlap_threshold) -> list[Annotation]: + raise NotImplementedError diff --git a/src/ria_toolkit_oss/annotations/cusum_annotator.py b/src/ria_toolkit_oss/annotations/cusum_annotator.py new file mode 100644 index 0000000..d37186c --- /dev/null +++ b/src/ria_toolkit_oss/annotations/cusum_annotator.py @@ -0,0 +1,203 @@ +import json +from typing import Optional + +import numpy as np + +from ria_toolkit_oss.datatypes import Annotation, Recording + + +def annotate_with_cusum( + recording: Recording, + label: Optional[str] = "segment", + window_size: Optional[int] = 1, + min_duration: Optional[float] = None, + tolerance: Optional[int] = None, + annotation_type: Optional[str] = "standalone", +): + """ + Add annotations that divide the recording into distinct time segments. + + This algorithm computes the cumulative sum of the sample magnitudes and + determines break points in the signal. + + This tool can be used to find points where a signal turns on or off, or + changes between a low and high amplitude. + + :param recording: A ``Recording`` object to annotate. + :type recording: ``ria_toolkit_oss.datatypes.Recording`` + :param label: Label for the detected segments. + :type label: str + :param window_size: The length (in samples) of the moving average window. + :type window_size: int + :param min_duration: The minimum duration (in ms) of a segment. + The algorithm will not produce annotations shorter than this length. + :type min_duration: float + :param tolerance: The minimum length (in samples) of a segment. + :type tolerance: int + :param annotation_type: Annotation type (standalone, parallel, intersection). + :type annotation_type: str + """ + + sample_rate = recording.metadata["sample_rate"] + center_frequency = recording.metadata.get("center_frequency", 0) + + # Create an object of the time segmenter + time_segmenter = TimeSegmenter(sample_rate, min_duration, window_size, tolerance) + + change_points = time_segmenter.apply(recording.data[0]) + + time_segments_indices = np.append(np.insert(change_points, 0, 0), len(recording.data[0])) + annotations = [] + for i in range(len(time_segments_indices) - 1): + # Build comment JSON with type metadata + comment_data = { + "type": annotation_type, + "generator": "cusum_annotator", + "params": { + "window_size": window_size, + "min_duration": min_duration, + "tolerance": tolerance, + }, + } + f_min, f_max = detect_frequency( + signal=recording.data[0], + start=time_segments_indices[i], + stop=time_segments_indices[i + 1], + sample_rate=sample_rate, + ) + + annotations.append( + Annotation( + sample_start=time_segments_indices[i], + sample_count=time_segments_indices[i + 1] - time_segments_indices[i], + freq_lower_edge=center_frequency + f_min, + freq_upper_edge=center_frequency + f_max, + label=label, + comment=json.dumps(comment_data), + detail={"generator": "cusum_annotator"}, + ) + ) + + return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations) + + +def _compute_cusum(_signal, sample_rate: int, tolerance: int = None, min_duration: float = -1): + """ + This function efficiently computes the cumulative sum of a give list (_signal), with an optional tolerance. + + Args: + - _signal: array of iq samples. + - Tolerance: the least acceptable length of a block, Defaults to None. + + Returns: + - cusum (array): Array of the cumulative sum of the given list + - sample_rate (int): __description_ + - change_points (array): Array of the indices at which a change in the CUSUM direction happens. + - min_duration (float): The least acceptable time width of each segment (in ms). Defaults to -1. + """ + + # efficiently calculate the running sum of the signal + # cusum = list(itertools.accumulate((_signal - np.mean(_signal)))) + x = _signal - np.mean(_signal) + cusum = np.cumsum(x) + + # 'diff' computes the differences between the consecutive values, + # then 'sign' determines if it is +ve or -ve. + change_indicators = np.sign(np.diff(cusum)) + change_points = np.where(np.diff(change_indicators))[0] + 1 + + # Limit the change_points + # Reject those whose number of samples < minimum accepted #n of samples in (min duration) ms. + if min_duration is not None and min_duration > 0: + min_samples_wide = int(min_duration * sample_rate / 1000) + segments_lengths = np.diff(change_points) + segments_lengths = np.insert(segments_lengths, 0, change_points[0]) + change_points = change_points[np.where(segments_lengths > min_samples_wide)[0]] + return cusum, change_points + + +def detect_frequency(signal, start, stop, sample_rate): + signal_segment = signal[start:stop] + if len(signal_segment) > 0: + fft_data = np.abs(np.fft.fftshift(np.fft.fft(signal_segment))) + fft_freqs = np.fft.fftshift(np.fft.fftfreq(len(signal_segment), 1 / sample_rate)) + + # Use a spectral threshold to find the 'height' of the orange block + spectral_thresh = np.max(fft_data) * 0.15 + sig_indices = np.where(fft_data > spectral_thresh)[0] + + if len(sig_indices) > 4: + return fft_freqs[sig_indices[0]], fft_freqs[sig_indices[-1]] + else: + return -sample_rate / 4, sample_rate / 4 + else: + return -sample_rate / 4, sample_rate / 4 + + +class TimeSegmenter: + """Time Segmenter class, it creates a segmenter object with certain\ + characteristics to easily split an input signal to segments based on\ + the cumulative sum of deviations (of the signal mean) + """ + + def __init__( + self, sample_rate: int, min_duration: float = 1, moving_average_window: int = 3, tolerance: int = None + ): + """_summary_ + + Args: + sample_rate (int): _description_ + min_duration (float, optional): _description_. Defaults to 1. + moving_average_window (int, optional): _description_. Defaults to 3. + tolerance (int, optional): _description_. Defaults to None. + """ + self.sample_rate = sample_rate + self.min_duration = min_duration + self.moving_average_window = moving_average_window + self._moving_avg_filter = self._init_filter() + self.tolerance = tolerance + + def _init_filter(self): + """_summary_ + + Returns: + _type_: _description_ + """ + return np.ones(self.moving_average_window) / self.moving_average_window + + def _apply_filter(self, iqsignal: np.array): + """_summary_ + + Args: + iqsignal (np.array): _description_ + + Returns: + _type_: _description_ + """ + return np.convolve(abs(iqsignal), self._moving_avg_filter, mode="same") + + def _create_segments(self, iq_signal: np.array, change_points: np.array): + """_summary_ + + Args: + iq_signal (np.array): _description_ + change_points (np.array): _description_ + + Returns: + _type_: _description_ + """ + return np.split(iq_signal, change_points) + + def apply(self, iq_signal: np.array): + """_summary_ + + Args: + iq_signal (np.array): _description_ + + Returns: + _type_: _description_ + """ + smoothed_signal = self._apply_filter(iq_signal) + _, change_points = _compute_cusum(smoothed_signal, self.sample_rate, self.tolerance, self.min_duration) + # segments = self._create_segments(iq_signal, change_points) + return change_points diff --git a/src/ria_toolkit_oss/annotations/energy_detector.py b/src/ria_toolkit_oss/annotations/energy_detector.py new file mode 100644 index 0000000..109fe6e --- /dev/null +++ b/src/ria_toolkit_oss/annotations/energy_detector.py @@ -0,0 +1,438 @@ +""" +Energy-based signal detection and bandwidth analysis. + +Provides automatic annotation generation using energy-based signal detection +and occupied bandwidth calculation following ITU-R SM.328 standard. +""" + +import json +from typing import Tuple + +import numpy as np +from scipy.signal import filtfilt + +from ria_toolkit_oss.datatypes import Annotation, Recording + + +def detect_signals_energy( + recording: Recording, + k: int = 10, + threshold_factor: float = 1.2, + window_size: int = 200, + min_distance: int = 5000, + label: str = "signal", + annotation_type: str = "standalone", + freq_method: str = "nbw", + nfft: int = None, + obw_power: float = 0.99, +) -> Recording: + """ + Detect signal bursts using energy-based method with adaptive noise floor estimation. + + This algorithm smooths the signal with a moving average filter, estimates the noise + floor from k segments, applies a threshold to detect regions above noise, and merges + nearby detections. Detected time boundaries are then assigned frequency bounds based + on the selected frequency method. + + Time Detection Algorithm: + 1. Smooth signal using moving average (envelope detection) + 2. Divide smoothed signal into k segments + 3. Estimate noise floor as median of segment mean powers + 4. Detect regions where power exceeds threshold_factor * noise_floor + 5. Merge regions closer than min_distance samples + + Frequency Bounding (freq_method): + - 'nbw': Nominal bandwidth (OBW + center frequency) - DEFAULT + - 'obw': Occupied bandwidth (99.99% power, includes siedelobes) + - 'full-detected': Lowest to highest spectral component + - 'full-bandwidth': Entire Nyquist span (center_freq ± sample_rate/2) + + :param recording: Recording to analyze + :type recording: Recording + :param k: Number of segments for noise floor estimation (default: 10) + :type k: int + :param threshold_factor: Threshold multiplier above noise floor (typical: 1.2-2.0, default: 1.2) + :type threshold_factor: float + :param window_size: Moving average window size in samples (default: 200) + :type window_size: int + :param min_distance: Minimum distance between separate signals in samples (default: 5000) + :type min_distance: int + :param label: Label for detected annotations (default: "signal") + :type label: str + :param annotation_type: Annotation type (standalone, parallel, intersection, default: standalone) + :type annotation_type: str + :param freq_method: How to calculate frequency bounds (default: 'nbw') + :type freq_method: str + :param nfft: FFT size for frequency calculations (default: None) + :type nfft: int + :param obw_power: Power percentage for OBW (0.9999 = 99.99%, default: 0.99) + :type obw_power: float + + :returns: New Recording with added annotations + :rtype: Recording + + **Example**:: + + >>> from ria.io import load_recording + >>> from ria_toolkit_oss.annotations import detect_signals_energy + >>> recording = load_recording("capture.sigmf") + + >>> # Detect with NBW frequency bounds (default, best for real signals) + >>> annotated = detect_signals_energy(recording, label="burst") + + >>> # Detect with OBW (more conservative, includes siedelobes) + >>> annotated = detect_signals_energy( + ... recording, label="burst", freq_method="obw" + ... ) + + >>> # Detect with full detected range (captures all spectral components) + >>> annotated = detect_signals_energy( + ... recording, label="burst", freq_method="full-detected" + ... ) + """ + # Extract signal data (use first channel only) + signal = recording.data[0] + + # Calculate smoothed signal power + kernel = np.ones(window_size) / window_size + smoothed_power = filtfilt(kernel, [1], np.abs(signal) ** 2) + + # Estimate noise floor using segment-based median (robust to signal presence) + segments = np.array_split(smoothed_power, k) + noise_floor = np.median([np.mean(s) for s in segments]) + + # Detect signal boundaries (regions above threshold) + enter = noise_floor * threshold_factor + exit = enter * 0.8 + boundaries = [] + start = None + active = False + + for i, p in enumerate(smoothed_power): + if not active and p > enter: + start = i + active = True + elif active and p < exit: + boundaries.append((start, i - start)) + active = False + + if active: + boundaries.append((start, len(smoothed_power) - start)) + + # Merge boundaries that are closer than min_distance + merged_boundaries = [] + if boundaries: + start, length = boundaries[0] + for next_start, next_length in boundaries[1:]: + if next_start - (start + length) < min_distance: + # Merge with current boundary + length = next_start + next_length - start + else: + # Save current and start new boundary + merged_boundaries.append((start, length)) + start, length = next_start, next_length + # Add final boundary + merged_boundaries.append((start, length)) + + # Create annotations from detected boundaries + sample_rate = recording.metadata["sample_rate"] + center_frequency = recording.metadata.get("center_frequency", 0) + + # Validate frequency method + valid_freq_methods = ["nbw", "obw", "full-detected", "full-bandwidth"] + if freq_method not in valid_freq_methods: + raise ValueError(f"Invalid freq_method '{freq_method}'. " f"Must be one of: {', '.join(valid_freq_methods)}") + + annotations = [] + for start_sample, sample_count in merged_boundaries: + # Calculate frequency bounds based on method + freq_lower, freq_upper = calculate_frequency_bounds( + freq_method, center_frequency, sample_rate, nfft, signal, start_sample, sample_count, obw_power + ) + # Build comment JSON with type metadata + comment_data = { + "type": annotation_type, + "generator": "energy_detector", + "freq_method": freq_method, + "params": { + "threshold_factor": threshold_factor, + "window_size": window_size, + "noise_floor": float(noise_floor), + "threshold": float(enter), + }, + } + + anno = Annotation( + sample_start=start_sample, + sample_count=sample_count, + freq_lower_edge=freq_lower, + freq_upper_edge=freq_upper, + label=label, + comment=json.dumps(comment_data), + detail={"generator": "energy_detector", "freq_method": freq_method}, + ) + annotations.append(anno) + + return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations) + + +def calculate_occupied_bandwidth( + signal: np.ndarray, + sampling_rate: float, + nfft: int = None, + power_percentage: float = 0.99, +): + if nfft is None: + nfft = max(65536, 2 ** int(np.floor(np.log2(len(signal))))) + + window = np.blackman(len(signal)) + spec = np.fft.fftshift(np.fft.fft(signal * window, n=nfft)) + + psd = np.abs(spec) ** 2 + psd = psd / psd.sum() # normalize + + freqs = np.fft.fftshift(np.fft.fftfreq(nfft, 1 / sampling_rate)) + + cdf = np.cumsum(psd) + + tail = (1 - power_percentage) / 2 + + lower_idx = np.searchsorted(cdf, tail) + upper_idx = np.searchsorted(cdf, 1 - tail) + + return freqs[upper_idx] - freqs[lower_idx], freqs[lower_idx], freqs[upper_idx] + + +def calculate_nominal_bandwidth( + signal: np.ndarray, + sampling_rate: float, + nfft: int = None, + power_percentage: float = 0.99, +) -> Tuple[float, float]: + """ + Calculate nominal bandwidth and center frequency. + + Nominal bandwidth (NBW) is the occupied bandwidth along with the center + frequency of the signal's spectral occupancy. Useful for characterizing + signals with unknown or drifting center frequencies. + + :param signal: Complex IQ signal samples + :type signal: np.ndarray + :param sampling_rate: Sample rate in Hz + :type sampling_rate: float + :param nfft: FFT size + :type nfft: int + :param power_percentage: Fraction of power to contain + :type power_percentage: float + + :returns: Tuple of (nominal_bandwidth_hz, center_frequency_hz) + :rtype: Tuple[float, float] + + **Example**:: + + >>> from ria_toolkit_oss.annotations import calculate_nominal_bandwidth + >>> nbw, center = calculate_nominal_bandwidth(signal, sampling_rate=10e6) + >>> print(f"NBW: {nbw/1e6:.3f} MHz, Center: {center/1e6:.3f} MHz") + """ + bw, lower_freq, upper_freq = calculate_occupied_bandwidth(signal, sampling_rate, nfft, power_percentage) + + # Center frequency is midpoint of occupied band + center_freq = (lower_freq + upper_freq) / 2 + + return lower_freq, upper_freq, center_freq + + +def calculate_full_detected_bandwidth( + signal: np.ndarray, + sampling_rate: float, + nfft: int = None, + start_offset: int = 1000, +) -> Tuple[float, float, float]: + """ + Calculate frequency range from lowest to highest spectral component. + + Unlike OBW/NBW which define a power-based bandwidth, this calculates + the absolute frequency span from the lowest non-zero spectral component + to the highest non-zero component. + + Useful for: + - Signals with spectral gaps + - Multiple parallel signals (captures all of them) + - Understanding total occupied spectrum vs. actual bandwidth + + :param signal: Complex IQ signal samples + :type signal: np.ndarray + :param sampling_rate: Sample rate in Hz + :type sampling_rate: float + :param nfft: FFT size + :type nfft: int + :param start_offset: Skip samples at start + :type start_offset: int + + :returns: Tuple of (bandwidth_hz, lower_freq_hz, upper_freq_hz) + :rtype: Tuple[float, float, float] + + **Example**:: + + >>> # Signal with two components at different frequencies + >>> bw, f_low, f_high = calculate_full_detected_bandwidth( + ... signal, sampling_rate=10e6, nfft=65536 + ... ) + >>> print(f"Full span: {f_low/1e6:.3f} to {f_high/1e6:.3f} MHz") + """ + # Validate input + if len(signal) < nfft + start_offset: + raise ValueError( + f"Signal too short: need {nfft + start_offset} samples, " + f"got {len(signal)}. Reduce nfft or start_offset." + ) + + # Extract segment + signal_segment = signal[start_offset : nfft + start_offset] + + # Compute FFT and power spectral density + freq_spectrum = np.fft.fft(signal_segment, n=nfft) + psd = np.abs(freq_spectrum) ** 2 + + # Shift to center DC + psd_shifted = np.fft.fftshift(psd) + freq_bins = np.fft.fftshift(np.fft.fftfreq(nfft, 1 / sampling_rate)) + + # Find noise floor (mean of lowest 10% of bins) and all bins above noise floor + noise_floor = np.mean(np.sort(psd_shifted)[: int(len(psd_shifted) * 0.1)]) + above_noise = np.where(psd_shifted > noise_floor * 1.5)[0] + + if len(above_noise) == 0: + # No signal above noise, return zero bandwidth + return 0.0, 0.0, 0.0 + + # Get frequency range of signal components + lower_idx = above_noise[0] + upper_idx = above_noise[-1] + + lower_freq = freq_bins[lower_idx] + upper_freq = freq_bins[upper_idx] + + bandwidth = upper_freq - lower_freq + + return bandwidth, lower_freq, upper_freq + + +def annotate_with_obw( + recording: Recording, + label: str = "signal", + annotation_type: str = "standalone", + nfft: int = None, + power_percentage: float = 0.99, +) -> Recording: + """ + Create a single annotation spanning the occupied bandwidth of the entire recording. + + Analyzes the full recording to find its occupied bandwidth and creates an annotation + covering that frequency range for the entire time duration. + + :param recording: Recording to analyze + :type recording: Recording + :param label: Annotation label + :type label: str + :param annotation_type: Annotation type + :type annotation_type: str + :param nfft: FFT size + :type nfft: int + :param power_percentage: Power percentage for OBW calculation + :type power_percentage: float + + :returns: Recording with OBW annotation added + :rtype: Recording + + **Example**:: + + >>> from ria_toolkit_oss.annotations import annotate_with_obw + >>> annotated = annotate_with_obw(recording, label="signal_obw") + """ + signal = recording.data[0] + sample_rate = recording.metadata["sample_rate"] + center_freq = recording.metadata.get("center_frequency", 0) + + # Calculate OBW + obw, lower_offset, upper_offset = calculate_occupied_bandwidth(signal, sample_rate, nfft, power_percentage) + + # Convert baseband offsets to absolute frequencies + freq_lower = center_freq + lower_offset + freq_upper = center_freq + upper_offset + + # Create comment JSON + comment_data = { + "type": annotation_type, + "generator": "obw_annotator", + "obw_hz": float(obw), + "power_percentage": power_percentage, + "params": {"nfft": nfft}, + } + + # Create annotation spanning entire recording + anno = Annotation( + sample_start=0, + sample_count=len(signal), + freq_lower_edge=freq_lower, + freq_upper_edge=freq_upper, + label=label, + comment=json.dumps(comment_data), + detail={"generator": "obw_annotator", "obw_hz": float(obw)}, + ) + + return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + [anno]) + + +def calculate_frequency_bounds( + freq_method, center_frequency, sample_rate, nfft, signal, start_sample, sample_count, obw_power +): + if freq_method == "full-bandwidth": + # Full Nyquist span + freq_lower = center_frequency - (sample_rate / 2) + freq_upper = center_frequency + (sample_rate / 2) + else: + # Extract segment for frequency analysis + segment_start = start_sample + segment_end = min(start_sample + sample_count, len(signal)) + segment = signal[segment_start:segment_end] + + if nfft is None or len(segment) >= nfft: + if freq_method == "nbw": + # Nominal bandwidth (OBW + center frequency) + try: + lower_freq, upper_freq, _ = calculate_nominal_bandwidth(segment, sample_rate, nfft, obw_power) + freq_lower = center_frequency + lower_freq + freq_upper = center_frequency + upper_freq + except (ValueError, IndexError): + # Fallback if calculation fails + freq_lower = center_frequency - (sample_rate / 2) + freq_upper = center_frequency + (sample_rate / 2) + + elif freq_method == "obw": + # Occupied bandwidth + try: + _, f_lower, f_upper = calculate_occupied_bandwidth(segment, sample_rate, nfft, obw_power) + freq_lower = center_frequency + f_lower + freq_upper = center_frequency + f_upper + except (ValueError, IndexError): + # Fallback if calculation fails + freq_lower = center_frequency - (sample_rate / 2) + freq_upper = center_frequency + (sample_rate / 2) + + elif freq_method == "full-detected": + # Full detected range (lowest to highest component) + try: + _, f_lower, f_upper = calculate_full_detected_bandwidth(segment, sample_rate, nfft) + freq_lower = center_frequency + f_lower + freq_upper = center_frequency + f_upper + except (ValueError, IndexError): + # Fallback if calculation fails + freq_lower = center_frequency - (sample_rate / 2) + freq_upper = center_frequency + (sample_rate / 2) + else: + # Segment too short for FFT, use full bandwidth + freq_lower = center_frequency - (sample_rate / 2) + freq_upper = center_frequency + (sample_rate / 2) + + return freq_lower, freq_upper diff --git a/src/ria_toolkit_oss/annotations/parallel_signal_separator.py b/src/ria_toolkit_oss/annotations/parallel_signal_separator.py new file mode 100644 index 0000000..957cf58 --- /dev/null +++ b/src/ria_toolkit_oss/annotations/parallel_signal_separator.py @@ -0,0 +1,435 @@ +""" +Parallel signal separation for multi-component frequency-offset signals. + +Provides methods to detect and separate overlapping frequency-domain signals +that occupy the same time window but different frequency bands. + +This module implements **spectral peak detection** to identify distinct frequency +components and split single time-domain annotations into frequency-specific +sub-annotations. + +**Key Design Decisions** (per Codex review): + +1. **Complex IQ Support**: Uses `scipy.signal.welch` with `return_onesided=False` + for proper complex signal handling. Window length automatically adapts to + signal length via `nperseg=min(nfft, len(signal))` to handle bursts >> from ria_toolkit_oss.annotations import find_spectral_components + >>> # Detect the two distinct channels (returns relative frequencies) + >>> components = find_spectral_components(signal, sampling_rate=20e6) + >>> print(f"Found {len(components)} components") + Found 2 components + +The module is designed to work with detected time-domain annotations, +allowing splitting of overlapping signals into separate training samples. +""" + +import json +from typing import List, Optional, Tuple + +import numpy as np +from scipy import ndimage +from scipy import signal as scipy_signal + +from ria_toolkit_oss.datatypes import Annotation, Recording + + +def find_spectral_components( + signal_data: np.ndarray, + sampling_rate: float, + nfft: int = 65536, + noise_threshold_db: Optional[float] = None, + min_component_bw: float = 50e3, + time_percentile: float = 70.0, +) -> List[Tuple[float, float, float]]: + """ + Find distinct frequency components using spectral peak detection. + + Identifies separate frequency components in a signal by analyzing the power + spectral density and finding peaks corresponding to distinct signals. This is + useful for separating parallel signals that occupy different frequency bands. + + **Frequency Representation**: Returns frequencies in **baseband/relative** Hz + (centered at 0). To get absolute RF frequencies, add center_frequency_hz from + recording metadata to all returned values. + + Algorithm: + 1. Compute power spectral density using Welch (properly handles complex IQ) + 2. Auto-estimate noise floor from data if not specified + 3. Smooth PSD to reduce spurious peaks + 4. Find local maxima above noise floor + 5. Estimate bandwidth per peak using -3dB (fallback: cumulative power) + 6. Filter components below minimum bandwidth threshold + + :param signal_data: Complex IQ signal samples (np.complex64/128) + :type signal_data: np.ndarray + :param sampling_rate: Sample rate in Hz + :type sampling_rate: float + :param nfft: FFT size / window length for Welch. Automatically capped at + signal length to handle bursts (default: 65536) + :type nfft: int + :param noise_threshold_db: Minimum SNR threshold in dB. If None (default), + auto-estimates as np.percentile(psd_db, 10). + Adapt this across hardware (Pluto: ~-100, ThinkRF: ~-60). + :type noise_threshold_db: Optional[float] + :param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz) + :type min_component_bw: float + :param power_threshold: Cumulative power threshold for fallback bandwidth + estimation (default: 0.99 = 99% power, like OBW) + :type power_threshold: float + + :returns: List of (center_freq_hz, lower_freq_hz, upper_freq_hz) tuples. + **All frequencies are relative (baseband, 0-centered).** + Add recording metadata['center_frequency'] to get absolute RF frequencies. + :rtype: List[Tuple[float, float, float]] + + :raises ValueError: If signal has fewer than 256 samples + + **Example**:: + + >>> from ria.io import load_recording + >>> from ria_toolkit_oss.annotations import find_spectral_components + >>> recording = load_recording("capture.sigmf") + >>> segment = recording.data[0][start:end] + >>> # Components in relative (baseband) frequency + >>> components = find_spectral_components(segment, sampling_rate=20e6) + >>> for center_rel, lower_rel, upper_rel in components: + ... # Convert to absolute RF frequency + ... center_abs = recording.metadata['center_frequency'] + center_rel + ... print(f"Component @ {center_abs/1e9:.3f} GHz") + """ + # Validate input + min_samples = 256 + if len(signal_data) < min_samples: + raise ValueError(f"Signal too short: need at least {min_samples} samples, " f"got {len(signal_data)}.") + + # Compute PSD using Welch method for complex IQ signals + # CRITICAL: return_onesided=False for proper complex signal handling + nperseg = min(nfft, len(signal_data)) + noverlap = nperseg // 2 + + # --- STFT --- + freqs, times, Zxx = scipy_signal.stft( + signal_data, + fs=sampling_rate, + window="blackman", + nperseg=nperseg, + noverlap=noverlap, + return_onesided=False, + boundary=None, + ) + + # Shift zero freq to center + Zxx = np.fft.fftshift(Zxx, axes=0) + freqs = np.fft.fftshift(freqs) + + # Power spectrogram + power = np.abs(Zxx) ** 2 + power_db = 10 * np.log10(power + 1e-12) + + # --- Aggregate across time robustly --- + # Using percentile instead of mean prevents short signals from being diluted + freq_profile_db = np.percentile(power_db, time_percentile, axis=1) + + # --- Noise floor estimation --- + if noise_threshold_db is None: + noise_threshold_db = np.percentile(freq_profile_db, 20) + + threshold = noise_threshold_db + 3 # 3 dB above noise floor + + # --- Smooth lightly (avoid merging nearby signals) --- + freq_profile_db = ndimage.gaussian_filter1d(freq_profile_db, sigma=1.5) + + # --- Binary mask of significant frequencies --- + mask = freq_profile_db > threshold + + # --- Find contiguous frequency regions --- + labeled, num_features = ndimage.label(mask) + + components = [] + + for region_label in range(1, num_features + 1): + region_indices = np.where(labeled == region_label)[0] + + if len(region_indices) == 0: + continue + + lower_idx = region_indices[0] + upper_idx = region_indices[-1] + + lower_freq = freqs[lower_idx] + upper_freq = freqs[upper_idx] + bw = upper_freq - lower_freq + + if bw < min_component_bw: + continue + + center_freq = (lower_freq + upper_freq) / 2 + components.append((center_freq, lower_freq, upper_freq)) + + return components + + +def split_annotation_by_components( + annotation: Annotation, + signal: np.ndarray, + sampling_rate: float, + center_frequency_hz: float = 0.0, + nfft: int = 65536, + noise_threshold_db: Optional[float] = None, + min_component_bw: float = 50e3, +) -> List[Annotation]: + """ + Split an annotation into multiple annotations by detected frequency components. + + Takes an existing annotation spanning multiple frequency components and + analyzes the frequency content to create separate sub-annotations for + each distinct frequency component. + + **Use case**: Energy detection found a time window with 2-3 parallel WiFi + channels. This function splits it into separate annotations per channel. + + **Frequency Handling**: `find_spectral_components` returns relative (baseband) + frequencies. This function adds `center_frequency_hz` to convert to absolute + RF frequencies for SigMF annotation bounds. This ensures correct frequency + context across baseband and RF domains. + + :param annotation: Original annotation to split + :type annotation: Annotation + :param signal: Full signal array (complex IQ) + :type signal: np.ndarray + :param sampling_rate: Sample rate in Hz + :type sampling_rate: float + :param center_frequency_hz: RF center frequency to add to relative frequencies + from peak detection (default: 0.0 = baseband) + :type center_frequency_hz: float + :param nfft: FFT size for analysis (default: 65536, auto-capped at signal length) + :type nfft: int + :param noise_threshold_db: Noise floor threshold in dB. If None (default), + auto-estimates from data. + :type noise_threshold_db: Optional[float] + :param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz) + :type min_component_bw: float + + :returns: List of new annotations (one per detected component). + Returns empty list if no components found or segment too short. + :rtype: List[Annotation] + + **Example**:: + + >>> from ria.io import load_recording + >>> from ria_toolkit_oss.annotations import split_annotation_by_components + >>> recording = load_recording("capture.sigmf") + >>> # Original annotation spans multiple channels + >>> original = recording.annotations[0] + >>> # Split using RF center frequency from metadata + >>> components = split_annotation_by_components( + ... original, + ... recording.data[0], + ... recording.metadata['sample_rate'], + ... center_frequency_hz=recording.metadata.get('center_frequency', 0.0) + ... ) + >>> print(f"Split into {len(components)} components") + Split into 2 components + + **Algorithm**: + 1. Extract segment corresponding to annotation time bounds + 2. Find frequency components in that segment (returns relative frequencies) + 3. Add center_frequency_hz to get absolute RF frequencies + 4. Create new annotation for each component + 5. Preserve original metadata (label, type, etc.) + 6. Add component info to comment JSON + + **Notes**: + - Original annotation is not modified + - Returns empty list if segment too short (<256 samples) + - Segments Recording: + """ + Split multiple annotations in a recording by frequency components. + + Processes specified annotations (or all if indices=None), replacing each + with its frequency-separated components. Uses RF center_frequency from + recording metadata for proper absolute frequency conversion. + + :param recording: Recording to process + :type recording: Recording + :param indices: Annotation indices to split (None = all, default: None). + Use indices=[] to skip splitting (returns unchanged recording). + :type indices: Optional[List[int]] + :param nfft: FFT size for spectral analysis (default: 65536, + auto-capped at signal segment length) + :type nfft: int + :param noise_threshold_db: Noise floor threshold in dB. If None (default), + auto-estimates from each segment. + :type noise_threshold_db: Optional[float] + :param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz). + Components narrower than this are filtered out. + :type min_component_bw: float + + :returns: New Recording with split annotations + :rtype: Recording + + **Example**:: + + >>> from ria.io import load_recording + >>> from ria_toolkit_oss.annotations import split_recording_annotations + >>> recording = load_recording("capture.sigmf") + >>> # Split all annotations + >>> split_rec = split_recording_annotations(recording) + >>> print(f"Original: {len(recording.annotations)} annotations") + >>> print(f"Split: {len(split_rec.annotations)} annotations") + Original: 5 annotations + Split: 9 annotations + + **Algorithm**: + 1. For each annotation in indices (or all if None): + 2. Call split_annotation_by_components with RF center_frequency + 3. If components found, replace annotation with components + 4. If no components found, keep original annotation + 5. Annotations not in indices are kept unchanged + + **Notes**: + - Original recording is not modified + - Returns empty Recording.annotations if recording has no annotations + - RF center_frequency from metadata ensures correct absolute frequencies + - If an annotation can't be split (too short, wrong format), original kept + """ + if indices is None: + # Split all annotations + indices = list(range(len(recording.annotations))) + + if not recording.annotations: + # No annotations to split + return recording + + signal = recording.data[0] + sample_rate = recording.metadata["sample_rate"] + center_frequency = recording.metadata.get("center_frequency", 0.0) + + # Build new annotation list + new_annotations = [] + for i, anno in enumerate(recording.annotations): + if i in indices: + # Attempt to split this annotation + try: + components = split_annotation_by_components( + anno, + signal, + sample_rate, + center_frequency_hz=center_frequency, + nfft=nfft, + noise_threshold_db=noise_threshold_db, + min_component_bw=min_component_bw, + ) + if components: + # Split successful, use components + new_annotations.extend(components) + else: + # No components found, keep original + new_annotations.append(anno) + except Exception: + # Split failed for any reason, keep original + new_annotations.append(anno) + else: + # Not in split list, keep as-is + new_annotations.append(anno) + + return Recording(data=recording.data, metadata=recording.metadata, annotations=new_annotations) diff --git a/src/ria_toolkit_oss/annotations/qualify_slice.py b/src/ria_toolkit_oss/annotations/qualify_slice.py new file mode 100644 index 0000000..2336fe5 --- /dev/null +++ b/src/ria_toolkit_oss/annotations/qualify_slice.py @@ -0,0 +1,35 @@ +import numpy as np + +from ria_toolkit_oss.datatypes import Recording + + +def qualify_slice_from_annotations(recording: Recording, slice_length: int): + """ + Slice a recording into many smaller recordings, + discarding any slices which do not have annotations that apply to those samples. + Used together with an annotation based qualifier. + + :param recording: The recording to slice. + :type recording: Recording + :param slice_length: The length in samples of a slice. + :type slice_length: int""" + + if len(recording.annotations) == 0: + print("Warning, no annotations.") + + annotation_mask = np.zeros(len(recording.data[0])) + + for annotation in recording.annotations: + annotation_mask[annotation.sample_start : annotation.sample_start + annotation.sample_count] = 1 + + output_recordings = [] + + for i in range((len(recording.data[0]) // slice_length) - 1): + start_index = slice_length * i + end_index = slice_length * (i + 1) + + if 1 in annotation_mask[start_index:end_index]: + sl = recording.data[:, start_index:end_index] + output_recordings.append(Recording(data=sl, metadata=recording.metadata)) + + return output_recordings diff --git a/src/ria_toolkit_oss/annotations/signal_isolation.py b/src/ria_toolkit_oss/annotations/signal_isolation.py new file mode 100644 index 0000000..47852ae --- /dev/null +++ b/src/ria_toolkit_oss/annotations/signal_isolation.py @@ -0,0 +1,97 @@ +import numpy as np +from scipy.signal import butter, lfilter + +from ria_toolkit_oss.datatypes.annotation import Annotation +from ria_toolkit_oss.datatypes.recording import Recording + + +def isolate_signal(recording: Recording, annotation: Annotation) -> Recording: + """ + Slice, filter and frequency shift the input recording according to the bounding box defined by the annotation. + + :param recording: The input Recording to be sliced. + :type recording: Recording + :param annotation: The Annotation object defining the area of the recording to isolate. + :type annotation: Annotation + :param decimate: Decimate the input signal after filtering to reduce the sample rate. + :type decimate: bool + + :returns: The subsection of the original recording defined by the annotation. + :rtype: Recording""" + + sample_start = max(0, annotation.sample_start) + sample_stop = min(len(recording), annotation.sample_start + annotation.sample_count) + + anno_base_center_freq = (annotation.freq_lower_edge + annotation.freq_upper_edge) / 2 - recording.metadata.get( + "center_frequency", 0 + ) + + anno_bw = annotation.freq_upper_edge - annotation.freq_lower_edge + + signal_slice = recording.data[0, sample_start:sample_stop] + + # normalize + signal_slice = signal_slice / np.max(np.abs(signal_slice)) + + isolation_bw = anno_bw + + # frequency shift the center of the box about zero + shifted_signal_slice = frequency_shift_iq_samples( + iq_samples=signal_slice, + sample_rate=recording.metadata["sample_rate"], + shift_frequency=-1 * anno_base_center_freq, + ) + + # filter + if isolation_bw < recording.metadata["sample_rate"] - 1: + filtered_signal = apply_complex_lowpass_filter( + signal=shifted_signal_slice, cutoff_frequency=isolation_bw, sample_rate=recording.metadata["sample_rate"] + ) + + else: + filtered_signal = shifted_signal_slice + + output = Recording(data=[filtered_signal], metadata=recording.metadata) + return output + + +def frequency_shift_iq_samples(iq_samples, sample_rate, shift_frequency): + # Number of samples + num_samples = len(iq_samples) + + # Create a time vector from 0 to the total duration in seconds + time_vector = np.arange(num_samples) / sample_rate + + # Generate the complex exponential for the frequency shift + complex_exponential = np.exp(1j * 2 * np.pi * shift_frequency * time_vector) + + # Apply the frequency shift to the IQ samples + shifted_samples = iq_samples * complex_exponential + + return shifted_samples + + +# Function to apply a lowpass Butterworth filter to a complex signal +def apply_complex_lowpass_filter(signal, cutoff_frequency, sample_rate, order=5): + # Design the lowpass filter + b, a = design_complex_lowpass_filter(cutoff_frequency, sample_rate, order) + + # Apply the lowpass filter + filtered_signal = lfilter(b, a, signal) + return filtered_signal + + +def design_complex_lowpass_filter(cutoff_frequency, sample_rate, order=5): + # Nyquist frequency for complex signals is the sample rate + nyquist = sample_rate + + # Ensure the cutoff frequency is positive and within the Nyquist limit + if cutoff_frequency <= 0 or cutoff_frequency > nyquist: + raise ValueError("Cutoff frequency must be between 0 and the Nyquist frequency.") + + # Normalize the cutoff frequency to the Nyquist frequency + cutoff_normalized = cutoff_frequency / nyquist + + # Create a Butterworth lowpass filter + b, a = butter(order, cutoff_normalized, btype="low") + return b, a diff --git a/src/ria_toolkit_oss/annotations/threshold_qualifier.py b/src/ria_toolkit_oss/annotations/threshold_qualifier.py new file mode 100644 index 0000000..804c5e1 --- /dev/null +++ b/src/ria_toolkit_oss/annotations/threshold_qualifier.py @@ -0,0 +1,359 @@ +""" +Temporal signal detection and boundary refinement via Hysteresis Thresholding. + +Provides methods to detect signal bursts in the time domain by triggering on +smoothed power peaks and expanding boundaries to capture the full energy envelope. + +This module implements a **dual-threshold trigger** to solve the 'chatter' +problem in noisy environments, ensuring that signal annotations encapsulate +the entire rise and fall of a burst rather than just the peak. + +**Key Design Decisions**: + +1. **Hysteresis Logic (Dual-Threshold)**: + - **Trigger**: High threshold (`threshold * max_power`) ensures high confidence + in signal presence. + - **Boundary**: Low threshold (`0.5 * trigger`) allows the annotation to + "crawl" outward, capturing the lower-energy start and end of the burst + often missed by simple single-threshold detectors. + +2. **Temporal Smoothing**: Uses a moving average window (`window_size`) prior + - to thresholding. This prevents high-frequency noise spikes from causing + fragmented annotations and provides a more stable estimate of the + signal's power envelope. + +3. **Spectral Profiling**: Once a temporal segment is isolated, the module + - performs an automated FFT analysis. It identifies the **90% spectral + occupancy** to define the frequency boundaries (`f_min`, `f_max`), + allowing the detector to work on narrowband and wideband signals without + manual frequency tuning. + +4. **Baseband/RF Mapping**: Automatically handles the conversion from + - relative FFT bin frequencies to absolute RF frequencies by referencing + `recording.metadata["center_frequency"]`. + +5. **False Positive Mitigation**: Implements a hard minimum duration check + - (10ms) to ignore transient hardware spikes or noise floor fluctuations + that do not constitute a valid signal burst. + +The module is designed to be the primary "first-pass" detector for pulsed +waveforms (like ADS-B, Lora, or bursty FSK) before passing them to +classification or demodulation stages. +""" + +import json +from typing import Optional + +import numpy as np + +from ria_toolkit_oss.datatypes import Annotation, Recording + + +def _find_ranges(indices, max_gap): + """ + Groups individual indices into continuous temporal ranges. + + Args: + indices: Array of indices where the signal exceeded a threshold. + max_gap: Maximum gap allowed between indices to consider them part + of the same range. + + Returns: + A list of (start, stop) tuples representing detected signal segments. + """ + + if len(indices) == 0: + return [] + + start = indices[0] + prev = indices[0] + ranges = [] + + for i in range(1, len(indices)): + if indices[i] - prev > max_gap: + ranges.append((start, prev)) + start = indices[i] + prev = indices[i] + + ranges.append((start, prev)) + + return ranges + + +def _expand_and_filter_ranges( + smoothed_power: np.ndarray, + initial_ranges: list[tuple[int, int]], + boundary_val: float, + min_duration_samples: int, +) -> list[tuple[int, int]]: + """Apply hysteresis expansion and minimum-duration filtering.""" + out: list[tuple[int, int]] = [] + n = len(smoothed_power) + for start, stop in initial_ranges: + if (stop - start) < min_duration_samples: + continue + + true_start = start + while true_start > 0 and smoothed_power[true_start] > boundary_val: + true_start -= 1 + + true_stop = stop + while true_stop < n - 1 and smoothed_power[true_stop] > boundary_val: + true_stop += 1 + + if (true_stop - true_start) >= min_duration_samples: + out.append((true_start, true_stop)) + return out + + +def _merge_ranges(ranges: list[tuple[int, int]], max_gap: int) -> list[tuple[int, int]]: + """Merge overlapping or near-adjacent ranges.""" + if not ranges: + return [] + ranges = sorted(ranges, key=lambda r: r[0]) + merged = [ranges[0]] + for s, e in ranges[1:]: + last_s, last_e = merged[-1] + if s <= last_e + max_gap: + merged[-1] = (last_s, max(last_e, e)) + else: + merged.append((s, e)) + return merged + + +def _estimate_noise_floor(power: np.ndarray, quantile: float = 20.0) -> float: + """Estimate baseline from the quieter portion of the envelope.""" + return float(np.percentile(power, quantile)) + + +def _estimate_group_gap(sample_rate: float) -> int: + """Use a fixed temporal grouping gap instead of reusing the smoothing window.""" + return max(1, int(0.001 * sample_rate)) + + +def _estimate_spectral_bounds(signal_segment: np.ndarray, sample_rate: float) -> tuple[float, float]: + """Estimate occupied bandwidth from a smoothed magnitude spectrum.""" + if len(signal_segment) == 0: + return -sample_rate / 4, sample_rate / 4 + + window = np.hanning(len(signal_segment)) + windowed = signal_segment * window + + fft_data = np.abs(np.fft.fftshift(np.fft.fft(windowed))) + fft_freqs = np.fft.fftshift(np.fft.fftfreq(len(signal_segment), 1 / sample_rate)) + + # Smooth the spectrum so noise-like wideband bursts form a contiguous mask + # instead of thousands of tiny isolated runs. + spectral_smooth_bins = max(5, min(257, (len(signal_segment) // 512) | 1)) + spectral_kernel = np.ones(spectral_smooth_bins, dtype=np.float64) / spectral_smooth_bins + smoothed_fft = np.convolve(fft_data, spectral_kernel, mode="same") + + spectral_floor = float(np.percentile(smoothed_fft, 20)) + spectral_peak = float(np.max(smoothed_fft)) + spectral_ratio = spectral_peak / max(spectral_floor, 1e-12) + + if spectral_ratio < 1.2: + return -sample_rate / 4, sample_rate / 4 + + spectral_thresh = spectral_floor + 0.1 * (spectral_peak - spectral_floor) + sig_indices = np.where(smoothed_fft > spectral_thresh)[0] + + if len(sig_indices) == 0: + peak_idx = int(np.argmax(smoothed_fft)) + bin_hz = sample_rate / len(signal_segment) + half_bins = max(1, int(np.ceil(10_000.0 / bin_hz))) + lo_idx = max(0, peak_idx - half_bins) + hi_idx = min(len(smoothed_fft) - 1, peak_idx + half_bins) + else: + runs = _find_ranges(sig_indices, max_gap=max(1, spectral_smooth_bins // 2)) + peak_idx = int(np.argmax(smoothed_fft)) + lo_idx, hi_idx = min( + runs, + key=lambda run: 0 if run[0] <= peak_idx <= run[1] else min(abs(run[0] - peak_idx), abs(run[1] - peak_idx)), + ) + + # Prevent extremely narrow tone boxes from collapsing to just a few bins. + min_total_bw_hz = 20_000.0 + min_half_bins = max(1, int(np.ceil((min_total_bw_hz / 2) / (sample_rate / len(signal_segment))))) + center_idx = int(round((lo_idx + hi_idx) / 2)) + lo_idx = max(0, min(lo_idx, center_idx - min_half_bins)) + hi_idx = min(len(smoothed_fft) - 1, max(hi_idx, center_idx + min_half_bins)) + + return float(fft_freqs[lo_idx]), float(fft_freqs[hi_idx]) + + +def threshold_qualifier( + recording: Recording, + threshold: float, + window_size: Optional[int] = None, + label: Optional[str] = None, + annotation_type: Optional[str] = "standalone", + channel: int = 0, +) -> Recording: + """ + Annotate a recording with bounding boxes for regions above a threshold. + Threshold is defined as a fraction of the maximum sample magnitude. + This algorithm searches for samples above the threshold and combines them into ranges if they + are within window_size of each other. + Detects and annotates signals using energy thresholding and spectral analysis. + + The algorithm follows these steps: + 1. Smooths power data using a moving average. + 2. Identifies 'peak' regions exceeding a high trigger threshold. + 3. Uses hysteresis to expand boundaries until power drops below a lower threshold. + 4. Performs an FFT on each segment to determine frequency occupancy. + + Args: + recording: The Recording object containing IQ or real signal data. + threshold: Sensitivity multiplier (0.0 to 1.0) applied to max power. + window_size: Size of the smoothing filter in samples. Defaults to 1ms worth of samples. + label: Custom string label for annotations. + annotation_type: Metadata string for the 'type' field in the annotation. + channel: Index of the channel to annotate. Defaults to 0. + + Returns: + A new Recording object populated with detected Annotations. + """ + # Extract signal and metadata + sample_data = recording.data[channel] + sample_rate = recording.metadata["sample_rate"] + center_frequency = recording.metadata.get("center_frequency", 0) + + if window_size is None: + window_size = max(64, int(sample_rate * 0.001)) + + # --- 1. SIGNAL CONDITIONING --- + # Convert to power (Magnitude squared) + power_data = np.abs(sample_data) ** 2 + smoothing_window = np.ones(window_size) / window_size + smoothed_power = np.convolve(power_data, smoothing_window, mode="same") + group_gap_samples = _estimate_group_gap(sample_rate) + + # Define thresholds using peak relative to baseline. + max_power = np.max(smoothed_power) + noise_floor = _estimate_noise_floor(smoothed_power) + dynamic_range_ratio = max_power / max(noise_floor, 1e-12) + + # Soft early exit: keep a guard for low-contrast noise, but compute it from + # the quieter tail of the envelope so burst-heavy captures are not rejected. + if dynamic_range_ratio < 1.5: + return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations) + + trigger_val = noise_floor + threshold * (max_power - noise_floor) + boundary_val = noise_floor + 0.5 * threshold * (max_power - noise_floor) + + # --- 2. INITIAL DETECTION --- + # Enforce an explicit minimum duration in seconds; this is stable across + # varying capture lengths and avoids over-fitting to recording length. + min_duration_samples = max(1, int(0.005 * sample_rate)) + annotations = [] + + # Pass 1: Detect stronger bursts. + indices = np.where(smoothed_power > trigger_val)[0] + pass1_initial = _find_ranges(indices=indices, max_gap=group_gap_samples) + pass1_ranges = _expand_and_filter_ranges( + smoothed_power=smoothed_power, + initial_ranges=pass1_initial, + boundary_val=boundary_val, + min_duration_samples=min_duration_samples, + ) + + # Pass 2: Recover weaker bursts on residual power not already covered. + # This improves recall in mixed-amplitude captures. + # Expand each Pass-1 range by the smoothing window on both sides so the + # smoothing skirts of a strong burst are not re-detected as a weak burst + # immediately adjacent to it (mirrors the guard used in Pass 3). + mask = np.ones_like(smoothed_power, dtype=np.float32) + pass2_mask_expand = window_size + for s, e in pass1_ranges: + mask[max(0, s - pass2_mask_expand) : min(len(mask), e + pass2_mask_expand)] = 0.0 + residual_power = smoothed_power * mask + + residual_max = float(np.max(residual_power)) + residual_ratio = residual_max / max(noise_floor, 1e-12) + + pass2_ranges: list[tuple[int, int]] = [] + if residual_ratio >= 2.0: + weak_threshold = max(0.3, threshold * 0.7) + weak_trigger = noise_floor + weak_threshold * (residual_max - noise_floor) + weak_boundary = noise_floor + 0.5 * weak_threshold * (residual_max - noise_floor) + weak_indices = np.where(residual_power > weak_trigger)[0] + pass2_initial = _find_ranges(indices=weak_indices, max_gap=group_gap_samples) + pass2_ranges = _expand_and_filter_ranges( + smoothed_power=residual_power, + initial_ranges=pass2_initial, + boundary_val=weak_boundary, + min_duration_samples=min_duration_samples, + ) + + # Pass 3: Detect sustained faint bursts via macro-window averaging. + # Targets bursts whose peak power is near the trigger level but whose + # *average* power is consistently elevated above the noise floor — these + # are missed by peak-based detection because only a few short spikes exceed + # the trigger, all too brief to pass the minimum-duration filter. + # + # The mask is applied to power_data *before* convolving so that bright + # burst energy does not bleed through the long window into adjacent regions, + # which would inflate macro_residual_max and push the trigger above the + # faint burst's average power. + macro_window_size = max(window_size * 16, int(sample_rate * 0.02)) + macro_kernel = np.ones(macro_window_size, dtype=np.float64) / macro_window_size + # Expand each annotated range by half the macro window on both sides so that + # the long convolution cannot "see" the leading/trailing edges of already- + # annotated bursts, which would produce spurious short fragments in Pass 3. + macro_expand = macro_window_size * 2 + masked_power_for_macro = power_data.copy() + n = len(masked_power_for_macro) + for s, e in pass1_ranges + pass2_ranges: + masked_power_for_macro[max(0, s - macro_expand) : min(n, e + macro_expand)] = 0.0 + macro_residual = np.convolve(masked_power_for_macro, macro_kernel, mode="same") + + macro_residual_max = float(np.max(macro_residual)) + + pass3_ranges: list[tuple[int, int]] = [] + if macro_residual_max / max(noise_floor, 1e-12) >= 1.3: + macro_trigger = noise_floor + threshold * (macro_residual_max - noise_floor) + macro_boundary = noise_floor + 0.5 * threshold * (macro_residual_max - noise_floor) + macro_indices = np.where(macro_residual > macro_trigger)[0] + macro_initial = _find_ranges(indices=macro_indices, max_gap=group_gap_samples) + pass3_ranges = _expand_and_filter_ranges( + smoothed_power=macro_residual, + initial_ranges=macro_initial, + boundary_val=macro_boundary, + min_duration_samples=min_duration_samples, + ) + + all_ranges = _merge_ranges(pass1_ranges + pass2_ranges + pass3_ranges, max_gap=group_gap_samples) + + for true_start, true_stop in all_ranges: + + # --- 4. SPECTRAL ANALYSIS (Frequency Detection) --- + signal_segment = sample_data[true_start:true_stop] + f_min, f_max = _estimate_spectral_bounds(signal_segment, sample_rate) + + # --- 5. ANNOTATION GENERATION --- + ann_label = label if label is not None else f"{int(threshold*100)}%" + + # Pack metadata for the UI/Downstream processing + comment_data = { + "type": annotation_type, + "generator": "threshold_qualifier", + "params": { + "threshold": threshold, + "window_size": window_size, + }, + } + + anno = Annotation( + sample_start=true_start, + sample_count=true_stop - true_start, + freq_lower_edge=center_frequency + f_min, + freq_upper_edge=center_frequency + f_max, + label=ann_label, + comment=json.dumps(comment_data), + detail={"generator": "hysteresis_qualifier"}, + ) + annotations.append(anno) + + # Return a new Recording object including the new annotations + return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations) diff --git a/src/ria_toolkit_oss/data/__init__.py b/src/ria_toolkit_oss/data/__init__.py new file mode 100644 index 0000000..b72f469 --- /dev/null +++ b/src/ria_toolkit_oss/data/__init__.py @@ -0,0 +1,8 @@ +""" +The Data package contains abstract data types tailored for radio machine learning, such as ``Recording``, as well +as the abstract interfaces for the radio dataset and radio dataset builder framework. +""" + +__all__ = ["Annotation", "Recording"] +from .annotation import Annotation +from .recording import Recording diff --git a/src/ria_toolkit_oss/data/annotation.py b/src/ria_toolkit_oss/data/annotation.py new file mode 100644 index 0000000..1182480 --- /dev/null +++ b/src/ria_toolkit_oss/data/annotation.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import json +from typing import Any, Optional + +from sigmf import SigMFFile + + +class Annotation: + """Signal annotations are labels or additional information associated with specific data points or segments within + a signal. These annotations could be used for tasks like supervised learning, where the goal is to train a model + to recognize patterns or characteristics in the signal associated with these annotations. + + Annotations can be used to label interesting points in your recording. + + :param sample_start: The index of the starting sample of the annotation. + :type sample_start: int + :param sample_count: The index of the ending sample of the annotation, inclusive. + :type sample_count: int + :param freq_lower_edge: The lower frequency of the annotation. + :type freq_lower_edge: float + :param freq_upper_edge: The upper frequency of the annotation. + :type freq_upper_edge: float + :param label: The label that will be displayed with the bounding box in compatible viewers including IQEngine. + Defaults to an emtpy string. + :type label: str, optional + :param comment: A human-readable comment. Defaults to an empty string. + :type comment: str, optional + :param detail: A dictionary of user defined annotation-specific metadata. Defaults to None. + :type detail: dict, optional + """ + + def __init__( + self, + sample_start: int, + sample_count: int, + freq_lower_edge: float, + freq_upper_edge: float, + label: Optional[str] = "", + comment: Optional[str] = "", + detail: Optional[dict] = None, + ): + """Initialize a new Annotation instance.""" + self.sample_start = int(sample_start) + self.sample_count = int(sample_count) + self.freq_lower_edge = float(freq_lower_edge) + self.freq_upper_edge = float(freq_upper_edge) + self.label = str(label) + self.comment = str(comment) + + if detail is None: + self.detail = {} + elif not _is_jsonable(detail): + raise ValueError(f"Detail object is not json serializable: {detail}") + else: + self.detail = detail + + def is_valid(self) -> bool: + """ + Check that the annotation sample count is > 0 and the freq_lower_edge 0 and self.freq_lower_edge < self.freq_upper_edge + + def overlap(self, other): + """ + Quantify how much the bounding box in this annotation overlaps with another annotation. + + :param other: The other annotation. + :type other: Annotation + + :returns: The area of the overlap in samples*frequency, or 0 if they do not overlap.""" + + sample_overlap_start = max(self.sample_start, other.sample_start) + sample_overlap_end = min(self.sample_start + self.sample_count, other.sample_start + other.sample_count) + + freq_overlap_start = max(self.freq_lower_edge, other.freq_lower_edge) + freq_overlap_end = min(self.freq_upper_edge, other.freq_upper_edge) + + if freq_overlap_start >= freq_overlap_end or sample_overlap_start >= sample_overlap_end: + return 0 + else: + return (sample_overlap_end - sample_overlap_start) * (freq_overlap_end - freq_overlap_start) + + def area(self): + """ + The 'area' of the bounding box, samples*frequency. + Useful to quantify annotation size. + + :returns: sample length multiplied by bandwidth.""" + + return self.sample_count * (self.freq_upper_edge - self.freq_lower_edge) + + def __eq__(self, other: Annotation) -> bool: + return self.__dict__ == other.__dict__ + + def to_sigmf_format(self): + """ + Returns a JSON dictionary representing this annotation formatted to be saved in a .sigmf-meta file. + """ + + annotation_dict = {SigMFFile.START_INDEX_KEY: self.sample_start, SigMFFile.LENGTH_INDEX_KEY: self.sample_count} + + annotation_dict["metadata"] = { + SigMFFile.LABEL_KEY: self.label, + SigMFFile.COMMENT_KEY: self.comment, + SigMFFile.FHI_KEY: self.freq_upper_edge, + SigMFFile.FLO_KEY: self.freq_lower_edge, + "ria:detail": self.detail, + } + + if _is_jsonable(annotation_dict): + return annotation_dict + else: + raise ValueError("Annotation dictionary was not json serializable.") + + +def _is_jsonable(x: Any) -> bool: + """ + :return: True if x is JSON serializable, False otherwise. + """ + try: + json.dumps(x) + return True + except (TypeError, OverflowError): + return False diff --git a/src/ria_toolkit_oss/data/recording.py b/src/ria_toolkit_oss/data/recording.py new file mode 100644 index 0000000..20939bd --- /dev/null +++ b/src/ria_toolkit_oss/data/recording.py @@ -0,0 +1,853 @@ +from __future__ import annotations + +import copy +import hashlib +import json +import os +import re +import time +import warnings +from typing import Any, Iterator, Optional + +import numpy as np +from numpy.typing import ArrayLike + +from ria_toolkit_oss.datatypes.annotation import Annotation + +PROTECTED_KEYS = ["rec_id", "timestamp"] + + +class Recording: + """Tape of complex IQ (in-phase and quadrature) samples with associated metadata and annotations. + + Recording data is a complex array of shape C x N, where C is the number of channels + and N is the number of samples in each channel. + + Metadata is stored in a dictionary of key value pairs, + to include information such as sample_rate and center_frequency. + + Annotations are a list of :ref:`Annotation `, + defining bounding boxes in time and frequency with labels and metadata. + + Here, signal data is represented as a NumPy array. This class is then extended in the RIA Backends to provide + support for different data structures, such as Tensors. + + Recordings are long-form tapes can be obtained either from a software-defined radio (SDR) or generated + synthetically. Then, machine learning datasets are curated from collection of recordings by segmenting these + longer-form tapes into shorter units called slices. + + All recordings are assigned a unique 64-character recording ID, ``rec_id``. If this field is missing from the + provided metadata, a new ID will be generated upon object instantiation. + + :param data: Signal data as a tape IQ samples, either C x N complex, where C is the number of + channels and N is number of samples in the signal. If data is a one-dimensional array of complex samples with + length N, it will be reshaped to a two-dimensional array with dimensions 1 x N. + :type data: array_like + + :param metadata: Additional information associated with the recording. + :type metadata: dict, optional + :param annotations: A collection of ``Annotation`` objects defining bounding boxes. + :type annotations: list of Annotations, optional + + :param dtype: Explicitly specify the data-type of the complex samples. Must be a complex NumPy type, such as + ``np.complex64`` or ``np.complex128``. Default is None, in which case the type is determined implicitly. If + ``data`` is a NumPy array, the Recording will use the dtype of ``data`` directly without any conversion. + :type dtype: numpy dtype object, optional + :param timestamp: The timestamp when the recording data was generated. If provided, it should be a float or integer + representing the time in seconds since epoch (e.g., ``time.time()``). Only used if the `timestamp` field is not + present in the provided metadata. + :type dtype: float or int, optional + + :raises ValueError: If data is not complex 1xN or CxN. + :raises ValueError: If metadata is not a python dict. + :raises ValueError: If metadata is not json serializable. + :raises ValueError: If annotations is not a list of valid annotation objects. + + **Examples:** + + >>> import numpy + >>> from ria_toolkit_oss.datatypes import Recording, Annotation + + >>> # Create an array of complex samples, just 1s in this case. + >>> samples = numpy.ones(10000, dtype=numpy.complex64) + + >>> # Create a dictionary of relevant metadata. + >>> sample_rate = 1e6 + >>> center_frequency = 2.44e9 + >>> metadata = { + ... "sample_rate": sample_rate, + ... "center_frequency": center_frequency, + ... "author": "me", + ... } + + >>> # Create an annotation for the annotations list. + >>> annotations = [ + ... Annotation( + ... sample_start=0, + ... sample_count=1000, + ... freq_lower_edge=center_frequency - (sample_rate / 2), + ... freq_upper_edge=center_frequency + (sample_rate / 2), + ... label="example", + ... ) + ... ] + + >>> # Store samples, metadata, and annotations together in a convenient object. + >>> recording = Recording(data=samples, metadata=metadata, annotations=annotations) + >>> print(recording.metadata) + {'sample_rate': 1000000.0, 'center_frequency': 2440000000.0, 'author': 'me'} + >>> print(recording.annotations[0].label) + 'example' + """ + + def __init__( # noqa C901 + self, + data: ArrayLike | list[list], + metadata: Optional[dict[str, any]] = None, + dtype: Optional[np.dtype] = None, + timestamp: Optional[float | int] = None, + annotations: Optional[list[Annotation]] = None, + ): + + data_arr = np.asarray(data) + + if np.iscomplexobj(data_arr): + # Expect C x N + if data_arr.ndim == 1: + self._data = np.expand_dims(data_arr, axis=0) # N -> 1 x N + elif data_arr.ndim == 2: + self._data = data_arr + else: + raise ValueError("Complex data must be C x N.") + + else: + raise ValueError("Input data must be complex.") + + if dtype is not None: + self._data = self._data.astype(dtype) + + assert np.iscomplexobj(self._data) + + if metadata is None: + self._metadata = {} + elif isinstance(metadata, dict): + self._metadata = metadata + else: + raise ValueError(f"Metadata must be a python dict, but was {type(metadata)}.") + + if not _is_jsonable(metadata): + raise ValueError("Value must be JSON serializable.") + + if "timestamp" not in self.metadata: + if timestamp is not None: + if not isinstance(timestamp, (int, float)): + raise ValueError(f"timestamp must be int or float, not {type(timestamp)}") + self._metadata["timestamp"] = timestamp + else: + self._metadata["timestamp"] = time.time() + else: + if not isinstance(self._metadata["timestamp"], (int, float)): + raise ValueError("timestamp must be int or float, not ", type(self._metadata["timestamp"])) + + if "rec_id" not in self.metadata: + self._metadata["rec_id"] = generate_recording_id(data=self.data, timestamp=self._metadata["timestamp"]) + + if annotations is None: + self._annotations = [] + elif isinstance(annotations, list): + self._annotations = annotations + else: + raise ValueError("Annotations must be a list or None.") + + if not all(isinstance(annotation, Annotation) for annotation in self._annotations): + raise ValueError("All elements in self._annotations must be of type Annotation.") + + self._index = 0 + + @property + def data(self) -> np.ndarray: + """ + :return: Recording data, as a complex array. + :type: np.ndarray + + .. note:: + + For recordings with more than 1,024 samples, this property returns a read-only view of the data. + + .. note:: + + To access specific samples, consider indexing the object directly with ``rec[c, n]``. + """ + if self._data.size > 1024: + # Returning a read-only view prevents mutation at a distance while maintaining performance. + v = self._data.view() + v.setflags(write=False) + return v + else: + return self._data.copy() + + @property + def metadata(self) -> dict: + """ + :return: Dictionary of recording metadata. + :type: dict + """ + return self._metadata.copy() + + @property + def annotations(self) -> list[Annotation]: + """ + :return: List of recording annotations + :type: list of Annotation objects + """ + return self._annotations.copy() + + @property + def shape(self) -> tuple[int]: + """ + :return: The shape of the data array. + :type: tuple of ints + """ + return np.shape(self.data) + + @property + def n_chan(self) -> int: + """ + :return: The number of channels in the recording. + :type: int + """ + return self.shape[0] + + @property + def rec_id(self) -> str: + """ + :return: Recording ID. + :type: str + """ + return self.metadata["rec_id"] + + @property + def dtype(self) -> str: + """ + :return: Data-type of the data array's elements. + :type: numpy dtype object + """ + return self.data.dtype + + @property + def timestamp(self) -> float | int: + """ + :return: Recording timestamp (time in seconds since epoch). + :type: float or int + """ + return self.metadata["timestamp"] + + @property + def sample_rate(self) -> float | None: + """ + :return: Sample rate of the recording, or None if 'sample_rate' is not in metadata. + :type: str + """ + return self.metadata.get("sample_rate") + + @sample_rate.setter + def sample_rate(self, sample_rate: float | int) -> None: + """Set the sample rate of the recording. + + :param sample_rate: The sample rate of the recording. + :type sample_rate: float or int + + :return: None + """ + self.add_to_metadata(key="sample_rate", value=sample_rate) + + def astype(self, dtype: np.dtype) -> Recording: + """Copy of the recording, data cast to a specified type. + + .. todo: This method is not yet implemented. + + :param dtype: Data-type to which the array is cast. Must be a complex scalar type, such as ``np.complex64`` or + ``np.complex128``. + :type dtype: NumPy data type, optional + + .. note: Casting to a data type with less precision can risk losing data by truncating or rounding values, + potentially resulting in a loss of accuracy and significant information. + + :return: A new recording with the same metadata and data, with dtype. + + TODO: Add example usage. + """ + # Rather than check for a valid datatype, let's cast and check the result. This makes it easier to provide + # cross-platform support where the types are aliased across platforms. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # Casting may generate user warnings. E.g., complex -> real + data = self.data.astype(dtype) + + if np.iscomplexobj(data): + return Recording(data=data, metadata=self.metadata, annotations=self.annotations) + else: + raise ValueError("dtype must be a complex number scalar type.") + + def add_to_metadata(self, key: str, value: Any) -> None: + """Add a new key-value pair to the recording metadata. + + :param key: New metadata key, must be snake_case. + :type key: str + :param value: Corresponding metadata value. + :type value: any + + :raises ValueError: If key is already in metadata or if key is not a valid metadata key. + :raises ValueError: If value is not JSON serializable. + + :return: None. + + **Examples:** + + Create a recording and add metadata: + + >>> import numpy + >>> from ria_toolkit_oss.datatypes import Recording + >>> + >>> samples = numpy.ones(10000, dtype=numpy.complex64) + >>> metadata = { + >>> "sample_rate": 1e6, + >>> "center_frequency": 2.44e9, + >>> } + >>> + >>> recording = Recording(data=samples, metadata=metadata) + >>> print(recording.metadata) + {'sample_rate': 1000000.0, + 'center_frequency': 2440000000.0, + 'timestamp': 17369..., + 'rec_id': 'fda0f41...'} + >>> + >>> recording.add_to_metadata(key="author", value="me") + >>> print(recording.metadata) + {'sample_rate': 1000000.0, + 'center_frequency': 2440000000.0, + 'author': 'me', + 'timestamp': 17369..., + 'rec_id': 'fda0f41...'} + """ + if key in self.metadata: + raise ValueError( + f"Key {key} already in metadata. Use Recording.update_metadata() to modify existing fields." + ) + + if not _is_valid_metadata_key(key): + raise ValueError(f"Invalid metadata key: {key}.") + + if not _is_jsonable(value): + raise ValueError("Value must be JSON serializable.") + + self._metadata[key] = value + + def update_metadata(self, key: str, value: Any) -> None: + """Update the value of an existing metadata key, + or add the key value pair if it does not already exist. + + :param key: Existing metadata key. + :type key: str + :param value: New value to enter at key. + :type value: any + + :raises ValueError: If value is not JSON serializable + :raises ValueError: If key is protected. + + :return: None. + + **Examples:** + + Create a recording and update metadata: + + >>> import numpy + >>> from ria_toolkit_oss.datatypes import Recording + + >>> samples = numpy.ones(10000, dtype=numpy.complex64) + >>> metadata = { + >>> "sample_rate": 1e6, + >>> "center_frequency": 2.44e9, + >>> "author": "me" + >>> } + + >>> recording = Recording(data=samples, metadata=metadata) + >>> print(recording.metadata) + {'sample_rate': 1000000.0, + 'center_frequency': 2440000000.0, + 'author': "me", + 'timestamp': 17369... + 'rec_id': 'fda0f41...'} + + >>> recording.update_metadata(key="author", value=you") + >>> print(recording.metadata) + {'sample_rate': 1000000.0, + 'center_frequency': 2440000000.0, + 'author': "you", + 'timestamp': 17369... + 'rec_id': 'fda0f41...'} + """ + if key not in self.metadata: + self.add_to_metadata(key=key, value=value) + + if not _is_jsonable(value): + raise ValueError("Value must be JSON serializable.") + + if key in PROTECTED_KEYS: # Check protected keys. + raise ValueError(f"Key {key} is protected and cannot be modified or removed.") + + else: + self._metadata[key] = value + + def remove_from_metadata(self, key: str): + """ + Remove a key from the recording metadata. + Does not remove key if it is protected. + + :param key: The key to remove. + :type key: str + + :raises ValueError: If key is protected. + + :return: None. + + **Examples:** + + Create a recording and add metadata: + + >>> import numpy + >>> from ria_toolkit_oss.datatypes import Recording + + >>> samples = numpy.ones(10000, dtype=numpy.complex64) + >>> metadata = { + ... "sample_rate": 1e6, + ... "center_frequency": 2.44e9, + ... } + + >>> recording = Recording(data=samples, metadata=metadata) + >>> print(recording.metadata) + {'sample_rate': 1000000.0, + 'center_frequency': 2440000000.0, + 'timestamp': 17369..., # Example value + 'rec_id': 'fda0f41...'} # Example value + + >>> recording.add_to_metadata(key="author", value="me") + >>> print(recording.metadata) + {'sample_rate': 1000000.0, + 'center_frequency': 2440000000.0, + 'author': 'me', + 'timestamp': 17369..., # Example value + 'rec_id': 'fda0f41...'} # Example value + """ + if key not in PROTECTED_KEYS: + self._metadata.pop(key) + else: + raise ValueError(f"Key {key} is protected and cannot be modified or removed.") + + def view(self, output_path: Optional[str] = "images/signal.png", **kwargs) -> None: + """Create a plot of various signal visualizations as a PNG image. + + :param output_path: The output image path. Defaults to "images/signal.png". + :type output_path: str, optional + :param kwargs: Keyword arguments passed on to utils.view.view_sig. + :type: dict of keyword arguments + + **Examples:** + + Create a recording and view it as a plot in a .png image: + + >>> import numpy + >>> from ria_toolkit_oss.datatypes import Recording + + >>> samples = numpy.ones(10000, dtype=numpy.complex64) + >>> metadata = { + >>> "sample_rate": 1e6, + >>> "center_frequency": 2.44e9, + >>> } + + >>> recording = Recording(data=samples, metadata=metadata) + >>> recording.view() + """ + from ria_toolkit_oss.view import view_sig + + view_sig(recording=self, output_path=output_path, **kwargs) + + def simple_view(self, **kwargs) -> None: + """Create a plot of various signal visualizations as a PNG or SVG image. + + :param kwargs: Keyword arguments passed on to ria_toolkit_oss.view.view_signal_simple.create_plots. + :type: dict of keyword arguments + + **Examples:** + + Create a recording and view it as a plot in a .png image: + + >>> import numpy + >>> from ria_toolkit_oss.datatypes import Recording + + >>> samples = numpy.ones(10000, dtype=numpy.complex64) + >>> metadata = { + >>> "sample_rate": 1e6, + >>> "center_frequency": 2.44e9, + >>> } + + >>> recording = Recording(data=samples, metadata=metadata) + >>> recording.simple_view() + """ + from ria_toolkit_oss.view.view_signal_simple import view_simple_sig + + view_simple_sig(recording=self, **kwargs) + + def to_sigmf( + self, filename: Optional[str] = None, path: Optional[os.PathLike | str] = None, overwrite: bool = False + ) -> None: + """Write recording to a set of SigMF files. + + The SigMF io format is defined by the `SigMF Specification Project `_ + + :param recording: The recording to be written to file. + :type recording: utils.data.Recording + :param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename. + :type filename: os.PathLike or str, optional + :param path: The directory path to where the recording is to be saved. Defaults to recordings/. + :type path: os.PathLike or str, optional + + :raises IOError: If there is an issue encountered during the file writing process. + + :return: None + + **Examples:** + + Create a recording and view it as a plot in a `.png` image: + + >>> import numpy + >>> from utils.data import Recording + + >>> samples = numpy.ones(10000, dtype=numpy.complex64) + >>> metadata = { + ... "sample_rate": 1e6, + ... "center_frequency": 2.44e9, + ... } + + >>> recording = Recording(data=samples, metadata=metadata) + >>> recording.view() + """ + from ria_toolkit_oss.io.recording import to_sigmf + + to_sigmf(filename=filename, path=path, recording=self, overwrite=overwrite) + + def to_npy( + self, filename: Optional[str] = None, path: Optional[os.PathLike | str] = None, overwrite: bool = False + ) -> str: + """Write recording to ``.npy`` binary file. + + :param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename. + :type filename: os.PathLike or str, optional + :param path: The directory path to where the recording is to be saved. Defaults to recordings/. + :type path: os.PathLike or str, optional + + :raises IOError: If there is an issue encountered during the file writing process. + + :return: Path where the file was saved. + :rtype: str + + **Examples:** + + Create a recording and save it to a .npy file: + + >>> import numpy + >>> from utils.data import Recording + + >>> samples = numpy.ones(10000, dtype=numpy.complex64) + >>> metadata = { + >>> "sample_rate": 1e6, + >>> "center_frequency": 2.44e9, + >>> } + + >>> recording = Recording(data=samples, metadata=metadata) + >>> recording.to_npy() + """ + from ria_toolkit_oss.io.recording import to_npy + + to_npy(recording=self, filename=filename, path=path, overwrite=overwrite) + + def to_wav( + self, + filename: Optional[str] = None, + path: Optional[os.PathLike | str] = None, + target_sample_rate: Optional[int] = 48000, + bits_per_sample: int = 32, + overwrite: bool = False, + ) -> str: + """Write recording to WAV file with embedded YAML metadata. + + WAV format uses stereo audio with I (in-phase) in left channel and Q (quadrature) in right channel. + Metadata is stored in standard LIST INFO chunks with RF-specific metadata encoded as YAML + in the ICMT (comment) field for human readability. + + :param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename. + :type filename: os.PathLike or str, optional + :param path: The directory path to where the recording is to be saved. Defaults to recordings/. + :type path: os.PathLike or str, optional + :param target_sample_rate: Sample rate stored in the WAV header when no sample_rate metadata + is present. IQ samples are written without decimation or interpolation. Default is 48000 Hz. + :type target_sample_rate: int, optional + :param bits_per_sample: Bits per sample (32 for float32, 16 for int16). Default is 32. + :type bits_per_sample: int, optional + :param overwrite: Whether to overwrite existing files. Default is False. + :type overwrite: bool, optional + + :raises IOError: If there is an issue encountered during the file writing process. + + :return: Path where the file was saved. + :rtype: str + + **Examples:** + + Create a recording and save it to a .wav file: + + >>> import numpy + >>> from utils.data import Recording + >>> samples = numpy.exp(1j * 2 * numpy.pi * 0.1 * numpy.arange(10000)) + >>> metadata = {"sample_rate": 1e6, "center_frequency": 915e6} + >>> recording = Recording(data=samples, metadata=metadata) + >>> recording.to_wav() + """ + from ria_toolkit_oss.io.recording import to_wav + + return to_wav( + recording=self, + filename=filename, + path=path, + target_sample_rate=target_sample_rate, + bits_per_sample=bits_per_sample, + overwrite=overwrite, + ) + + def to_blue( + self, + filename: Optional[str] = None, + path: Optional[os.PathLike | str] = None, + data_format: str = "CI", + overwrite: bool = False, + ) -> str: + """Write recording to MIDAS Blue file format. + + MIDAS Blue is a legacy RF file format with a 512-byte binary header. + Commonly used with X-Midas and other RF/radar signal processing tools. + + :param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename. + :type filename: os.PathLike or str, optional + :param path: The directory path to where the recording is to be saved. Defaults to recordings/. + :type path: os.PathLike or str, optional + :param data_format: Format code (default 'CI' = complex int16). + Common formats: 'CI' (complex int16), 'CF' (complex float32), 'CD' (complex float64). + Integer formats require the IQ samples to already be scaled within [-1, 1). + :type data_format: str, optional + :param overwrite: Whether to overwrite existing files. Default is False. + :type overwrite: bool, optional + + :raises IOError: If there is an issue encountered during the file writing process. + + :return: Path where the file was saved. + :rtype: str + + **Examples:** + + Create a recording and save it to a .blue file: + + >>> import numpy + >>> from utils.data import Recording + >>> samples = numpy.ones(10000, dtype=numpy.complex64) + >>> metadata = {"sample_rate": 1e6, "center_frequency": 2.44e9} + >>> recording = Recording(data=samples, metadata=metadata) + >>> recording.to_blue() + """ + from ria_toolkit_oss.io.recording import to_blue + + return to_blue(recording=self, filename=filename, path=path, data_format=data_format, overwrite=overwrite) + + def trim(self, num_samples: int, start_sample: Optional[int] = 0) -> Recording: + """Trim Recording samples to a desired length, shifting annotations to maintain alignment. + + :param start_sample: The start index of the desired trimmed recording. Defaults to 0. + :type start_sample: int, optional + :param num_samples: The number of samples that the output trimmed recording will have. + :type num_samples: int + :raises IndexError: If start_sample + num_samples is greater than the length of the recording. + :raises IndexError: If sample_start < 0 or num_samples < 0. + + :return: The trimmed Recording. + :rtype: Recording + + **Examples:** + + Create a recording and trim it: + + >>> import numpy + >>> from utils.data import Recording + + >>> samples = numpy.ones(10000, dtype=numpy.complex64) + >>> metadata = { + ... "sample_rate": 1e6, + ... "center_frequency": 2.44e9, + ... } + + >>> recording = Recording(data=samples, metadata=metadata) + >>> print(len(recording)) + 10000 + + >>> trimmed_recording = recording.trim(start_sample=1000, num_samples=1000) + >>> print(len(trimmed_recording)) + 1000 + """ + + if start_sample < 0: + raise IndexError("start_sample cannot be < 0.") + elif start_sample + num_samples > len(self): + raise IndexError( + f"start_sample {start_sample} + num_samples {num_samples} > recording length {len(self)}." + ) + + end_sample = start_sample + num_samples + + data = self.data[:, start_sample:end_sample] + + new_annotations = copy.deepcopy(self.annotations) + for annotation in new_annotations: + # trim annotation if it goes outside the trim boundaries + if annotation.sample_start < start_sample: + annotation.sample_count = annotation.sample_count - (start_sample - annotation.sample_start) + annotation.sample_start = start_sample + + if annotation.sample_start + annotation.sample_count > end_sample: + annotation.sample_count = end_sample - annotation.sample_start + + # shift annotation to align with the new start point + annotation.sample_start = annotation.sample_start - start_sample + + return Recording(data=data, metadata=self.metadata, annotations=new_annotations) + + def normalize(self) -> Recording: + """Scale the recording data, relative to its maximum value, so that the magnitude of the maximum sample is 1. + + :return: Recording where the maximum sample amplitude is 1. + :rtype: Recording + + **Examples:** + + Create a recording with maximum amplitude 0.5 and normalize to a maximum amplitude of 1: + + >>> import numpy + >>> from utils.data import Recording + + >>> samples = numpy.ones(10000, dtype=numpy.complex64) * 0.5 + >>> metadata = { + ... "sample_rate": 1e6, + ... "center_frequency": 2.44e9, + ... } + + >>> recording = Recording(data=samples, metadata=metadata) + >>> print(numpy.max(numpy.abs(recording.data))) + 0.5 + + >>> normalized_recording = recording.normalize() + >>> print(numpy.max(numpy.abs(normalized_recording.data))) + 1 + """ + scaled_data = self.data / np.max(abs(self.data)) + return Recording(data=scaled_data, metadata=self.metadata, annotations=self.annotations) + + def __len__(self) -> int: + """The length of a recording is defined by the number of complex samples in each channel of the recording.""" + return self.shape[1] + + def __eq__(self, other: Recording) -> bool: + """Two Recordings are equal if all data, metadata, and annotations are the same.""" + + # counter used to allow for differently ordered annotation lists + return ( + np.array_equal(self.data, other.data) + and self.metadata == other.metadata + and self.annotations == other.annotations + ) + + def __ne__(self, other: Recording) -> bool: + """Two Recordings are equal if all data, and metadata, and annotations are the same.""" + return not self.__eq__(other=other) + + def __iter__(self) -> Iterator: + self._index = 0 + return self + + def __next__(self) -> np.ndarray: + if self._index < self.n_chan: + to_ret = self.data[self._index] + self._index += 1 + return to_ret + else: + raise StopIteration + + def __getitem__(self, key: int | tuple[int] | slice) -> np.ndarray | np.complexfloating: + """If key is an integer, tuple of integers, or a slice, return the corresponding samples. + + For arrays with 1,024 or fewer samples, return a copy of the recording data. For larger arrays, return a + read-only view. This prevents mutation at a distance while maintaining performance. + """ + if isinstance(key, (int, tuple, slice)): + v = self._data[key] + if isinstance(v, np.complexfloating): + return v + elif v.size > 1024: + v.setflags(write=False) # Make view read-only. + return v + else: + return v.copy() + + else: + raise ValueError(f"Key must be an integer, tuple, or slice but was {type(key)}.") + + def __setitem__(self, *args, **kwargs) -> None: + """Raise an error if an attempt is made to assign to the recording.""" + raise ValueError("Assignment to Recording is not allowed.") + + +def generate_recording_id(data: np.ndarray, timestamp: Optional[float | int] = None) -> str: + """Generate unique 64-character recording ID. The recording ID is generated by hashing the recording data with + the datetime that the recording data was generated. If no datatime is provided, the current datatime is used. + + :param data: Tape of IQ samples, as a NumPy array. + :type data: np.ndarray + :param timestamp: Unix timestamp in seconds. Defaults to None. + :type timestamp: float or int, optional + + :return: 256-character hash, to be used as the recording ID. + :rtype: str + """ + if timestamp is None: + timestamp = time.time() + + byte_sequence = data.tobytes() + str(timestamp).encode("utf-8") + sha256_hash = hashlib.sha256(byte_sequence) + + return sha256_hash.hexdigest() + + +def _is_jsonable(x: Any) -> bool: + """ + :return: True if x is JSON serializable, False otherwise. + """ + try: + json.dumps(x) + return True + except (TypeError, OverflowError): + return False + + +def _is_valid_metadata_key(key: Any) -> bool: + """ + :return: True if key is a valid metadata key, False otherwise. + """ + if isinstance(key, str) and key.islower() and re.match(pattern=r"^[a-z_]+$", string=key) is not None: + return True + + else: + return False diff --git a/src/ria_toolkit_oss/io/recording.py b/src/ria_toolkit_oss/io/recording.py index ae38bc8..1a81a04 100644 --- a/src/ria_toolkit_oss/io/recording.py +++ b/src/ria_toolkit_oss/io/recording.py @@ -367,9 +367,7 @@ def to_sigmf( meta_dict = sigMF_metafile.ordered_metadata() meta_dict["ria"] = metadata - if overwrite and os.path.isfile(meta_file_path): - os.remove(meta_file_path) - sigMF_metafile.tofile(meta_file_path) + sigMF_metafile.tofile(meta_file_path, overwrite=overwrite) def from_sigmf(file: os.PathLike | str) -> Recording: diff --git a/src/ria_toolkit_oss/view/recording.py b/src/ria_toolkit_oss/view/recording.py index 381f07e..b9c413b 100644 --- a/src/ria_toolkit_oss/view/recording.py +++ b/src/ria_toolkit_oss/view/recording.py @@ -11,7 +11,7 @@ def spectrogram(rec: Recording, thumbnail: bool = False) -> Figure: """Create a spectrogram for the recording. :param rec: Signal to plot. - :type rec: utils.data.Recording + :type rec: ria_toolkit_oss.datatypes.Recording :param thumbnail: Whether to return a small thumbnail version or full plot. :type thumbnail: bool @@ -95,7 +95,7 @@ def iq_time_series(rec: Recording) -> Figure: """Create a time series plot of the real and imaginary parts of signal. :param rec: Signal to plot. - :type rec: utils.data.Recording + :type rec: ria_toolkit_oss.datatypes.Recording :return: Time series plot as a Plotly figure. """ @@ -125,7 +125,7 @@ def frequency_spectrum(rec: Recording) -> Figure: """Create a frequency spectrum plot from the recording. :param rec: Input signal to plot. - :type rec: utils.data.Recording + :type rec: ria_toolkit_oss.datatypes.Recording :return: Frequency spectrum as a Plotly figure. """ @@ -160,7 +160,7 @@ def constellation(rec: Recording) -> Figure: """Create a constellation plot from the recording. :param rec: Input signal to plot. - :type rec: utils.data.Recording + :type rec: ria_toolkit_oss.datatypes.Recording :return: Constellation as a Plotly figure. """ diff --git a/src/ria_toolkit_oss/view/view_signal.py b/src/ria_toolkit_oss/view/view_signal.py index f8d5731..ded3c8c 100644 --- a/src/ria_toolkit_oss/view/view_signal.py +++ b/src/ria_toolkit_oss/view/view_signal.py @@ -6,6 +6,7 @@ from typing import Optional import matplotlib.pyplot as plt import numpy as np from matplotlib import gridspec +from matplotlib.patches import Patch from PIL import Image from scipy.fft import fft, fftshift from scipy.signal import spectrogram @@ -39,6 +40,76 @@ def set_spines(ax, spines): ax.spines["left"].set_visible(False) +def view_annotations( + recording: Recording, + channel: Optional[int] = 0, + output_path: Optional[str] = "images/annotations.png", + title: Optional[str] = "Annotated Spectrogram", + dpi: Optional[int] = 300, + title_fontsize: Optional[int] = 15, + dark: Optional[bool] = True, +) -> None: + # 1. Setup Plotting Environment + plt.close("all") + if dark: + plt.style.use("dark_background") + else: + plt.style.use("default") + + fig, ax = plt.subplots(figsize=(12, 8)) + + complex_signal = recording.data[channel] + sample_rate, center_frequency, _ = extract_metadata_fields(recording.metadata) + annotations = recording.annotations + + # 2. Setup Color Mapping + palette = ["#2196F3", "#9C27B0", "#64B5F6", "#7B1FA2", "#5C6BC0", "#CE93D8", "#1565C0", "#7C4DFF"] + unique_labels = sorted(list(set(ann.label for ann in annotations if ann.label))) + label_to_color = {label: palette[i % len(palette)] for i, label in enumerate(unique_labels)} + + # 3. Generate Spectrogram + Pxx, freqs, times, im = ax.specgram( + complex_signal, NFFT=256, Fs=sample_rate, Fc=center_frequency, noverlap=128, cmap="twilight" + ) + + # 4. Draw Annotations (highest threshold % first so lower % renders on top) + def _threshold_sort_key(ann): + try: + return int(ann.label.rstrip("%")) + except (ValueError, AttributeError): + return 0 + + for annotation in sorted(annotations, key=_threshold_sort_key, reverse=True): + t_start = annotation.sample_start / sample_rate + t_width = annotation.sample_count / sample_rate + f_start = annotation.freq_lower_edge + f_height = annotation.freq_upper_edge - annotation.freq_lower_edge + + ann_color = label_to_color.get(annotation.label, "gray") + + rect = plt.Rectangle( + (t_start, f_start), t_width, f_height, linewidth=1.5, edgecolor=ann_color, facecolor="none", alpha=0.8 + ) + ax.add_patch(rect) + + if unique_labels: + legend_elements = [ + Patch(facecolor=label_to_color[label], alpha=0.3, edgecolor=label_to_color[label], label=label) + for label in unique_labels + ] + ax.legend(handles=legend_elements, loc="upper right", framealpha=0.2) + + ax.set_title(title, fontsize=title_fontsize, pad=20) + ax.set_xlabel("Time (s)", fontsize=12) + ax.set_ylabel("Frequency (MHz)", fontsize=12) + ax.grid(alpha=0.1) + + output_path, _ = set_path(output_path=output_path) + plt.savefig(output_path, dpi=dpi, bbox_inches="tight") + plt.close(fig) + print(f"Professional annotation plot saved to {output_path}") + + def view_channels( recording: Recording, output_path: Optional[str] = "images/signal.png", @@ -209,9 +280,7 @@ def view_sig( ) set_spines(spec_ax, spines) - spec_ax.set_title("Spectrogram", fontsize=subtitle_fontsize) - spec_ax.set_ylabel("Frequency (Hz)") - spec_ax.set_xlabel("Time (s)") + spec_ax.set_title("Spectrogram", loc="center", fontsize=subtitle_fontsize) if iq: iq_ax = plt.subplot(gs[plot_y_indx : plot_y_indx + 2, :]) @@ -295,7 +364,11 @@ def view_sig( set_spines(meta_ax, spines) if logo and os.path.isfile(logo_path): - logo_ax = plt.subplot(gs[plot_y_indx + 2 :, 2]) + # logo_ax = plt.subplot(gs[plot_y_indx:, 2]) + logo_pos = [0.75, 0.05, 0.2, 0.08] + logo_ax = fig.add_axes(logo_pos, anchor="SE", zorder=10) + plot_x_indx = plot_x_indx + 1 + logo_ax.axis("off") try: @@ -314,7 +387,6 @@ def view_sig( hspace=2.5, # Vertical space between subplots ) - # save path handling output_path, _ = set_path(output_path=output_path) plt.savefig(output_path, dpi=dpi) print(f"Saved signal plot to {output_path}") diff --git a/src/ria_toolkit_oss/view/view_signal_simple.py b/src/ria_toolkit_oss/view/view_signal_simple.py index ab56f7d..1b847ab 100644 --- a/src/ria_toolkit_oss/view/view_signal_simple.py +++ b/src/ria_toolkit_oss/view/view_signal_simple.py @@ -3,6 +3,7 @@ from __future__ import annotations import gc +import json from typing import Optional import matplotlib @@ -20,6 +21,52 @@ from ria_toolkit_oss.view.tools import ( ) +def _add_annotations(annotations, compact_mode, show_labels, sample_rate_hz, center_freq_hz, ax2): + if annotations and not compact_mode: + for annotation in annotations: + start_idx = annotation.get("core:sample_start", 0) + length = annotation.get("core:sample_count", 0) + start_time = start_idx / sample_rate_hz + end_time = (start_idx + length) / sample_rate_hz + freq_low = annotation.get("core:freq_lower_edge", center_freq_hz - sample_rate_hz / 4) + freq_high = annotation.get("core:freq_upper_edge", center_freq_hz + sample_rate_hz / 4) + comment = annotation.get("core:comment", "{}") + + try: + comment_data = json.loads(comment) if isinstance(comment, str) else comment + ann_type = comment_data.get("type", "unknown") + if ann_type == "intersection": + color = COLORS["success"] + elif ann_type == "parallel": + color = COLORS["primary"] + elif ann_type == "standalone": + color = COLORS["warning"] + else: + color = COLORS["error"] + except Exception: + color = COLORS["error"] + + rect = plt.Rectangle( + (start_time, freq_low), + end_time - start_time, + freq_high - freq_low, + color=color, + alpha=0.4, + linewidth=2, + ) + ax2.add_patch(rect) + if show_labels: + label = annotation.get("core:label", "Signal") + ax2.text( + start_time, + freq_high, + label, + color=COLORS["light"], + fontsize=10, + bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7), + ) + + def _get_nfft_size(signal, fast_mode): if len(signal) < 1000: nfft = 128 @@ -138,6 +185,7 @@ def detect_constellation_symbols(signal: np.ndarray, method: str = "differential def view_simple_sig( recording: Recording, + annotations: Optional[list] = None, output_path: Optional[str] = "images/signal.png", saveplot: Optional[bool] = True, fast_mode: Optional[bool] = False, @@ -261,6 +309,15 @@ def view_simple_sig( ax2.set_title("Spectrogram", loc="left", pad=10) + _add_annotations( + annotations=annotations, + compact_mode=compact_mode, + show_labels=show_labels, + sample_rate_hz=sample_rate_hz, + center_freq_hz=center_freq_hz, + ax2=ax2, + ) + if ax_constellation is not None: constellation_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=50_000, fast_max=20_000) method = "differential" if fast_mode else "combined" @@ -310,7 +367,7 @@ def view_simple_sig( else: plt.tight_layout() if show_title: - plt.subplots_adjust(top=0.90) + plt.subplots_adjust(top=0.92) if saveplot: output_path, extension = set_path(output_path=output_path) diff --git a/src/ria_toolkit_oss_cli/ria_toolkit_oss/annotate.py b/src/ria_toolkit_oss_cli/ria_toolkit_oss/annotate.py new file mode 100644 index 0000000..4a8d6ac --- /dev/null +++ b/src/ria_toolkit_oss_cli/ria_toolkit_oss/annotate.py @@ -0,0 +1,828 @@ +"""Annotate command - Automatic detection and manual annotation management.""" + +import json +from pathlib import Path + +import click + +from ria_toolkit_oss.annotations import ( + annotate_with_cusum, + detect_signals_energy, + split_recording_annotations, + threshold_qualifier, +) +from ria_toolkit_oss.datatypes import Annotation +from ria_toolkit_oss.datatypes.recording import Recording +from ria_toolkit_oss.io import load_recording, to_blue, to_npy, to_sigmf, to_wav +from ria_toolkit_oss_cli.ria_toolkit_oss.common import ( + format_frequency, + format_sample_count, +) + + +def normalize_sigmf_path(filepath): + """Normalize SigMF path to base name without extension.""" + path = Path(filepath) + + # Handle .sigmf-data, .sigmf-meta, or .sigmf + if ".sigmf" in path.suffix: + # Remove the suffix to get base name + return path.with_suffix("") + else: + return path + + +def detect_input_format(filepath): + """Detect file format from extension.""" + path = Path(filepath) + ext = path.suffix.lower() + + if ext in [".sigmf-data", ".sigmf-meta"]: + return "sigmf" + elif path.name.endswith(".sigmf"): + return "sigmf" + elif ext == ".npy": + return "npy" + elif ext == ".wav": + return "wav" + elif ext == ".blue": + return "blue" + else: + raise click.ClickException(f"Unknown format for '{filepath}'. Supported: .sigmf, .npy, .wav, .blue") + + +def determine_output_path(input_path, output_path, fmt, quiet, overwrite): + input_path = Path(input_path) + input_is_annotated = input_path.stem.endswith("_annotated") + + if output_path: + target = Path(output_path) + elif overwrite and input_is_annotated: + # Write back in-place only when the input is already an _annotated file + target = input_path + else: + target = input_path.with_name(f"{input_path.stem}_annotated{input_path.suffix}") + + if fmt == "sigmf": + final_path = normalize_sigmf_path(target) + if not quiet: + click.echo(f"Saving SigMF metadata to: {final_path}") + else: + final_path = target + if not quiet: + click.echo(f"Saving to: {final_path}") + + # Always allow writing to _annotated files; guard against overwriting originals + target_is_annotated = final_path.stem.endswith("_annotated") + if final_path.exists() and not target_is_annotated and final_path != input_path: + click.echo(f"Error: {final_path} is not an annotated file and cannot be overwritten.", err=True) + return None + + return final_path + + +def save_recording_auto(recording, output_path, input_path, quiet=False, overwrite=False): + """Save recording, auto-detecting format from extension. + + For SigMF: Only overwrites metadata file, data file is unchanged + For other formats: Creates _annotated copy by default, unless overwrite=True + """ + input_path = Path(input_path) + fmt = detect_input_format(input_path) + + # Determine output path + output_path = determine_output_path( + input_path=input_path, output_path=output_path, fmt=fmt, quiet=quiet, overwrite=overwrite + ) + + if fmt == "sigmf": + # Normalize path for SigMF + base_path = output_path + stem = base_path.name + parent = base_path.parent + + # For SigMF: only save metadata, copy data if needed + meta_path = parent / f"{stem}.sigmf-meta" + data_path = parent / f"{stem}.sigmf-data" + + # If output is different from input, copy data file + input_base = normalize_sigmf_path(input_path) + if input_base != base_path: + import shutil + + # Construct input data path correctly + # input_base is like /path/to/recording or /path/to/recording.sigmf + # We need /path/to/recording.sigmf-data + if str(input_base).endswith(".sigmf"): + input_data = Path(str(input_base).replace(".sigmf", ".sigmf-data")) + else: + input_data = input_base.parent / f"{input_base.name}.sigmf-data" + if not quiet: + click.echo(f" Copying: {data_path}") + shutil.copy2(input_data, data_path) + + # Always save metadata (this is the whole point) + to_sigmf(recording, filename=stem, path=parent, overwrite=True) + + if not quiet: + click.echo(f" Updated: {meta_path}") + if input_base != base_path: + click.echo(f" Created: {data_path}") + + elif fmt == "npy": + to_npy(recording, filename=output_path.stem, path=output_path.parent, overwrite=True) + if not quiet: + click.echo(f" Created: {output_path}") + elif fmt == "wav": + to_wav(recording, filename=output_path.stem, path=output_path.parent, overwrite=True) + if not quiet: + click.echo(f" Created: {output_path}") + elif fmt == "blue": + to_blue(recording, filename=output_path.stem, path=output_path.parent, overwrite=True) + if not quiet: + click.echo(f" Created: {output_path}") + + +def determine_frequency_bounds(recording: Recording, freq_lower, freq_upper): + # Handle frequency bounds + if (freq_lower is None) != (freq_upper is None): + raise click.ClickException("Must specify both --freq-lower and --freq-upper, or neither") + + if freq_lower is None: + # Default to full bandwidth + sample_rate = recording.metadata.get("sample_rate", 1) + center_freq = recording.metadata.get("center_frequency", 0) + freq_lower = center_freq - (sample_rate / 2) + freq_upper = center_freq + (sample_rate / 2) + freq_default = True + else: + freq_default = False + if freq_lower >= freq_upper: + raise click.ClickException( + f"Invalid frequency range: lower ({format_frequency(freq_lower)}) " + f"must be < upper ({format_frequency(freq_upper)})" + ) + + return freq_lower, freq_upper, freq_default + + +def get_indices_list(indices, recording: Recording): + if indices: + try: + indices_list = [int(idx.strip()) for idx in indices.split(",")] + # Validate indices + for idx in indices_list: + if idx < 0 or idx >= len(recording.annotations): + raise click.ClickException( + f"Invalid index {idx}. Recording has {len(recording.annotations)} annotation(s)" + ) + except ValueError as e: + raise click.ClickException(f"Invalid indices format. Expected comma-separated integers: {e}") + + return indices_list + else: + return None + + +# ============================================================================ +# Main command group +# ============================================================================ + + +@click.group() +def annotate(): + """Manage and auto-detect annotations on RF recordings. + + \b + MANUAL MANAGEMENT: + list - List all current annotations + add - Manually add a specific annotation + remove - Delete an annotation by its index + clear - Remove all annotations from the recording + + \b + DETECTION & SEPARATION: + energy - Auto-detect using energy-based thresholding + cusum - Auto-detect segments using signal state changes + threshold - Auto-detect samples above magnitude percentage + separate - Auto-detect parallel frequency-offset signals, split into sub-bands + + \b + File Path Handling: + - SigMF files: Pass .sigmf-data, .sigmf-meta, or base name + - Other formats: .npy, .wav, .blue files + + \b + Output Behavior: + - SigMF: Updates .sigmf-meta only (data unchanged), in-place + - Other: Creates _annotated copy unless --overwrite specified + """ + pass + + +# ============================================================================ +# List subcommand +# ============================================================================ + + +@annotate.command() +@click.argument("input", type=click.Path(exists=True)) +@click.option("--verbose", is_flag=True, help="Show detailed annotation info") +def list(input, verbose): + """List all annotations in a recording. + + \b + Examples: + ria annotate list recording.sigmf-data + ria annotate list signal.npy --verbose + """ + try: + recording = load_recording(input) + except Exception as e: + raise click.ClickException(f"Failed to load recording: {e}") + + if len(recording.annotations) == 0: + click.echo(f"No annotations in {Path(input).name}") + return + + click.echo(f"\nAnnotations in {Path(input).name}:") + for i, ann in enumerate(recording.annotations): + # Parse type from comment JSON + try: + comment_data = json.loads(ann.comment) + ann_type = comment_data.get("type", "unknown") + user_comment = comment_data.get("user_comment", "") + except (json.JSONDecodeError, TypeError): + ann_type = "unknown" + user_comment = ann.comment or "" + + # Basic info + freq_range = f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}" + click.echo( + f" [{i}] Samples {format_sample_count(ann.sample_start)}-" + f"{format_sample_count(ann.sample_start + ann.sample_count)}: {ann.label}" + ) + click.echo(f" Type: {ann_type}") + + if verbose: + if user_comment: + click.echo(f" Comment: {user_comment}") + click.echo(f" Frequency: {freq_range}") + if ann.detail: + click.echo(f" Detail: {ann.detail}") + + click.echo(f"\nTotal: {len(recording.annotations)} annotation(s)") + + +# ============================================================================ +# Add subcommand +# ============================================================================ + + +@annotate.command(context_settings={"max_content_width": 200}) +@click.argument("input", type=click.Path(exists=True)) +@click.option("--start", type=int, required=True, help="Start sample index") +@click.option("--count", type=int, required=True, help="Sample count") +@click.option("--label", type=str, required=True, help="Annotation label") +@click.option("--freq-lower", type=float, help="Lower frequency edge (Hz)") +@click.option("--freq-upper", type=float, help="Upper frequency edge (Hz)") +@click.option("--comment", type=str, help="Human-readable comment") +@click.option( + "--type", + "annotation_type", + type=click.Choice(["standalone", "parallel", "intersection"]), + default="standalone", + help="Annotation type", +) +@click.option("--output", "-o", type=click.Path(), help="Output file path") +@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)") +@click.option("--quiet", is_flag=True, help="Quiet mode") +def add(input, start, count, label, freq_lower, freq_upper, comment, annotation_type, output, overwrite, quiet): + """Add a manual annotation. + + \b + Examples: + ria annotate add file.npy --start 1000 --count 500 --label wifi + ria annotate add signal.sigmf-data --start 0 --count 1000 --label burst --comment "Strong signal" + """ + try: + recording = load_recording(input) + if not quiet: + click.echo(f"Loaded: {input}") + except Exception as e: + raise click.ClickException(f"Failed to load recording: {e}") + + # Validate sample range + n_samples = len(recording.data[0]) + if start < 0: + raise click.ClickException(f"--start must be >= 0, got {start}") + if count <= 0: + raise click.ClickException(f"--count must be > 0, got {count}") + if start + count > n_samples: + raise click.ClickException( + f"Invalid annotation range:\n" + f" Start: {start:,}\n" + f" Count: {count:,}\n" + f" End: {start + count:,}\n" + f"Recording only has {n_samples:,} samples" + ) + + # Handle frequency bounds + freq_lower, freq_upper, freq_default = determine_frequency_bounds( + recording=recording, freq_lower=freq_lower, freq_upper=freq_upper + ) + + # Build comment JSON + comment_data = {"type": annotation_type} + if comment: + comment_data["user_comment"] = comment + + # Create annotation + ann = Annotation( + sample_start=start, + sample_count=count, + freq_lower_edge=freq_lower, + freq_upper_edge=freq_upper, + label=label, + comment=json.dumps(comment_data), + detail={}, + ) + + recording._annotations.append(ann) + + if not quiet: + click.echo("\nAdding annotation:") + click.echo(f" Start: {format_sample_count(start)}") + click.echo(f" Count: {format_sample_count(count)} samples") + freq_str = ( + "full bandwidth" if freq_default else f"{format_frequency(freq_lower)} - {format_frequency(freq_upper)}" + ) + click.echo(f" Frequency: {freq_str}") + click.echo(f" Label: {label}") + click.echo(f" Type: {annotation_type}") + if comment: + click.echo(f" Comment: {comment}") + + try: + save_recording_auto(recording, output, input, quiet, overwrite) + if not quiet: + click.echo(" ✓ Saved") + except Exception as e: + raise click.ClickException(f"Failed to save: {e}") + + +# ============================================================================ +# Remove subcommand +# ============================================================================ + + +@annotate.command(context_settings={"max_content_width": 200}) +@click.argument("input", type=click.Path(exists=True)) +@click.argument("index", type=int) +@click.option("--output", "-o", type=click.Path(), help="Output file path") +@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)") +@click.option("--quiet", is_flag=True, help="Quiet mode") +def remove(input, index, output, overwrite, quiet): + """Remove annotation by index. + + Use 'ria annotate list' to see annotation indices. + + \b + Examples: + ria annotate remove signal.sigmf-data 2 + ria annotate remove file.npy 0 + """ + try: + recording = load_recording(input) + if not quiet: + click.echo(f"Loaded: {input}") + except Exception as e: + raise click.ClickException(f"Failed to load recording: {e}") + + if index < 0 or index >= len(recording.annotations): + raise click.ClickException( + f"Cannot remove annotation at index {index}\n" + f"Recording has {len(recording.annotations)} annotation(s) (indices 0-{len(recording.annotations)-1})" + ) + + removed_ann = recording.annotations[index] + recording._annotations.pop(index) + + if not quiet: + click.echo(f"\nRemoving annotation [{index}]:") + click.echo( + f" Removed: samples {format_sample_count(removed_ann.sample_start)}-" + f"{format_sample_count(removed_ann.sample_start + removed_ann.sample_count)} ({removed_ann.label})" + ) + + try: + save_recording_auto(recording, output_path=input, input_path=input, quiet=quiet, overwrite=True) + if not quiet: + click.echo(" ✓ Saved") + except Exception as e: + raise click.ClickException(f"Failed to save: {e}") + + +# ============================================================================ +# Clear subcommand +# ============================================================================ + + +@annotate.command(context_settings={"max_content_width": 175}) +@click.argument("input", type=click.Path(exists=True)) +@click.option("--output", "-o", type=click.Path(), help="Output file path") +@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)") +@click.option("--force", is_flag=True, help="Skip confirmation") +@click.option("--quiet", is_flag=True, help="Quiet mode") +def clear(input, output, overwrite, force, quiet): + """Clear all annotations. + + \b + Examples: + ria annotate clear signal.sigmf-data + ria annotate clear file.npy --force + """ + try: + recording = load_recording(input) + if not quiet: + click.echo(f"Loaded: {input}") + except Exception as e: + raise click.ClickException(f"Failed to load recording: {e}") + + count_before = len(recording.annotations) + + if count_before == 0: + if not quiet: + click.echo("No annotations to clear") + return + + # Confirm unless --force + if not force and not quiet: + click.echo(f"\nWarning: This will remove all {count_before} annotation(s)") + click.confirm("Continue?", abort=True) + + recording._annotations = [] + + if not quiet: + click.echo(f"\nCleared {count_before} annotation(s)") + + recording._annotations = [] + + try: + save_recording_auto(recording, output_path=input, input_path=input, quiet=quiet, overwrite=True) + if not quiet: + click.echo(" ✓ Saved") + except Exception as e: + raise click.ClickException(f"Failed to save: {e}") + + +# ============================================================================ +# Energy detection subcommand +# ============================================================================ + + +@annotate.command(context_settings={"max_content_width": 200}) +@click.argument("input", type=click.Path(exists=True)) +@click.option("--label", type=str, default="signal", help="Annotation label") +@click.option("--threshold", type=float, default=1.2, help="Threshold multiplier above noise floor") +@click.option("--segments", type=int, default=10, help="Number of segments for noise estimation") +@click.option("--window-size", type=int, default=200, help="Smoothing window size") +@click.option("--min-distance", type=int, default=5000, help="Min distance between detections") +@click.option( + "--freq-method", + type=click.Choice(["nbw", "obw", "full-detected", "full-bandwidth"]), + default="nbw", + help="Frequency bounding method", +) +@click.option("--nfft", type=int, default=None, help="FFT size for frequency calculation") +@click.option("--obw-power", type=float, default=0.99, help="Power percentage for OBW/NBW (0.98-0.9999)") +@click.option( + "--type", + "annotation_type", + type=click.Choice(["standalone", "parallel", "intersection"]), + default="standalone", + help="Annotation type", +) +@click.option("--output", "-o", type=click.Path(), help="Output file path") +@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)") +@click.option("--quiet", is_flag=True, help="Quiet mode") +def energy( + input, + label, + threshold, + segments, + window_size, + min_distance, + freq_method, + nfft, + obw_power, + annotation_type, + output, + overwrite, + quiet, +): + """Auto-detect signals using energy-based method. + + Detects bursts based on energy above noise floor. Best for bursty signals + and intermittent transmissions. + + \b + Frequency Bounding Methods: + nbw - Nominal bandwidth (default, best for real signals) + obw - Occupied bandwidth (more conservative, includes sidelobes) + full-detected - Lowest to highest spectral component + full-bandwidth - Entire Nyquist span + + \b + Examples: + ria annotate energy capture.sigmf-data --label burst + ria annotate energy signal.npy --threshold 1.5 --min-distance 10000 + ria annotate energy signal.sigmf-data --freq-method obw + ria annotate energy signal.sigmf-data --freq-method full-detected + + """ + try: + recording = load_recording(input) + if not quiet: + click.echo(f"Loaded: {input}") + except Exception as e: + raise click.ClickException(f"Failed to load recording: {e}") + + if not quiet: + click.echo("\nDetecting signals using energy-based method...") + click.echo(" Time detection:") + click.echo(f" Segments: {segments}") + click.echo(f" Threshold: {threshold}x noise floor") + click.echo(f" Window size: {window_size} samples") + click.echo(f" Min distance: {min_distance} samples") + click.echo(f" Frequency bounds: {freq_method}") + + try: + initial_count = len(recording.annotations) + recording = detect_signals_energy( + recording, + k=segments, + threshold_factor=threshold, + window_size=window_size, + min_distance=min_distance, + label=label, + annotation_type=annotation_type, + freq_method=freq_method, + nfft=nfft, + obw_power=obw_power, + ) + added = len(recording.annotations) - initial_count + + if not quiet: + click.echo(f" ✓ Added {added} annotation(s)") + + save_recording_auto(recording, output, input, quiet, overwrite) + if not quiet: + click.echo(" ✓ Saved") + except Exception as e: + raise click.ClickException(f"Energy detection failed: {e}") + + +# ============================================================================ +# CUSUM detection subcommand +# ============================================================================ + + +@annotate.command() +@click.argument("input", type=click.Path(exists=True)) +@click.option("--label", type=str, default="segment", help="Annotation label") +@click.option("--min-duration", type=float, default=5.0, help="Min duration in ms (prevents over-segmentation)") +@click.option("--window-size", type=int, default=1, help="Smoothing window size") +@click.option("--tolerance", type=int, default=-1, help="Sample tolerance for merging") +@click.option( + "--type", + "annotation_type", + type=click.Choice(["standalone", "parallel", "intersection"]), + default="standalone", + help="Annotation type", +) +@click.option("--output", "-o", type=click.Path(), help="Output file path") +@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)") +@click.option("--quiet", is_flag=True, help="Quiet mode") +def cusum(input, label, min_duration, window_size, tolerance, annotation_type, output, overwrite, quiet): + """Auto-detect segments using CUSUM method. + + Detects signal state changes (on/off, amplitude transitions). Best for + segmenting continuous signals. + + IMPORTANT: Always specify --min-duration to prevent excessive segmentation. + + \b + Examples: + ria annotate cusum signal.sigmf-data --min-duration 5.0 + ria annotate cusum data.npy --min-duration 10.0 --label state + """ + try: + recording = load_recording(input) + if not quiet: + click.echo(f"Loaded: {input}") + except Exception as e: + raise click.ClickException(f"Failed to load recording: {e}") + + if not quiet: + click.echo("\nDetecting segments using CUSUM...") + click.echo(f" Min duration: {min_duration} ms") + if window_size != 1: + click.echo(f" Window size: {window_size} samples") + + try: + initial_count = len(recording.annotations) + recording = annotate_with_cusum( + recording, + label=label, + window_size=window_size, + min_duration=min_duration, + tolerance=tolerance, + annotation_type=annotation_type, + ) + added = len(recording.annotations) - initial_count + + if not quiet: + click.echo(f" ✓ Added {added} annotation(s)") + + save_recording_auto(recording, output, input, quiet, overwrite) + if not quiet: + click.echo(" ✓ Saved") + except Exception as e: + raise click.ClickException(f"CUSUM detection failed: {e}") + + +# ============================================================================ +# Threshold detection subcommand +# ============================================================================ + + +@annotate.command() +@click.argument("input", type=click.Path(exists=True)) +@click.option("--threshold", type=float, required=True, help="Threshold (0.0-1.0, fraction of max magnitude)") +@click.option("--label", type=str, default=None, help="Annotation label") +@click.option( + "--window-size", + type=int, + default=None, + help="Smoothing window size in samples (default: 1ms at recording sample rate)", +) +@click.option( + "--type", + "annotation_type", + type=click.Choice(["standalone", "parallel", "intersection"]), + default="standalone", + help="Annotation type", +) +@click.option("--channel", type=int, default=0, help="Channel index to annotate (default: 0)") +@click.option("--output", "-o", type=click.Path(), help="Output file path") +@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)") +@click.option("--quiet", is_flag=True, help="Quiet mode") +def threshold(input, threshold, label, window_size, annotation_type, channel, output, overwrite, quiet): + """Auto-detect signals using threshold method. + + Detects samples above a percentage of maximum magnitude. Best for simple + power-based detection. + + \b + Examples: + ria annotate threshold signal.sigmf-data --threshold 0.7 --label wifi + ria annotate threshold data.npy --threshold 0.5 --window-size 2048 + """ + if not (0.0 <= threshold <= 1.0): + raise click.ClickException(f"--threshold must be between 0.0 and 1.0, got {threshold}") + + try: + recording = load_recording(input) + if not quiet: + click.echo(f"Loaded: {input}") + except Exception as e: + raise click.ClickException(f"Failed to load recording: {e}") + + if not quiet: + click.echo("\nDetecting signals using threshold qualifier...") + click.echo(f" Threshold: {threshold * 100:.1f}% of max magnitude") + click.echo(f" Window size: {'auto (1ms)' if window_size is None else f'{window_size} samples'}") + click.echo(f" Channel: {channel}") + + try: + initial_count = len(recording.annotations) + recording = threshold_qualifier( + recording, + threshold=threshold, + window_size=window_size, + label=label, + annotation_type=annotation_type, + channel=channel, + ) + added = len(recording.annotations) - initial_count + + if not quiet: + click.echo(f" ✓ Added {added} annotation(s)") + + save_recording_auto(recording, output, input, quiet, overwrite) + if not quiet: + click.echo(" ✓ Saved") + except Exception as e: + raise click.ClickException(f"Threshold detection failed: {e}") + + +# ============================================================================ +# Separate subcommand (Phase 2: Parallel signal separation) +# ============================================================================ + + +@annotate.command() +@click.argument("input", type=click.Path(exists=True)) +@click.option("--indices", type=str, help="Comma-separated annotation indices to split (default: all)") +@click.option("--nfft", type=int, default=65536, help="FFT size for spectral analysis") +@click.option("--noise-threshold-db", type=float, help="Noise floor threshold in dB (auto-estimated if not specified)") +@click.option("--min-component-bw", type=float, default=50e3, help="Min component bandwidth in Hz") +@click.option("--output", "-o", type=click.Path(), help="Output file path") +@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)") +@click.option("--quiet", is_flag=True, help="Quiet mode") +@click.option("--verbose", is_flag=True, help="Verbose output (show detected components)") +def separate(input, indices, nfft, noise_threshold_db, min_component_bw, output, overwrite, quiet, verbose): + """ + Auto-detect parallel frequency-offset signals and split into sub-bands. + + Provides methods to detect and separate overlapping frequency-domain signals + that occupy the same time window but different frequency bands. + + Detects multiple frequency components within single annotations and splits + them into separate annotations. Uses spectral peak detection with dual + bandwidth estimation. + + \b + Key Features: + - Spectral peak detection for frequency components + - Auto noise floor estimation (or user-specified) + - Dual bandwidth estimation: -3dB primary, cumulative power fallback + - Handles narrowband and wide signals (OFDM) + + \b + Examples: + ria annotate separate capture.sigmf-data + ria annotate separate signal.npy --indices 0,1,2 + ria annotate separate data.sigmf-data --noise-threshold-db -70 + ria annotate separate signal.npy --min-component-bw 100000 + + """ + try: + recording = load_recording(input) + if not quiet: + click.echo(f"Loaded: {input}") + except Exception as e: + raise click.ClickException(f"Failed to load recording: {e}") + + # Parse indices if specified + indices_list = get_indices_list(indices=indices, recording=recording) + + if len(recording.annotations) == 0: + if not quiet: + click.echo("No annotations to split") + return + + if not quiet: + click.echo("\nSplitting annotations by frequency components...") + click.echo(f" Input annotations: {len(recording.annotations)}") + if indices_list: + click.echo(f" Splitting indices: {indices_list}") + click.echo(f" FFT size: {nfft}") + if noise_threshold_db is not None: + click.echo(f" Noise threshold: {noise_threshold_db} dB") + else: + click.echo(" Noise threshold: auto-estimated") + click.echo(f" Min component BW: {format_frequency(min_component_bw)}") + + try: + initial_count = len(recording.annotations) + + recording = split_recording_annotations( + recording, + indices=indices_list, + nfft=nfft, + noise_threshold_db=noise_threshold_db, + min_component_bw=min_component_bw, + ) + + final_count = len(recording.annotations) + added = final_count - initial_count + + if not quiet: + click.echo(f" ✓ Output annotations: {final_count} ({'+' if added >= 0 else ''}{added} change)") + if verbose and added > 0: + click.echo("\n Details:") + for i in range(initial_count, final_count): + ann = recording.annotations[i] + freq_range = f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}" + click.echo( + f" [{i}] samples {format_sample_count(ann.sample_start)}-" + f"{format_sample_count(ann.sample_start + ann.sample_count)}: {freq_range}" + ) + + save_recording_auto(recording, output, input, quiet, overwrite) + if not quiet: + click.echo(" ✓ Saved") + except Exception as e: + raise click.ClickException(f"Spectral separation failed: {e}") diff --git a/src/ria_toolkit_oss_cli/ria_toolkit_oss/commands.py b/src/ria_toolkit_oss_cli/ria_toolkit_oss/commands.py index 6ad4af0..174a5f4 100644 --- a/src/ria_toolkit_oss_cli/ria_toolkit_oss/commands.py +++ b/src/ria_toolkit_oss_cli/ria_toolkit_oss/commands.py @@ -3,6 +3,7 @@ This module contains all the CLI bindings for the ria package. """ +from .annotate import annotate from .campaign import campaign from .capture import capture from .combine import combine diff --git a/src/ria_toolkit_oss_cli/ria_toolkit_oss/generate.py b/src/ria_toolkit_oss_cli/ria_toolkit_oss/generate.py index fb6d92c..370d27a 100644 --- a/src/ria_toolkit_oss_cli/ria_toolkit_oss/generate.py +++ b/src/ria_toolkit_oss_cli/ria_toolkit_oss/generate.py @@ -232,8 +232,8 @@ def generate(): \b Examples: - utils synth chirp -b 1e6 -p 0.01 -s 10e6 -o chirp_basic.sigmf - utils synth fsk -M 2 -r 100e3 -s 2e6 -o fsk2_basic.sigmf + ria synth chirp -b 1e6 -p 0.01 -s 10e6 -o chirp_basic.sigmf + ria synth fsk -M 2 -r 100e3 -s 2e6 -o fsk2_basic.sigmf """ pass diff --git a/src/ria_toolkit_oss_cli/ria_toolkit_oss/transform.py b/src/ria_toolkit_oss_cli/ria_toolkit_oss/transform.py index 9d2bade..4d131c1 100644 --- a/src/ria_toolkit_oss_cli/ria_toolkit_oss/transform.py +++ b/src/ria_toolkit_oss_cli/ria_toolkit_oss/transform.py @@ -270,13 +270,13 @@ def transform(): Examples:\n \b # List available augmentations - utils transform augment --list + ria transform augment --list \b # Apply channel swap - utils transform augment channel_swap input.npy + ria transform augment channel_swap input.npy \b # Apply AWGN impairment - utils transform impair awgn input.npy --snr-db 15 + ria transform impair awgn input.npy --snr-db 15 """ pass diff --git a/src/ria_toolkit_oss_cli/ria_toolkit_oss/view.py b/src/ria_toolkit_oss_cli/ria_toolkit_oss/view.py index 8e0b51f..cac8ceb 100644 --- a/src/ria_toolkit_oss_cli/ria_toolkit_oss/view.py +++ b/src/ria_toolkit_oss_cli/ria_toolkit_oss/view.py @@ -7,7 +7,7 @@ from typing import Optional import click from ria_toolkit_oss.io.recording import from_npy, load_recording -from ria_toolkit_oss.view.view_signal import view_channels, view_sig +from ria_toolkit_oss.view.view_signal import view_annotations, view_channels, view_sig from ria_toolkit_oss.view.view_signal_simple import view_simple_sig from .common import echo_progress, echo_verbose, load_yaml_config @@ -34,6 +34,11 @@ VISUALIZATION_TYPES = { "spines", ], }, + "annotations": { + "function": view_annotations, + "description": "Annotation-focused spectrogram view", + "options": ["channel", "dark"], + }, "channels": {"function": view_channels, "description": "Multi-channel IQ and spectrogram view", "options": []}, } @@ -194,7 +199,7 @@ def print_metadata(recording, quiet): @click.option( "--type", "viz_type", - type=click.Choice(list(VISUALIZATION_TYPES.keys())), + type=click.Choice(list(VISUALIZATION_TYPES.keys()) + ["annotate", "annotation"]), default="simple", show_default=True, help="Visualization type", @@ -238,7 +243,7 @@ def print_metadata(recording, quiet): @click.option("--verbose", "-v", is_flag=True, help="Verbose output") @click.option("--quiet", "-q", is_flag=True, help="Suppress output") @click.option("--overwrite", is_flag=True, help="Overwrite existing output file") -def view( +def view( # noqa: C901 input, viz_type, output, @@ -297,6 +302,9 @@ def view( # Legacy NPY file ria view old_capture.npy --legacy --type simple """ + if viz_type in ["annotate", "annotation"]: + viz_type = "annotations" + # Load config file if specified if config: _ = load_yaml_config(config) diff --git a/tests/ria_toolkit_oss_cli/README.md b/tests/ria_toolkit_oss_cli/README.md index 1c4cc8e..06a5258 100644 --- a/tests/ria_toolkit_oss_cli/README.md +++ b/tests/ria_toolkit_oss_cli/README.md @@ -1,6 +1,6 @@ # CLI Tests -Comprehensive test suite for the utils CLI commands. +Comprehensive test suite for the ria CLI commands. ## Test Structure @@ -13,25 +13,25 @@ Comprehensive test suite for the utils CLI commands. ### Run all CLI tests: ```bash -poetry run pytest tests/utils_cli/ -v +poetry run pytest tests/ria_toolkit_oss_cli/ -v ``` ### Run specific test file: ```bash -poetry run pytest tests/utils_cli/test_common.py -v -poetry run pytest tests/utils_cli/test_discover.py -v -poetry run pytest tests/utils_cli/test_capture.py -v +poetry run pytest tests/ria_toolkit_oss_cli/test_common.py -v +poetry run pytest tests/ria_toolkit_oss_cli/test_discover.py -v +poetry run pytest tests/ria_toolkit_oss_cli/test_capture.py -v ``` ### Run specific test class or function: ```bash -poetry run pytest tests/utils_cli/test_capture.py::TestCaptureCommand::test_capture_basic -v -poetry run pytest tests/utils_cli/test_common.py::test_parse_frequency -v +poetry run pytest tests/ria_toolkit_oss_cli/test_capture.py::TestCaptureCommand::test_capture_basic -v +poetry run pytest tests/ria_toolkit_oss_cli/test_common.py::test_parse_frequency -v ``` ### Run with coverage: ```bash -poetry run pytest tests/utils_cli/ --cov=utils_cli --cov-report=html +poetry run pytest tests/ria_toolkit_oss_cli/ --cov=utils_cli --cov-report=html ``` ## Test Coverage diff --git a/tests/ria_toolkit_oss_cli/__init__.py b/tests/ria_toolkit_oss_cli/__init__.py index 77c8a64..26d94ee 100644 --- a/tests/ria_toolkit_oss_cli/__init__.py +++ b/tests/ria_toolkit_oss_cli/__init__.py @@ -1 +1 @@ -"""Tests for utils CLI commands.""" +"""Tests for ria CLI commands."""