From 19a86e2a67e5e80a03b25bc1ba5205d46b66f212 Mon Sep 17 00:00:00 2001 From: ben Date: Wed, 19 Nov 2025 16:07:29 -0500 Subject: [PATCH] New stylings for onnx and pytorch --- src/ria_toolkit_oss/viz/onnx.py | 132 +++++++++++++++++------ src/ria_toolkit_oss/viz/pytorch_model.py | 28 ++--- 2 files changed, 108 insertions(+), 52 deletions(-) diff --git a/src/ria_toolkit_oss/viz/onnx.py b/src/ria_toolkit_oss/viz/onnx.py index e260eeb..aed6c6f 100644 --- a/src/ria_toolkit_oss/viz/onnx.py +++ b/src/ria_toolkit_oss/viz/onnx.py @@ -155,7 +155,7 @@ def graph_structure(file_path: Path) -> go.Figure: ), "x": 0.5, "xanchor": "center", - "font": {"size": 22}, + "font": {"size": 20, "family": "Inter, system-ui, sans-serif"}, }, xaxis_title="Execution Order", yaxis_title="", @@ -163,8 +163,11 @@ def graph_structure(file_path: Path) -> go.Figure: height=500, template="plotly_dark", yaxis=dict(showticklabels=False, showgrid=False), - xaxis=dict(showgrid=False), - margin=dict(l=50, r=50, t=80, b=50), + xaxis=dict(showgrid=True, gridcolor="#374151", gridwidth=1), + margin=dict(l=60, r=60, t=100, b=60), + plot_bgcolor="#111827", + paper_bgcolor="#1f2937", + font=dict(color="#e5e7eb", family="Inter, system-ui, sans-serif"), ) return fig @@ -211,6 +214,7 @@ def operator_analysis(file_path: Path) -> go.Figure: cols=1, subplot_titles=("Operator Distribution", "Operator Frequency"), specs=[[{"type": "pie"}], [{"type": "bar"}]], + vertical_spacing=0.15, ) # Pie chart for operator distribution @@ -248,10 +252,14 @@ def operator_analysis(file_path: Path) -> go.Figure: ), "x": 0.5, "xanchor": "center", - "font": {"size": 22}, + "font": {"size": 20, "family": "Inter, system-ui, sans-serif"}, }, - height=700, + height=750, template="plotly_dark", + margin=dict(l=60, r=60, t=100, b=60), + plot_bgcolor="#111827", + paper_bgcolor="#1f2937", + font=dict(color="#e5e7eb", family="Inter, system-ui, sans-serif"), ) return fig @@ -300,6 +308,8 @@ def model_metadata(file_path: Path) -> go.Figure: cols=2, subplot_titles=("Model Size", "Architecture", "Inputs/Outputs", "Parameters"), specs=[[{"type": "indicator"}, {"type": "bar"}], [{"type": "table"}, {"type": "indicator"}]], + vertical_spacing=0.15, + horizontal_spacing=0.12, ) # Model size indicator @@ -307,16 +317,21 @@ def model_metadata(file_path: Path) -> go.Figure: go.Indicator( mode="number+gauge", value=file_size_mb, - title={"text": "Model Size (MB)"}, - number={"suffix": " MB", "valueformat": ".2f"}, + title={"text": "Model Size (MB)", "font": {"size": 14}}, + number={"suffix": " MB", "valueformat": ".2f", "font": {"size": 24}}, gauge={ "axis": {"range": [0, max(100, file_size_mb * 1.5)]}, - "bar": {"color": "darkblue"}, + "bar": {"color": "#3b82f6"}, "steps": [ - {"range": [0, 10], "color": "lightgreen"}, - {"range": [10, 50], "color": "yellow"}, - {"range": [50, 100], "color": "orange"}, + {"range": [0, 10], "color": "#10b981"}, + {"range": [10, 50], "color": "#f59e0b"}, + {"range": [50, 100], "color": "#ef4444"}, ], + "threshold": { + "line": {"color": "white", "width": 2}, + "thickness": 0.75, + "value": file_size_mb, + }, }, ), row=1, @@ -328,7 +343,15 @@ def model_metadata(file_path: Path) -> go.Figure: arch_values = [total_nodes, total_inputs, total_outputs, total_initializers] fig.add_trace( - go.Bar(x=arch_data, y=arch_values, marker_color=["blue", "green", "orange", "red"], showlegend=False), + go.Bar( + x=arch_data, + y=arch_values, + marker_color=["#3b82f6", "#10b981", "#f59e0b", "#ef4444"], + showlegend=False, + text=arch_values, + textposition="outside", + textfont=dict(size=12, color="#e5e7eb"), + ), row=1, col=2, ) @@ -389,8 +412,20 @@ def model_metadata(file_path: Path) -> go.Figure: if io_data: fig.add_trace( go.Table( - header=dict(values=["Type", "Name", "Shape", "Data Type"], fill_color="lightblue", align="left"), - cells=dict(values=list(zip(*io_data)), fill_color="white", align="left"), + header=dict( + values=["Type", "Name", "Shape", "Data Type"], + fill_color="#374151", + align="left", + font=dict(color="#f3f4f6", size=12, family="Inter, system-ui, sans-serif"), + height=30, + ), + cells=dict( + values=list(zip(*io_data)), + fill_color="#1f2937", + align="left", + font=dict(color="#e5e7eb", size=11, family="Menlo, Consolas, monospace"), + height=25, + ), ), row=2, col=1, @@ -400,10 +435,9 @@ def model_metadata(file_path: Path) -> go.Figure: fig.add_trace( go.Indicator( mode="number", - value=total_params, - title={"text": "Total Parameters"}, - number={"suffix": "M", "valueformat": ".2f"}, - number_font_size=30, + value=total_params / 1e6, + title={"text": "Total Parameters", "font": {"size": 14}}, + number={"suffix": "M", "valueformat": ".2f", "font": {"size": 32}}, ), row=2, col=2, @@ -417,11 +451,15 @@ def model_metadata(file_path: Path) -> go.Figure: ), "x": 0.5, "xanchor": "center", - "font": {"size": 22}, + "font": {"size": 20, "family": "Inter, system-ui, sans-serif"}, }, - height=600, + height=700, template="plotly_dark", showlegend=False, + margin=dict(l=60, r=60, t=100, b=60), + plot_bgcolor="#111827", + paper_bgcolor="#1f2937", + font=dict(color="#e5e7eb", family="Inter, system-ui, sans-serif"), ) return fig @@ -479,6 +517,8 @@ def performance_metrics(file_path: Path) -> go.Figure: cols=2, subplot_titles=("Model Efficiency", "Memory Usage", "Operation Types", "Complexity Score"), specs=[[{"type": "bar"}, {"type": "bar"}], [{"type": "pie"}, {"type": "indicator"}]], + vertical_spacing=0.15, + horizontal_spacing=0.12, ) # Model efficiency metrics @@ -487,7 +527,13 @@ def performance_metrics(file_path: Path) -> go.Figure: fig.add_trace( go.Bar( - x=efficiency_metrics, y=efficiency_values, marker_color=["blue", "green", "orange"], showlegend=False + x=efficiency_metrics, + y=efficiency_values, + marker_color=["#3b82f6", "#10b981", "#f59e0b"], + showlegend=False, + text=[f"{v:.2f}" for v in efficiency_values], + textposition="outside", + textfont=dict(size=12, color="#e5e7eb"), ), row=1, col=1, @@ -498,7 +544,15 @@ def performance_metrics(file_path: Path) -> go.Figure: memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate fig.add_trace( - go.Bar(x=memory_types, y=memory_values, marker_color=["purple", "red"], showlegend=False), + go.Bar( + x=memory_types, + y=memory_values, + marker_color=["#8b5cf6", "#ef4444"], + showlegend=False, + text=[f"{v:.2f} MB" for v in memory_values], + textposition="outside", + textfont=dict(size=12, color="#e5e7eb"), + ), row=1, col=2, ) @@ -508,7 +562,8 @@ def performance_metrics(file_path: Path) -> go.Figure: go.Pie( labels=["Compute Ops", "Efficient Ops", "Other Ops"], values=[compute_count, efficient_count, other_count], - marker_colors=["red", "green", "gray"], + marker_colors=["#ef4444", "#10b981", "#6b7280"], + textfont=dict(size=12, color="#ffffff"), ), row=2, col=1, @@ -521,17 +576,23 @@ def performance_metrics(file_path: Path) -> go.Figure: go.Indicator( mode="gauge+number", value=complexity_score, - title={"text": "Complexity Score"}, + title={"text": "Complexity Score", "font": {"size": 14}}, + number={"font": {"size": 28}}, gauge={ "axis": {"range": [0, 100]}, "bar": { - "color": "darkred" if complexity_score > 70 else "orange" if complexity_score > 40 else "green" + "color": "#ef4444" if complexity_score > 70 else "#f59e0b" if complexity_score > 40 else "#10b981" }, "steps": [ - {"range": [0, 40], "color": "lightgreen"}, - {"range": [40, 70], "color": "yellow"}, - {"range": [70, 100], "color": "lightcoral"}, + {"range": [0, 40], "color": "rgba(16, 185, 129, 0.12)"}, + {"range": [40, 70], "color": "rgba(245, 158, 11, 0.12)"}, + {"range": [70, 100], "color": "rgba(239, 68, 68, 0.12)"}, ], + "threshold": { + "line": {"color": "white", "width": 2}, + "thickness": 0.75, + "value": complexity_score, + }, }, ), row=2, @@ -547,16 +608,25 @@ def performance_metrics(file_path: Path) -> go.Figure: ), "x": 0.5, "xanchor": "center", - "font": {"size": 22}, + "font": {"size": 20, "family": "Inter, system-ui, sans-serif"}, }, - height=600, + height=700, template="plotly_dark", showlegend=False, + margin=dict(l=60, r=60, t=100, b=60), + plot_bgcolor="#111827", + paper_bgcolor="#1f2937", + font=dict(color="#e5e7eb", family="Inter, system-ui, sans-serif"), ) return fig except Exception as e: + import traceback + error_details = f"Error: {str(e)}\n\nTraceback: {traceback.format_exc()}" + print(f"[ONNX Performance Metrics] Error: {error_details}") return create_styled_error_figure( - "Performance Analysis Error", "Could not analyze ONNX model performance.", f"Error: {str(e)}" + "Performance Analysis Error", + f"Could not analyze ONNX model performance: {str(e)}", + "Check the server logs for more details" ) diff --git a/src/ria_toolkit_oss/viz/pytorch_model.py b/src/ria_toolkit_oss/viz/pytorch_model.py index ef9733d..0bac2c7 100644 --- a/src/ria_toolkit_oss/viz/pytorch_model.py +++ b/src/ria_toolkit_oss/viz/pytorch_model.py @@ -336,20 +336,6 @@ def model_metadata_plot(file_path: Path) -> Figure: opacity=0.3, layer="below", ) - # Header bar - fig.add_shape( - type="rect", - xref="paper", - yref="paper", - x0=card["x"], - y0=card["y"] - 0.07, - x1=card["x"] + card["width"], - y1=card["y"], - fillcolor=card["color"], - line=dict(width=0), - opacity=0.45, - layer="below", - ) # --- CARD 1: Model Overview --- card = cards[0] @@ -358,7 +344,7 @@ def model_metadata_plot(file_path: Path) -> Figure: xref="paper", yref="paper", x=card["x"] + 0.03, - y=card["y"] - 0.02, + y=card["y"] - 0.04, xanchor="left", yanchor="middle", showarrow=False, @@ -371,7 +357,7 @@ def model_metadata_plot(file_path: Path) -> Figure: xref="paper", yref="paper", x=card["x"] + 0.04, - y=card["y"] - 0.13, + y=card["y"] - 0.15, xanchor="left", yanchor="top", showarrow=False, @@ -387,7 +373,7 @@ def model_metadata_plot(file_path: Path) -> Figure: xref="paper", yref="paper", x=card["x"] + 0.03, - y=card["y"] - 0.02, + y=card["y"] - 0.04, xanchor="left", yanchor="middle", showarrow=False, @@ -425,7 +411,7 @@ def model_metadata_plot(file_path: Path) -> Figure: xref="paper", yref="paper", x=card["x"] + 0.03, - y=card["y"] - 0.02, + y=card["y"] - 0.04, xanchor="left", yanchor="middle", showarrow=False, @@ -436,7 +422,7 @@ def model_metadata_plot(file_path: Path) -> Figure: xref="paper", yref="paper", x=card["x"] + 0.04, - y=card["y"] - 0.13, + y=card["y"] - 0.15, xanchor="left", yanchor="top", showarrow=False, @@ -450,7 +436,7 @@ def model_metadata_plot(file_path: Path) -> Figure: xref="paper", yref="paper", x=card["x"] + 0.03, - y=card["y"] - 0.02, + y=card["y"] - 0.04, xanchor="left", yanchor="middle", showarrow=False, @@ -472,7 +458,7 @@ def model_metadata_plot(file_path: Path) -> Figure: xref="paper", yref="paper", x=card["x"] + 0.04, - y=card["y"] - 0.13, + y=card["y"] - 0.15, xanchor="left", yanchor="top", showarrow=False,