""" ONNX model visualization utilities. This module provides visualization functions for ONNX models following the same pattern as other ria-toolkit-oss visualization modules. """ from pathlib import Path from typing import Optional import plotly.graph_objects as go import plotly.express as px from plotly.subplots import make_subplots import pandas as pd import numpy as np try: import onnx import onnx.helper import onnx.numpy_helper ONNX_AVAILABLE = True except ImportError: ONNX_AVAILABLE = False def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure: """Create a professional error figure with Qoherent dark theme styling.""" fig = go.Figure() # Create a clean, centered text display using Plotly's text formatting main_text = f"⚠️ {title}

" main_text += f"{message}" if suggestion: main_text += f"

💡 Suggestion:
" main_text += f"{suggestion}" # Add the main text annotation fig.add_annotation( text=main_text, xref="paper", yref="paper", x=0.5, y=0.5, xanchor='center', yanchor='middle', showarrow=False, align="center", borderwidth=2, bordercolor="#4a5568", bgcolor="#2d3748", font=dict( family="Arial, sans-serif", size=14, color="#e2e8f0" ) ) # Update layout with dark theme fig.update_layout( title="", height=400, template="plotly_dark", margin=dict(l=40, r=40, t=40, b=40), plot_bgcolor="#1a202c", paper_bgcolor="#1a202c", font=dict(color="#e2e8f0") ) # Remove axes and grid fig.update_xaxes(visible=False) fig.update_yaxes(visible=False) return fig def graph_structure(file_path: Path) -> go.Figure: """ Visualize the ONNX model graph structure showing nodes and connections. Matches layout ID: graph_structure """ if not ONNX_AVAILABLE: return create_styled_error_figure( "ONNX Not Available", "ONNX library is required for model analysis.", "Install with: pip install onnx" ) try: # Load ONNX model model = onnx.load(str(file_path)) graph = model.graph nodes = graph.node if len(nodes) == 0: return create_styled_error_figure( "Empty Model", "This ONNX model contains no operators.", "Please check if the model file is valid." ) # Create network diagram data node_info = [] for i, node in enumerate(nodes): node_info.append({ 'id': i, 'name': node.name or f"{node.op_type}_{i}", 'op_type': node.op_type, 'inputs': len(node.input), 'outputs': len(node.output) }) # Create visualization fig = go.Figure() # Simple linear layout for now x_positions = list(range(len(node_info))) y_positions = [0] * len(node_info) # Add nodes as scatter points fig.add_trace(go.Scatter( x=x_positions, y=y_positions, mode='markers+text', marker=dict( size=[min(max(info['inputs'] + info['outputs'] + 15, 20), 50) for info in node_info], color=px.colors.qualitative.Set3[:len(node_info)], opacity=0.8, line=dict(width=2, color='white') ), text=[f"{info['op_type']}" for info in node_info], textposition="middle center", textfont=dict(size=10, color="white"), hovertemplate="%{text}
" + "Name: %{customdata[0]}
" + "Inputs: %{customdata[1]}
" + "Outputs: %{customdata[2]}
" + "", customdata=[[info['name'], info['inputs'], info['outputs']] for info in node_info], name="Operators" )) # Add connecting lines for i in range(len(node_info) - 1): fig.add_trace(go.Scatter( x=[x_positions[i], x_positions[i+1]], y=[y_positions[i], y_positions[i+1]], mode='lines', line=dict(color='gray', width=1, dash='dot'), showlegend=False, hoverinfo='skip' )) fig.update_layout( title={ 'text': f"ONNX Graph Structure
{len(nodes)} Operators", 'x': 0.5, 'xanchor': 'center', 'font': {'size': 22} }, xaxis_title="Execution Order", yaxis_title="", showlegend=False, height=500, template="plotly_dark", yaxis=dict(showticklabels=False, showgrid=False), xaxis=dict(showgrid=False), margin=dict(l=50, r=50, t=80, b=50) ) return fig except Exception as e: return create_styled_error_figure( "Graph Analysis Error", f"Could not analyze ONNX model structure.", f"Error: {str(e)}" ) def operator_analysis(file_path: Path) -> go.Figure: """ Analyze the distribution and types of operators in the ONNX model. Matches layout ID: operator_analysis """ if not ONNX_AVAILABLE: return create_styled_error_figure( "ONNX Not Available", "ONNX library is required for operator analysis.", "Install with: pip install onnx" ) try: model = onnx.load(str(file_path)) graph = model.graph # Count operators op_counts = {} for node in graph.node: op_type = node.op_type op_counts[op_type] = op_counts.get(op_type, 0) + 1 if not op_counts: return create_styled_error_figure( "No Operators", "This ONNX model contains no operators to analyze.", "Please verify the model file is valid." ) # Sort by frequency sorted_ops = sorted(op_counts.items(), key=lambda x: x[1], reverse=True) # Create pie chart and bar chart fig = make_subplots( rows=2, cols=1, subplot_titles=("Operator Distribution", "Operator Frequency"), specs=[[{"type": "pie"}], [{"type": "bar"}]] ) # Pie chart for operator distribution op_names, op_values = zip(*sorted_ops) if sorted_ops else ([], []) fig.add_trace( go.Pie( labels=list(op_names), values=list(op_values), textinfo="label+percent", textposition="auto", showlegend=False ), row=1, col=1 ) # Bar chart for frequency fig.add_trace( go.Bar( x=list(op_names), y=list(op_values), marker_color=px.colors.qualitative.Set3[:len(op_names)], showlegend=False ), row=2, col=1 ) fig.update_layout( title={ 'text': f"ONNX Operator Analysis
{len(op_counts)} Unique Types", 'x': 0.5, 'xanchor': 'center', 'font': {'size': 22} }, height=700, template="plotly_dark" ) return fig except Exception as e: return create_styled_error_figure( "Operator Analysis Error", f"Could not analyze ONNX operators.", f"Error: {str(e)}" ) def model_metadata(file_path: Path) -> go.Figure: """ Display comprehensive metadata about the ONNX model. Matches layout ID: model_metadata """ if not ONNX_AVAILABLE: return create_styled_error_figure( "ONNX Not Available", "ONNX library is required for metadata analysis.", "Install with: pip install onnx" ) try: model = onnx.load(str(file_path)) graph = model.graph # Calculate basic statistics total_nodes = len(graph.node) total_inputs = len(graph.input) total_outputs = len(graph.output) total_initializers = len(graph.initializer) # Calculate parameter count total_params = 0 for initializer in graph.initializer: try: tensor = onnx.numpy_helper.to_array(initializer) total_params += tensor.size except: pass # Skip if tensor can't be loaded # Get model file size file_size_mb = file_path.stat().st_size / (1024 * 1024) # Create metadata display fig = make_subplots( rows=2, cols=2, subplot_titles=("Model Size", "Architecture", "Inputs/Outputs", "Parameters"), specs=[[{"type": "indicator"}, {"type": "bar"}], [{"type": "table"}, {"type": "indicator"}]] ) # Model size indicator fig.add_trace( go.Indicator( mode="number+gauge", value=file_size_mb, title={'text': "Model Size (MB)"}, number={'suffix': ' MB', 'valueformat': '.2f'}, gauge={ 'axis': {'range': [0, max(100, file_size_mb * 1.5)]}, 'bar': {'color': "darkblue"}, 'steps': [ {'range': [0, 10], 'color': "lightgreen"}, {'range': [10, 50], 'color': "yellow"}, {'range': [50, 100], 'color': "orange"} ] } ), row=1, col=1 ) # Architecture components arch_data = ["Nodes", "Inputs", "Outputs", "Initializers"] arch_values = [total_nodes, total_inputs, total_outputs, total_initializers] fig.add_trace( go.Bar( x=arch_data, y=arch_values, marker_color=['blue', 'green', 'orange', 'red'], showlegend=False ), row=1, col=2 ) # I/O Table io_data = [] # Add input info for inp in graph.input[:5]: # Limit to first 5 shape = "Unknown" dtype = "Unknown" if inp.type and inp.type.tensor_type: # Get shape if inp.type.tensor_type.shape: dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in inp.type.tensor_type.shape.dim] shape = f"[{', '.join(dims)}]" # Get data type elem_type = inp.type.tensor_type.elem_type type_map = {1: 'float32', 2: 'uint8', 3: 'int8', 6: 'int32', 7: 'int64', 9: 'bool', 10: 'float16', 11: 'double'} dtype = type_map.get(elem_type, f'type_{elem_type}') io_data.append(['Input', inp.name[:20], shape, dtype]) # Add output info for out in graph.output[:5]: # Limit to first 5 shape = "Unknown" dtype = "Unknown" if out.type and out.type.tensor_type: if out.type.tensor_type.shape: dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in out.type.tensor_type.shape.dim] shape = f"[{', '.join(dims)}]" elem_type = out.type.tensor_type.elem_type type_map = {1: 'float32', 2: 'uint8', 3: 'int8', 6: 'int32', 7: 'int64', 9: 'bool', 10: 'float16', 11: 'double'} dtype = type_map.get(elem_type, f'type_{elem_type}') io_data.append(['Output', out.name[:20], shape, dtype]) if io_data: fig.add_trace( go.Table( header=dict( values=['Type', 'Name', 'Shape', 'Data Type'], fill_color='lightblue', align='left' ), cells=dict( values=list(zip(*io_data)), fill_color='white', align='left' ) ), row=2, col=1 ) # Parameters indicator fig.add_trace( go.Indicator( mode="number", value=total_params, title={'text': "Total Parameters"}, number={'suffix': 'M', 'valueformat': '.2f'}, number_font_size=30 ), row=2, col=2 ) fig.update_layout( title={ 'text': f"ONNX Model Metadata
{total_params/1e6:.2f}M Parameters", 'x': 0.5, 'xanchor': 'center', 'font': {'size': 22} }, height=600, template="plotly_dark", showlegend=False ) return fig except Exception as e: return create_styled_error_figure( "Metadata Analysis Error", f"Could not extract ONNX model metadata.", f"Error: {str(e)}" ) def performance_metrics(file_path: Path) -> go.Figure: """ Display performance and computational metrics for the ONNX model. Matches layout ID: performance_metrics """ if not ONNX_AVAILABLE: return create_styled_error_figure( "ONNX Not Available", "ONNX library is required for performance analysis.", "Install with: pip install onnx" ) try: model = onnx.load(str(file_path)) graph = model.graph # Calculate metrics model_size_bytes = file_path.stat().st_size model_size_mb = model_size_bytes / (1024 * 1024) # Count parameters total_params = 0 for initializer in graph.initializer: try: tensor = onnx.numpy_helper.to_array(initializer) total_params += tensor.size except: pass # Estimate memory usage (rough approximation) param_memory_mb = (total_params * 4) / (1024 * 1024) # Assume float32 # Count operations by complexity compute_ops = ['Conv', 'MatMul', 'Gemm', 'LSTM', 'GRU'] efficient_ops = ['Relu', 'Add', 'Mul', 'BatchNormalization', 'Dropout'] compute_count = sum(1 for node in graph.node if any(op in node.op_type for op in compute_ops)) efficient_count = sum(1 for node in graph.node if any(op in node.op_type for op in efficient_ops)) total_ops = len(graph.node) other_count = total_ops - compute_count - efficient_count # Create performance dashboard fig = make_subplots( rows=2, cols=2, subplot_titles=("Model Efficiency", "Memory Usage", "Operation Types", "Complexity Score"), specs=[[{"type": "bar"}, {"type": "bar"}], [{"type": "pie"}, {"type": "indicator"}]] ) # Model efficiency metrics efficiency_metrics = ["Model Size (MB)", "Parameters (M)", "Total Ops"] efficiency_values = [model_size_mb, total_params/1e6, total_ops] fig.add_trace( go.Bar( x=efficiency_metrics, y=efficiency_values, marker_color=['blue', 'green', 'orange'], showlegend=False ), row=1, col=1 ) # Memory usage memory_types = ["Parameters", "Est. Inference"] memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate fig.add_trace( go.Bar( x=memory_types, y=memory_values, marker_color=['purple', 'red'], showlegend=False ), row=1, col=2 ) # Operation types pie chart fig.add_trace( go.Pie( labels=['Compute Ops', 'Efficient Ops', 'Other Ops'], values=[compute_count, efficient_count, other_count], marker_colors=['red', 'green', 'gray'] ), row=2, col=1 ) # Complexity score (simple heuristic) complexity_score = min(100, (model_size_mb * 10 + total_params / 1e6 * 20 + compute_count)) fig.add_trace( go.Indicator( mode="gauge+number", value=complexity_score, title={'text': "Complexity Score"}, gauge={ 'axis': {'range': [0, 100]}, 'bar': {'color': "darkred" if complexity_score > 70 else "orange" if complexity_score > 40 else "green"}, 'steps': [ {'range': [0, 40], 'color': "lightgreen"}, {'range': [40, 70], 'color': "yellow"}, {'range': [70, 100], 'color': "lightcoral"} ] } ), row=2, col=2 ) fig.update_layout( title={ 'text': f"ONNX Performance Metrics
Complexity Score: {complexity_score:.0f}/100", 'x': 0.5, 'xanchor': 'center', 'font': {'size': 22} }, height=600, template="plotly_dark", showlegend=False ) return fig except Exception as e: return create_styled_error_figure( "Performance Analysis Error", f"Could not analyze ONNX model performance.", f"Error: {str(e)}" )