diff --git a/src/ria_toolkit_oss/viz/pytorch_model.py b/src/ria_toolkit_oss/viz/pytorch_model.py new file mode 100644 index 0000000..2376ce3 --- /dev/null +++ b/src/ria_toolkit_oss/viz/pytorch_model.py @@ -0,0 +1,416 @@ +"""Visualization functions for PyTorch model (.py) files. + +This module provides visualization capabilities for PyTorch model Python files, +extracting architectural information through AST parsing and static analysis. +""" + +import ast +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import plotly.graph_objects as go +from plotly.graph_objects import Figure + + +def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure: + """Create a professional error figure with Qoherent dark theme styling.""" + fig = go.Figure() + + # Create a clean, centered text display using Plotly's text formatting + main_text = f"⚠️ {title}

" + main_text += f"{message}" + + if suggestion: + main_text += "

💡 Suggestion:
" + main_text += f"{suggestion}" + + # Add the main text annotation + fig.add_annotation( + text=main_text, + xref="paper", + yref="paper", + x=0.5, + y=0.5, + xanchor="center", + yanchor="middle", + showarrow=False, + align="center", + borderwidth=2, + bordercolor="#4a5568", + bgcolor="#2d3748", + font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"), + ) + + # Update layout with dark theme + fig.update_layout( + title="", + height=400, + template="plotly_dark", + margin=dict(l=40, r=40, t=40, b=40), + plot_bgcolor="#1a202c", + paper_bgcolor="#1a202c", + font=dict(color="#e2e8f0"), + ) + + # Remove axes and grid + fig.update_xaxes(visible=False) + fig.update_yaxes(visible=False) + + return fig + + +def _parse_model_file(file_path: Path) -> Tuple[Optional[ast.Module], Optional[str]]: + """Parse a Python model file and return the AST and any error message.""" + try: + with open(file_path, "r", encoding="utf-8") as f: + code = f.read() + tree = ast.parse(code, filename=str(file_path)) + return tree, None + except SyntaxError as e: + return None, f"Syntax error in file: {e}" + except Exception as e: + return None, f"Failed to parse file: {e}" + + +def _find_model_class(tree: ast.Module) -> Optional[ast.ClassDef]: + """Find the main model class (subclass of nn.Module) in the AST.""" + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + # Check if it inherits from nn.Module or torch.nn.Module + for base in node.bases: + base_name = "" + if isinstance(base, ast.Name): + base_name = base.id + elif isinstance(base, ast.Attribute): + if isinstance(base.value, ast.Name): + base_name = f"{base.value.id}.{base.attr}" + + if "Module" in base_name or "nn.Module" in base_name: + return node + return None + + +def _extract_layer_info(model_class: ast.ClassDef) -> List[Dict[str, Any]]: + """Extract layer information from the model's __init__ method.""" + layers = [] + + # Find __init__ method + init_method = None + for node in model_class.body: + if isinstance(node, ast.FunctionDef) and node.name == "__init__": + init_method = node + break + + if not init_method: + return layers + + # Parse assignments in __init__ + for node in ast.walk(init_method): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Attribute): + layer_name = target.attr + layer_type = _extract_layer_type(node.value) + if layer_type: + layers.append( + {"name": layer_name, "type": layer_type, "details": _extract_layer_params(node.value)} + ) + + return layers + + +def _extract_layer_type(node: ast.expr) -> Optional[str]: + """Extract the layer type from an AST node.""" + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + return node.func.id + elif isinstance(node.func, ast.Attribute): + return node.func.attr + return None + + +def _extract_layer_params(node: ast.Call) -> str: + """Extract layer parameters as a string.""" + params = [] + + # Extract positional arguments + for arg in node.args: + if isinstance(arg, ast.Constant): + params.append(str(arg.value)) + elif isinstance(arg, ast.Name): + params.append(arg.id) + + # Extract keyword arguments + for keyword in node.keywords: + if isinstance(keyword.value, ast.Constant): + params.append(f"{keyword.arg}={keyword.value.value}") + elif isinstance(keyword.value, ast.Name): + params.append(f"{keyword.arg}={keyword.value.id}") + + return ", ".join(params) + + +def _count_parameters(layers: List[Dict[str, Any]]) -> int: + """Estimate parameter count from layer definitions (rough estimate).""" + # This is a very rough estimate - actual counts would require instantiating the model + param_estimates = { + "Linear": 1000, + "Conv1d": 500, + "Conv2d": 5000, + "Conv3d": 10000, + "LSTM": 4000, + "GRU": 3000, + "TransformerEncoder": 50000, + "Embedding": 10000, + } + + total = 0 + for layer in layers: + layer_type = layer["type"] + total += param_estimates.get(layer_type, 100) + + return total + + +def model_architecture_plot(file_path: Path) -> Figure: + """Visualize the architecture of a PyTorch model from its .py file. + + Parses the model file using AST to extract layers and their connections. + """ + tree, error = _parse_model_file(file_path) + + if error: + return create_styled_error_figure( + "Parse Error", error, "Ensure the .py file contains valid Python code with a PyTorch nn.Module class" + ) + + model_class = _find_model_class(tree) + if not model_class: + return create_styled_error_figure( + "No Model Found", + "Could not find a PyTorch nn.Module class in the file", + "Ensure your model class inherits from torch.nn.Module or nn.Module", + ) + + layers = _extract_layer_info(model_class) + + if not layers: + return create_styled_error_figure( + "No Layers Found", + "Could not extract layer information from the model", + "Ensure your model defines layers in the __init__ method", + ) + + # Create a hierarchical visualization + layer_names = [f"{i+1}. {layer['name']}" for i, layer in enumerate(layers)] + layer_types = [layer["type"] for layer in layers] + layer_details = [layer["details"] for layer in layers] + + # Create a bar chart showing layers + fig = go.Figure() + + fig.add_trace( + go.Bar( + y=layer_names, + x=[1] * len(layer_names), + orientation="h", + text=layer_types, + textposition="inside", + hovertext=[ + f"{name}
Type: {type_}
Params: {details}" + for name, type_, details in zip(layer_names, layer_types, layer_details) + ], + hoverinfo="text", + marker=dict(color="rgba(99, 179, 237, 0.8)", line=dict(color="rgba(99, 179, 237, 1.0)", width=2)), + ) + ) + + fig.update_layout( + title=f"Model Architecture: {model_class.name}", + xaxis=dict(visible=False), + yaxis=dict(title="Layers", autorange="reversed"), + template="plotly_dark", + height=max(400, len(layers) * 40), + showlegend=False, + margin=dict(l=200, r=40, t=60, b=40), + ) + + return fig + + +def model_complexity_plot(file_path: Path) -> Figure: + """Analyze and visualize model complexity metrics.""" + tree, error = _parse_model_file(file_path) + + if error: + return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code") + + model_class = _find_model_class(tree) + if not model_class: + return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file") + + layers = _extract_layer_info(model_class) + + if not layers: + return create_styled_error_figure("No Layers Found", "Could not extract layer information from the model") + + # Count layer types + layer_type_counts = {} + for layer in layers: + layer_type = layer["type"] + layer_type_counts[layer_type] = layer_type_counts.get(layer_type, 0) + 1 + + # Create pie chart of layer types + fig = go.Figure( + data=[ + go.Pie( + labels=list(layer_type_counts.keys()), + values=list(layer_type_counts.values()), + hole=0.3, + marker=dict(colors=["#5c79ff", "#63b3ed", "#48bb78", "#f6ad55", "#fc8181"]), + ) + ] + ) + + fig.update_layout( + title="Layer Type Distribution", + template="plotly_dark", + height=400, + ) + + return fig + + +def model_metadata_plot(file_path: Path) -> Figure: + """Display model metadata and information extracted from the Python file.""" + tree, error = _parse_model_file(file_path) + + if error: + return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code") + + model_class = _find_model_class(tree) + if not model_class: + return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file") + + layers = _extract_layer_info(model_class) + + # Extract imports + imports = [] + for node in tree.body: + if isinstance(node, ast.Import): + for alias in node.names: + imports.append(alias.name) + elif isinstance(node, ast.ImportFrom): + if node.module: + imports.append(node.module) + + # Get docstring + docstring = ast.get_docstring(model_class) or "No docstring available" + if len(docstring) > 200: + docstring = docstring[:200] + "..." + + # Build metadata display + metadata_text = f"""Model: {model_class.name}

