diff --git a/src/ria_toolkit_oss/annotations/annotation_transforms.py b/src/ria_toolkit_oss/annotations/annotation_transforms.py new file mode 100644 index 0000000..af48465 --- /dev/null +++ b/src/ria_toolkit_oss/annotations/annotation_transforms.py @@ -0,0 +1,55 @@ +from utils.data.annotation import Annotation + +# TODO figure out how to transfer labels in the merge case + + +def remove_contained_boxes(annotations: list[Annotation]): + """ + Remove all annotations (bounding boxes) that are entirely contained within other boxes in the list. + + :param annotations: A list of Annotation objects. + :type annotations: list[Annotation] + + :returns: A new list of Annotation objects. + :rtype: list[Annotation]""" + + output_boxes = [] + + for i in range(len(annotations)): + contained = False + for j in range(len(annotations)): + if i != j and is_annotation_contained(annotations[i], annotations[j]): + contained = True + break + + if not contained: + output_boxes.append(annotations[i]) + + return output_boxes + + +def is_annotation_contained(inner: Annotation, outer: Annotation) -> bool: + """ + Check if an annotation box is entirely contained within another annotation bounding box. + + :param inner: The inner box. + :type inner: Annotation. + :param outer: The outer box. + :type outer: Annotation. + + :returns: True if inner is within outer, false otherwise. + :rtype: bool + """ + + inner_sample_stop = inner.sample_start + inner.sample_count + outer_sample_stop = outer.sample_start + outer.sample_count + + if inner.sample_start > outer.sample_start and inner_sample_stop < outer_sample_stop: + if inner.freq_lower_edge > outer.freq_lower_edge and inner.freq_upper_edge < outer.freq_upper_edge: + return True + + return False + + +def merge_annotations(annotations: list[Annotation], overlap_threshold) -> list[Annotation]: + raise NotImplementedError diff --git a/src/ria_toolkit_oss/annotations/cusum_annotator.py b/src/ria_toolkit_oss/annotations/cusum_annotator.py new file mode 100644 index 0000000..a32162b --- /dev/null +++ b/src/ria_toolkit_oss/annotations/cusum_annotator.py @@ -0,0 +1,203 @@ +import json +from typing import Optional + +import numpy as np + +from utils.data import Annotation, Recording + + +def annotate_with_cusum( + recording: Recording, + label: Optional[str] = "segment", + window_size: Optional[int] = 1, + min_duration: Optional[float] = None, + tolerance: Optional[int] = None, + annotation_type: Optional[str] = "standalone", +): + """ + Add annotations that divide the recording into distinct time segments. + + This algorithm computes the cumulative sum of the sample magnitudes and + determines break points in the signal. + + This tool can be used to find points where a signal turns on or off, or + changes between a low and high amplitude. + + :param recording: A ``Recording`` object to annotate. + :type recording: ``utils.data.Recording`` + :param label: Label for the detected segments. + :type label: str + :param window_size: The length (in samples) of the moving average window. + :type window_size: int + :param min_duration: The minimum duration (in ms) of a segment. + The algorithm will not produce annotations shorter than this length. + :type min_duration: float + :param tolerance: The minimum length (in samples) of a segment. + :type tolerance: int + :param annotation_type: Annotation type (standalone, parallel, intersection). + :type annotation_type: str + """ + + sample_rate = recording.metadata["sample_rate"] + center_frequency = recording.metadata.get("center_frequency", 0) + + # Create an object of the time segmenter + time_segmenter = TimeSegmenter(sample_rate, min_duration, window_size, tolerance) + + change_points = time_segmenter.apply(recording.data[0]) + + time_segments_indices = np.append(np.insert(change_points, 0, 0), len(recording.data[0])) + annotations = [] + for i in range(len(time_segments_indices) - 1): + # Build comment JSON with type metadata + comment_data = { + "type": annotation_type, + "generator": "cusum_annotator", + "params": { + "window_size": window_size, + "min_duration": min_duration, + "tolerance": tolerance, + }, + } + f_min, f_max = detect_frequency( + signal=recording.data[0], + start=time_segments_indices[i], + stop=time_segments_indices[i + 1], + sample_rate=sample_rate, + ) + + annotations.append( + Annotation( + sample_start=time_segments_indices[i], + sample_count=time_segments_indices[i + 1] - time_segments_indices[i], + freq_lower_edge=center_frequency + f_min, + freq_upper_edge=center_frequency + f_max, + label=label, + comment=json.dumps(comment_data), + detail={"generator": "cusum_annotator"}, + ) + ) + + return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations) + + +def _compute_cusum(_signal, sample_rate: int, tolerance: int = None, min_duration: float = -1): + """ + This function efficiently computes the cumulative sum of a give list (_signal), with an optional tolerance. + + Args: + - _signal: array of iq samples. + - Tolerance: the least acceptable length of a block, Defaults to None. + + Returns: + - cusum (array): Array of the cumulative sum of the given list + - sample_rate (int): __description_ + - change_points (array): Array of the indices at which a change in the CUSUM direction happens. + - min_duration (float): The least acceptable time width of each segment (in ms). Defaults to -1. + """ + + # efficiently calculate the running sum of the signal + # cusum = list(itertools.accumulate((_signal - np.mean(_signal)))) + x = _signal - np.mean(_signal) + cusum = np.cumsum(x) + + # 'diff' computes the differences between the consecutive values, + # then 'sign' determines if it is +ve or -ve. + change_indicators = np.sign(np.diff(cusum)) + change_points = np.where(np.diff(change_indicators))[0] + 1 + + # Limit the change_points + # Reject those whose number of samples < minimum accepted #n of samples in (min duration) ms. + if min_duration is not None and min_duration > 0: + min_samples_wide = int(min_duration * sample_rate / 1000) + segments_lengths = np.diff(change_points) + segments_lengths = np.insert(segments_lengths, 0, change_points[0]) + change_points = change_points[np.where(segments_lengths > min_samples_wide)[0]] + return cusum, change_points + + +def detect_frequency(signal, start, stop, sample_rate): + signal_segment = signal[start:stop] + if len(signal_segment) > 0: + fft_data = np.abs(np.fft.fftshift(np.fft.fft(signal_segment))) + fft_freqs = np.fft.fftshift(np.fft.fftfreq(len(signal_segment), 1 / sample_rate)) + + # Use a spectral threshold to find the 'height' of the orange block + spectral_thresh = np.max(fft_data) * 0.15 + sig_indices = np.where(fft_data > spectral_thresh)[0] + + if len(sig_indices) > 4: + return fft_freqs[sig_indices[0]], fft_freqs[sig_indices[-1]] + else: + return -sample_rate / 4, sample_rate / 4 + else: + return -sample_rate / 4, sample_rate / 4 + + +class TimeSegmenter: + """Time Segmenter class, it creates a segmenter object with certain\ + characteristics to easily split an input signal to segments based on\ + the cumulative sum of deviations (of the signal mean) + """ + + def __init__( + self, sample_rate: int, min_duration: float = 1, moving_average_window: int = 3, tolerance: int = None + ): + """_summary_ + + Args: + sample_rate (int): _description_ + min_duration (float, optional): _description_. Defaults to 1. + moving_average_window (int, optional): _description_. Defaults to 3. + tolerance (int, optional): _description_. Defaults to None. + """ + self.sample_rate = sample_rate + self.min_duration = min_duration + self.moving_average_window = moving_average_window + self._moving_avg_filter = self._init_filter() + self.tolerance = tolerance + + def _init_filter(self): + """_summary_ + + Returns: + _type_: _description_ + """ + return np.ones(self.moving_average_window) / self.moving_average_window + + def _apply_filter(self, iqsignal: np.array): + """_summary_ + + Args: + iqsignal (np.array): _description_ + + Returns: + _type_: _description_ + """ + return np.convolve(abs(iqsignal), self._moving_avg_filter, mode="same") + + def _create_segments(self, iq_signal: np.array, change_points: np.array): + """_summary_ + + Args: + iq_signal (np.array): _description_ + change_points (np.array): _description_ + + Returns: + _type_: _description_ + """ + return np.split(iq_signal, change_points) + + def apply(self, iq_signal: np.array): + """_summary_ + + Args: + iq_signal (np.array): _description_ + + Returns: + _type_: _description_ + """ + smoothed_signal = self._apply_filter(iq_signal) + _, change_points = _compute_cusum(smoothed_signal, self.sample_rate, self.tolerance, self.min_duration) + # segments = self._create_segments(iq_signal, change_points) + return change_points diff --git a/src/ria_toolkit_oss/annotations/energy_detector.py b/src/ria_toolkit_oss/annotations/energy_detector.py new file mode 100644 index 0000000..6cc2466 --- /dev/null +++ b/src/ria_toolkit_oss/annotations/energy_detector.py @@ -0,0 +1,438 @@ +""" +Energy-based signal detection and bandwidth analysis. + +Provides automatic annotation generation using energy-based signal detection +and occupied bandwidth calculation following ITU-R SM.328 standard. +""" + +import json +from typing import Tuple + +import numpy as np +from scipy.signal import filtfilt + +from utils.data import Annotation, Recording + + +def detect_signals_energy( + recording: Recording, + k: int = 10, + threshold_factor: float = 1.2, + window_size: int = 200, + min_distance: int = 5000, + label: str = "signal", + annotation_type: str = "standalone", + freq_method: str = "nbw", + nfft: int = None, + obw_power: float = 0.99, +) -> Recording: + """ + Detect signal bursts using energy-based method with adaptive noise floor estimation. + + This algorithm smooths the signal with a moving average filter, estimates the noise + floor from k segments, applies a threshold to detect regions above noise, and merges + nearby detections. Detected time boundaries are then assigned frequency bounds based + on the selected frequency method. + + Time Detection Algorithm: + 1. Smooth signal using moving average (envelope detection) + 2. Divide smoothed signal into k segments + 3. Estimate noise floor as median of segment mean powers + 4. Detect regions where power exceeds threshold_factor * noise_floor + 5. Merge regions closer than min_distance samples + + Frequency Bounding (freq_method): + - 'nbw': Nominal bandwidth (OBW + center frequency) - DEFAULT + - 'obw': Occupied bandwidth (99.99% power, includes siedelobes) + - 'full-detected': Lowest to highest spectral component + - 'full-bandwidth': Entire Nyquist span (center_freq ± sample_rate/2) + + :param recording: Recording to analyze + :type recording: Recording + :param k: Number of segments for noise floor estimation (default: 10) + :type k: int + :param threshold_factor: Threshold multiplier above noise floor (typical: 1.2-2.0, default: 1.2) + :type threshold_factor: float + :param window_size: Moving average window size in samples (default: 200) + :type window_size: int + :param min_distance: Minimum distance between separate signals in samples (default: 5000) + :type min_distance: int + :param label: Label for detected annotations (default: "signal") + :type label: str + :param annotation_type: Annotation type (standalone, parallel, intersection, default: standalone) + :type annotation_type: str + :param freq_method: How to calculate frequency bounds (default: 'nbw') + :type freq_method: str + :param nfft: FFT size for frequency calculations (default: None) + :type nfft: int + :param obw_power: Power percentage for OBW (0.9999 = 99.99%, default: 0.99) + :type obw_power: float + + :returns: New Recording with added annotations + :rtype: Recording + + **Example**:: + + >>> from utils.io import load_recording + >>> from utils.annotations import detect_signals_energy + >>> recording = load_recording("capture.sigmf") + + >>> # Detect with NBW frequency bounds (default, best for real signals) + >>> annotated = detect_signals_energy(recording, label="burst") + + >>> # Detect with OBW (more conservative, includes siedelobes) + >>> annotated = detect_signals_energy( + ... recording, label="burst", freq_method="obw" + ... ) + + >>> # Detect with full detected range (captures all spectral components) + >>> annotated = detect_signals_energy( + ... recording, label="burst", freq_method="full-detected" + ... ) + """ + # Extract signal data (use first channel only) + signal = recording.data[0] + + # Calculate smoothed signal power + kernel = np.ones(window_size) / window_size + smoothed_power = filtfilt(kernel, [1], np.abs(signal) ** 2) + + # Estimate noise floor using segment-based median (robust to signal presence) + segments = np.array_split(smoothed_power, k) + noise_floor = np.median([np.mean(s) for s in segments]) + + # Detect signal boundaries (regions above threshold) + enter = noise_floor * threshold_factor + exit = enter * 0.8 + boundaries = [] + start = None + active = False + + for i, p in enumerate(smoothed_power): + if not active and p > enter: + start = i + active = True + elif active and p < exit: + boundaries.append((start, i - start)) + active = False + + if active: + boundaries.append((start, len(smoothed_power) - start)) + + # Merge boundaries that are closer than min_distance + merged_boundaries = [] + if boundaries: + start, length = boundaries[0] + for next_start, next_length in boundaries[1:]: + if next_start - (start + length) < min_distance: + # Merge with current boundary + length = next_start + next_length - start + else: + # Save current and start new boundary + merged_boundaries.append((start, length)) + start, length = next_start, next_length + # Add final boundary + merged_boundaries.append((start, length)) + + # Create annotations from detected boundaries + sample_rate = recording.metadata["sample_rate"] + center_frequency = recording.metadata.get("center_frequency", 0) + + # Validate frequency method + valid_freq_methods = ["nbw", "obw", "full-detected", "full-bandwidth"] + if freq_method not in valid_freq_methods: + raise ValueError(f"Invalid freq_method '{freq_method}'. " f"Must be one of: {', '.join(valid_freq_methods)}") + + annotations = [] + for start_sample, sample_count in merged_boundaries: + # Calculate frequency bounds based on method + freq_lower, freq_upper = calculate_frequency_bounds( + freq_method, center_frequency, sample_rate, nfft, signal, start_sample, sample_count, obw_power + ) + # Build comment JSON with type metadata + comment_data = { + "type": annotation_type, + "generator": "energy_detector", + "freq_method": freq_method, + "params": { + "threshold_factor": threshold_factor, + "window_size": window_size, + "noise_floor": float(noise_floor), + "threshold": float(enter), + }, + } + + anno = Annotation( + sample_start=start_sample, + sample_count=sample_count, + freq_lower_edge=freq_lower, + freq_upper_edge=freq_upper, + label=label, + comment=json.dumps(comment_data), + detail={"generator": "energy_detector", "freq_method": freq_method}, + ) + annotations.append(anno) + + return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations) + + +def calculate_occupied_bandwidth( + signal: np.ndarray, + sampling_rate: float, + nfft: int = None, + power_percentage: float = 0.99, +): + if nfft is None: + nfft = max(65536, 2 ** int(np.floor(np.log2(len(signal))))) + + window = np.blackman(len(signal)) + spec = np.fft.fftshift(np.fft.fft(signal * window, n=nfft)) + + psd = np.abs(spec) ** 2 + psd = psd / psd.sum() # normalize + + freqs = np.fft.fftshift(np.fft.fftfreq(nfft, 1 / sampling_rate)) + + cdf = np.cumsum(psd) + + tail = (1 - power_percentage) / 2 + + lower_idx = np.searchsorted(cdf, tail) + upper_idx = np.searchsorted(cdf, 1 - tail) + + return freqs[upper_idx] - freqs[lower_idx], freqs[lower_idx], freqs[upper_idx] + + +def calculate_nominal_bandwidth( + signal: np.ndarray, + sampling_rate: float, + nfft: int = None, + power_percentage: float = 0.99, +) -> Tuple[float, float]: + """ + Calculate nominal bandwidth and center frequency. + + Nominal bandwidth (NBW) is the occupied bandwidth along with the center + frequency of the signal's spectral occupancy. Useful for characterizing + signals with unknown or drifting center frequencies. + + :param signal: Complex IQ signal samples + :type signal: np.ndarray + :param sampling_rate: Sample rate in Hz + :type sampling_rate: float + :param nfft: FFT size + :type nfft: int + :param power_percentage: Fraction of power to contain + :type power_percentage: float + + :returns: Tuple of (nominal_bandwidth_hz, center_frequency_hz) + :rtype: Tuple[float, float] + + **Example**:: + + >>> from utils.annotations import calculate_nominal_bandwidth + >>> nbw, center = calculate_nominal_bandwidth(signal, sampling_rate=10e6) + >>> print(f"NBW: {nbw/1e6:.3f} MHz, Center: {center/1e6:.3f} MHz") + """ + bw, lower_freq, upper_freq = calculate_occupied_bandwidth(signal, sampling_rate, nfft, power_percentage) + + # Center frequency is midpoint of occupied band + center_freq = (lower_freq + upper_freq) / 2 + + return lower_freq, upper_freq, center_freq + + +def calculate_full_detected_bandwidth( + signal: np.ndarray, + sampling_rate: float, + nfft: int = None, + start_offset: int = 1000, +) -> Tuple[float, float, float]: + """ + Calculate frequency range from lowest to highest spectral component. + + Unlike OBW/NBW which define a power-based bandwidth, this calculates + the absolute frequency span from the lowest non-zero spectral component + to the highest non-zero component. + + Useful for: + - Signals with spectral gaps + - Multiple parallel signals (captures all of them) + - Understanding total occupied spectrum vs. actual bandwidth + + :param signal: Complex IQ signal samples + :type signal: np.ndarray + :param sampling_rate: Sample rate in Hz + :type sampling_rate: float + :param nfft: FFT size + :type nfft: int + :param start_offset: Skip samples at start + :type start_offset: int + + :returns: Tuple of (bandwidth_hz, lower_freq_hz, upper_freq_hz) + :rtype: Tuple[float, float, float] + + **Example**:: + + >>> # Signal with two components at different frequencies + >>> bw, f_low, f_high = calculate_full_detected_bandwidth( + ... signal, sampling_rate=10e6, nfft=65536 + ... ) + >>> print(f"Full span: {f_low/1e6:.3f} to {f_high/1e6:.3f} MHz") + """ + # Validate input + if len(signal) < nfft + start_offset: + raise ValueError( + f"Signal too short: need {nfft + start_offset} samples, " + f"got {len(signal)}. Reduce nfft or start_offset." + ) + + # Extract segment + signal_segment = signal[start_offset : nfft + start_offset] + + # Compute FFT and power spectral density + freq_spectrum = np.fft.fft(signal_segment, n=nfft) + psd = np.abs(freq_spectrum) ** 2 + + # Shift to center DC + psd_shifted = np.fft.fftshift(psd) + freq_bins = np.fft.fftshift(np.fft.fftfreq(nfft, 1 / sampling_rate)) + + # Find noise floor (mean of lowest 10% of bins) and all bins above noise floor + noise_floor = np.mean(np.sort(psd_shifted)[: int(len(psd_shifted) * 0.1)]) + above_noise = np.where(psd_shifted > noise_floor * 1.5)[0] + + if len(above_noise) == 0: + # No signal above noise, return zero bandwidth + return 0.0, 0.0, 0.0 + + # Get frequency range of signal components + lower_idx = above_noise[0] + upper_idx = above_noise[-1] + + lower_freq = freq_bins[lower_idx] + upper_freq = freq_bins[upper_idx] + + bandwidth = upper_freq - lower_freq + + return bandwidth, lower_freq, upper_freq + + +def annotate_with_obw( + recording: Recording, + label: str = "signal", + annotation_type: str = "standalone", + nfft: int = None, + power_percentage: float = 0.99, +) -> Recording: + """ + Create a single annotation spanning the occupied bandwidth of the entire recording. + + Analyzes the full recording to find its occupied bandwidth and creates an annotation + covering that frequency range for the entire time duration. + + :param recording: Recording to analyze + :type recording: Recording + :param label: Annotation label + :type label: str + :param annotation_type: Annotation type + :type annotation_type: str + :param nfft: FFT size + :type nfft: int + :param power_percentage: Power percentage for OBW calculation + :type power_percentage: float + + :returns: Recording with OBW annotation added + :rtype: Recording + + **Example**:: + + >>> from utils.annotations import annotate_with_obw + >>> annotated = annotate_with_obw(recording, label="signal_obw") + """ + signal = recording.data[0] + sample_rate = recording.metadata["sample_rate"] + center_freq = recording.metadata.get("center_frequency", 0) + + # Calculate OBW + obw, lower_offset, upper_offset = calculate_occupied_bandwidth(signal, sample_rate, nfft, power_percentage) + + # Convert baseband offsets to absolute frequencies + freq_lower = center_freq + lower_offset + freq_upper = center_freq + upper_offset + + # Create comment JSON + comment_data = { + "type": annotation_type, + "generator": "obw_annotator", + "obw_hz": float(obw), + "power_percentage": power_percentage, + "params": {"nfft": nfft}, + } + + # Create annotation spanning entire recording + anno = Annotation( + sample_start=0, + sample_count=len(signal), + freq_lower_edge=freq_lower, + freq_upper_edge=freq_upper, + label=label, + comment=json.dumps(comment_data), + detail={"generator": "obw_annotator", "obw_hz": float(obw)}, + ) + + return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + [anno]) + + +def calculate_frequency_bounds( + freq_method, center_frequency, sample_rate, nfft, signal, start_sample, sample_count, obw_power +): + if freq_method == "full-bandwidth": + # Full Nyquist span + freq_lower = center_frequency - (sample_rate / 2) + freq_upper = center_frequency + (sample_rate / 2) + else: + # Extract segment for frequency analysis + segment_start = start_sample + segment_end = min(start_sample + sample_count, len(signal)) + segment = signal[segment_start:segment_end] + + if nfft is None or len(segment) >= nfft: + if freq_method == "nbw": + # Nominal bandwidth (OBW + center frequency) + try: + lower_freq, upper_freq, _ = calculate_nominal_bandwidth(segment, sample_rate, nfft, obw_power) + freq_lower = center_frequency + lower_freq + freq_upper = center_frequency + upper_freq + except (ValueError, IndexError): + # Fallback if calculation fails + freq_lower = center_frequency - (sample_rate / 2) + freq_upper = center_frequency + (sample_rate / 2) + + elif freq_method == "obw": + # Occupied bandwidth + try: + _, f_lower, f_upper = calculate_occupied_bandwidth(segment, sample_rate, nfft, obw_power) + freq_lower = center_frequency + f_lower + freq_upper = center_frequency + f_upper + except (ValueError, IndexError): + # Fallback if calculation fails + freq_lower = center_frequency - (sample_rate / 2) + freq_upper = center_frequency + (sample_rate / 2) + + elif freq_method == "full-detected": + # Full detected range (lowest to highest component) + try: + _, f_lower, f_upper = calculate_full_detected_bandwidth(segment, sample_rate, nfft) + freq_lower = center_frequency + f_lower + freq_upper = center_frequency + f_upper + except (ValueError, IndexError): + # Fallback if calculation fails + freq_lower = center_frequency - (sample_rate / 2) + freq_upper = center_frequency + (sample_rate / 2) + else: + # Segment too short for FFT, use full bandwidth + freq_lower = center_frequency - (sample_rate / 2) + freq_upper = center_frequency + (sample_rate / 2) + + return freq_lower, freq_upper diff --git a/src/ria_toolkit_oss/annotations/parallel_signal_separator.py b/src/ria_toolkit_oss/annotations/parallel_signal_separator.py new file mode 100644 index 0000000..b75a28f --- /dev/null +++ b/src/ria_toolkit_oss/annotations/parallel_signal_separator.py @@ -0,0 +1,435 @@ +""" +Parallel signal separation for multi-component frequency-offset signals. + +Provides methods to detect and separate overlapping frequency-domain signals +that occupy the same time window but different frequency bands. + +This module implements **spectral peak detection** to identify distinct frequency +components and split single time-domain annotations into frequency-specific +sub-annotations. + +**Key Design Decisions** (per Codex review): + +1. **Complex IQ Support**: Uses `scipy.signal.welch` with `return_onesided=False` + for proper complex signal handling. Window length automatically adapts to + signal length via `nperseg=min(nfft, len(signal))` to handle bursts >> from utils.annotations import find_spectral_components + >>> # Detect the two distinct channels (returns relative frequencies) + >>> components = find_spectral_components(signal, sampling_rate=20e6) + >>> print(f"Found {len(components)} components") + Found 2 components + +The module is designed to work with detected time-domain annotations, +allowing splitting of overlapping signals into separate training samples. +""" + +import json +from typing import List, Optional, Tuple + +import numpy as np +from scipy import ndimage +from scipy import signal as scipy_signal + +from utils.data import Annotation, Recording + + +def find_spectral_components( + signal_data: np.ndarray, + sampling_rate: float, + nfft: int = 65536, + noise_threshold_db: Optional[float] = None, + min_component_bw: float = 50e3, + time_percentile: float = 70.0, +) -> List[Tuple[float, float, float]]: + """ + Find distinct frequency components using spectral peak detection. + + Identifies separate frequency components in a signal by analyzing the power + spectral density and finding peaks corresponding to distinct signals. This is + useful for separating parallel signals that occupy different frequency bands. + + **Frequency Representation**: Returns frequencies in **baseband/relative** Hz + (centered at 0). To get absolute RF frequencies, add center_frequency_hz from + recording metadata to all returned values. + + Algorithm: + 1. Compute power spectral density using Welch (properly handles complex IQ) + 2. Auto-estimate noise floor from data if not specified + 3. Smooth PSD to reduce spurious peaks + 4. Find local maxima above noise floor + 5. Estimate bandwidth per peak using -3dB (fallback: cumulative power) + 6. Filter components below minimum bandwidth threshold + + :param signal_data: Complex IQ signal samples (np.complex64/128) + :type signal_data: np.ndarray + :param sampling_rate: Sample rate in Hz + :type sampling_rate: float + :param nfft: FFT size / window length for Welch. Automatically capped at + signal length to handle bursts (default: 65536) + :type nfft: int + :param noise_threshold_db: Minimum SNR threshold in dB. If None (default), + auto-estimates as np.percentile(psd_db, 10). + Adapt this across hardware (Pluto: ~-100, ThinkRF: ~-60). + :type noise_threshold_db: Optional[float] + :param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz) + :type min_component_bw: float + :param power_threshold: Cumulative power threshold for fallback bandwidth + estimation (default: 0.99 = 99% power, like OBW) + :type power_threshold: float + + :returns: List of (center_freq_hz, lower_freq_hz, upper_freq_hz) tuples. + **All frequencies are relative (baseband, 0-centered).** + Add recording metadata['center_frequency'] to get absolute RF frequencies. + :rtype: List[Tuple[float, float, float]] + + :raises ValueError: If signal has fewer than 256 samples + + **Example**:: + + >>> from utils.io import load_recording + >>> from utils.annotations import find_spectral_components + >>> recording = load_recording("capture.sigmf") + >>> segment = recording.data[0][start:end] + >>> # Components in relative (baseband) frequency + >>> components = find_spectral_components(segment, sampling_rate=20e6) + >>> for center_rel, lower_rel, upper_rel in components: + ... # Convert to absolute RF frequency + ... center_abs = recording.metadata['center_frequency'] + center_rel + ... print(f"Component @ {center_abs/1e9:.3f} GHz") + """ + # Validate input + min_samples = 256 + if len(signal_data) < min_samples: + raise ValueError(f"Signal too short: need at least {min_samples} samples, " f"got {len(signal_data)}.") + + # Compute PSD using Welch method for complex IQ signals + # CRITICAL: return_onesided=False for proper complex signal handling + nperseg = min(nfft, len(signal_data)) + noverlap = nperseg // 2 + + # --- STFT --- + freqs, times, Zxx = scipy_signal.stft( + signal_data, + fs=sampling_rate, + window="blackman", + nperseg=nperseg, + noverlap=noverlap, + return_onesided=False, + boundary=None, + ) + + # Shift zero freq to center + Zxx = np.fft.fftshift(Zxx, axes=0) + freqs = np.fft.fftshift(freqs) + + # Power spectrogram + power = np.abs(Zxx) ** 2 + power_db = 10 * np.log10(power + 1e-12) + + # --- Aggregate across time robustly --- + # Using percentile instead of mean prevents short signals from being diluted + freq_profile_db = np.percentile(power_db, time_percentile, axis=1) + + # --- Noise floor estimation --- + if noise_threshold_db is None: + noise_threshold_db = np.percentile(freq_profile_db, 20) + + threshold = noise_threshold_db + 3 # 3 dB above noise floor + + # --- Smooth lightly (avoid merging nearby signals) --- + freq_profile_db = ndimage.gaussian_filter1d(freq_profile_db, sigma=1.5) + + # --- Binary mask of significant frequencies --- + mask = freq_profile_db > threshold + + # --- Find contiguous frequency regions --- + labeled, num_features = ndimage.label(mask) + + components = [] + + for region_label in range(1, num_features + 1): + region_indices = np.where(labeled == region_label)[0] + + if len(region_indices) == 0: + continue + + lower_idx = region_indices[0] + upper_idx = region_indices[-1] + + lower_freq = freqs[lower_idx] + upper_freq = freqs[upper_idx] + bw = upper_freq - lower_freq + + if bw < min_component_bw: + continue + + center_freq = (lower_freq + upper_freq) / 2 + components.append((center_freq, lower_freq, upper_freq)) + + return components + + +def split_annotation_by_components( + annotation: Annotation, + signal: np.ndarray, + sampling_rate: float, + center_frequency_hz: float = 0.0, + nfft: int = 65536, + noise_threshold_db: Optional[float] = None, + min_component_bw: float = 50e3, +) -> List[Annotation]: + """ + Split an annotation into multiple annotations by detected frequency components. + + Takes an existing annotation spanning multiple frequency components and + analyzes the frequency content to create separate sub-annotations for + each distinct frequency component. + + **Use case**: Energy detection found a time window with 2-3 parallel WiFi + channels. This function splits it into separate annotations per channel. + + **Frequency Handling**: `find_spectral_components` returns relative (baseband) + frequencies. This function adds `center_frequency_hz` to convert to absolute + RF frequencies for SigMF annotation bounds. This ensures correct frequency + context across baseband and RF domains. + + :param annotation: Original annotation to split + :type annotation: Annotation + :param signal: Full signal array (complex IQ) + :type signal: np.ndarray + :param sampling_rate: Sample rate in Hz + :type sampling_rate: float + :param center_frequency_hz: RF center frequency to add to relative frequencies + from peak detection (default: 0.0 = baseband) + :type center_frequency_hz: float + :param nfft: FFT size for analysis (default: 65536, auto-capped at signal length) + :type nfft: int + :param noise_threshold_db: Noise floor threshold in dB. If None (default), + auto-estimates from data. + :type noise_threshold_db: Optional[float] + :param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz) + :type min_component_bw: float + + :returns: List of new annotations (one per detected component). + Returns empty list if no components found or segment too short. + :rtype: List[Annotation] + + **Example**:: + + >>> from utils.io import load_recording + >>> from utils.annotations import split_annotation_by_components + >>> recording = load_recording("capture.sigmf") + >>> # Original annotation spans multiple channels + >>> original = recording.annotations[0] + >>> # Split using RF center frequency from metadata + >>> components = split_annotation_by_components( + ... original, + ... recording.data[0], + ... recording.metadata['sample_rate'], + ... center_frequency_hz=recording.metadata.get('center_frequency', 0.0) + ... ) + >>> print(f"Split into {len(components)} components") + Split into 2 components + + **Algorithm**: + 1. Extract segment corresponding to annotation time bounds + 2. Find frequency components in that segment (returns relative frequencies) + 3. Add center_frequency_hz to get absolute RF frequencies + 4. Create new annotation for each component + 5. Preserve original metadata (label, type, etc.) + 6. Add component info to comment JSON + + **Notes**: + - Original annotation is not modified + - Returns empty list if segment too short (<256 samples) + - Segments Recording: + """ + Split multiple annotations in a recording by frequency components. + + Processes specified annotations (or all if indices=None), replacing each + with its frequency-separated components. Uses RF center_frequency from + recording metadata for proper absolute frequency conversion. + + :param recording: Recording to process + :type recording: Recording + :param indices: Annotation indices to split (None = all, default: None). + Use indices=[] to skip splitting (returns unchanged recording). + :type indices: Optional[List[int]] + :param nfft: FFT size for spectral analysis (default: 65536, + auto-capped at signal segment length) + :type nfft: int + :param noise_threshold_db: Noise floor threshold in dB. If None (default), + auto-estimates from each segment. + :type noise_threshold_db: Optional[float] + :param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz). + Components narrower than this are filtered out. + :type min_component_bw: float + + :returns: New Recording with split annotations + :rtype: Recording + + **Example**:: + + >>> from utils.io import load_recording + >>> from utils.annotations import split_recording_annotations + >>> recording = load_recording("capture.sigmf") + >>> # Split all annotations + >>> split_rec = split_recording_annotations(recording) + >>> print(f"Original: {len(recording.annotations)} annotations") + >>> print(f"Split: {len(split_rec.annotations)} annotations") + Original: 5 annotations + Split: 9 annotations + + **Algorithm**: + 1. For each annotation in indices (or all if None): + 2. Call split_annotation_by_components with RF center_frequency + 3. If components found, replace annotation with components + 4. If no components found, keep original annotation + 5. Annotations not in indices are kept unchanged + + **Notes**: + - Original recording is not modified + - Returns empty Recording.annotations if recording has no annotations + - RF center_frequency from metadata ensures correct absolute frequencies + - If an annotation can't be split (too short, wrong format), original kept + """ + if indices is None: + # Split all annotations + indices = list(range(len(recording.annotations))) + + if not recording.annotations: + # No annotations to split + return recording + + signal = recording.data[0] + sample_rate = recording.metadata["sample_rate"] + center_frequency = recording.metadata.get("center_frequency", 0.0) + + # Build new annotation list + new_annotations = [] + for i, anno in enumerate(recording.annotations): + if i in indices: + # Attempt to split this annotation + try: + components = split_annotation_by_components( + anno, + signal, + sample_rate, + center_frequency_hz=center_frequency, + nfft=nfft, + noise_threshold_db=noise_threshold_db, + min_component_bw=min_component_bw, + ) + if components: + # Split successful, use components + new_annotations.extend(components) + else: + # No components found, keep original + new_annotations.append(anno) + except Exception: + # Split failed for any reason, keep original + new_annotations.append(anno) + else: + # Not in split list, keep as-is + new_annotations.append(anno) + + return Recording(data=recording.data, metadata=recording.metadata, annotations=new_annotations) diff --git a/src/ria_toolkit_oss/annotations/qualify_slice.py b/src/ria_toolkit_oss/annotations/qualify_slice.py new file mode 100644 index 0000000..10ff369 --- /dev/null +++ b/src/ria_toolkit_oss/annotations/qualify_slice.py @@ -0,0 +1,35 @@ +import numpy as np + +from utils.data import Recording + + +def qualify_slice_from_annotations(recording: Recording, slice_length: int): + """ + Slice a recording into many smaller recordings, + discarding any slices which do not have annotations that apply to those samples. + Used together with an annotation based qualifier. + + :param recording: The recording to slice. + :type recording: Recording + :param slice_length: The length in samples of a slice. + :type slice_length: int""" + + if len(recording.annotations) == 0: + print("Warning, no annotations.") + + annotation_mask = np.zeros(len(recording.data[0])) + + for annotation in recording.annotations: + annotation_mask[annotation.sample_start : annotation.sample_start + annotation.sample_count] = 1 + + output_recordings = [] + + for i in range((len(recording.data[0]) // slice_length) - 1): + start_index = slice_length * i + end_index = slice_length * (i + 1) + + if 1 in annotation_mask[start_index:end_index]: + sl = recording.data[:, start_index:end_index] + output_recordings.append(Recording(data=sl, metadata=recording.metadata)) + + return output_recordings diff --git a/src/ria_toolkit_oss/annotations/signal_isolation.py b/src/ria_toolkit_oss/annotations/signal_isolation.py new file mode 100644 index 0000000..8d6c9ac --- /dev/null +++ b/src/ria_toolkit_oss/annotations/signal_isolation.py @@ -0,0 +1,97 @@ +import numpy as np +from scipy.signal import butter, lfilter + +from utils.data.annotation import Annotation +from utils.data.recording import Recording + + +def isolate_signal(recording: Recording, annotation: Annotation) -> Recording: + """ + Slice, filter and frequency shift the input recording according to the bounding box defined by the annotation. + + :param recording: The input Recording to be sliced. + :type recording: Recording + :param annotation: The Annotation object defining the area of the recording to isolate. + :type annotation: Annotation + :param decimate: Decimate the input signal after filtering to reduce the sample rate. + :type decimate: bool + + :returns: The subsection of the original recording defined by the annotation. + :rtype: Recording""" + + sample_start = max(0, annotation.sample_start) + sample_stop = min(len(recording), annotation.sample_start + annotation.sample_count) + + anno_base_center_freq = (annotation.freq_lower_edge + annotation.freq_upper_edge) / 2 - recording.metadata.get( + "center_frequency", 0 + ) + + anno_bw = annotation.freq_upper_edge - annotation.freq_lower_edge + + signal_slice = recording.data[0, sample_start:sample_stop] + + # normalize + signal_slice = signal_slice / np.max(np.abs(signal_slice)) + + isolation_bw = anno_bw + + # frequency shift the center of the box about zero + shifted_signal_slice = frequency_shift_iq_samples( + iq_samples=signal_slice, + sample_rate=recording.metadata["sample_rate"], + shift_frequency=-1 * anno_base_center_freq, + ) + + # filter + if isolation_bw < recording.metadata["sample_rate"] - 1: + filtered_signal = apply_complex_lowpass_filter( + signal=shifted_signal_slice, cutoff_frequency=isolation_bw, sample_rate=recording.metadata["sample_rate"] + ) + + else: + filtered_signal = shifted_signal_slice + + output = Recording(data=[filtered_signal], metadata=recording.metadata) + return output + + +def frequency_shift_iq_samples(iq_samples, sample_rate, shift_frequency): + # Number of samples + num_samples = len(iq_samples) + + # Create a time vector from 0 to the total duration in seconds + time_vector = np.arange(num_samples) / sample_rate + + # Generate the complex exponential for the frequency shift + complex_exponential = np.exp(1j * 2 * np.pi * shift_frequency * time_vector) + + # Apply the frequency shift to the IQ samples + shifted_samples = iq_samples * complex_exponential + + return shifted_samples + + +# Function to apply a lowpass Butterworth filter to a complex signal +def apply_complex_lowpass_filter(signal, cutoff_frequency, sample_rate, order=5): + # Design the lowpass filter + b, a = design_complex_lowpass_filter(cutoff_frequency, sample_rate, order) + + # Apply the lowpass filter + filtered_signal = lfilter(b, a, signal) + return filtered_signal + + +def design_complex_lowpass_filter(cutoff_frequency, sample_rate, order=5): + # Nyquist frequency for complex signals is the sample rate + nyquist = sample_rate + + # Ensure the cutoff frequency is positive and within the Nyquist limit + if cutoff_frequency <= 0 or cutoff_frequency > nyquist: + raise ValueError("Cutoff frequency must be between 0 and the Nyquist frequency.") + + # Normalize the cutoff frequency to the Nyquist frequency + cutoff_normalized = cutoff_frequency / nyquist + + # Create a Butterworth lowpass filter + b, a = butter(order, cutoff_normalized, btype="low") + return b, a diff --git a/src/ria_toolkit_oss/annotations/threshold_qualifier.py b/src/ria_toolkit_oss/annotations/threshold_qualifier.py new file mode 100644 index 0000000..200c9e8 --- /dev/null +++ b/src/ria_toolkit_oss/annotations/threshold_qualifier.py @@ -0,0 +1,212 @@ +""" +Temporal signal detection and boundary refinement via Hysteresis Thresholding. + +Provides methods to detect signal bursts in the time domain by triggering on +smoothed power peaks and expanding boundaries to capture the full energy envelope. + +This module implements a **dual-threshold trigger** to solve the 'chatter' +problem in noisy environments, ensuring that signal annotations encapsulate +the entire rise and fall of a burst rather than just the peak. + +**Key Design Decisions**: + +1. **Hysteresis Logic (Dual-Threshold)**: + - **Trigger**: High threshold (`threshold * max_power`) ensures high confidence + in signal presence. + - **Boundary**: Low threshold (`0.5 * trigger`) allows the annotation to + "crawl" outward, capturing the lower-energy start and end of the burst + often missed by simple single-threshold detectors. + +2. **Temporal Smoothing**: Uses a moving average window (`window_size`) prior + - to thresholding. This prevents high-frequency noise spikes from causing + fragmented annotations and provides a more stable estimate of the + signal's power envelope. + +3. **Spectral Profiling**: Once a temporal segment is isolated, the module + - performs an automated FFT analysis. It identifies the **90% spectral + occupancy** to define the frequency boundaries (`f_min`, `f_max`), + allowing the detector to work on narrowband and wideband signals without + manual frequency tuning. + +4. **Baseband/RF Mapping**: Automatically handles the conversion from + - relative FFT bin frequencies to absolute RF frequencies by referencing + `recording.metadata["center_frequency"]`. + +5. **False Positive Mitigation**: Implements a hard minimum duration check + - (10ms) to ignore transient hardware spikes or noise floor fluctuations + that do not constitute a valid signal burst. + +The module is designed to be the primary "first-pass" detector for pulsed +waveforms (like ADS-B, Lora, or bursty FSK) before passing them to +classification or demodulation stages. +""" + +import json +from typing import Optional + +import numpy as np + +from utils.data import Annotation, Recording + + +def _find_ranges(indices, window_size): + """ + Groups individual indices into continuous temporal ranges. + + Args: + indices: Array of indices where the signal exceeded a threshold. + window_size: Maximum gap allowed between indices to consider them part + of the same range. + + Returns: + A list of (start, stop) tuples representing detected signal segments. + """ + + if len(indices) == 0: + return [] + + ranges = [] + + start = indices[0] + in_range = False + + for i in range(1, len(indices)): + # If the gap between current and previous index is within window_size, + # keep the range alive. + if indices[i] - indices[i - 1] <= window_size: + if not in_range: + # Start a new range + start = indices[i - 1] + in_range = True + else: + # Gap is too large; close the current range if one was active. + if in_range: + ranges.append((start, indices[i - 1])) + in_range = False + + # Ensure the final segment is captured if the loop ends while in_range. + if in_range: + ranges.append((start, indices[-1])) + + return ranges + + +def threshold_qualifier( + recording: Recording, + threshold: float, + window_size: Optional[int] = 1024, + label: Optional[str] = None, + annotation_type: Optional[str] = "standalone", +) -> Recording: + """ + Annotate a recording with bounding boxes for regions above a threshold. + Threshold is defined as a fraction of the maximum sample magnitude. + This algorithm searches for samples above the threshold and combines them into ranges if they + are within window_size of each other. + Detects and annotates signals using energy thresholding and spectral analysis. + + The algorithm follows these steps: + 1. Smooths power data using a moving average. + 2. Identifies 'peak' regions exceeding a high trigger threshold. + 3. Uses hysteresis to expand boundaries until power drops below a lower threshold. + 4. Performs an FFT on each segment to determine frequency occupancy. + + Args: + recording: The Recording object containing IQ or real signal data. + threshold: Sensitivity multiplier (0.0 to 1.0) applied to max power. + window_size: Size of the smoothing filter and max gap for merging hits. + label: Custom string label for annotations. + annotation_type: Metadata string for the 'type' field in the annotation. + + Returns: + A new Recording object populated with detected Annotations. + """ + # Extract signal and metadata + sample_data = recording.data[0] + sample_rate = recording.metadata["sample_rate"] + center_frequency = recording.metadata.get("center_frequency", 0) + + # --- 1. SIGNAL CONDITIONING --- + # Convert to power (Magnitude squared) + power_data = np.abs(sample_data) ** 2 + smoothing_window = np.ones(window_size) / window_size + smoothed_power = np.convolve(power_data, smoothing_window, mode="same") + + # Define thresholds based on the global peak of the smoothed signal + max_power = np.max(smoothed_power) + trigger_val = threshold * max_power # High threshold to trigger detection + boundary_val = (threshold / 2) * max_power # Low threshold to define signal edges + + # --- 2. INITIAL DETECTION --- + # Identify indices that strictly exceed the high trigger + indices = np.where(smoothed_power > trigger_val)[0] + initial_ranges = _find_ranges(indices=indices, window_size=window_size) + + annotations = [] + + threshold_base = min(sample_rate, len(sample_data)) + + for start, stop in initial_ranges: + if (stop - start) < (threshold_base * 0.01): + continue + + # --- 3. HYSTERESIS (Boundary Expansion) --- + # Search backward from 'start' until power drops below the low boundary_val + true_start = start + while true_start > 0 and smoothed_power[true_start] > boundary_val: + true_start -= 1 + + # Search forward from 'stop' until power drops below the low boundary_val + true_stop = stop + while true_stop < len(smoothed_power) - 1 and smoothed_power[true_stop] > boundary_val: + true_stop += 1 + + # --- 4. SPECTRAL ANALYSIS (Frequency Detection) --- + signal_segment = sample_data[true_start:true_stop] + if len(signal_segment) > 0: + fft_data = np.abs(np.fft.fftshift(np.fft.fft(signal_segment))) + fft_freqs = np.fft.fftshift(np.fft.fftfreq(len(signal_segment), 1 / sample_rate)) + + # Determine frequency bounds where spectral energy is > 15% of segment peak + spectral_thresh = np.max(fft_data) * 0.15 + sig_indices = np.where(fft_data > spectral_thresh)[0] + + # Ensure the signal has some spectral width before annotating + if len(sig_indices) < 5: + continue + + if len(sig_indices) > 0: + f_min, f_max = fft_freqs[sig_indices[0]], fft_freqs[sig_indices[-1]] + else: + # Default to middle half of bandwidth if no clear peaks found + f_min, f_max = -sample_rate / 4, sample_rate / 4 + else: + f_min, f_max = -sample_rate / 4, sample_rate / 4 + + # --- 5. ANNOTATION GENERATION --- + if label is None: + label = f"{int(threshold*100)}%" + + # Pack metadata for the UI/Downstream processing + comment_data = { + "type": annotation_type, + "generator": "threshold_qualifier", + "params": { + "threshold": threshold, + "window_size": window_size, + }, + } + + anno = Annotation( + sample_start=true_start, + sample_count=true_stop - true_start, + freq_lower_edge=center_frequency + f_min, + freq_upper_edge=center_frequency + f_max, + label=label, + comment=json.dumps(comment_data), + detail={"generator": "hysteresis_qualifier"}, + ) + annotations.append(anno) + + # Return a new Recording object including the new annotations + return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations) diff --git a/src/ria_toolkit_oss/view/view_signal.py b/src/ria_toolkit_oss/view/view_signal.py index f8d5731..0f2ed33 100644 --- a/src/ria_toolkit_oss/view/view_signal.py +++ b/src/ria_toolkit_oss/view/view_signal.py @@ -6,18 +6,14 @@ from typing import Optional import matplotlib.pyplot as plt import numpy as np from matplotlib import gridspec +from matplotlib.patches import Patch from PIL import Image from scipy.fft import fft, fftshift from scipy.signal import spectrogram from scipy.signal.windows import hann -from ria_toolkit_oss.datatypes.recording import Recording -from ria_toolkit_oss.view.tools import ( - COLORS, - decimate, - extract_metadata_fields, - set_path, -) +from utils.data.recording import Recording +from utils.view.tools import COLORS, decimate, extract_metadata_fields, set_path def get_fft_size(plot_length): @@ -39,6 +35,80 @@ def set_spines(ax, spines): ax.spines["left"].set_visible(False) +def view_annotations( + recording: Recording, + channel: Optional[int] = 0, + output_path: Optional[str] = "images/annotations.png", + title: Optional[str] = "Annotated Spectrogram", + dpi: Optional[int] = 300, + title_fontsize: Optional[int] = 15, + dark: Optional[bool] = True, +) -> None: + # 1. Setup Plotting Environment + plt.close("all") + if dark: + plt.style.use("dark_background") + else: + plt.style.use("default") + + fig, ax = plt.subplots(figsize=(12, 8)) + + complex_signal = recording.data[channel] + sample_rate, center_frequency, _ = extract_metadata_fields(recording.metadata) + annotations = recording.annotations + + # 2. Setup Color Mapping (No more hardcoded yellow fallback!) + # available_colors = [ + # COLORS.get("magenta", "magenta"), + # COLORS.get("accent", "cyan"), + # COLORS.get("light", "white"), + # "lime", + # ] + + palette = ["#FF00FF", "#00FF00", "#00FFFF", "#FFFF00", "#FF8000"] + unique_labels = sorted(list(set(ann.label for ann in annotations if ann.label))) + label_to_color = {label: palette[i % len(palette)] for i, label in enumerate(unique_labels)} + + # 3. Generate Spectrogram + Pxx, freqs, times, im = ax.specgram( + complex_signal, NFFT=256, Fs=sample_rate, Fc=center_frequency, noverlap=128, cmap="twilight" + ) + + # 4. Draw Annotations + for annotation in annotations: + # --- DEFINING VARIABLES FIRST --- + t_start = annotation.sample_start / sample_rate + t_width = annotation.sample_count / sample_rate + f_start = annotation.freq_lower_edge + f_height = annotation.freq_upper_edge - annotation.freq_lower_edge + + # Look up the color for this specific label + ann_color = label_to_color.get(annotation.label, "gray") + + # Draw the Rectangle + rect = plt.Rectangle( + (t_start, f_start), t_width, f_height, linewidth=1.5, edgecolor=ann_color, facecolor="none", alpha=0.8 + ) + ax.add_patch(rect) + + if unique_labels: + legend_elements = [ + Patch(facecolor=label_to_color[label], alpha=0.3, edgecolor=label_to_color[label], label=label) + for label in unique_labels + ] + ax.legend(handles=legend_elements, loc="upper right", framealpha=0.2) + + ax.set_title(title, fontsize=title_fontsize, pad=20) + ax.set_xlabel("Time (s)", fontsize=12) + ax.set_ylabel("Frequency (MHz)", fontsize=12) + ax.grid(alpha=0.1) # Add faint grid + + output_path, _ = set_path(output_path=output_path) + plt.savefig(output_path, dpi=dpi, bbox_inches="tight") + plt.close(fig) + print(f"Professional annotation plot saved to {output_path}") + + def view_channels( recording: Recording, output_path: Optional[str] = "images/signal.png", @@ -209,9 +279,7 @@ def view_sig( ) set_spines(spec_ax, spines) - spec_ax.set_title("Spectrogram", fontsize=subtitle_fontsize) - spec_ax.set_ylabel("Frequency (Hz)") - spec_ax.set_xlabel("Time (s)") + spec_ax.set_title("Spectrogram", loc="center", fontsize=subtitle_fontsize) if iq: iq_ax = plt.subplot(gs[plot_y_indx : plot_y_indx + 2, :]) @@ -295,7 +363,11 @@ def view_sig( set_spines(meta_ax, spines) if logo and os.path.isfile(logo_path): - logo_ax = plt.subplot(gs[plot_y_indx + 2 :, 2]) + # logo_ax = plt.subplot(gs[plot_y_indx:, 2]) + logo_pos = [0.75, 0.05, 0.2, 0.08] + logo_ax = fig.add_axes(logo_pos, anchor="SE", zorder=10) + plot_x_indx = plot_x_indx + 1 + logo_ax.axis("off") try: @@ -314,7 +386,6 @@ def view_sig( hspace=2.5, # Vertical space between subplots ) - # save path handling output_path, _ = set_path(output_path=output_path) plt.savefig(output_path, dpi=dpi) print(f"Saved signal plot to {output_path}") diff --git a/src/ria_toolkit_oss/view/view_signal_simple.py b/src/ria_toolkit_oss/view/view_signal_simple.py index ab56f7d..248486f 100644 --- a/src/ria_toolkit_oss/view/view_signal_simple.py +++ b/src/ria_toolkit_oss/view/view_signal_simple.py @@ -3,6 +3,7 @@ from __future__ import annotations import gc +import json from typing import Optional import matplotlib @@ -11,13 +12,54 @@ import numpy as np from scipy.fft import fft, fftshift from scipy.signal.windows import hann -from ria_toolkit_oss.datatypes.recording import Recording -from ria_toolkit_oss.view.tools import ( - COLORS, - decimate, - extract_metadata_fields, - set_path, -) +from utils.data.recording import Recording +from utils.view.tools import COLORS, decimate, extract_metadata_fields, set_path + + +def _add_annotations(annotations, compact_mode, show_labels, sample_rate_hz, center_freq_hz, ax2): + if annotations and not compact_mode: + for annotation in annotations: + start_idx = annotation.get("core:sample_start", 0) + length = annotation.get("core:sample_count", 0) + start_time = start_idx / sample_rate_hz + end_time = (start_idx + length) / sample_rate_hz + freq_low = annotation.get("core:freq_lower_edge", center_freq_hz - sample_rate_hz / 4) + freq_high = annotation.get("core:freq_upper_edge", center_freq_hz + sample_rate_hz / 4) + comment = annotation.get("core:comment", "{}") + + try: + comment_data = json.loads(comment) if isinstance(comment, str) else comment + ann_type = comment_data.get("type", "unknown") + if ann_type == "intersection": + color = COLORS["success"] + elif ann_type == "parallel": + color = COLORS["primary"] + elif ann_type == "standalone": + color = COLORS["warning"] + else: + color = COLORS["error"] + except Exception: + color = COLORS["error"] + + rect = plt.Rectangle( + (start_time, freq_low), + end_time - start_time, + freq_high - freq_low, + color=color, + alpha=0.4, + linewidth=2, + ) + ax2.add_patch(rect) + if show_labels: + label = annotation.get("core:label", "Signal") + ax2.text( + start_time, + freq_high, + label, + color=COLORS["light"], + fontsize=10, + bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7), + ) def _get_nfft_size(signal, fast_mode): @@ -138,6 +180,7 @@ def detect_constellation_symbols(signal: np.ndarray, method: str = "differential def view_simple_sig( recording: Recording, + annotations: Optional[list] = None, output_path: Optional[str] = "images/signal.png", saveplot: Optional[bool] = True, fast_mode: Optional[bool] = False, @@ -261,6 +304,15 @@ def view_simple_sig( ax2.set_title("Spectrogram", loc="left", pad=10) + _add_annotations( + annotations=annotations, + compact_mode=compact_mode, + show_labels=show_labels, + sample_rate_hz=sample_rate_hz, + center_freq_hz=center_freq_hz, + ax2=ax2, + ) + if ax_constellation is not None: constellation_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=50_000, fast_max=20_000) method = "differential" if fast_mode else "combined" @@ -310,7 +362,7 @@ def view_simple_sig( else: plt.tight_layout() if show_title: - plt.subplots_adjust(top=0.90) + plt.subplots_adjust(top=0.92) if saveplot: output_path, extension = set_path(output_path=output_path)