diff --git a/src/ria_toolkit_oss/viz/onnx.py b/src/ria_toolkit_oss/viz/onnx.py
index 04e86b9..b92c3e4 100644
--- a/src/ria_toolkit_oss/viz/onnx.py
+++ b/src/ria_toolkit_oss/viz/onnx.py
@@ -6,18 +6,16 @@ as other ria-toolkit-oss visualization modules.
"""
from pathlib import Path
-from typing import Optional
-import plotly.graph_objects as go
import plotly.express as px
+import plotly.graph_objects as go
from plotly.subplots import make_subplots
-import pandas as pd
-import numpy as np
try:
import onnx
import onnx.helper
import onnx.numpy_helper
+
ONNX_AVAILABLE = True
except ImportError:
ONNX_AVAILABLE = False
@@ -26,33 +24,32 @@ except ImportError:
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure:
"""Create a professional error figure with Qoherent dark theme styling."""
fig = go.Figure()
-
+
# Create a clean, centered text display using Plotly's text formatting
main_text = f"⚠️ {title}
"
main_text += f"{message}"
-
+
if suggestion:
- main_text += f"
💡 Suggestion:
"
+ main_text += "
💡 Suggestion:
"
main_text += f"{suggestion}"
-
+
# Add the main text annotation
fig.add_annotation(
text=main_text,
- xref="paper", yref="paper",
- x=0.5, y=0.5,
- xanchor='center', yanchor='middle',
+ xref="paper",
+ yref="paper",
+ x=0.5,
+ y=0.5,
+ xanchor="center",
+ yanchor="middle",
showarrow=False,
align="center",
borderwidth=2,
bordercolor="#4a5568",
bgcolor="#2d3748",
- font=dict(
- family="Arial, sans-serif",
- size=14,
- color="#e2e8f0"
- )
+ font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
)
-
+
# Update layout with dark theme
fig.update_layout(
title="",
@@ -61,13 +58,13 @@ def create_styled_error_figure(title: str, message: str, suggestion: str = None)
margin=dict(l=40, r=40, t=40, b=40),
plot_bgcolor="#1a202c",
paper_bgcolor="#1a202c",
- font=dict(color="#e2e8f0")
+ font=dict(color="#e2e8f0"),
)
-
+
# Remove axes and grid
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
-
+
return fig
@@ -82,78 +79,85 @@ def graph_structure(file_path: Path) -> go.Figure:
"ONNX library is required for model analysis.",
"Install with: pip install onnx"
)
-
+
try:
# Load ONNX model
model = onnx.load(str(file_path))
graph = model.graph
nodes = graph.node
-
+
if len(nodes) == 0:
return create_styled_error_figure(
"Empty Model",
"This ONNX model contains no operators.",
"Please check if the model file is valid."
)
-
+
# Create network diagram data
node_info = []
for i, node in enumerate(nodes):
- node_info.append({
- 'id': i,
- 'name': node.name or f"{node.op_type}_{i}",
- 'op_type': node.op_type,
- 'inputs': len(node.input),
- 'outputs': len(node.output)
- })
-
+ node_info.append(
+ {
+ "id": i,
+ "name": node.name or f"{node.op_type}_{i}",
+ "op_type": node.op_type,
+ "inputs": len(node.input),
+ "outputs": len(node.output),
+ }
+ )
+
# Create visualization
fig = go.Figure()
-
+
# Simple linear layout for now
x_positions = list(range(len(node_info)))
y_positions = [0] * len(node_info)
-
+
# Add nodes as scatter points
- fig.add_trace(go.Scatter(
- x=x_positions,
- y=y_positions,
- mode='markers+text',
- marker=dict(
- size=[min(max(info['inputs'] + info['outputs'] + 15, 20), 50) for info in node_info],
- color=px.colors.qualitative.Set3[:len(node_info)],
- opacity=0.8,
- line=dict(width=2, color='white')
- ),
- text=[f"{info['op_type']}" for info in node_info],
- textposition="middle center",
- textfont=dict(size=10, color="white"),
- hovertemplate="%{text}
" +
- "Name: %{customdata[0]}
" +
- "Inputs: %{customdata[1]}
" +
- "Outputs: %{customdata[2]}
" +
- "",
- customdata=[[info['name'], info['inputs'], info['outputs']] for info in node_info],
- name="Operators"
- ))
-
+ fig.add_trace(
+ go.Scatter(
+ x=x_positions,
+ y=y_positions,
+ mode="markers+text",
+ marker=dict(
+ size=[min(max(info["inputs"] + info["outputs"] + 15, 20), 50) for info in node_info],
+ color=px.colors.qualitative.Set3[: len(node_info)],
+ opacity=0.8,
+ line=dict(width=2, color="white"),
+ ),
+ text=[f"{info['op_type']}" for info in node_info],
+ textposition="middle center",
+ textfont=dict(size=10, color="white"),
+ hovertemplate="%{text}
"
+ + "Name: %{customdata[0]}
"
+ + "Inputs: %{customdata[1]}
"
+ + "Outputs: %{customdata[2]}
"
+ + "",
+ customdata=[[info["name"], info["inputs"], info["outputs"]] for info in node_info],
+ name="Operators",
+ )
+ )
+
# Add connecting lines
for i in range(len(node_info) - 1):
- fig.add_trace(go.Scatter(
- x=[x_positions[i], x_positions[i+1]],
- y=[y_positions[i], y_positions[i+1]],
- mode='lines',
- line=dict(color='gray', width=1, dash='dot'),
- showlegend=False,
- hoverinfo='skip'
- ))
-
+ fig.add_trace(
+ go.Scatter(
+ x=[x_positions[i], x_positions[i + 1]],
+ y=[y_positions[i], y_positions[i + 1]],
+ mode="lines",
+ line=dict(color="gray", width=1, dash="dot"),
+ showlegend=False,
+ hoverinfo="skip",
+ )
+ )
+
fig.update_layout(
title={
- 'text': f"ONNX Graph Structure
{len(nodes)} Operators",
- 'x': 0.5,
- 'xanchor': 'center',
- 'font': {'size': 22}
+ "text": ("ONNX Graph Structure
"
+ f"{len(nodes)} Operators"),
+ "x": 0.5,
+ "xanchor": "center",
+ "font": {"size": 22},
},
xaxis_title="Execution Order",
yaxis_title="",
@@ -162,15 +166,15 @@ def graph_structure(file_path: Path) -> go.Figure:
template="plotly_dark",
yaxis=dict(showticklabels=False, showgrid=False),
xaxis=dict(showgrid=False),
- margin=dict(l=50, r=50, t=80, b=50)
+ margin=dict(l=50, r=50, t=80, b=50),
)
-
+
return fig
-
+
except Exception as e:
return create_styled_error_figure(
- "Graph Analysis Error",
- f"Could not analyze ONNX model structure.",
+ "Graph Analysis Error",
+ "Could not analyze ONNX model structure.",
f"Error: {str(e)}"
)
@@ -186,76 +190,80 @@ def operator_analysis(file_path: Path) -> go.Figure:
"ONNX library is required for operator analysis.",
"Install with: pip install onnx"
)
-
+
try:
model = onnx.load(str(file_path))
graph = model.graph
-
+
# Count operators
op_counts = {}
for node in graph.node:
op_type = node.op_type
op_counts[op_type] = op_counts.get(op_type, 0) + 1
-
+
if not op_counts:
return create_styled_error_figure(
"No Operators",
"This ONNX model contains no operators to analyze.",
- "Please verify the model file is valid."
+ "Please verify the model file is valid.",
)
-
+
# Sort by frequency
sorted_ops = sorted(op_counts.items(), key=lambda x: x[1], reverse=True)
-
+
# Create pie chart and bar chart
fig = make_subplots(
- rows=2, cols=1,
+ rows=2,
+ cols=1,
subplot_titles=("Operator Distribution", "Operator Frequency"),
- specs=[[{"type": "pie"}], [{"type": "bar"}]]
+ specs=[[{"type": "pie"}], [{"type": "bar"}]],
)
-
+
# Pie chart for operator distribution
op_names, op_values = zip(*sorted_ops) if sorted_ops else ([], [])
-
+
fig.add_trace(
go.Pie(
labels=list(op_names),
values=list(op_values),
textinfo="label+percent",
textposition="auto",
- showlegend=False
+ showlegend=False,
),
- row=1, col=1
+ row=1,
+ col=1,
)
-
+
# Bar chart for frequency
fig.add_trace(
go.Bar(
x=list(op_names),
y=list(op_values),
- marker_color=px.colors.qualitative.Set3[:len(op_names)],
- showlegend=False
+ marker_color=px.colors.qualitative.Set3[: len(op_names)],
+ showlegend=False,
),
- row=2, col=1
+ row=2,
+ col=1,
)
-
+
fig.update_layout(
title={
- 'text': f"ONNX Operator Analysis
{len(op_counts)} Unique Types",
- 'x': 0.5,
- 'xanchor': 'center',
- 'font': {'size': 22}
+ "text": ("ONNX Operator Analysis
"
+ f"{len(op_counts)} Unique Types"),
+ "x": 0.5,
+ "xanchor": "center",
+ "font": {"size": 22},
},
height=700,
- template="plotly_dark"
+ template="plotly_dark",
)
-
+
return fig
-
+
except Exception as e:
return create_styled_error_figure(
"Operator Analysis Error",
- f"Could not analyze ONNX operators.",
+ "Could not analyze ONNX operators.",
f"Error: {str(e)}"
)
@@ -271,74 +279,76 @@ def model_metadata(file_path: Path) -> go.Figure:
"ONNX library is required for metadata analysis.",
"Install with: pip install onnx"
)
-
+
try:
model = onnx.load(str(file_path))
graph = model.graph
-
+
# Calculate basic statistics
total_nodes = len(graph.node)
total_inputs = len(graph.input)
total_outputs = len(graph.output)
total_initializers = len(graph.initializer)
-
+
# Calculate parameter count
total_params = 0
for initializer in graph.initializer:
try:
tensor = onnx.numpy_helper.to_array(initializer)
total_params += tensor.size
- except:
+ except Exception:
pass # Skip if tensor can't be loaded
-
+
# Get model file size
file_size_mb = file_path.stat().st_size / (1024 * 1024)
-
+
# Create metadata display
fig = make_subplots(
- rows=2, cols=2,
+ rows=2,
+ cols=2,
subplot_titles=("Model Size", "Architecture", "Inputs/Outputs", "Parameters"),
- specs=[[{"type": "indicator"}, {"type": "bar"}],
- [{"type": "table"}, {"type": "indicator"}]]
+ specs=[[{"type": "indicator"}, {"type": "bar"}], [{"type": "table"}, {"type": "indicator"}]],
)
-
+
# Model size indicator
fig.add_trace(
go.Indicator(
mode="number+gauge",
value=file_size_mb,
- title={'text': "Model Size (MB)"},
- number={'suffix': ' MB', 'valueformat': '.2f'},
+ title={"text": "Model Size (MB)"},
+ number={"suffix": " MB", "valueformat": ".2f"},
gauge={
- 'axis': {'range': [0, max(100, file_size_mb * 1.5)]},
- 'bar': {'color': "darkblue"},
- 'steps': [
- {'range': [0, 10], 'color': "lightgreen"},
- {'range': [10, 50], 'color': "yellow"},
- {'range': [50, 100], 'color': "orange"}
- ]
- }
+ "axis": {"range": [0, max(100, file_size_mb * 1.5)]},
+ "bar": {"color": "darkblue"},
+ "steps": [
+ {"range": [0, 10], "color": "lightgreen"},
+ {"range": [10, 50], "color": "yellow"},
+ {"range": [50, 100], "color": "orange"},
+ ],
+ },
),
- row=1, col=1
+ row=1,
+ col=1,
)
-
+
# Architecture components
arch_data = ["Nodes", "Inputs", "Outputs", "Initializers"]
arch_values = [total_nodes, total_inputs, total_outputs, total_initializers]
-
+
fig.add_trace(
go.Bar(
x=arch_data,
y=arch_values,
- marker_color=['blue', 'green', 'orange', 'red'],
+ marker_color=["blue", "green", "orange", "red"],
showlegend=False
),
- row=1, col=2
+ row=1,
+ col=2,
)
-
+
# I/O Table
io_data = []
-
+
# Add input info
for inp in graph.input[:5]: # Limit to first 5
shape = "Unknown"
@@ -346,82 +356,99 @@ def model_metadata(file_path: Path) -> go.Figure:
if inp.type and inp.type.tensor_type:
# Get shape
if inp.type.tensor_type.shape:
- dims = [str(d.dim_value) if d.dim_value > 0 else "?"
- for d in inp.type.tensor_type.shape.dim]
+ dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in inp.type.tensor_type.shape.dim]
shape = f"[{', '.join(dims)}]"
-
+
# Get data type
elem_type = inp.type.tensor_type.elem_type
- type_map = {1: 'float32', 2: 'uint8', 3: 'int8', 6: 'int32',
- 7: 'int64', 9: 'bool', 10: 'float16', 11: 'double'}
- dtype = type_map.get(elem_type, f'type_{elem_type}')
-
- io_data.append(['Input', inp.name[:20], shape, dtype])
-
+ type_map = {
+ 1: "float32",
+ 2: "uint8",
+ 3: "int8",
+ 6: "int32",
+ 7: "int64",
+ 9: "bool",
+ 10: "float16",
+ 11: "double",
+ }
+ dtype = type_map.get(elem_type, f"type_{elem_type}")
+
+ io_data.append(["Input", inp.name[:20], shape, dtype])
+
# Add output info
for out in graph.output[:5]: # Limit to first 5
shape = "Unknown"
dtype = "Unknown"
if out.type and out.type.tensor_type:
if out.type.tensor_type.shape:
- dims = [str(d.dim_value) if d.dim_value > 0 else "?"
- for d in out.type.tensor_type.shape.dim]
+ dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in out.type.tensor_type.shape.dim]
shape = f"[{', '.join(dims)}]"
-
+
elem_type = out.type.tensor_type.elem_type
- type_map = {1: 'float32', 2: 'uint8', 3: 'int8', 6: 'int32',
- 7: 'int64', 9: 'bool', 10: 'float16', 11: 'double'}
- dtype = type_map.get(elem_type, f'type_{elem_type}')
-
- io_data.append(['Output', out.name[:20], shape, dtype])
-
+ type_map = {
+ 1: "float32",
+ 2: "uint8",
+ 3: "int8",
+ 6: "int32",
+ 7: "int64",
+ 9: "bool",
+ 10: "float16",
+ 11: "double",
+ }
+ dtype = type_map.get(elem_type, f"type_{elem_type}")
+
+ io_data.append(["Output", out.name[:20], shape, dtype])
+
if io_data:
fig.add_trace(
go.Table(
header=dict(
- values=['Type', 'Name', 'Shape', 'Data Type'],
- fill_color='lightblue',
- align='left'
+ values=["Type", "Name", "Shape", "Data Type"],
+ fill_color="lightblue",
+ align="left"
),
cells=dict(
values=list(zip(*io_data)),
- fill_color='white',
- align='left'
- )
+ fill_color="white",
+ align="left"
+ ),
),
- row=2, col=1
+ row=2,
+ col=1,
)
-
+
# Parameters indicator
fig.add_trace(
go.Indicator(
mode="number",
value=total_params,
- title={'text': "Total Parameters"},
- number={'suffix': 'M', 'valueformat': '.2f'},
- number_font_size=30
+ title={"text": "Total Parameters"},
+ number={"suffix": "M", "valueformat": ".2f"},
+ number_font_size=30,
),
- row=2, col=2
+ row=2,
+ col=2,
)
-
+
fig.update_layout(
title={
- 'text': f"ONNX Model Metadata
{total_params/1e6:.2f}M Parameters",
- 'x': 0.5,
- 'xanchor': 'center',
- 'font': {'size': 22}
+ "text": ("ONNX Model Metadata
"
+ f"{total_params/1e6:.2f}M Parameters"),
+ "x": 0.5,
+ "xanchor": "center",
+ "font": {"size": 22},
},
height=600,
template="plotly_dark",
- showlegend=False
+ showlegend=False,
)
-
+
return fig
-
+
except Exception as e:
return create_styled_error_figure(
"Metadata Analysis Error",
- f"Could not extract ONNX model metadata.",
+ "Could not extract ONNX model metadata.",
f"Error: {str(e)}"
)
@@ -435,124 +462,130 @@ def performance_metrics(file_path: Path) -> go.Figure:
return create_styled_error_figure(
"ONNX Not Available",
"ONNX library is required for performance analysis.",
- "Install with: pip install onnx"
+ "Install with: pip install onnx",
)
-
+
try:
model = onnx.load(str(file_path))
graph = model.graph
-
+
# Calculate metrics
model_size_bytes = file_path.stat().st_size
model_size_mb = model_size_bytes / (1024 * 1024)
-
+
# Count parameters
total_params = 0
for initializer in graph.initializer:
try:
tensor = onnx.numpy_helper.to_array(initializer)
total_params += tensor.size
- except:
+ except Exception:
pass
-
+
# Estimate memory usage (rough approximation)
param_memory_mb = (total_params * 4) / (1024 * 1024) # Assume float32
-
+
# Count operations by complexity
- compute_ops = ['Conv', 'MatMul', 'Gemm', 'LSTM', 'GRU']
- efficient_ops = ['Relu', 'Add', 'Mul', 'BatchNormalization', 'Dropout']
-
- compute_count = sum(1 for node in graph.node
- if any(op in node.op_type for op in compute_ops))
- efficient_count = sum(1 for node in graph.node
- if any(op in node.op_type for op in efficient_ops))
+ compute_ops = ["Conv", "MatMul", "Gemm", "LSTM", "GRU"]
+ efficient_ops = ["Relu", "Add", "Mul", "BatchNormalization", "Dropout"]
+
+ compute_count = sum(1 for node in graph.node if any(op in node.op_type for op in compute_ops))
+ efficient_count = sum(1 for node in graph.node if any(op in node.op_type for op in efficient_ops))
total_ops = len(graph.node)
other_count = total_ops - compute_count - efficient_count
-
+
# Create performance dashboard
fig = make_subplots(
- rows=2, cols=2,
+ rows=2,
+ cols=2,
subplot_titles=("Model Efficiency", "Memory Usage", "Operation Types", "Complexity Score"),
- specs=[[{"type": "bar"}, {"type": "bar"}],
- [{"type": "pie"}, {"type": "indicator"}]]
+ specs=[[{"type": "bar"}, {"type": "bar"}], [{"type": "pie"}, {"type": "indicator"}]],
)
-
+
# Model efficiency metrics
efficiency_metrics = ["Model Size (MB)", "Parameters (M)", "Total Ops"]
- efficiency_values = [model_size_mb, total_params/1e6, total_ops]
-
+ efficiency_values = [model_size_mb, total_params / 1e6, total_ops]
+
fig.add_trace(
go.Bar(
x=efficiency_metrics,
y=efficiency_values,
- marker_color=['blue', 'green', 'orange'],
+ marker_color=["blue", "green", "orange"],
showlegend=False
),
- row=1, col=1
+ row=1,
+ col=1,
)
-
+
# Memory usage
memory_types = ["Parameters", "Est. Inference"]
memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate
-
+
fig.add_trace(
go.Bar(
x=memory_types,
y=memory_values,
- marker_color=['purple', 'red'],
+ marker_color=["purple", "red"],
showlegend=False
),
- row=1, col=2
+ row=1,
+ col=2,
)
-
+
# Operation types pie chart
fig.add_trace(
go.Pie(
- labels=['Compute Ops', 'Efficient Ops', 'Other Ops'],
+ labels=["Compute Ops", "Efficient Ops", "Other Ops"],
values=[compute_count, efficient_count, other_count],
- marker_colors=['red', 'green', 'gray']
+ marker_colors=["red", "green", "gray"],
),
- row=2, col=1
+ row=2,
+ col=1,
)
-
+
# Complexity score (simple heuristic)
complexity_score = min(100, (model_size_mb * 10 + total_params / 1e6 * 20 + compute_count))
-
+
fig.add_trace(
go.Indicator(
mode="gauge+number",
value=complexity_score,
- title={'text': "Complexity Score"},
+ title={"text": "Complexity Score"},
gauge={
- 'axis': {'range': [0, 100]},
- 'bar': {'color': "darkred" if complexity_score > 70 else "orange" if complexity_score > 40 else "green"},
- 'steps': [
- {'range': [0, 40], 'color': "lightgreen"},
- {'range': [40, 70], 'color': "yellow"},
- {'range': [70, 100], 'color': "lightcoral"}
- ]
- }
+ "axis": {"range": [0, 100]},
+ "bar": {
+ "color": "darkred" if complexity_score > 70 else "orange" if complexity_score > 40 else "green"
+ },
+ "steps": [
+ {"range": [0, 40], "color": "lightgreen"},
+ {"range": [40, 70], "color": "yellow"},
+ {"range": [70, 100], "color": "lightcoral"},
+ ],
+ },
),
- row=2, col=2
+ row=2,
+ col=2,
)
-
+
fig.update_layout(
title={
- 'text': f"ONNX Performance Metrics
Complexity Score: {complexity_score:.0f}/100",
- 'x': 0.5,
- 'xanchor': 'center',
- 'font': {'size': 22}
+ "text": ("ONNX Performance Metrics
"
+ f""
+ f"Complexity Score: {complexity_score:.0f}/100"),
+ "x": 0.5,
+ "xanchor": "center",
+ "font": {"size": 22},
},
height=600,
template="plotly_dark",
- showlegend=False
+ showlegend=False,
)
-
+
return fig
-
+
except Exception as e:
return create_styled_error_figure(
"Performance Analysis Error",
- f"Could not analyze ONNX model performance.",
+ "Could not analyze ONNX model performance.",
f"Error: {str(e)}"
- )
\ No newline at end of file
+ )
diff --git a/src/ria_toolkit_oss/viz/pytorch_state_dict.py b/src/ria_toolkit_oss/viz/pytorch_state_dict.py
index 7db7528..b549662 100644
--- a/src/ria_toolkit_oss/viz/pytorch_state_dict.py
+++ b/src/ria_toolkit_oss/viz/pytorch_state_dict.py
@@ -1,7 +1,7 @@
-import torch
+import numpy as np
import plotly.graph_objects as go
from plotly.graph_objects import Figure
-import numpy as np
+
def model_summary_plot(state_dict: dict) -> Figure:
"""Generate a summary plot of the PyTorch model state dict."""
@@ -10,220 +10,212 @@ def model_summary_plot(state_dict: dict) -> Figure:
fig = go.Figure()
fig.add_annotation(
text="No parameters found in state dict",
- xref="paper", yref="paper",
- x=0.5, y=0.5, showarrow=False,
- font=dict(size=16)
+ xref="paper",
+ yref="paper",
+ x=0.5,
+ y=0.5,
+ showarrow=False,
+ font=dict(size=16),
)
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
- template="plotly_dark"
+ template="plotly_dark",
)
return fig
-
+
# Count parameters by layer type
layer_info = []
for key, tensor in state_dict.items():
- if 'weight' in key:
+ if "weight" in key:
try:
- layer_name = key.replace('.weight', '')
- param_count = tensor.numel() if hasattr(tensor, 'numel') else len(tensor.flatten()) if hasattr(tensor, 'flatten') else 0
- shape = list(tensor.shape) if hasattr(tensor, 'shape') else [len(tensor)] if hasattr(tensor, '__len__') else []
- layer_info.append({
- 'layer': layer_name,
- 'parameters': param_count,
- 'shape': shape
- })
+ layer_name = key.replace(".weight", "")
+ param_count = (
+ tensor.numel()
+ if hasattr(tensor, "numel")
+ else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0
+ )
+ shape = (
+ list(tensor.shape)
+ if hasattr(tensor, "shape")
+ else [len(tensor)] if hasattr(tensor, "__len__") else []
+ )
+ layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape})
except Exception as e:
print(f"Warning: Could not process layer {key}: {e}")
continue
-
+
if not layer_info:
# Handle case where no weight layers found
fig = go.Figure()
fig.add_annotation(
text="No weight layers found in state dict",
- xref="paper", yref="paper",
- x=0.5, y=0.5, showarrow=False,
- font=dict(size=16)
+ xref="paper",
+ yref="paper",
+ x=0.5,
+ y=0.5,
+ showarrow=False,
+ font=dict(size=16),
)
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
- template="plotly_dark"
+ template="plotly_dark",
)
return fig
-
+
# Create bar chart of parameter counts
- fig = go.Figure(data=[
- go.Bar(
- x=[info['layer'] for info in layer_info],
- y=[info['parameters'] for info in layer_info],
- text=[f"Shape: {info['shape']}" for info in layer_info],
- textposition='auto',
- )
- ])
-
+ fig = go.Figure(
+ data=[
+ go.Bar(
+ x=[info["layer"] for info in layer_info],
+ y=[info["parameters"] for info in layer_info],
+ text=[f"Shape: {info['shape']}" for info in layer_info],
+ textposition="auto",
+ )
+ ]
+ )
+
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
- template="plotly_dark"
+ template="plotly_dark",
)
-
+
return fig
+
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
"""Visualize weights for a specific layer."""
if not state_dict:
fig = go.Figure()
fig.add_annotation(
- text="No data in state dict",
- xref="paper", yref="paper",
- x=0.5, y=0.5, showarrow=False,
- font=dict(size=16)
- )
- fig.update_layout(
- title="Layer Weights",
- template="plotly_dark"
+ text="No data in state dict", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16)
)
+ fig.update_layout(title="Layer Weights", template="plotly_dark")
return fig
-
+
if layer_name is None:
# Get first weight tensor
- weight_keys = [k for k in state_dict.keys() if 'weight' in k]
+ weight_keys = [k for k in state_dict.keys() if "weight" in k]
if not weight_keys:
fig = go.Figure()
fig.add_annotation(
text="No weight tensors found in state dict",
- xref="paper", yref="paper",
- x=0.5, y=0.5, showarrow=False,
- font=dict(size=16)
- )
- fig.update_layout(
- title="Layer Weights",
- template="plotly_dark"
+ xref="paper",
+ yref="paper",
+ x=0.5,
+ y=0.5,
+ showarrow=False,
+ font=dict(size=16),
)
+ fig.update_layout(title="Layer Weights", template="plotly_dark")
return fig
layer_name = weight_keys[0]
-
+
try:
weights = state_dict[layer_name]
-
+
# Convert to numpy if it's a torch tensor
- if hasattr(weights, 'numpy'):
- weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy()
- elif hasattr(weights, 'cpu'):
+ if hasattr(weights, "numpy"):
+ weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy()
+ elif hasattr(weights, "cpu"):
weights_np = weights.cpu().detach().numpy()
else:
weights_np = np.array(weights)
-
+
# For 2D weights, create heatmap
if len(weights_np.shape) == 2:
- fig = go.Figure(data=go.Heatmap(
- z=weights_np,
- colorscale='RdBu',
- zmid=0
- ))
- fig.update_layout(
- title=f"Weights Heatmap: {layer_name}",
- template="plotly_dark"
- )
+ fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0))
+ fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark")
else:
# For other shapes, flatten and show histogram
flat_weights = weights_np.flatten()
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
- fig.update_layout(
- title=f"Weight Distribution: {layer_name}",
- template="plotly_dark"
- )
-
+ fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark")
+
return fig
-
+
except Exception as e:
fig = go.Figure()
fig.add_annotation(
text=f"Error processing layer {layer_name}: {str(e)}",
- xref="paper", yref="paper",
- x=0.5, y=0.5, showarrow=False,
- font=dict(size=14)
- )
- fig.update_layout(
- title="Layer Weights - Error",
- template="plotly_dark"
+ xref="paper",
+ yref="paper",
+ x=0.5,
+ y=0.5,
+ showarrow=False,
+ font=dict(size=14),
)
+ fig.update_layout(title="Layer Weights - Error", template="plotly_dark")
return fig
+
def weight_distribution_plot(state_dict: dict) -> Figure:
"""Show distribution of weights across all layers."""
if not state_dict:
fig = go.Figure()
fig.add_annotation(
- text="No data in state dict",
- xref="paper", yref="paper",
- x=0.5, y=0.5, showarrow=False,
- font=dict(size=16)
+ text="No data in state dict", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16)
)
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
- template="plotly_dark"
+ template="plotly_dark",
)
return fig
-
+
all_weights = []
layer_names = []
-
+
for key, tensor in state_dict.items():
- if 'weight' in key:
+ if "weight" in key:
try:
# Convert to numpy if it's a torch tensor
- if hasattr(tensor, 'numpy'):
- weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy()
- elif hasattr(tensor, 'cpu'):
+ if hasattr(tensor, "numpy"):
+ weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy()
+ elif hasattr(tensor, "cpu"):
weights_np = tensor.cpu().detach().numpy()
else:
weights_np = np.array(tensor)
-
+
flat_weights = weights_np.flatten()
all_weights.extend(flat_weights)
layer_names.extend([key] * len(flat_weights))
except Exception as e:
print(f"Warning: Could not process weights for layer {key}: {e}")
continue
-
+
if not all_weights:
fig = go.Figure()
fig.add_annotation(
text="No weight data found in state dict",
- xref="paper", yref="paper",
- x=0.5, y=0.5, showarrow=False,
- font=dict(size=16)
+ xref="paper",
+ yref="paper",
+ x=0.5,
+ y=0.5,
+ showarrow=False,
+ font=dict(size=16),
)
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
- template="plotly_dark"
+ template="plotly_dark",
)
return fig
-
- fig = go.Figure(data=[
- go.Histogram(
- x=all_weights,
- nbinsx=100,
- name="All Weights"
- )
- ])
-
+
+ fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")])
+
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
- template="plotly_dark"
+ template="plotly_dark",
)
-
- return fig
\ No newline at end of file
+
+ return fig
diff --git a/src/ria_toolkit_oss/viz/radio_dataset.py b/src/ria_toolkit_oss/viz/radio_dataset.py
index edc5004..cae4084 100644
--- a/src/ria_toolkit_oss/viz/radio_dataset.py
+++ b/src/ria_toolkit_oss/viz/radio_dataset.py
@@ -6,7 +6,6 @@ import random
from typing import Optional
import numpy as np
-import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.graph_objects import Figure
@@ -16,33 +15,32 @@ from plotly.subplots import make_subplots
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> Figure:
"""Create a professional error figure with Qoherent dark theme styling."""
fig = go.Figure()
-
+
# Create a clean, centered text display using Plotly's text formatting
main_text = f"⚠️ {title}
"
main_text += f"{message}"
-
+
if suggestion:
- main_text += f"
💡 Suggestion:
"
+ main_text += "
💡 Suggestion:
"
main_text += f"{suggestion}"
-
+
# Add the main text annotation
fig.add_annotation(
text=main_text,
- xref="paper", yref="paper",
- x=0.5, y=0.5,
- xanchor='center', yanchor='middle',
+ xref="paper",
+ yref="paper",
+ x=0.5,
+ y=0.5,
+ xanchor="center",
+ yanchor="middle",
showarrow=False,
align="center",
borderwidth=2,
bordercolor="#4a5568",
bgcolor="#2d3748",
- font=dict(
- family="Arial, sans-serif",
- size=14,
- color="#e2e8f0"
- )
+ font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
)
-
+
# Update layout with dark theme
fig.update_layout(
title="",
@@ -51,13 +49,13 @@ def create_styled_error_figure(title: str, message: str, suggestion: str = None)
margin=dict(l=40, r=40, t=40, b=40),
plot_bgcolor="#1a202c",
paper_bgcolor="#1a202c",
- font=dict(color="#e2e8f0")
+ font=dict(color="#e2e8f0"),
)
-
+
# Remove axes and grid
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
-
+
return fig
@@ -67,37 +65,37 @@ def _check_dataset_compatibility(dataset, plot_type: str) -> tuple[bool, str]:
"""
try:
metadata = dataset.metadata
-
+
if len(metadata) == 0:
return False, "Dataset is empty"
-
+
if plot_type == "class_distribution":
# Check if we have any categorical columns
- categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object']
+ categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
alternatives = ["class", "label", "modulation", "impairment", "use_case", "category", "labels"]
-
+
has_class_col = any(alt in metadata.columns for alt in alternatives)
has_categorical = len(categorical_cols) > 0
-
+
if not has_class_col and not has_categorical:
return False, "No categorical columns found for class distribution"
-
+
elif plot_type == "sample_spectrogram":
# Check if we can generate a valid spectrogram
if len(metadata) < 1:
return False, "No samples available for spectrogram"
-
+
# Check if we can access sample data (basic test)
try:
- sample_data = dataset[0] if hasattr(dataset, '__getitem__') else None
+ sample_data = dataset[0] if hasattr(dataset, "__getitem__") else None
if sample_data is None or len(sample_data) < 32:
return False, "Insufficient sample data for spectrogram (need at least 32 points)"
except Exception:
# If we can't access data, we'll rely on synthetic data generation
pass
-
+
return True, ""
-
+
except Exception as e:
return False, f"Dataset compatibility check failed: {str(e)}"
@@ -111,11 +109,11 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
return create_styled_error_figure(
"Dataset Not Compatible",
"This dataset doesn't have categorical labels needed for class distribution analysis.",
- "Try using the Dataset Overview widget to explore the available data columns."
+ "Try using the Dataset Overview widget to explore the available data columns.",
)
-
+
metadata = dataset.metadata
-
+
# Find the class column
if class_key not in metadata.columns:
# Try common alternatives
@@ -127,47 +125,44 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
else:
# Use first categorical column
for col in metadata.columns:
- if metadata[col].dtype == 'object' or metadata[col].nunique() < 50:
+ if metadata[col].dtype == "object" or metadata[col].nunique() < 50:
class_key = col
break
-
+
if class_key not in metadata.columns:
return create_styled_error_figure(
"No Class Labels Found",
"This dataset contains numerical data without categorical labels.",
- "Try using the Dataset Overview widget for data analysis, or check if your dataset has hidden categorical columns."
+ ("Try using the Dataset Overview widget for data analysis, "
+ "or check if your dataset has hidden categorical columns."),
)
-
+
# Count examples per class (limit to top 20 for performance)
class_counts = metadata[class_key].value_counts()
if len(class_counts) > 20:
class_counts = class_counts.head(20)
-
+
class_counts = class_counts.sort_index()
-
+
# Create simple bar plot
- fig = px.bar(
- x=class_counts.index,
- y=class_counts.values,
- title=f'Class Distribution: {class_key.title()}'
- )
-
- fig.update_traces(texttemplate='%{y}', textposition='outside')
+ fig = px.bar(x=class_counts.index, y=class_counts.values, title=f"Class Distribution: {class_key.title()}")
+
+ fig.update_traces(texttemplate="%{y}", textposition="outside")
fig.update_layout(
xaxis_title=class_key.title(),
- yaxis_title='Number of Examples',
+ yaxis_title="Number of Examples",
showlegend=False,
height=400,
- template="plotly_dark"
+ template="plotly_dark",
)
-
+
return fig
-
+
except Exception as e:
return create_styled_error_figure(
"Class Distribution Error",
- f"An error occurred while generating the class distribution plot.",
- f"Technical details: {str(e)}"
+ "An error occurred while generating the class distribution plot.",
+ f"Technical details: {str(e)}",
)
@@ -176,96 +171,84 @@ def dataset_overview_plot(dataset) -> Figure:
try:
metadata = dataset.metadata
total_examples = len(metadata)
-
+
# Create subplot with multiple charts
-
+
# Determine subplot titles based on data type
- categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object']
- numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ['int64', 'float64']]
-
+ categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
+ numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ["int64", "float64"]]
+
dist_title = "Value Distribution" if categorical_cols else "Data Distribution"
-
+
fig = make_subplots(
- rows=2, cols=2,
+ rows=2,
+ cols=2,
subplot_titles=("Dataset Size", "Data Types", dist_title, "Statistics Summary"),
- specs=[[{"type": "indicator"}, {"type": "bar"}],
- [{"type": "histogram" if not categorical_cols else "bar"}, {"type": "table"}]]
+ specs=[
+ [{"type": "indicator"}, {"type": "bar"}],
+ [{"type": "histogram" if not categorical_cols else "bar"}, {"type": "table"}],
+ ],
)
-
+
# Top left: Dataset size indicator
fig.add_trace(
go.Indicator(
- mode="number",
- value=total_examples,
- title={"text": "Total Examples"},
- number={"font": {"size": 40}}
+ mode="number", value=total_examples, title={"text": "Total Examples"}, number={"font": {"size": 40}}
),
- row=1, col=1
+ row=1,
+ col=1,
)
-
+
# Top right: Data types distribution
dtype_counts = metadata.dtypes.value_counts()
fig.add_trace(
go.Bar(
- x=[str(dt) for dt in dtype_counts.index],
- y=dtype_counts.values,
- name="Data Types",
- showlegend=False
+ x=[str(dt) for dt in dtype_counts.index], y=dtype_counts.values, name="Data Types", showlegend=False
),
- row=1, col=2
+ row=1,
+ col=2,
)
-
+
# Bottom left: Show distribution of numeric columns or categorical if available
- categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object']
- numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ['int64', 'float64']]
-
+ categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
+ numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ["int64", "float64"]]
+
if categorical_cols:
col = categorical_cols[0] # Show first categorical column
value_counts = metadata[col].value_counts().head(10)
fig.add_trace(
- go.Bar(
- x=value_counts.index,
- y=value_counts.values,
- name=f"{col} Distribution",
- showlegend=False
- ),
- row=2, col=1
+ go.Bar(x=value_counts.index, y=value_counts.values, name=f"{col} Distribution", showlegend=False),
+ row=2,
+ col=1,
)
elif numeric_cols:
# Show histogram of first numeric column
col = numeric_cols[0]
fig.add_trace(
- go.Histogram(
- x=metadata[col],
- name=f"{col} Distribution",
- showlegend=False,
- nbinsx=20
- ),
- row=2, col=1
+ go.Histogram(x=metadata[col], name=f"{col} Distribution", showlegend=False, nbinsx=20), row=2, col=1
)
-
+
# Bottom right: Basic statistics table
stats_data = []
- display_cols = (numeric_cols[:5] if len(numeric_cols) > 0 else metadata.columns[:5])
-
+ display_cols = numeric_cols[:5] if len(numeric_cols) > 0 else metadata.columns[:5]
+
for col in display_cols:
- if metadata[col].dtype in ['int64', 'float64']:
- stats_data.append([
- col[:15] + "..." if len(col) > 15 else col, # Truncate long column names
- f"{metadata[col].mean():.3f}",
- f"{metadata[col].std():.3f}",
- f"{metadata[col].min():.3f}",
- f"{metadata[col].max():.3f}"
- ])
+ if metadata[col].dtype in ["int64", "float64"]:
+ stats_data.append(
+ [
+ col[:15] + "..." if len(col) > 15 else col, # Truncate long column names
+ f"{metadata[col].mean():.3f}",
+ f"{metadata[col].std():.3f}",
+ f"{metadata[col].min():.3f}",
+ f"{metadata[col].max():.3f}",
+ ]
+ )
else:
unique_count = metadata[col].nunique()
- stats_data.append([
- col[:15] + "..." if len(col) > 15 else col,
- "N/A", "N/A",
- f"{unique_count} unique",
- "N/A"
- ])
-
+ stats_data.append(
+ [col[:15] + "..." if len(col) > 15 else col, "N/A", "N/A", f"{unique_count} unique", "N/A"]
+ )
+
if stats_data:
fig.add_trace(
go.Table(
@@ -273,41 +256,127 @@ def dataset_overview_plot(dataset) -> Figure:
values=["Column", "Mean", "Std", "Min/Unique", "Max"],
fill_color="rgba(30, 30, 30, 0.8)",
align="center",
- font=dict(color="white", size=12)
+ font=dict(color="white", size=12),
),
cells=dict(
values=list(zip(*stats_data)),
fill_color="rgba(50, 50, 50, 0.6)",
align="center",
- font=dict(color="white", size=11)
- )
+ font=dict(color="white", size=11),
+ ),
),
- row=2, col=2
+ row=2,
+ col=2,
)
-
+
# Create informative title
total_cols = len(metadata.columns)
title = f"Dataset Overview - {total_examples} samples, {total_cols} columns"
if total_cols > 5:
- title += f" (showing first 5)"
-
- fig.update_layout(
- title=title,
- height=600,
- showlegend=False,
- template="plotly_dark"
- )
-
+ title += " (showing first 5)"
+
+ fig.update_layout(title=title, height=600, showlegend=False, template="plotly_dark")
+
return fig
-
+
except Exception as e:
return create_styled_error_figure(
"Dataset Overview Error",
"An error occurred while generating the dataset overview.",
- f"Technical details: {str(e)}"
+ f"Technical details: {str(e)}",
)
+def _find_class_column(metadata, class_key: str) -> str:
+ """Find the appropriate class column in metadata."""
+ if class_key in metadata.columns:
+ return class_key
+
+ alternatives = ["class", "label", "modulation", "impairment", "use_case"]
+ for alt in alternatives:
+ if alt in metadata.columns:
+ return alt
+ return class_key
+
+
+def _get_sample_data(dataset, sample_idx: int):
+ """Get sample data from dataset, with synthetic fallback."""
+ try:
+ return dataset[sample_idx]
+ except Exception:
+ # Generate synthetic signal based on class
+ n_samples = 1024
+ t = np.linspace(0, 1, n_samples)
+ freq = 0.1 + 0.05 * sample_idx % 5 # Vary frequency by sample
+ sample_data = np.exp(1j * 2 * np.pi * freq * t)
+ # Add some noise
+ sample_data += 0.1 * (np.random.randn(n_samples) + 1j * np.random.randn(n_samples))
+ return sample_data
+
+
+def _calculate_spectrogram_params(n_samples: int) -> tuple[int, int, int, int]:
+ """Calculate spectrogram parameters based on sample length."""
+ if n_samples < 32:
+ raise ValueError(f"Insufficient data: need at least 32 samples, got {n_samples}")
+
+ nperseg = min(256, max(32, n_samples // 4))
+ hop_length = max(1, nperseg // 2)
+
+ # Adjust for very short signals
+ if n_samples < nperseg:
+ nperseg = n_samples
+ hop_length = 1
+
+ n_frames = max(1, (n_samples - nperseg) // hop_length + 1)
+ freq_bins = max(1, nperseg // 2)
+
+ return nperseg, hop_length, n_frames, freq_bins
+
+
+def _compute_spectrogram(sample_data, nperseg: int, hop_length: int, n_frames: int, freq_bins: int):
+ """Compute spectrogram using FFT."""
+ n_samples = len(sample_data)
+ Sxx = np.zeros((freq_bins, n_frames))
+
+ for i in range(n_frames):
+ start_idx = i * hop_length
+ end_idx = min(start_idx + nperseg, n_samples)
+
+ if end_idx > start_idx:
+ windowed = sample_data[start_idx:end_idx]
+
+ # Pad if necessary to maintain nperseg size
+ if len(windowed) < nperseg:
+ windowed = np.pad(windowed, (0, nperseg - len(windowed)), mode="constant")
+
+ fft_result = np.fft.fft(windowed)
+ Sxx[:, i] = np.abs(fft_result[:freq_bins]) ** 2
+
+ return Sxx
+
+
+def _create_spectrogram_figure(Sxx, n_frames: int, hop_length: int, n_samples: int, freq_bins: int,
+ sample_idx: int, class_key: str, sample_metadata) -> Figure:
+ """Create the plotly figure for the spectrogram."""
+ # Convert to dB
+ Sxx_db = 10 * np.log10(Sxx + 1e-10)
+
+ # Create time and frequency vectors
+ t = np.arange(n_frames) * hop_length / max(1, n_samples)
+ f = np.linspace(0, 0.5, freq_bins)
+
+ # Create plot
+ fig = go.Figure(data=go.Heatmap(z=Sxx_db, x=t, y=f, colorscale="viridis", colorbar=dict(title="Power (dB)")))
+
+ # Add title with metadata
+ title = f"Sample Spectrogram (Index: {sample_idx})"
+ if class_key in sample_metadata:
+ title += f" - {class_key}: {sample_metadata[class_key]}"
+
+ fig.update_layout(title=title, xaxis_title="Time", yaxis_title="Frequency", height=400, template="plotly_dark")
+ return fig
+
+
def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx: Optional[int] = None) -> Figure:
"""Generate a spectrogram plot from a sample in the dataset."""
try:
@@ -317,114 +386,36 @@ def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx:
return create_styled_error_figure(
"Spectrogram Not Available",
"This dataset doesn't have sufficient signal data for spectrogram visualization.",
- "Ensure your dataset contains complex-valued signal samples with at least 32 data points per sample."
+ "Ensure your dataset contains complex-valued signal samples with at least 32 data points per sample.",
)
-
+
metadata = dataset.metadata
-
if len(metadata) == 0:
raise ValueError("Dataset is empty")
-
- # Find class column
- if class_key not in metadata.columns:
- alternatives = ["class", "label", "modulation", "impairment", "use_case"]
- for alt in alternatives:
- if alt in metadata.columns:
- class_key = alt
- break
-
- # Select sample
+
+ # Find class column and select sample
+ class_key = _find_class_column(metadata, class_key)
if sample_idx is None:
sample_idx = random.randint(0, len(metadata) - 1)
-
sample_metadata = metadata.iloc[sample_idx]
-
- # Try to get actual sample data, fall back to synthetic
- try:
- sample_data = dataset[sample_idx]
- except:
- # Generate synthetic signal based on class
- n_samples = 1024
- t = np.linspace(0, 1, n_samples)
- freq = 0.1 + 0.05 * sample_idx % 5 # Vary frequency by sample
- sample_data = np.exp(1j * 2 * np.pi * freq * t)
- # Add some noise
- sample_data += 0.1 * (np.random.randn(n_samples) + 1j * np.random.randn(n_samples))
-
- # Ensure complex data
+
+ # Get sample data and ensure it's complex
+ sample_data = _get_sample_data(dataset, sample_idx)
if not np.iscomplexobj(sample_data):
sample_data = sample_data.astype(complex)
-
- # Simple FFT-based spectrogram
+
+ # Calculate spectrogram parameters and compute spectrogram
n_samples = len(sample_data)
-
- # Ensure minimum viable data size
- if n_samples < 32:
- raise ValueError(f"Insufficient data: need at least 32 samples, got {n_samples}")
-
- nperseg = min(256, max(32, n_samples // 4))
-
- # Create spectrogram using numpy (no scipy dependency)
- hop_length = max(1, nperseg // 2) # Prevent zero hop_length
-
- # Ensure we can create at least one frame
- if n_samples < nperseg:
- nperseg = n_samples
- hop_length = 1
-
- n_frames = max(1, (n_samples - nperseg) // hop_length + 1)
-
- freq_bins = max(1, nperseg // 2) # Prevent zero frequency bins
- Sxx = np.zeros((freq_bins, n_frames))
-
- for i in range(n_frames):
- start_idx = i * hop_length
- end_idx = min(start_idx + nperseg, n_samples) # Prevent index overflow
-
- if end_idx > start_idx: # Ensure we have data to process
- windowed = sample_data[start_idx:end_idx]
-
- # Pad if necessary to maintain nperseg size
- if len(windowed) < nperseg:
- windowed = np.pad(windowed, (0, nperseg - len(windowed)), mode='constant')
-
- fft_result = np.fft.fft(windowed)
- Sxx[:, i] = np.abs(fft_result[:freq_bins]) ** 2
-
- # Convert to dB
- Sxx_db = 10 * np.log10(Sxx + 1e-10)
-
- # Create time and frequency vectors
- t = np.arange(n_frames) * hop_length / max(1, n_samples) # Prevent division by zero
- f = np.linspace(0, 0.5, freq_bins)
-
- # Create plot
- fig = go.Figure(data=go.Heatmap(
- z=Sxx_db,
- x=t,
- y=f,
- colorscale='viridis',
- colorbar=dict(title="Power (dB)")
- ))
-
- # Add title with metadata
- title = f"Sample Spectrogram (Index: {sample_idx})"
- if class_key in sample_metadata:
- title += f" - {class_key}: {sample_metadata[class_key]}"
-
- fig.update_layout(
- title=title,
- xaxis_title="Time",
- yaxis_title="Frequency",
- height=400,
- template="plotly_dark"
- )
-
- return fig
-
+ nperseg, hop_length, n_frames, freq_bins = _calculate_spectrogram_params(n_samples)
+ Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins)
+
+ # Create and return the figure
+ return _create_spectrogram_figure(Sxx, n_frames, hop_length, n_samples, freq_bins,
+ sample_idx, class_key, sample_metadata)
+
except Exception as e:
return create_styled_error_figure(
"Spectrogram Error",
"An error occurred while generating the spectrogram plot.",
- f"Technical details: {str(e)}"
- )
\ No newline at end of file
+ f"Technical details: {str(e)}",
+ )