""" + metadata_text += f"📝 Description:
{docstring}

" + metadata_text += f"🔢 Number of Layers: {len(layers)}
" + metadata_text += f"📦 Estimated Parameters: ~{_count_parameters(layers):,}

" + metadata_text += f"📚 Key Imports:
" + + relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:5] + for imp in relevant_imports: + metadata_text += f" • {imp}
" + + fig = go.Figure() + + fig.add_annotation( + text=metadata_text, + xref="paper", + yref="paper", + x=0.05, + y=0.95, + xanchor="left", + yanchor="top", + showarrow=False, + align="left", + borderwidth=2, + bordercolor="#4a5568", + bgcolor="#2d3748", + font=dict(family="Arial, sans-serif", size=13, color="#e2e8f0"), + ) + + fig.update_layout( + title="Model Metadata", + template="plotly_dark", + height=450, + margin=dict(l=40, r=40, t=60, b=40), + plot_bgcolor="#1a202c", + paper_bgcolor="#1a202c", + ) + + fig.update_xaxes(visible=False) + fig.update_yaxes(visible=False) + + return fig + + +def code_structure_plot(file_path: Path) -> Figure: + """Visualize the code structure and method definitions in the model.""" + tree, error = _parse_model_file(file_path) + + if error: + return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code") + + model_class = _find_model_class(tree) + if not model_class: + return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file") + + # Extract methods + methods = [] + for node in model_class.body: + if isinstance(node, ast.FunctionDef): + # Count lines in method + if hasattr(node, "end_lineno") and hasattr(node, "lineno"): + lines = node.end_lineno - node.lineno + 1 + else: + lines = 1 + + methods.append({"name": node.name, "lines": lines, "args": len(node.args.args) - 1}) # Exclude self + + if not methods: + return create_styled_error_figure( + "No Methods Found", "Could not extract method information from the model class" + ) + + # Create visualization of methods + method_names = [m["name"] for m in methods] + method_lines = [m["lines"] for m in methods] + method_args = [m["args"] for m in methods] + + fig = go.Figure() + + # Bar chart for method complexity (lines of code) + fig.add_trace( + go.Bar( + x=method_names, + y=method_lines, + name="Lines of Code", + marker=dict(color="rgba(99, 179, 237, 0.8)"), + hovertext=[ + f"{name}
Lines: {lines}
Arguments: {args}" + for name, lines, args in zip(method_names, method_lines, method_args) + ], + hoverinfo="text", + ) + ) + + fig.update_layout( + title=f"Method Complexity - {model_class.name}", + xaxis_title="Methods", + yaxis_title="Lines of Code", + template="plotly_dark", + height=400, + showlegend=False, + ) + + return fig