pytorch-widgets #10
|
|
@ -5,7 +5,6 @@ extracting architectural information through AST parsing and static analysis.
|
|||
"""
|
||||
|
||||
import ast
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
|
@ -283,9 +282,10 @@ def model_complexity_plot(file_path: Path) -> Figure:
|
|||
|
||||
|
||||
def model_metadata_plot(file_path: Path) -> Figure:
|
||||
"""Display model metadata and information extracted from the Python file."""
|
||||
tree, error = _parse_model_file(file_path)
|
||||
"""Display model metadata and information extracted from the Python file (clean, aligned layout)."""
|
||||
import textwrap
|
||||
|
||||
tree, error = _parse_model_file(file_path)
|
||||
if error:
|
||||
return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code")
|
||||
|
||||
|
|
@ -301,56 +301,200 @@ def model_metadata_plot(file_path: Path) -> Figure:
|
|||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
imports.append(alias.name)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module:
|
||||
imports.append(node.module)
|
||||
elif isinstance(node, ast.ImportFrom) and node.module:
|
||||
imports.append(node.module)
|
||||
|
||||
# Get docstring
|
||||
# Get docstring and wrap it
|
||||
docstring = ast.get_docstring(model_class) or "No docstring available"
|
||||
if len(docstring) > 200:
|
||||
docstring = docstring[:200] + "..."
|
||||
wrapped_doc = "<br>".join(textwrap.wrap(docstring, width=70))
|
||||
|
||||
# Build metadata display
|
||||
metadata_text = f"""<b style='font-size:16px;color:#63b3ed'>Model: {model_class.name}</b><br><br>"""
|
||||
metadata_text += f"<b>📝 Description:</b><br><span style='color:#cbd5e0'>{docstring}</span><br><br>"
|
||||
metadata_text += f"<b>🔢 Number of Layers:</b> {len(layers)}<br>"
|
||||
metadata_text += f"<b>📦 Estimated Parameters:</b> ~{_count_parameters(layers):,}<br><br>"
|
||||
metadata_text += f"<b>📚 Key Imports:</b><br>"
|
||||
relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:4]
|
||||
param_count = _count_parameters(layers)
|
||||
|
||||
relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:5]
|
||||
for imp in relevant_imports:
|
||||
metadata_text += f" • {imp}<br>"
|
||||
# Define card grid (aligned 2x2)
|
||||
cards = [
|
||||
{"x": 0.05, "y": 0.93, "width": 0.43, "height": 0.38, "title": "📦 Model Overview", "color": "#2d5f8d"},
|
||||
{"x": 0.52, "y": 0.93, "width": 0.43, "height": 0.38, "title": "🔢 Statistics", "color": "#2d6b5f"},
|
||||
{"x": 0.05, "y": 0.46, "width": 0.43, "height": 0.38, "title": "📝 Description", "color": "#5d4b7a"},
|
||||
{"x": 0.52, "y": 0.46, "width": 0.43, "height": 0.38, "title": "📚 Dependencies", "color": "#7a5b3d"},
|
||||
]
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
# Draw background cards with consistent opacity
|
||||
for card in cards:
|
||||
fig.add_shape(
|
||||
type="rect",
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x0=card["x"],
|
||||
y0=card["y"] - card["height"],
|
||||
x1=card["x"] + card["width"],
|
||||
y1=card["y"],
|
||||
fillcolor=card["color"],
|
||||
line=dict(color="#4a5568", width=2),
|
||||
opacity=0.3,
|
||||
layer="below",
|
||||
)
|
||||
# Header bar
|
||||
fig.add_shape(
|
||||
type="rect",
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x0=card["x"],
|
||||
y0=card["y"] - 0.07,
|
||||
x1=card["x"] + card["width"],
|
||||
y1=card["y"],
|
||||
fillcolor=card["color"],
|
||||
line=dict(width=0),
|
||||
opacity=0.45,
|
||||
layer="below",
|
||||
)
|
||||
|
||||
# --- CARD 1: Model Overview ---
|
||||
card = cards[0]
|
||||
fig.add_annotation(
|
||||
text=metadata_text,
|
||||
text=f"<b>{card['title']}</b>",
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=0.05,
|
||||
y=0.95,
|
||||
x=card["x"] + 0.03,
|
||||
y=card["y"] - 0.02,
|
||||
xanchor="left",
|
||||
yanchor="middle",
|
||||
showarrow=False,
|
||||
align="left",
|
||||
font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
|
||||
)
|
||||
fig.add_annotation(
|
||||
text=f"<b style='font-size:26px;color:#ffffff'>{model_class.name}</b><br>"
|
||||
f"<span style='color:#94a3b8;font-size:15px'>PyTorch Neural Network</span>",
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=card["x"] + 0.04,
|
||||
y=card["y"] - 0.13,
|
||||
xanchor="left",
|
||||
yanchor="top",
|
||||
showarrow=False,
|
||||
align="left",
|
||||
borderwidth=2,
|
||||
bordercolor="#4a5568",
|
||||
bgcolor="#2d3748",
|
||||
font=dict(family="Arial, sans-serif", size=13, color="#e2e8f0"),
|
||||
font=dict(size=15, color="#cbd5e0", family="Inter, Arial, sans-serif"),
|
||||
)
|
||||
|
||||
# --- CARD 2: Statistics ---
|
||||
card = cards[1]
|
||||
y_center = card["y"] - card["height"] / 2
|
||||
fig.add_annotation(
|
||||
text=f"<b>{card['title']}</b>",
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=card["x"] + 0.03,
|
||||
y=card["y"] - 0.02,
|
||||
xanchor="left",
|
||||
yanchor="middle",
|
||||
showarrow=False,
|
||||
font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
|
||||
)
|
||||
fig.add_annotation(
|
||||
text=f"<b style='font-size:44px;color:#63b3ed'>{len(layers)}</b><br>"
|
||||
f"<span style='color:#94a3b8;font-size:13px;letter-spacing:1.5px'>LAYERS</span>",
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=card["x"] + card["width"] / 2,
|
||||
y=y_center + 0.07,
|
||||
xanchor="center",
|
||||
yanchor="middle",
|
||||
showarrow=False,
|
||||
align="center",
|
||||
)
|
||||
fig.add_annotation(
|
||||
text=f"<b style='font-size:36px;color:#48bb78'>~{param_count:,}</b><br>"
|
||||
f"<span style='color:#94a3b8;font-size:13px;letter-spacing:1.5px'>PARAMETERS</span>",
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=card["x"] + card["width"] / 2,
|
||||
y=y_center - 0.10,
|
||||
xanchor="center",
|
||||
yanchor="middle",
|
||||
showarrow=False,
|
||||
align="center",
|
||||
)
|
||||
|
||||
# --- CARD 3: Description ---
|
||||
card = cards[2]
|
||||
fig.add_annotation(
|
||||
text=f"<b>{card['title']}</b>",
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=card["x"] + 0.03,
|
||||
y=card["y"] - 0.02,
|
||||
xanchor="left",
|
||||
yanchor="middle",
|
||||
showarrow=False,
|
||||
font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
|
||||
)
|
||||
fig.add_annotation(
|
||||
text=f"<span style='color:#cbd5e0;font-size:14px;line-height:1.5'>{wrapped_doc}</span>",
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=card["x"] + 0.04,
|
||||
y=card["y"] - 0.13,
|
||||
xanchor="left",
|
||||
yanchor="top",
|
||||
showarrow=False,
|
||||
align="left",
|
||||
)
|
||||
|
||||
# --- CARD 4: Dependencies ---
|
||||
card = cards[3]
|
||||
fig.add_annotation(
|
||||
text=f"<b>{card['title']}</b>",
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=card["x"] + 0.03,
|
||||
y=card["y"] - 0.02,
|
||||
xanchor="left",
|
||||
yanchor="middle",
|
||||
showarrow=False,
|
||||
font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
|
||||
)
|
||||
imports_text = (
|
||||
"<br>".join(
|
||||
[
|
||||
f"<span style='color:#48bb78;font-size:16px'>▸</span> "
|
||||
f"<span style='color:#e2e8f0;font-family:\"Courier New\",monospace;font-size:14px'>{imp}</span>"
|
||||
for imp in relevant_imports
|
||||
]
|
||||
)
|
||||
if relevant_imports
|
||||
else "<span style='color:#94a3b8;font-style:italic;font-size:14px'>No torch imports detected</span>"
|
||||
)
|
||||
fig.add_annotation(
|
||||
text=imports_text,
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=card["x"] + 0.04,
|
||||
y=card["y"] - 0.13,
|
||||
xanchor="left",
|
||||
yanchor="top",
|
||||
showarrow=False,
|
||||
align="left",
|
||||
)
|
||||
|
||||
# Layout polish
|
||||
fig.update_layout(
|
||||
title="Model Metadata",
|
||||
title=dict(
|
||||
text="<b>Model Metadata</b>",
|
||||
font=dict(size=20, color="#e2e8f0", family="Inter, Arial, sans-serif"),
|
||||
x=0.5,
|
||||
xanchor="center",
|
||||
),
|
||||
template="plotly_dark",
|
||||
height=450,
|
||||
margin=dict(l=40, r=40, t=60, b=40),
|
||||
height=500,
|
||||
margin=dict(l=20, r=20, t=70, b=20),
|
||||
plot_bgcolor="#1a202c",
|
||||
paper_bgcolor="#1a202c",
|
||||
)
|
||||
|
||||
fig.update_xaxes(visible=False)
|
||||
fig.update_yaxes(visible=False)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user