"""Shared plotting primitives for signal visualization.""" from __future__ import annotations import gc from typing import Optional import matplotlib import matplotlib.pyplot as plt import numpy as np from scipy.fft import fft, fftshift from scipy.signal.windows import hann from ria_toolkit_oss.datatypes.recording import Recording from ria_toolkit_oss.view.tools import ( COLORS, decimate, extract_metadata_fields, set_path, ) def _get_nfft_size(signal, fast_mode): if len(signal) < 1000: nfft = 128 elif len(signal) < 10_000: nfft = 256 elif len(signal) < 100_000: nfft = 512 elif len(signal) < 1_000_000: nfft = 1024 else: nfft = 2048 if fast_mode: nfft = min(nfft, 512) overlap = nfft // 8 if fast_mode else nfft // 4 return nfft, overlap def _get_plot_samples(signal, fast_mode, slow_max, fast_max): max_samples = fast_max if fast_mode else slow_max if len(signal) > max_samples: start_idx = len(signal) // 2 - max_samples // 2 return signal[start_idx : start_idx + max_samples] else: return signal def _set_dpi(fast_mode, labels_mode, extension): if fast_mode: dpi = 75 elif labels_mode: dpi = 200 else: dpi = 150 return dpi if extension == "png" else None def setup_style(*, labels_mode: bool = False, compact_mode: bool = False) -> None: """Configure matplotlib with the signal-testbed styling.""" plt.style.use("dark_background") if compact_mode: base_font = 8 title_font = 10 label_font = 8 elif labels_mode: base_font = 12 title_font = 16 label_font = 14 else: base_font = 10 title_font = 12 label_font = 10 matplotlib.rcParams.update( { "figure.facecolor": "#0f172a", "axes.facecolor": "#1e293b", "axes.edgecolor": COLORS["muted"], "axes.labelcolor": COLORS["light"], "text.color": COLORS["light"], "xtick.color": COLORS["muted"], "ytick.color": COLORS["muted"], "grid.color": COLORS["muted"], "grid.alpha": 0.3, "font.size": base_font, "axes.titlesize": title_font, "axes.labelsize": label_font, "figure.titlesize": title_font + 2, "legend.frameon": False, "legend.facecolor": "none", "xtick.labelsize": base_font, "ytick.labelsize": base_font, } ) def detect_constellation_symbols(signal: np.ndarray, method: str = "differential") -> np.ndarray: """Heuristic symbol detector used for constellation highlighting.""" if len(signal) < 100: return np.ones(len(signal), dtype=bool) if method == "differential": di = np.diff(signal.imag) dq = np.diff(signal.real) derivative_magnitude = np.sqrt(di**2 + dq**2) derivative_magnitude = np.append(derivative_magnitude, 0) threshold = np.percentile(derivative_magnitude, 15) return derivative_magnitude < threshold if method == "amplitude": amplitude = np.abs(signal) amplitude_change = np.abs(np.diff(amplitude)) amplitude_change = np.append(amplitude_change, 0) threshold = np.percentile(amplitude_change, 20) return amplitude_change < threshold if method == "phase": phase = np.angle(signal) phase_diff = np.diff(np.unwrap(phase)) phase_diff = np.append(phase_diff, 0) threshold = np.percentile(np.abs(phase_diff), 20) return np.abs(phase_diff) < threshold if method == "combined": diff_stable = detect_constellation_symbols(signal, "differential") amp_stable = detect_constellation_symbols(signal, "amplitude") phase_stable = detect_constellation_symbols(signal, "phase") stability_count = diff_stable.astype(int) + amp_stable.astype(int) + phase_stable.astype(int) return stability_count >= 2 raise ValueError(f"Unknown method: {method}") def view_simple_sig( recording: Recording, output_path: Optional[str] = "images/signal.png", saveplot: Optional[bool] = True, fast_mode: Optional[bool] = False, compact_mode: Optional[bool] = False, horizontal_mode: Optional[bool] = False, constellation_mode: Optional[bool] = False, labels_mode: Optional[bool] = False, slice: Optional[tuple] = None, title: Optional[str] = "Signal", ): """ Create a simple plot of various signal visualizations as a png or svg image. :param recording: The recording object to plot. :type recording: Recording :param output_path: The output image path. Defaults to "images/signal.png" :type output_path: str, optional :param saveplot: Whether or not to save the plot. Defaults to True. :type saveplot: bool, optional :param fast_mode: Use fast mode for faster render. Defaults to False. :type fast_mode: bool, optional :param compact_mode: Use compact mode for compact plot. Defaults to False. :type compact_mode: bool, optional :param horizontal_mode: Display plots horizontally. Defaults to False. :type horizontal_mode: bool, optional :param constellation_mode: Display constellation plot and PSD if not using compact mode. Defaults to False. :type constellation_mode: bool, optional :param labels_mode: Display more thorough labels. Defaults to False. :type labels_mode: bool, optional :param slice: Slice of signal to display. Defaults to None. :type slice: tuple[int, int], optional :param title: Title of plot. Defaults to "Signal". :type title: str, optional """ signal = recording.data[0] sample_rate_hz, center_freq_hz, sdr = extract_metadata_fields(recording.metadata) setup_style(labels_mode=labels_mode, compact_mode=compact_mode) if slice: start_idx, end_idx = slice signal = signal[start_idx:end_idx] print(f"Using slice: samples {start_idx} to {end_idx} ({len(signal):,} samples)") max_display_pixels = 100_000 if fast_mode else 250_000 display_signal = decimate(signal, max_display_pixels) if len(signal) > max_display_pixels else signal spec_signal = signal if compact_mode: fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6), gridspec_kw={"height_ratios": [1, 5]}) show_title = False show_labels = False ax_constellation = ax_psd = None elif horizontal_mode: if constellation_mode: fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) ax_constellation = ax3 else: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) ax_constellation = None show_title = True show_labels = labels_mode ax_psd = None else: if constellation_mode: fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12)) ax_constellation, ax_psd = ax3, ax4 else: fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10)) ax_constellation = ax_psd = None show_title = True show_labels = labels_mode if show_title: fig.suptitle(title, fontsize=16, color=COLORS["light"], y=0.96) fig.patch.set_facecolor("#0f172a") total_duration_s = len(signal) / sample_rate_hz if sample_rate_hz else 0.0 t_s = np.linspace(0, total_duration_s, len(display_signal)) if len(display_signal) else np.array([]) ax1.plot(t_s, display_signal.real, color=COLORS["purple"], linewidth=0.8, alpha=0.8, label="I") ax1.plot(t_s, display_signal.imag, color=COLORS["magenta"], linewidth=0.8, alpha=0.8, label="Q") ax1.set_xlim(0, total_duration_s) ax1.grid(True, alpha=0.3) nfft, overlap = _get_nfft_size(signal=signal, fast_mode=fast_mode) _, freqs, _, _ = ax2.specgram( spec_signal, NFFT=nfft, Fc=center_freq_hz, Fs=sample_rate_hz, noverlap=overlap, cmap="twilight", ) ax2.set_ylim(center_freq_hz - sample_rate_hz / 2, center_freq_hz + sample_rate_hz / 2) ax2.set_xlim(0, total_duration_s) if show_labels: if horizontal_mode: ax1.set_xlabel("Time (s)") else: ax2.set_xlabel("Time (s)") ax1.set_ylabel("Amplitude") ax1.set_title(f"Time Series - {sdr} SDR", loc='left', pad=10) ax1.legend(loc="upper right") ax2.set_ylabel("Frequency (Hz)") ax2.set_title(f"Spectrogram - {center_freq_hz / 1e6:.1f} MHz ± {sample_rate_hz / 2e6:.1f} MHz", loc='left', pad=10) yticks = ax2.get_yticks() ax2.set_yticklabels([f"{y / 1e6:.1f}" for y in yticks]) elif not compact_mode: ax1.set_title("Time Series", loc='left', pad=10) ax1.legend(loc="upper right", fontsize=8) ax2.set_title("Spectrogram", loc='left', pad=10) if ax_constellation is not None: constellation_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=50_000, fast_max=20_000) method = "differential" if fast_mode else "combined" stable_points = detect_constellation_symbols(constellation_samples, method=method) ax_constellation.scatter( constellation_samples.real[~stable_points], constellation_samples.imag[~stable_points], c=COLORS["muted"], s=0.5, alpha=0.2, ) ax_constellation.scatter( constellation_samples.real[stable_points], constellation_samples.imag[stable_points], c=COLORS["purple"], s=3, alpha=0.8, ) ax_constellation.set_xlabel("In-phase (I)") ax_constellation.set_ylabel("Quadrature (Q)") ax_constellation.set_title("Constellation") ax_constellation.grid(True, alpha=0.3) ax_constellation.set_aspect("equal") if ax_psd is not None: psd_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=65_536, fast_max=16_384) window = hann(len(psd_samples)) spectrum = np.abs(fftshift(fft(psd_samples * window))) ** 2 freqs = np.linspace(-sample_rate_hz / 2, sample_rate_hz / 2, len(psd_samples)) freqs = freqs + center_freq_hz spectrum_db = 10 * np.log10(spectrum + 1e-12) ax_psd.plot(freqs / 1e6, spectrum_db, color=COLORS["accent"], linewidth=1.0) ax_psd.set_xlabel("Frequency (MHz)") ax_psd.set_ylabel("Power (dB)") ax_psd.set_title("Power Spectral Density") ax_psd.grid(True, alpha=0.3) if compact_mode: ax1.set_xticks([]) ax1.set_yticks([]) ax2.set_xticks([]) ax2.set_yticks([]) plt.subplots_adjust(left=0, right=1, top=1, bottom=0, hspace=0) else: plt.tight_layout() if show_title: plt.subplots_adjust(top=0.90) if saveplot: output_path, extension = set_path(output_path=output_path) dpi_value = _set_dpi(fast_mode=fast_mode, labels_mode=labels_mode, extension=extension) plt.savefig(output_path, dpi=dpi_value, bbox_inches="tight", facecolor="#0f172a", edgecolor="none") print(f"Saved signal plot to {output_path}") return output_path plt.show() # Garbage collection and clean up to prevent memory overloading plt.close("all") gc.collect() return None __all__ = [ "setup_style", "detect_constellation_symbols", "view_simple_sig", ]