Pytorch Widgets
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 54s
Test with tox / Test with tox (3.12) (pull_request) Successful in 38s
Test with tox / Test with tox (3.11) (pull_request) Successful in 1m9s
Test with tox / Test with tox (3.10) (pull_request) Successful in 2m3s
Build Project / Build Project (3.10) (pull_request) Successful in 2m15s
Build Project / Build Project (3.11) (pull_request) Successful in 2m13s
Build Project / Build Project (3.12) (pull_request) Successful in 2m13s

This commit is contained in:
ben 2025-10-31 12:12:24 -04:00
parent b8ccead21e
commit 48f6b303f5

View File

@ -5,7 +5,6 @@ extracting architectural information through AST parsing and static analysis.
"""
import ast
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
@ -283,9 +282,10 @@ def model_complexity_plot(file_path: Path) -> Figure:
def model_metadata_plot(file_path: Path) -> Figure:
"""Display model metadata and information extracted from the Python file."""
tree, error = _parse_model_file(file_path)
"""Display model metadata and information extracted from the Python file (clean, aligned layout)."""
import textwrap
tree, error = _parse_model_file(file_path)
if error:
return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code")
@ -301,56 +301,200 @@ def model_metadata_plot(file_path: Path) -> Figure:
if isinstance(node, ast.Import):
for alias in node.names:
imports.append(alias.name)
elif isinstance(node, ast.ImportFrom):
if node.module:
imports.append(node.module)
elif isinstance(node, ast.ImportFrom) and node.module:
imports.append(node.module)
# Get docstring
# Get docstring and wrap it
docstring = ast.get_docstring(model_class) or "No docstring available"
if len(docstring) > 200:
docstring = docstring[:200] + "..."
wrapped_doc = "<br>".join(textwrap.wrap(docstring, width=70))
# Build metadata display
metadata_text = f"""<b style='font-size:16px;color:#63b3ed'>Model: {model_class.name}</b><br><br>"""
metadata_text += f"<b>📝 Description:</b><br><span style='color:#cbd5e0'>{docstring}</span><br><br>"
metadata_text += f"<b>🔢 Number of Layers:</b> {len(layers)}<br>"
metadata_text += f"<b>📦 Estimated Parameters:</b> ~{_count_parameters(layers):,}<br><br>"
metadata_text += f"<b>📚 Key Imports:</b><br>"
relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:4]
param_count = _count_parameters(layers)
relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:5]
for imp in relevant_imports:
metadata_text += f"{imp}<br>"
# Define card grid (aligned 2x2)
cards = [
{"x": 0.05, "y": 0.93, "width": 0.43, "height": 0.38, "title": "📦 Model Overview", "color": "#2d5f8d"},
{"x": 0.52, "y": 0.93, "width": 0.43, "height": 0.38, "title": "🔢 Statistics", "color": "#2d6b5f"},
{"x": 0.05, "y": 0.46, "width": 0.43, "height": 0.38, "title": "📝 Description", "color": "#5d4b7a"},
{"x": 0.52, "y": 0.46, "width": 0.43, "height": 0.38, "title": "📚 Dependencies", "color": "#7a5b3d"},
]
fig = go.Figure()
# Draw background cards with consistent opacity
for card in cards:
fig.add_shape(
type="rect",
xref="paper",
yref="paper",
x0=card["x"],
y0=card["y"] - card["height"],
x1=card["x"] + card["width"],
y1=card["y"],
fillcolor=card["color"],
line=dict(color="#4a5568", width=2),
opacity=0.3,
layer="below",
)
# Header bar
fig.add_shape(
type="rect",
xref="paper",
yref="paper",
x0=card["x"],
y0=card["y"] - 0.07,
x1=card["x"] + card["width"],
y1=card["y"],
fillcolor=card["color"],
line=dict(width=0),
opacity=0.45,
layer="below",
)
# --- CARD 1: Model Overview ---
card = cards[0]
fig.add_annotation(
text=metadata_text,
text=f"<b>{card['title']}</b>",
xref="paper",
yref="paper",
x=0.05,
y=0.95,
x=card["x"] + 0.03,
y=card["y"] - 0.02,
xanchor="left",
yanchor="middle",
showarrow=False,
align="left",
font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
)
fig.add_annotation(
text=f"<b style='font-size:26px;color:#ffffff'>{model_class.name}</b><br>"
f"<span style='color:#94a3b8;font-size:15px'>PyTorch Neural Network</span>",
xref="paper",
yref="paper",
x=card["x"] + 0.04,
y=card["y"] - 0.13,
xanchor="left",
yanchor="top",
showarrow=False,
align="left",
borderwidth=2,
bordercolor="#4a5568",
bgcolor="#2d3748",
font=dict(family="Arial, sans-serif", size=13, color="#e2e8f0"),
font=dict(size=15, color="#cbd5e0", family="Inter, Arial, sans-serif"),
)
# --- CARD 2: Statistics ---
card = cards[1]
y_center = card["y"] - card["height"] / 2
fig.add_annotation(
text=f"<b>{card['title']}</b>",
xref="paper",
yref="paper",
x=card["x"] + 0.03,
y=card["y"] - 0.02,
xanchor="left",
yanchor="middle",
showarrow=False,
font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
)
fig.add_annotation(
text=f"<b style='font-size:44px;color:#63b3ed'>{len(layers)}</b><br>"
f"<span style='color:#94a3b8;font-size:13px;letter-spacing:1.5px'>LAYERS</span>",
xref="paper",
yref="paper",
x=card["x"] + card["width"] / 2,
y=y_center + 0.07,
xanchor="center",
yanchor="middle",
showarrow=False,
align="center",
)
fig.add_annotation(
text=f"<b style='font-size:36px;color:#48bb78'>~{param_count:,}</b><br>"
f"<span style='color:#94a3b8;font-size:13px;letter-spacing:1.5px'>PARAMETERS</span>",
xref="paper",
yref="paper",
x=card["x"] + card["width"] / 2,
y=y_center - 0.10,
xanchor="center",
yanchor="middle",
showarrow=False,
align="center",
)
# --- CARD 3: Description ---
card = cards[2]
fig.add_annotation(
text=f"<b>{card['title']}</b>",
xref="paper",
yref="paper",
x=card["x"] + 0.03,
y=card["y"] - 0.02,
xanchor="left",
yanchor="middle",
showarrow=False,
font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
)
fig.add_annotation(
text=f"<span style='color:#cbd5e0;font-size:14px;line-height:1.5'>{wrapped_doc}</span>",
xref="paper",
yref="paper",
x=card["x"] + 0.04,
y=card["y"] - 0.13,
xanchor="left",
yanchor="top",
showarrow=False,
align="left",
)
# --- CARD 4: Dependencies ---
card = cards[3]
fig.add_annotation(
text=f"<b>{card['title']}</b>",
xref="paper",
yref="paper",
x=card["x"] + 0.03,
y=card["y"] - 0.02,
xanchor="left",
yanchor="middle",
showarrow=False,
font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
)
imports_text = (
"<br>".join(
[
f"<span style='color:#48bb78;font-size:16px'>▸</span> "
f"<span style='color:#e2e8f0;font-family:\"Courier New\",monospace;font-size:14px'>{imp}</span>"
for imp in relevant_imports
]
)
if relevant_imports
else "<span style='color:#94a3b8;font-style:italic;font-size:14px'>No torch imports detected</span>"
)
fig.add_annotation(
text=imports_text,
xref="paper",
yref="paper",
x=card["x"] + 0.04,
y=card["y"] - 0.13,
xanchor="left",
yanchor="top",
showarrow=False,
align="left",
)
# Layout polish
fig.update_layout(
title="Model Metadata",
title=dict(
text="<b>Model Metadata</b>",
font=dict(size=20, color="#e2e8f0", family="Inter, Arial, sans-serif"),
x=0.5,
xanchor="center",
),
template="plotly_dark",
height=450,
margin=dict(l=40, r=40, t=60, b=40),
height=500,
margin=dict(l=20, r=20, t=70, b=20),
plot_bgcolor="#1a202c",
paper_bgcolor="#1a202c",
)
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
return fig