ria-toolkit-oss/src/ria_toolkit_oss/view/view_signal_simple.py
G gillian e5a3d327e5
Some checks failed
Test with tox / Test with tox (3.11) (pull_request) Successful in 3m37s
Test with tox / Test with tox (3.12) (pull_request) Successful in 3m44s
Build Project / Build Project (3.10) (pull_request) Successful in 5m55s
Build Project / Build Project (3.11) (pull_request) Successful in 5m35s
Build Project / Build Project (3.12) (pull_request) Successful in 6m27s
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 6m22s
Test with tox / Test with tox (3.10) (pull_request) Failing after 5m33s
refactor: unify signal viewer styling and update docs screenshots
- Align view_simple and view_full on background colour (#161616), title
  size (25pt), subtitle size (15pt), base font/tick/label sizes, grid
  style (alpha=0.2), and legend fontsize (10pt)
- Spectrogram placed above IQ plot in view_simple; subplot renamed from
  "Time Series" to "IQ Sample Plot"
- Frequency and spectrogram Y-axes formatted in MHz across both viewers
- Added xlabel/ylabel, subtle grids, and IQ legend to view_full subplots
- Fixed spectrogram right-side clipping in view_simple by syncing xlim
  from specgram output rather than total signal duration
- Updated getting_started.rst to reference both simple and full viewer
  screenshots; replaced doc images with latest renders
2026-04-28 14:08:44 -04:00

400 lines
14 KiB
Python

"""Shared plotting primitives for signal visualization."""
from __future__ import annotations
import gc
import json
from typing import Optional
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from scipy.fft import fft, fftshift
from scipy.signal.windows import hann
from ria_toolkit_oss.data.recording import Recording
from ria_toolkit_oss.view.tools import (
COLORS,
decimate,
extract_metadata_fields,
set_path,
)
def _add_annotations(annotations, compact_mode, show_labels, sample_rate_hz, center_freq_hz, ax2):
if annotations and not compact_mode:
for annotation in annotations:
start_idx = annotation.get("core:sample_start", 0)
length = annotation.get("core:sample_count", 0)
start_time = start_idx / sample_rate_hz
end_time = (start_idx + length) / sample_rate_hz
freq_low = annotation.get("core:freq_lower_edge", center_freq_hz - sample_rate_hz / 4)
freq_high = annotation.get("core:freq_upper_edge", center_freq_hz + sample_rate_hz / 4)
comment = annotation.get("core:comment", "{}")
try:
comment_data = json.loads(comment) if isinstance(comment, str) else comment
ann_type = comment_data.get("type", "unknown")
if ann_type == "intersection":
color = COLORS["success"]
elif ann_type == "parallel":
color = COLORS["primary"]
elif ann_type == "standalone":
color = COLORS["warning"]
else:
color = COLORS["error"]
except Exception:
color = COLORS["error"]
rect = plt.Rectangle(
(start_time, freq_low),
end_time - start_time,
freq_high - freq_low,
color=color,
alpha=0.4,
linewidth=2,
)
ax2.add_patch(rect)
if show_labels:
label = annotation.get("core:label", "Signal")
ax2.text(
start_time,
freq_high,
label,
color=COLORS["light"],
fontsize=10,
bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7),
)
def _get_nfft_size(signal, fast_mode):
if len(signal) < 1000:
nfft = 128
elif len(signal) < 10_000:
nfft = 256
elif len(signal) < 100_000:
nfft = 512
elif len(signal) < 1_000_000:
nfft = 1024
else:
nfft = 2048
if fast_mode:
nfft = min(nfft, 512)
overlap = nfft // 8 if fast_mode else nfft // 4
return nfft, overlap
def _get_plot_samples(signal, fast_mode, slow_max, fast_max):
max_samples = fast_max if fast_mode else slow_max
if len(signal) > max_samples:
start_idx = len(signal) // 2 - max_samples // 2
return signal[start_idx : start_idx + max_samples]
else:
return signal
def _set_dpi(fast_mode, labels_mode, extension):
if fast_mode:
dpi = 75
elif labels_mode:
dpi = 200
else:
dpi = 150
return dpi if extension == "png" else None
def setup_style(*, labels_mode: bool = False, compact_mode: bool = False) -> None:
"""Configure matplotlib with the signal-testbed styling."""
plt.style.use("dark_background")
if compact_mode:
base_font = 8
title_font = 10
label_font = 8
elif labels_mode:
base_font = 12
title_font = 16
label_font = 14
else:
base_font = 10
title_font = 15
label_font = 10
matplotlib.rcParams.update(
{
"figure.facecolor": "#161616",
"axes.facecolor": "#161616",
"savefig.facecolor": "#161616",
"savefig.edgecolor": "#161616",
"font.size": base_font,
"axes.titlesize": title_font,
"axes.labelsize": label_font,
"figure.titlesize": title_font + 4,
"legend.frameon": False,
"legend.facecolor": "none",
"xtick.labelsize": base_font,
"ytick.labelsize": base_font,
}
)
def detect_constellation_symbols(signal: np.ndarray, method: str = "differential") -> np.ndarray:
"""Heuristic symbol detector used for constellation highlighting."""
if len(signal) < 100:
return np.ones(len(signal), dtype=bool)
if method == "differential":
di = np.diff(signal.imag)
dq = np.diff(signal.real)
derivative_magnitude = np.sqrt(di**2 + dq**2)
derivative_magnitude = np.append(derivative_magnitude, 0)
threshold = np.percentile(derivative_magnitude, 15)
return derivative_magnitude < threshold
if method == "amplitude":
amplitude = np.abs(signal)
amplitude_change = np.abs(np.diff(amplitude))
amplitude_change = np.append(amplitude_change, 0)
threshold = np.percentile(amplitude_change, 20)
return amplitude_change < threshold
if method == "phase":
phase = np.angle(signal)
phase_diff = np.diff(np.unwrap(phase))
phase_diff = np.append(phase_diff, 0)
threshold = np.percentile(np.abs(phase_diff), 20)
return np.abs(phase_diff) < threshold
if method == "combined":
diff_stable = detect_constellation_symbols(signal, "differential")
amp_stable = detect_constellation_symbols(signal, "amplitude")
phase_stable = detect_constellation_symbols(signal, "phase")
stability_count = diff_stable.astype(int) + amp_stable.astype(int) + phase_stable.astype(int)
return stability_count >= 2
raise ValueError(f"Unknown method: {method}")
def view_simple_sig(
recording: Recording,
annotations: Optional[list] = None,
output_path: Optional[str] = "images/signal.png",
saveplot: Optional[bool] = True,
fast_mode: Optional[bool] = False,
compact_mode: Optional[bool] = False,
horizontal_mode: Optional[bool] = False,
constellation_mode: Optional[bool] = False,
labels_mode: Optional[bool] = False,
slice: Optional[tuple] = None,
title: Optional[str] = "Signal Plot",
):
"""
Create a simple plot of various signal visualizations as a png or svg image.
:param recording: The recording object to plot.
:type recording: Recording
:param output_path: The output image path. Defaults to "images/signal.png"
:type output_path: str, optional
:param saveplot: Whether or not to save the plot. Defaults to True.
:type saveplot: bool, optional
:param fast_mode: Use fast mode for faster render. Defaults to False.
:type fast_mode: bool, optional
:param compact_mode: Use compact mode for compact plot. Defaults to False.
:type compact_mode: bool, optional
:param horizontal_mode: Display plots horizontally. Defaults to False.
:type horizontal_mode: bool, optional
:param constellation_mode: Display constellation plot and PSD if not using compact mode. Defaults to False.
:type constellation_mode: bool, optional
:param labels_mode: Display more thorough labels. Defaults to False.
:type labels_mode: bool, optional
:param slice: Slice of signal to display. Defaults to None.
:type slice: tuple[int, int], optional
:param title: Title of plot. Defaults to "Signal".
:type title: str, optional
"""
signal = recording.data[0]
sample_rate_hz, center_freq_hz, sdr = extract_metadata_fields(recording.metadata)
setup_style(labels_mode=labels_mode, compact_mode=compact_mode)
if slice:
start_idx, end_idx = slice
signal = signal[start_idx:end_idx]
print(f"Using slice: samples {start_idx} to {end_idx} ({len(signal):,} samples)")
max_display_pixels = 100_000 if fast_mode else 250_000
display_signal = decimate(signal, max_display_pixels) if len(signal) > max_display_pixels else signal
spec_signal = signal
if compact_mode:
fig, (ax2, ax1) = plt.subplots(2, 1, figsize=(12, 6), gridspec_kw={"height_ratios": [5, 1]})
show_title = False
show_labels = False
ax_constellation = ax_psd = None
elif horizontal_mode:
if constellation_mode:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
ax_constellation = ax3
else:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
ax_constellation = None
show_title = True
show_labels = labels_mode
ax_psd = None
else:
if constellation_mode:
fig, ((ax2, ax1), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
ax_constellation, ax_psd = ax3, ax4
else:
fig, (ax2, ax1) = plt.subplots(2, 1, figsize=(14, 10))
ax_constellation = ax_psd = None
show_title = True
show_labels = labels_mode
if show_title:
fig.suptitle(title, fontsize=25)
fig.patch.set_facecolor(matplotlib.rcParams["figure.facecolor"])
total_duration_s = len(signal) / sample_rate_hz if sample_rate_hz else 0.0
t_s = np.linspace(0, total_duration_s, len(display_signal)) if len(display_signal) else np.array([])
ax1.plot(t_s, display_signal.real, color=COLORS["purple"], linewidth=0.6, alpha=0.8, label="I")
ax1.plot(t_s, display_signal.imag, color=COLORS["magenta"], linewidth=0.6, alpha=0.8, label="Q")
ax1.grid(True, alpha=0.2, linewidth=0.5)
nfft, overlap = _get_nfft_size(signal=signal, fast_mode=fast_mode)
_, freqs, _, _ = ax2.specgram(
spec_signal,
NFFT=nfft,
Fc=center_freq_hz,
Fs=sample_rate_hz,
noverlap=overlap,
cmap="twilight",
)
ax2.set_ylim(center_freq_hz - sample_rate_hz / 2, center_freq_hz + sample_rate_hz / 2)
ax1.set_xlim(ax2.get_xlim())
if show_labels:
if horizontal_mode:
ax1.set_xlabel("Time (s)")
else:
ax2.set_xlabel("Time (s)")
ax1.set_ylabel("Amplitude")
ax1.set_title(f"IQ Sample Plot - {sdr} SDR", loc="left", pad=10, fontsize=15)
ax1.legend(loc="upper right", fontsize=10)
ax2.set_ylabel("Frequency (MHz)")
ax2.set_title(
f"Spectrogram - {center_freq_hz / 1e6:.1f} MHz ± {sample_rate_hz / 2e6:.1f} MHz", loc="left", pad=10, fontsize=15
)
ax2.yaxis.set_major_formatter(
matplotlib.ticker.FuncFormatter(lambda x, _: f"{x / 1e6:.1f}")
)
elif not compact_mode:
ax1.set_title("IQ Sample Plot", loc="left", pad=10, fontsize=15)
ax1.legend(loc="upper right", fontsize=10)
ax2.set_xlabel("Time (s)")
ax2.set_ylabel("Frequency (MHz)")
ax2.set_title("Spectrogram", loc="left", pad=10, fontsize=15)
ax2.yaxis.set_major_formatter(
matplotlib.ticker.FuncFormatter(lambda x, _: f"{x / 1e6:.1f}")
)
_add_annotations(
annotations=annotations,
compact_mode=compact_mode,
show_labels=show_labels,
sample_rate_hz=sample_rate_hz,
center_freq_hz=center_freq_hz,
ax2=ax2,
)
if ax_constellation is not None:
constellation_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=50_000, fast_max=20_000)
method = "differential" if fast_mode else "combined"
stable_points = detect_constellation_symbols(constellation_samples, method=method)
ax_constellation.scatter(
constellation_samples.real[~stable_points],
constellation_samples.imag[~stable_points],
c=COLORS["muted"],
s=0.5,
alpha=0.2,
)
ax_constellation.scatter(
constellation_samples.real[stable_points],
constellation_samples.imag[stable_points],
c=COLORS["purple"],
s=3,
alpha=0.8,
)
ax_constellation.set_xlabel("In-phase (I)")
ax_constellation.set_ylabel("Quadrature (Q)")
ax_constellation.set_title("Constellation", loc="left", fontsize=15)
ax_constellation.grid(True, alpha=0.2, linewidth=0.5)
ax_constellation.set_aspect("equal")
if ax_psd is not None:
psd_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=65_536, fast_max=16_384)
window = hann(len(psd_samples))
spectrum = np.abs(fftshift(fft(psd_samples * window))) ** 2
freqs = np.linspace(-sample_rate_hz / 2, sample_rate_hz / 2, len(psd_samples))
freqs = freqs + center_freq_hz
spectrum_db = 10 * np.log10(spectrum + 1e-12)
ax_psd.plot(freqs / 1e6, spectrum_db, color=COLORS["accent"], linewidth=0.8)
ax_psd.set_xlabel("Frequency (MHz)")
ax_psd.set_ylabel("Power (dB)")
ax_psd.set_title("Power Spectral Density", loc="left", fontsize=15)
ax_psd.grid(True, alpha=0.2, linewidth=0.5)
if compact_mode:
ax1.set_xticks([])
ax1.set_yticks([])
ax2.set_xticks([])
ax2.set_yticks([])
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, hspace=0)
else:
plt.tight_layout()
if show_title:
plt.subplots_adjust(top=0.9)
if saveplot:
output_path, extension = set_path(output_path=output_path)
dpi_value = _set_dpi(fast_mode=fast_mode, labels_mode=labels_mode, extension=extension)
plt.savefig(
output_path,
dpi=dpi_value,
bbox_inches="tight",
pad_inches=0.3,
facecolor=matplotlib.rcParams["savefig.facecolor"],
edgecolor=matplotlib.rcParams["savefig.edgecolor"],
)
print(f"Saved signal plot to {output_path}")
return output_path
plt.show()
# Garbage collection and clean up to prevent memory overloading
plt.close("all")
gc.collect()
return None
__all__ = [
"setup_style",
"detect_constellation_symbols",
"view_simple_sig",
]