format fix 3?
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 13s
Test with tox / Test with tox (3.11) (pull_request) Successful in 32s
Test with tox / Test with tox (3.12) (pull_request) Successful in 31s
Test with tox / Test with tox (3.10) (pull_request) Successful in 42s
Build Project / Build Project (3.10) (pull_request) Successful in 50s
Build Project / Build Project (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.12) (pull_request) Successful in 48s

This commit is contained in:
ben 2025-10-22 12:02:05 -04:00
parent 4872eea116
commit a0b46a35e2
3 changed files with 87 additions and 127 deletions

View File

@ -75,9 +75,7 @@ def graph_structure(file_path: Path) -> go.Figure:
"""
if not ONNX_AVAILABLE:
return create_styled_error_figure(
"ONNX Not Available",
"ONNX library is required for model analysis.",
"Install with: pip install onnx"
"ONNX Not Available", "ONNX library is required for model analysis.", "Install with: pip install onnx"
)
try:
@ -88,9 +86,7 @@ def graph_structure(file_path: Path) -> go.Figure:
if len(nodes) == 0:
return create_styled_error_figure(
"Empty Model",
"This ONNX model contains no operators.",
"Please check if the model file is valid."
"Empty Model", "This ONNX model contains no operators.", "Please check if the model file is valid."
)
# Create network diagram data
@ -153,8 +149,10 @@ def graph_structure(file_path: Path) -> go.Figure:
fig.update_layout(
title={
"text": ("ONNX Graph Structure<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{len(nodes)} Operators</span>"),
"text": (
"ONNX Graph Structure<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{len(nodes)} Operators</span>"
),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
@ -173,9 +171,7 @@ def graph_structure(file_path: Path) -> go.Figure:
except Exception as e:
return create_styled_error_figure(
"Graph Analysis Error",
"Could not analyze ONNX model structure.",
f"Error: {str(e)}"
"Graph Analysis Error", "Could not analyze ONNX model structure.", f"Error: {str(e)}"
)
@ -186,9 +182,7 @@ def operator_analysis(file_path: Path) -> go.Figure:
"""
if not ONNX_AVAILABLE:
return create_styled_error_figure(
"ONNX Not Available",
"ONNX library is required for operator analysis.",
"Install with: pip install onnx"
"ONNX Not Available", "ONNX library is required for operator analysis.", "Install with: pip install onnx"
)
try:
@ -248,8 +242,10 @@ def operator_analysis(file_path: Path) -> go.Figure:
fig.update_layout(
title={
"text": ("ONNX Operator Analysis<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{len(op_counts)} Unique Types</span>"),
"text": (
"ONNX Operator Analysis<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{len(op_counts)} Unique Types</span>"
),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
@ -262,9 +258,7 @@ def operator_analysis(file_path: Path) -> go.Figure:
except Exception as e:
return create_styled_error_figure(
"Operator Analysis Error",
"Could not analyze ONNX operators.",
f"Error: {str(e)}"
"Operator Analysis Error", "Could not analyze ONNX operators.", f"Error: {str(e)}"
)
@ -275,9 +269,7 @@ def model_metadata(file_path: Path) -> go.Figure:
"""
if not ONNX_AVAILABLE:
return create_styled_error_figure(
"ONNX Not Available",
"ONNX library is required for metadata analysis.",
"Install with: pip install onnx"
"ONNX Not Available", "ONNX library is required for metadata analysis.", "Install with: pip install onnx"
)
try:
@ -336,12 +328,7 @@ def model_metadata(file_path: Path) -> go.Figure:
arch_values = [total_nodes, total_inputs, total_outputs, total_initializers]
fig.add_trace(
go.Bar(
x=arch_data,
y=arch_values,
marker_color=["blue", "green", "orange", "red"],
showlegend=False
),
go.Bar(x=arch_data, y=arch_values, marker_color=["blue", "green", "orange", "red"], showlegend=False),
row=1,
col=2,
)
@ -402,16 +389,8 @@ def model_metadata(file_path: Path) -> go.Figure:
if io_data:
fig.add_trace(
go.Table(
header=dict(
values=["Type", "Name", "Shape", "Data Type"],
fill_color="lightblue",
align="left"
),
cells=dict(
values=list(zip(*io_data)),
fill_color="white",
align="left"
),
header=dict(values=["Type", "Name", "Shape", "Data Type"], fill_color="lightblue", align="left"),
cells=dict(values=list(zip(*io_data)), fill_color="white", align="left"),
),
row=2,
col=1,
@ -432,8 +411,10 @@ def model_metadata(file_path: Path) -> go.Figure:
fig.update_layout(
title={
"text": ("ONNX Model Metadata<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{total_params/1e6:.2f}M Parameters</span>"),
"text": (
"ONNX Model Metadata<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{total_params/1e6:.2f}M Parameters</span>"
),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
@ -447,9 +428,7 @@ def model_metadata(file_path: Path) -> go.Figure:
except Exception as e:
return create_styled_error_figure(
"Metadata Analysis Error",
"Could not extract ONNX model metadata.",
f"Error: {str(e)}"
"Metadata Analysis Error", "Could not extract ONNX model metadata.", f"Error: {str(e)}"
)
@ -508,10 +487,7 @@ def performance_metrics(file_path: Path) -> go.Figure:
fig.add_trace(
go.Bar(
x=efficiency_metrics,
y=efficiency_values,
marker_color=["blue", "green", "orange"],
showlegend=False
x=efficiency_metrics, y=efficiency_values, marker_color=["blue", "green", "orange"], showlegend=False
),
row=1,
col=1,
@ -522,12 +498,7 @@ def performance_metrics(file_path: Path) -> go.Figure:
memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate
fig.add_trace(
go.Bar(
x=memory_types,
y=memory_values,
marker_color=["purple", "red"],
showlegend=False
),
go.Bar(x=memory_types, y=memory_values, marker_color=["purple", "red"], showlegend=False),
row=1,
col=2,
)
@ -569,9 +540,11 @@ def performance_metrics(file_path: Path) -> go.Figure:
fig.update_layout(
title={
"text": ("ONNX Performance Metrics<br>"
"text": (
"ONNX Performance Metrics<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>"
f"Complexity Score: {complexity_score:.0f}/100</span>"),
f"Complexity Score: {complexity_score:.0f}/100</span>"
),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
@ -585,7 +558,5 @@ def performance_metrics(file_path: Path) -> go.Figure:
except Exception as e:
return create_styled_error_figure(
"Performance Analysis Error",
"Could not analyze ONNX model performance.",
f"Error: {str(e)}"
"Performance Analysis Error", "Could not analyze ONNX model performance.", f"Error: {str(e)}"
)

View File

@ -56,29 +56,25 @@ def model_summary_plot(state_dict: dict) -> Figure:
return create_styled_error_figure(
"Empty State Dict",
"No parameters found in state dict",
"Ensure the model state dictionary contains weight parameters"
"Ensure the model state dictionary contains weight parameters",
)
# Count parameters by layer type
layer_info = []
for key, tensor in state_dict.items():
if 'weight' in key:
if "weight" in key:
try:
layer_name = key.replace('.weight', '')
layer_name = key.replace(".weight", "")
param_count = (
tensor.numel() if hasattr(tensor, 'numel')
else len(tensor.flatten()) if hasattr(tensor, 'flatten')
else 0
tensor.numel()
if hasattr(tensor, "numel")
else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0
)
shape = (
list(tensor.shape) if hasattr(tensor, 'shape')
else [len(tensor)] if hasattr(tensor, '__len__')
else []
list(tensor.shape)
if hasattr(tensor, "shape")
else [len(tensor)] if hasattr(tensor, "__len__") else []
)
layer_info.append({
'layer': layer_name,
'parameters': param_count,
'shape': shape
})
layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape})
except Exception as e:
print(f"Warning: Could not process layer {key}: {e}")
continue
@ -86,22 +82,24 @@ def model_summary_plot(state_dict: dict) -> Figure:
return create_styled_error_figure(
"No Weight Layers Found",
"No weight layers found in state dict",
"Ensure the state dictionary contains layers with '.weight' parameters"
"Ensure the state dictionary contains layers with '.weight' parameters",
)
# Create bar chart of parameter counts
fig = go.Figure(data=[
fig = go.Figure(
data=[
go.Bar(
x=[info['layer'] for info in layer_info],
y=[info['parameters'] for info in layer_info],
x=[info["layer"] for info in layer_info],
y=[info["parameters"] for info in layer_info],
text=[f"Shape: {info['shape']}" for info in layer_info],
textposition='auto',
textposition="auto",
)
]
)
])
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
template="plotly_dark"
template="plotly_dark",
)
return fig
@ -110,48 +108,36 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
"""Visualize weights for a specific layer."""
if not state_dict:
return create_styled_error_figure(
"Empty State Dict",
"No data in state dict",
"Ensure the model state dictionary contains data"
"Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data"
)
if layer_name is None:
# Get first weight tensor
weight_keys = [k for k in state_dict.keys() if 'weight' in k]
weight_keys = [k for k in state_dict.keys() if "weight" in k]
if not weight_keys:
return create_styled_error_figure(
"No Weight Tensors Found",
"No weight tensors found in state dict",
"Ensure the state dictionary contains layers with '.weight' parameters"
"Ensure the state dictionary contains layers with '.weight' parameters",
)
layer_name = weight_keys[0]
try:
weights = state_dict[layer_name]
# Convert to numpy if it's a torch tensor
if hasattr(weights, 'numpy'):
weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy()
elif hasattr(weights, 'cpu'):
if hasattr(weights, "numpy"):
weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy()
elif hasattr(weights, "cpu"):
weights_np = weights.cpu().detach().numpy()
else:
weights_np = np.array(weights)
# For 2D weights, create heatmap
if len(weights_np.shape) == 2:
fig = go.Figure(data=go.Heatmap(
z=weights_np,
colorscale='RdBu',
zmid=0
))
fig.update_layout(
title=f"Weights Heatmap: {layer_name}",
template="plotly_dark"
)
fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0))
fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark")
else:
# For other shapes, flatten and show histogram
flat_weights = weights_np.flatten()
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
fig.update_layout(
title=f"Weight Distribution: {layer_name}",
template="plotly_dark"
)
fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark")
return fig
@ -159,7 +145,7 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
return create_styled_error_figure(
"Layer Processing Error",
f"Error processing layer {layer_name}: {str(e)}",
"Check that the layer name exists and contains valid tensor data"
"Check that the layer name exists and contains valid tensor data",
)
@ -167,21 +153,19 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
"""Show distribution of weights across all layers."""
if not state_dict:
return create_styled_error_figure(
"Empty State Dict",
"No data in state dict",
"Ensure the model state dictionary contains data"
"Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data"
)
all_weights = []
layer_names = []
for key, tensor in state_dict.items():
if 'weight' in key:
if "weight" in key:
try:
# Convert to numpy if it's a torch tensor
if hasattr(tensor, 'numpy'):
weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy()
elif hasattr(tensor, 'cpu'):
if hasattr(tensor, "numpy"):
weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy()
elif hasattr(tensor, "cpu"):
weights_np = tensor.cpu().detach().numpy()
else:
weights_np = np.array(tensor)
@ -196,21 +180,15 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
return create_styled_error_figure(
"No Weight Data Found",
"No weight data found in state dict",
"Ensure the state dictionary contains layers with '.weight' parameters"
"Ensure the state dictionary contains layers with '.weight' parameters",
)
fig = go.Figure(data=[
go.Histogram(
x=all_weights,
nbinsx=100,
name="All Weights"
)
])
fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")])
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
template="plotly_dark"
template="plotly_dark",
)
return fig

View File

@ -133,8 +133,10 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
return create_styled_error_figure(
"No Class Labels Found",
"This dataset contains numerical data without categorical labels.",
("Try using the Dataset Overview widget for data analysis, "
"or check if your dataset has hidden categorical columns."),
(
"Try using the Dataset Overview widget for data analysis, "
"or check if your dataset has hidden categorical columns."
),
)
# Count examples per class (limit to top 20 for performance)
@ -355,8 +357,16 @@ def _compute_spectrogram(sample_data, nperseg: int, hop_length: int, n_frames: i
return Sxx
def _create_spectrogram_figure(Sxx, n_frames: int, hop_length: int, n_samples: int, freq_bins: int,
sample_idx: int, class_key: str, sample_metadata) -> Figure:
def _create_spectrogram_figure(
Sxx,
n_frames: int,
hop_length: int,
n_samples: int,
freq_bins: int,
sample_idx: int,
class_key: str,
sample_metadata,
) -> Figure:
"""Create the plotly figure for the spectrogram."""
# Convert to dB
Sxx_db = 10 * np.log10(Sxx + 1e-10)
@ -410,8 +420,9 @@ def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx:
Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins)
# Create and return the figure
return _create_spectrogram_figure(Sxx, n_frames, hop_length, n_samples, freq_bins,
sample_idx, class_key, sample_metadata)
return _create_spectrogram_figure(
Sxx, n_frames, hop_length, n_samples, freq_bins, sample_idx, class_key, sample_metadata
)
except Exception as e:
return create_styled_error_figure(