format fix 3?
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 13s
Test with tox / Test with tox (3.11) (pull_request) Successful in 32s
Test with tox / Test with tox (3.12) (pull_request) Successful in 31s
Test with tox / Test with tox (3.10) (pull_request) Successful in 42s
Build Project / Build Project (3.10) (pull_request) Successful in 50s
Build Project / Build Project (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.12) (pull_request) Successful in 48s

This commit is contained in:
ben 2025-10-22 12:02:05 -04:00
parent 4872eea116
commit a0b46a35e2
3 changed files with 87 additions and 127 deletions

View File

@ -75,9 +75,7 @@ def graph_structure(file_path: Path) -> go.Figure:
""" """
if not ONNX_AVAILABLE: if not ONNX_AVAILABLE:
return create_styled_error_figure( return create_styled_error_figure(
"ONNX Not Available", "ONNX Not Available", "ONNX library is required for model analysis.", "Install with: pip install onnx"
"ONNX library is required for model analysis.",
"Install with: pip install onnx"
) )
try: try:
@ -88,9 +86,7 @@ def graph_structure(file_path: Path) -> go.Figure:
if len(nodes) == 0: if len(nodes) == 0:
return create_styled_error_figure( return create_styled_error_figure(
"Empty Model", "Empty Model", "This ONNX model contains no operators.", "Please check if the model file is valid."
"This ONNX model contains no operators.",
"Please check if the model file is valid."
) )
# Create network diagram data # Create network diagram data
@ -153,8 +149,10 @@ def graph_structure(file_path: Path) -> go.Figure:
fig.update_layout( fig.update_layout(
title={ title={
"text": ("ONNX Graph Structure<br>" "text": (
f"<span style='font-size:14px; color:#a0a0a0;'>{len(nodes)} Operators</span>"), "ONNX Graph Structure<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{len(nodes)} Operators</span>"
),
"x": 0.5, "x": 0.5,
"xanchor": "center", "xanchor": "center",
"font": {"size": 22}, "font": {"size": 22},
@ -173,9 +171,7 @@ def graph_structure(file_path: Path) -> go.Figure:
except Exception as e: except Exception as e:
return create_styled_error_figure( return create_styled_error_figure(
"Graph Analysis Error", "Graph Analysis Error", "Could not analyze ONNX model structure.", f"Error: {str(e)}"
"Could not analyze ONNX model structure.",
f"Error: {str(e)}"
) )
@ -186,9 +182,7 @@ def operator_analysis(file_path: Path) -> go.Figure:
""" """
if not ONNX_AVAILABLE: if not ONNX_AVAILABLE:
return create_styled_error_figure( return create_styled_error_figure(
"ONNX Not Available", "ONNX Not Available", "ONNX library is required for operator analysis.", "Install with: pip install onnx"
"ONNX library is required for operator analysis.",
"Install with: pip install onnx"
) )
try: try:
@ -248,8 +242,10 @@ def operator_analysis(file_path: Path) -> go.Figure:
fig.update_layout( fig.update_layout(
title={ title={
"text": ("ONNX Operator Analysis<br>" "text": (
f"<span style='font-size:14px; color:#a0a0a0;'>{len(op_counts)} Unique Types</span>"), "ONNX Operator Analysis<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{len(op_counts)} Unique Types</span>"
),
"x": 0.5, "x": 0.5,
"xanchor": "center", "xanchor": "center",
"font": {"size": 22}, "font": {"size": 22},
@ -262,9 +258,7 @@ def operator_analysis(file_path: Path) -> go.Figure:
except Exception as e: except Exception as e:
return create_styled_error_figure( return create_styled_error_figure(
"Operator Analysis Error", "Operator Analysis Error", "Could not analyze ONNX operators.", f"Error: {str(e)}"
"Could not analyze ONNX operators.",
f"Error: {str(e)}"
) )
@ -275,9 +269,7 @@ def model_metadata(file_path: Path) -> go.Figure:
""" """
if not ONNX_AVAILABLE: if not ONNX_AVAILABLE:
return create_styled_error_figure( return create_styled_error_figure(
"ONNX Not Available", "ONNX Not Available", "ONNX library is required for metadata analysis.", "Install with: pip install onnx"
"ONNX library is required for metadata analysis.",
"Install with: pip install onnx"
) )
try: try:
@ -336,12 +328,7 @@ def model_metadata(file_path: Path) -> go.Figure:
arch_values = [total_nodes, total_inputs, total_outputs, total_initializers] arch_values = [total_nodes, total_inputs, total_outputs, total_initializers]
fig.add_trace( fig.add_trace(
go.Bar( go.Bar(x=arch_data, y=arch_values, marker_color=["blue", "green", "orange", "red"], showlegend=False),
x=arch_data,
y=arch_values,
marker_color=["blue", "green", "orange", "red"],
showlegend=False
),
row=1, row=1,
col=2, col=2,
) )
@ -402,16 +389,8 @@ def model_metadata(file_path: Path) -> go.Figure:
if io_data: if io_data:
fig.add_trace( fig.add_trace(
go.Table( go.Table(
header=dict( header=dict(values=["Type", "Name", "Shape", "Data Type"], fill_color="lightblue", align="left"),
values=["Type", "Name", "Shape", "Data Type"], cells=dict(values=list(zip(*io_data)), fill_color="white", align="left"),
fill_color="lightblue",
align="left"
),
cells=dict(
values=list(zip(*io_data)),
fill_color="white",
align="left"
),
), ),
row=2, row=2,
col=1, col=1,
@ -432,8 +411,10 @@ def model_metadata(file_path: Path) -> go.Figure:
fig.update_layout( fig.update_layout(
title={ title={
"text": ("ONNX Model Metadata<br>" "text": (
f"<span style='font-size:14px; color:#a0a0a0;'>{total_params/1e6:.2f}M Parameters</span>"), "ONNX Model Metadata<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{total_params/1e6:.2f}M Parameters</span>"
),
"x": 0.5, "x": 0.5,
"xanchor": "center", "xanchor": "center",
"font": {"size": 22}, "font": {"size": 22},
@ -447,9 +428,7 @@ def model_metadata(file_path: Path) -> go.Figure:
except Exception as e: except Exception as e:
return create_styled_error_figure( return create_styled_error_figure(
"Metadata Analysis Error", "Metadata Analysis Error", "Could not extract ONNX model metadata.", f"Error: {str(e)}"
"Could not extract ONNX model metadata.",
f"Error: {str(e)}"
) )
@ -508,10 +487,7 @@ def performance_metrics(file_path: Path) -> go.Figure:
fig.add_trace( fig.add_trace(
go.Bar( go.Bar(
x=efficiency_metrics, x=efficiency_metrics, y=efficiency_values, marker_color=["blue", "green", "orange"], showlegend=False
y=efficiency_values,
marker_color=["blue", "green", "orange"],
showlegend=False
), ),
row=1, row=1,
col=1, col=1,
@ -522,12 +498,7 @@ def performance_metrics(file_path: Path) -> go.Figure:
memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate
fig.add_trace( fig.add_trace(
go.Bar( go.Bar(x=memory_types, y=memory_values, marker_color=["purple", "red"], showlegend=False),
x=memory_types,
y=memory_values,
marker_color=["purple", "red"],
showlegend=False
),
row=1, row=1,
col=2, col=2,
) )
@ -569,9 +540,11 @@ def performance_metrics(file_path: Path) -> go.Figure:
fig.update_layout( fig.update_layout(
title={ title={
"text": ("ONNX Performance Metrics<br>" "text": (
"ONNX Performance Metrics<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>" f"<span style='font-size:14px; color:#a0a0a0;'>"
f"Complexity Score: {complexity_score:.0f}/100</span>"), f"Complexity Score: {complexity_score:.0f}/100</span>"
),
"x": 0.5, "x": 0.5,
"xanchor": "center", "xanchor": "center",
"font": {"size": 22}, "font": {"size": 22},
@ -585,7 +558,5 @@ def performance_metrics(file_path: Path) -> go.Figure:
except Exception as e: except Exception as e:
return create_styled_error_figure( return create_styled_error_figure(
"Performance Analysis Error", "Performance Analysis Error", "Could not analyze ONNX model performance.", f"Error: {str(e)}"
"Could not analyze ONNX model performance.",
f"Error: {str(e)}"
) )

View File

@ -56,29 +56,25 @@ def model_summary_plot(state_dict: dict) -> Figure:
return create_styled_error_figure( return create_styled_error_figure(
"Empty State Dict", "Empty State Dict",
"No parameters found in state dict", "No parameters found in state dict",
"Ensure the model state dictionary contains weight parameters" "Ensure the model state dictionary contains weight parameters",
) )
# Count parameters by layer type # Count parameters by layer type
layer_info = [] layer_info = []
for key, tensor in state_dict.items(): for key, tensor in state_dict.items():
if 'weight' in key: if "weight" in key:
try: try:
layer_name = key.replace('.weight', '') layer_name = key.replace(".weight", "")
param_count = ( param_count = (
tensor.numel() if hasattr(tensor, 'numel') tensor.numel()
else len(tensor.flatten()) if hasattr(tensor, 'flatten') if hasattr(tensor, "numel")
else 0 else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0
) )
shape = ( shape = (
list(tensor.shape) if hasattr(tensor, 'shape') list(tensor.shape)
else [len(tensor)] if hasattr(tensor, '__len__') if hasattr(tensor, "shape")
else [] else [len(tensor)] if hasattr(tensor, "__len__") else []
) )
layer_info.append({ layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape})
'layer': layer_name,
'parameters': param_count,
'shape': shape
})
except Exception as e: except Exception as e:
print(f"Warning: Could not process layer {key}: {e}") print(f"Warning: Could not process layer {key}: {e}")
continue continue
@ -86,22 +82,24 @@ def model_summary_plot(state_dict: dict) -> Figure:
return create_styled_error_figure( return create_styled_error_figure(
"No Weight Layers Found", "No Weight Layers Found",
"No weight layers found in state dict", "No weight layers found in state dict",
"Ensure the state dictionary contains layers with '.weight' parameters" "Ensure the state dictionary contains layers with '.weight' parameters",
) )
# Create bar chart of parameter counts # Create bar chart of parameter counts
fig = go.Figure(data=[ fig = go.Figure(
data=[
go.Bar( go.Bar(
x=[info['layer'] for info in layer_info], x=[info["layer"] for info in layer_info],
y=[info['parameters'] for info in layer_info], y=[info["parameters"] for info in layer_info],
text=[f"Shape: {info['shape']}" for info in layer_info], text=[f"Shape: {info['shape']}" for info in layer_info],
textposition='auto', textposition="auto",
)
]
) )
])
fig.update_layout( fig.update_layout(
title="Model Layer Parameter Counts", title="Model Layer Parameter Counts",
xaxis_title="Layer", xaxis_title="Layer",
yaxis_title="Number of Parameters", yaxis_title="Number of Parameters",
template="plotly_dark" template="plotly_dark",
) )
return fig return fig
@ -110,48 +108,36 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
"""Visualize weights for a specific layer.""" """Visualize weights for a specific layer."""
if not state_dict: if not state_dict:
return create_styled_error_figure( return create_styled_error_figure(
"Empty State Dict", "Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data"
"No data in state dict",
"Ensure the model state dictionary contains data"
) )
if layer_name is None: if layer_name is None:
# Get first weight tensor # Get first weight tensor
weight_keys = [k for k in state_dict.keys() if 'weight' in k] weight_keys = [k for k in state_dict.keys() if "weight" in k]
if not weight_keys: if not weight_keys:
return create_styled_error_figure( return create_styled_error_figure(
"No Weight Tensors Found", "No Weight Tensors Found",
"No weight tensors found in state dict", "No weight tensors found in state dict",
"Ensure the state dictionary contains layers with '.weight' parameters" "Ensure the state dictionary contains layers with '.weight' parameters",
) )
layer_name = weight_keys[0] layer_name = weight_keys[0]
try: try:
weights = state_dict[layer_name] weights = state_dict[layer_name]
# Convert to numpy if it's a torch tensor # Convert to numpy if it's a torch tensor
if hasattr(weights, 'numpy'): if hasattr(weights, "numpy"):
weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy() weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy()
elif hasattr(weights, 'cpu'): elif hasattr(weights, "cpu"):
weights_np = weights.cpu().detach().numpy() weights_np = weights.cpu().detach().numpy()
else: else:
weights_np = np.array(weights) weights_np = np.array(weights)
# For 2D weights, create heatmap # For 2D weights, create heatmap
if len(weights_np.shape) == 2: if len(weights_np.shape) == 2:
fig = go.Figure(data=go.Heatmap( fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0))
z=weights_np, fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark")
colorscale='RdBu',
zmid=0
))
fig.update_layout(
title=f"Weights Heatmap: {layer_name}",
template="plotly_dark"
)
else: else:
# For other shapes, flatten and show histogram # For other shapes, flatten and show histogram
flat_weights = weights_np.flatten() flat_weights = weights_np.flatten()
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)]) fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
fig.update_layout( fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark")
title=f"Weight Distribution: {layer_name}",
template="plotly_dark"
)
return fig return fig
@ -159,7 +145,7 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
return create_styled_error_figure( return create_styled_error_figure(
"Layer Processing Error", "Layer Processing Error",
f"Error processing layer {layer_name}: {str(e)}", f"Error processing layer {layer_name}: {str(e)}",
"Check that the layer name exists and contains valid tensor data" "Check that the layer name exists and contains valid tensor data",
) )
@ -167,21 +153,19 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
"""Show distribution of weights across all layers.""" """Show distribution of weights across all layers."""
if not state_dict: if not state_dict:
return create_styled_error_figure( return create_styled_error_figure(
"Empty State Dict", "Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data"
"No data in state dict",
"Ensure the model state dictionary contains data"
) )
all_weights = [] all_weights = []
layer_names = [] layer_names = []
for key, tensor in state_dict.items(): for key, tensor in state_dict.items():
if 'weight' in key: if "weight" in key:
try: try:
# Convert to numpy if it's a torch tensor # Convert to numpy if it's a torch tensor
if hasattr(tensor, 'numpy'): if hasattr(tensor, "numpy"):
weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy() weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy()
elif hasattr(tensor, 'cpu'): elif hasattr(tensor, "cpu"):
weights_np = tensor.cpu().detach().numpy() weights_np = tensor.cpu().detach().numpy()
else: else:
weights_np = np.array(tensor) weights_np = np.array(tensor)
@ -196,21 +180,15 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
return create_styled_error_figure( return create_styled_error_figure(
"No Weight Data Found", "No Weight Data Found",
"No weight data found in state dict", "No weight data found in state dict",
"Ensure the state dictionary contains layers with '.weight' parameters" "Ensure the state dictionary contains layers with '.weight' parameters",
) )
fig = go.Figure(data=[ fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")])
go.Histogram(
x=all_weights,
nbinsx=100,
name="All Weights"
)
])
fig.update_layout( fig.update_layout(
title="Overall Weight Distribution", title="Overall Weight Distribution",
xaxis_title="Weight Value", xaxis_title="Weight Value",
yaxis_title="Frequency", yaxis_title="Frequency",
template="plotly_dark" template="plotly_dark",
) )
return fig return fig

View File

@ -133,8 +133,10 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
return create_styled_error_figure( return create_styled_error_figure(
"No Class Labels Found", "No Class Labels Found",
"This dataset contains numerical data without categorical labels.", "This dataset contains numerical data without categorical labels.",
("Try using the Dataset Overview widget for data analysis, " (
"or check if your dataset has hidden categorical columns."), "Try using the Dataset Overview widget for data analysis, "
"or check if your dataset has hidden categorical columns."
),
) )
# Count examples per class (limit to top 20 for performance) # Count examples per class (limit to top 20 for performance)
@ -355,8 +357,16 @@ def _compute_spectrogram(sample_data, nperseg: int, hop_length: int, n_frames: i
return Sxx return Sxx
def _create_spectrogram_figure(Sxx, n_frames: int, hop_length: int, n_samples: int, freq_bins: int, def _create_spectrogram_figure(
sample_idx: int, class_key: str, sample_metadata) -> Figure: Sxx,
n_frames: int,
hop_length: int,
n_samples: int,
freq_bins: int,
sample_idx: int,
class_key: str,
sample_metadata,
) -> Figure:
"""Create the plotly figure for the spectrogram.""" """Create the plotly figure for the spectrogram."""
# Convert to dB # Convert to dB
Sxx_db = 10 * np.log10(Sxx + 1e-10) Sxx_db = 10 * np.log10(Sxx + 1e-10)
@ -410,8 +420,9 @@ def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx:
Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins) Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins)
# Create and return the figure # Create and return the figure
return _create_spectrogram_figure(Sxx, n_frames, hop_length, n_samples, freq_bins, return _create_spectrogram_figure(
sample_idx, class_key, sample_metadata) Sxx, n_frames, hop_length, n_samples, freq_bins, sample_idx, class_key, sample_metadata
)
except Exception as e: except Exception as e:
return create_styled_error_figure( return create_styled_error_figure(