ria-toolkit-oss/src/ria_toolkit_oss/view/view_signal_simple.py

388 lines
14 KiB
Python
Raw Normal View History

"""Shared plotting primitives for signal visualization."""
from __future__ import annotations
import gc
M
2026-02-23 14:12:34 -05:00
import json
from typing import Optional
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from scipy.fft import fft, fftshift
from scipy.signal.windows import hann
from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.view.tools import COLORS, decimate, extract_metadata_fields, set_path
M
2026-02-23 14:12:34 -05:00
def _add_annotations(annotations, compact_mode, show_labels, sample_rate_hz, center_freq_hz, ax2):
if annotations and not compact_mode:
for annotation in annotations:
start_idx = annotation.get("core:sample_start", 0)
length = annotation.get("core:sample_count", 0)
start_time = start_idx / sample_rate_hz
end_time = (start_idx + length) / sample_rate_hz
freq_low = annotation.get("core:freq_lower_edge", center_freq_hz - sample_rate_hz / 4)
freq_high = annotation.get("core:freq_upper_edge", center_freq_hz + sample_rate_hz / 4)
comment = annotation.get("core:comment", "{}")
try:
comment_data = json.loads(comment) if isinstance(comment, str) else comment
ann_type = comment_data.get("type", "unknown")
if ann_type == "intersection":
color = COLORS["success"]
elif ann_type == "parallel":
color = COLORS["primary"]
elif ann_type == "standalone":
color = COLORS["warning"]
else:
color = COLORS["error"]
except Exception:
color = COLORS["error"]
rect = plt.Rectangle(
(start_time, freq_low),
end_time - start_time,
freq_high - freq_low,
color=color,
alpha=0.4,
linewidth=2,
)
ax2.add_patch(rect)
if show_labels:
label = annotation.get("core:label", "Signal")
ax2.text(
start_time,
freq_high,
label,
color=COLORS["light"],
fontsize=10,
bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7),
)
def _get_nfft_size(signal, fast_mode):
if len(signal) < 1000:
nfft = 128
elif len(signal) < 10_000:
nfft = 256
elif len(signal) < 100_000:
nfft = 512
elif len(signal) < 1_000_000:
nfft = 1024
else:
nfft = 2048
if fast_mode:
nfft = min(nfft, 512)
overlap = nfft // 8 if fast_mode else nfft // 4
return nfft, overlap
def _get_plot_samples(signal, fast_mode, slow_max, fast_max):
max_samples = fast_max if fast_mode else slow_max
if len(signal) > max_samples:
start_idx = len(signal) // 2 - max_samples // 2
return signal[start_idx : start_idx + max_samples]
else:
return signal
def _set_dpi(fast_mode, labels_mode, extension):
if fast_mode:
dpi = 75
elif labels_mode:
dpi = 200
else:
dpi = 150
return dpi if extension == "png" else None
def setup_style(*, labels_mode: bool = False, compact_mode: bool = False) -> None:
"""Configure matplotlib with the signal-testbed styling."""
plt.style.use("dark_background")
if compact_mode:
base_font = 8
title_font = 10
label_font = 8
elif labels_mode:
base_font = 12
title_font = 16
label_font = 14
else:
base_font = 10
title_font = 12
label_font = 10
matplotlib.rcParams.update(
{
"figure.facecolor": "#0f172a",
"axes.facecolor": "#1e293b",
"axes.edgecolor": COLORS["muted"],
"axes.labelcolor": COLORS["light"],
"text.color": COLORS["light"],
"xtick.color": COLORS["muted"],
"ytick.color": COLORS["muted"],
"grid.color": COLORS["muted"],
"grid.alpha": 0.3,
"font.size": base_font,
"axes.titlesize": title_font,
"axes.labelsize": label_font,
"figure.titlesize": title_font + 2,
"legend.frameon": False,
"legend.facecolor": "none",
"xtick.labelsize": base_font,
"ytick.labelsize": base_font,
}
)
def detect_constellation_symbols(signal: np.ndarray, method: str = "differential") -> np.ndarray:
"""Heuristic symbol detector used for constellation highlighting."""
if len(signal) < 100:
return np.ones(len(signal), dtype=bool)
if method == "differential":
di = np.diff(signal.imag)
dq = np.diff(signal.real)
derivative_magnitude = np.sqrt(di**2 + dq**2)
derivative_magnitude = np.append(derivative_magnitude, 0)
threshold = np.percentile(derivative_magnitude, 15)
return derivative_magnitude < threshold
if method == "amplitude":
amplitude = np.abs(signal)
amplitude_change = np.abs(np.diff(amplitude))
amplitude_change = np.append(amplitude_change, 0)
threshold = np.percentile(amplitude_change, 20)
return amplitude_change < threshold
if method == "phase":
phase = np.angle(signal)
phase_diff = np.diff(np.unwrap(phase))
phase_diff = np.append(phase_diff, 0)
threshold = np.percentile(np.abs(phase_diff), 20)
return np.abs(phase_diff) < threshold
if method == "combined":
diff_stable = detect_constellation_symbols(signal, "differential")
amp_stable = detect_constellation_symbols(signal, "amplitude")
phase_stable = detect_constellation_symbols(signal, "phase")
stability_count = diff_stable.astype(int) + amp_stable.astype(int) + phase_stable.astype(int)
return stability_count >= 2
raise ValueError(f"Unknown method: {method}")
def view_simple_sig(
recording: Recording,
M
2026-02-23 14:12:34 -05:00
annotations: Optional[list] = None,
output_path: Optional[str] = "images/signal.png",
saveplot: Optional[bool] = True,
fast_mode: Optional[bool] = False,
compact_mode: Optional[bool] = False,
horizontal_mode: Optional[bool] = False,
constellation_mode: Optional[bool] = False,
labels_mode: Optional[bool] = False,
slice: Optional[tuple] = None,
title: Optional[str] = "Signal",
):
M
2025-10-24 18:17:17 -04:00
"""
Create a simple plot of various signal visualizations as a png or svg image.
:param recording: The recording object to plot.
:type recording: Recording
:param output_path: The output image path. Defaults to "images/signal.png"
:type output_path: str, optional
:param saveplot: Whether or not to save the plot. Defaults to True.
:type saveplot: bool, optional
:param fast_mode: Use fast mode for faster render. Defaults to False.
:type fast_mode: bool, optional
:param compact_mode: Use compact mode for compact plot. Defaults to False.
:type compact_mode: bool, optional
:param horizontal_mode: Display plots horizontally. Defaults to False.
:type horizontal_mode: bool, optional
:param constellation_mode: Display constellation plot and PSD if not using compact mode. Defaults to False.
:type constellation_mode: bool, optional
:param labels_mode: Display more thorough labels. Defaults to False.
:type labels_mode: bool, optional
:param slice: Slice of signal to display. Defaults to None.
:type slice: tuple[int, int], optional
:param title: Title of plot. Defaults to "Signal".
:type title: str, optional
"""
signal = recording.data[0]
sample_rate_hz, center_freq_hz, sdr = extract_metadata_fields(recording.metadata)
setup_style(labels_mode=labels_mode, compact_mode=compact_mode)
if slice:
start_idx, end_idx = slice
signal = signal[start_idx:end_idx]
print(f"Using slice: samples {start_idx} to {end_idx} ({len(signal):,} samples)")
max_display_pixels = 100_000 if fast_mode else 250_000
display_signal = decimate(signal, max_display_pixels) if len(signal) > max_display_pixels else signal
spec_signal = signal
if compact_mode:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6), gridspec_kw={"height_ratios": [1, 5]})
show_title = False
show_labels = False
ax_constellation = ax_psd = None
elif horizontal_mode:
if constellation_mode:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
ax_constellation = ax3
else:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
ax_constellation = None
show_title = True
show_labels = labels_mode
ax_psd = None
else:
if constellation_mode:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
ax_constellation, ax_psd = ax3, ax4
else:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
ax_constellation = ax_psd = None
show_title = True
show_labels = labels_mode
if show_title:
fig.suptitle(title, fontsize=16, color=COLORS["light"], y=0.96)
fig.patch.set_facecolor("#0f172a")
total_duration_s = len(signal) / sample_rate_hz if sample_rate_hz else 0.0
t_s = np.linspace(0, total_duration_s, len(display_signal)) if len(display_signal) else np.array([])
ax1.plot(t_s, display_signal.real, color=COLORS["purple"], linewidth=0.8, alpha=0.8, label="I")
ax1.plot(t_s, display_signal.imag, color=COLORS["magenta"], linewidth=0.8, alpha=0.8, label="Q")
ax1.set_xlim(0, total_duration_s)
ax1.grid(True, alpha=0.3)
nfft, overlap = _get_nfft_size(signal=signal, fast_mode=fast_mode)
_, freqs, _, _ = ax2.specgram(
spec_signal,
NFFT=nfft,
Fc=center_freq_hz,
Fs=sample_rate_hz,
noverlap=overlap,
cmap="twilight",
)
ax2.set_ylim(center_freq_hz - sample_rate_hz / 2, center_freq_hz + sample_rate_hz / 2)
ax2.set_xlim(0, total_duration_s)
if show_labels:
if horizontal_mode:
ax1.set_xlabel("Time (s)")
else:
ax2.set_xlabel("Time (s)")
ax1.set_ylabel("Amplitude")
M
2026-01-30 17:43:10 -05:00
ax1.set_title(f"Time Series - {sdr} SDR", loc="left", pad=10)
ax1.legend(loc="upper right")
ax2.set_ylabel("Frequency (Hz)")
M
2026-01-30 17:43:10 -05:00
ax2.set_title(
f"Spectrogram - {center_freq_hz / 1e6:.1f} MHz ± {sample_rate_hz / 2e6:.1f} MHz", loc="left", pad=10
)
yticks = ax2.get_yticks()
ax2.set_yticklabels([f"{y / 1e6:.1f}" for y in yticks])
elif not compact_mode:
M
2026-01-30 17:43:10 -05:00
ax1.set_title("Time Series", loc="left", pad=10)
ax1.legend(loc="upper right", fontsize=8)
M
2026-01-30 17:43:10 -05:00
ax2.set_title("Spectrogram", loc="left", pad=10)
M
2026-02-23 14:12:34 -05:00
_add_annotations(
annotations=annotations,
compact_mode=compact_mode,
show_labels=show_labels,
sample_rate_hz=sample_rate_hz,
center_freq_hz=center_freq_hz,
ax2=ax2,
)
if ax_constellation is not None:
constellation_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=50_000, fast_max=20_000)
method = "differential" if fast_mode else "combined"
stable_points = detect_constellation_symbols(constellation_samples, method=method)
ax_constellation.scatter(
constellation_samples.real[~stable_points],
constellation_samples.imag[~stable_points],
c=COLORS["muted"],
s=0.5,
alpha=0.2,
)
ax_constellation.scatter(
constellation_samples.real[stable_points],
constellation_samples.imag[stable_points],
c=COLORS["purple"],
s=3,
alpha=0.8,
)
ax_constellation.set_xlabel("In-phase (I)")
ax_constellation.set_ylabel("Quadrature (Q)")
ax_constellation.set_title("Constellation")
ax_constellation.grid(True, alpha=0.3)
ax_constellation.set_aspect("equal")
if ax_psd is not None:
psd_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=65_536, fast_max=16_384)
window = hann(len(psd_samples))
spectrum = np.abs(fftshift(fft(psd_samples * window))) ** 2
freqs = np.linspace(-sample_rate_hz / 2, sample_rate_hz / 2, len(psd_samples))
freqs = freqs + center_freq_hz
spectrum_db = 10 * np.log10(spectrum + 1e-12)
ax_psd.plot(freqs / 1e6, spectrum_db, color=COLORS["accent"], linewidth=1.0)
ax_psd.set_xlabel("Frequency (MHz)")
ax_psd.set_ylabel("Power (dB)")
ax_psd.set_title("Power Spectral Density")
ax_psd.grid(True, alpha=0.3)
if compact_mode:
ax1.set_xticks([])
ax1.set_yticks([])
ax2.set_xticks([])
ax2.set_yticks([])
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, hspace=0)
else:
plt.tight_layout()
if show_title:
M
2026-02-23 14:12:34 -05:00
plt.subplots_adjust(top=0.92)
if saveplot:
output_path, extension = set_path(output_path=output_path)
dpi_value = _set_dpi(fast_mode=fast_mode, labels_mode=labels_mode, extension=extension)
plt.savefig(output_path, dpi=dpi_value, bbox_inches="tight", facecolor="#0f172a", edgecolor="none")
print(f"Saved signal plot to {output_path}")
return output_path
plt.show()
# Garbage collection and clean up to prevent memory overloading
plt.close("all")
gc.collect()
return None
__all__ = [
"setup_style",
"detect_constellation_symbols",
M
2025-10-24 18:02:10 -04:00
"view_simple_sig",
]