From 3b9e6971cad51d724e35fa4c59c80e0e6a3eeb7a Mon Sep 17 00:00:00 2001 From: ben Date: Mon, 27 Oct 2025 11:57:10 -0400 Subject: [PATCH] removed pytorch --- src/ria_toolkit_oss/viz/pytorch_model.py | 416 ----------------------- 1 file changed, 416 deletions(-) delete mode 100644 src/ria_toolkit_oss/viz/pytorch_model.py diff --git a/src/ria_toolkit_oss/viz/pytorch_model.py b/src/ria_toolkit_oss/viz/pytorch_model.py deleted file mode 100644 index 2376ce3..0000000 --- a/src/ria_toolkit_oss/viz/pytorch_model.py +++ /dev/null @@ -1,416 +0,0 @@ -"""Visualization functions for PyTorch model (.py) files. - -This module provides visualization capabilities for PyTorch model Python files, -extracting architectural information through AST parsing and static analysis. -""" - -import ast -import re -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple - -import plotly.graph_objects as go -from plotly.graph_objects import Figure - - -def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure: - """Create a professional error figure with Qoherent dark theme styling.""" - fig = go.Figure() - - # Create a clean, centered text display using Plotly's text formatting - main_text = f"⚠️ {title}

" - main_text += f"{message}" - - if suggestion: - main_text += "

💡 Suggestion:
" - main_text += f"{suggestion}" - - # Add the main text annotation - fig.add_annotation( - text=main_text, - xref="paper", - yref="paper", - x=0.5, - y=0.5, - xanchor="center", - yanchor="middle", - showarrow=False, - align="center", - borderwidth=2, - bordercolor="#4a5568", - bgcolor="#2d3748", - font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"), - ) - - # Update layout with dark theme - fig.update_layout( - title="", - height=400, - template="plotly_dark", - margin=dict(l=40, r=40, t=40, b=40), - plot_bgcolor="#1a202c", - paper_bgcolor="#1a202c", - font=dict(color="#e2e8f0"), - ) - - # Remove axes and grid - fig.update_xaxes(visible=False) - fig.update_yaxes(visible=False) - - return fig - - -def _parse_model_file(file_path: Path) -> Tuple[Optional[ast.Module], Optional[str]]: - """Parse a Python model file and return the AST and any error message.""" - try: - with open(file_path, "r", encoding="utf-8") as f: - code = f.read() - tree = ast.parse(code, filename=str(file_path)) - return tree, None - except SyntaxError as e: - return None, f"Syntax error in file: {e}" - except Exception as e: - return None, f"Failed to parse file: {e}" - - -def _find_model_class(tree: ast.Module) -> Optional[ast.ClassDef]: - """Find the main model class (subclass of nn.Module) in the AST.""" - for node in ast.walk(tree): - if isinstance(node, ast.ClassDef): - # Check if it inherits from nn.Module or torch.nn.Module - for base in node.bases: - base_name = "" - if isinstance(base, ast.Name): - base_name = base.id - elif isinstance(base, ast.Attribute): - if isinstance(base.value, ast.Name): - base_name = f"{base.value.id}.{base.attr}" - - if "Module" in base_name or "nn.Module" in base_name: - return node - return None - - -def _extract_layer_info(model_class: ast.ClassDef) -> List[Dict[str, Any]]: - """Extract layer information from the model's __init__ method.""" - layers = [] - - # Find __init__ method - init_method = None - for node in model_class.body: - if isinstance(node, ast.FunctionDef) and node.name == "__init__": - init_method = node - break - - if not init_method: - return layers - - # Parse assignments in __init__ - for node in ast.walk(init_method): - if isinstance(node, ast.Assign): - for target in node.targets: - if isinstance(target, ast.Attribute): - layer_name = target.attr - layer_type = _extract_layer_type(node.value) - if layer_type: - layers.append( - {"name": layer_name, "type": layer_type, "details": _extract_layer_params(node.value)} - ) - - return layers - - -def _extract_layer_type(node: ast.expr) -> Optional[str]: - """Extract the layer type from an AST node.""" - if isinstance(node, ast.Call): - if isinstance(node.func, ast.Name): - return node.func.id - elif isinstance(node.func, ast.Attribute): - return node.func.attr - return None - - -def _extract_layer_params(node: ast.Call) -> str: - """Extract layer parameters as a string.""" - params = [] - - # Extract positional arguments - for arg in node.args: - if isinstance(arg, ast.Constant): - params.append(str(arg.value)) - elif isinstance(arg, ast.Name): - params.append(arg.id) - - # Extract keyword arguments - for keyword in node.keywords: - if isinstance(keyword.value, ast.Constant): - params.append(f"{keyword.arg}={keyword.value.value}") - elif isinstance(keyword.value, ast.Name): - params.append(f"{keyword.arg}={keyword.value.id}") - - return ", ".join(params) - - -def _count_parameters(layers: List[Dict[str, Any]]) -> int: - """Estimate parameter count from layer definitions (rough estimate).""" - # This is a very rough estimate - actual counts would require instantiating the model - param_estimates = { - "Linear": 1000, - "Conv1d": 500, - "Conv2d": 5000, - "Conv3d": 10000, - "LSTM": 4000, - "GRU": 3000, - "TransformerEncoder": 50000, - "Embedding": 10000, - } - - total = 0 - for layer in layers: - layer_type = layer["type"] - total += param_estimates.get(layer_type, 100) - - return total - - -def model_architecture_plot(file_path: Path) -> Figure: - """Visualize the architecture of a PyTorch model from its .py file. - - Parses the model file using AST to extract layers and their connections. - """ - tree, error = _parse_model_file(file_path) - - if error: - return create_styled_error_figure( - "Parse Error", error, "Ensure the .py file contains valid Python code with a PyTorch nn.Module class" - ) - - model_class = _find_model_class(tree) - if not model_class: - return create_styled_error_figure( - "No Model Found", - "Could not find a PyTorch nn.Module class in the file", - "Ensure your model class inherits from torch.nn.Module or nn.Module", - ) - - layers = _extract_layer_info(model_class) - - if not layers: - return create_styled_error_figure( - "No Layers Found", - "Could not extract layer information from the model", - "Ensure your model defines layers in the __init__ method", - ) - - # Create a hierarchical visualization - layer_names = [f"{i+1}. {layer['name']}" for i, layer in enumerate(layers)] - layer_types = [layer["type"] for layer in layers] - layer_details = [layer["details"] for layer in layers] - - # Create a bar chart showing layers - fig = go.Figure() - - fig.add_trace( - go.Bar( - y=layer_names, - x=[1] * len(layer_names), - orientation="h", - text=layer_types, - textposition="inside", - hovertext=[ - f"{name}
Type: {type_}
Params: {details}" - for name, type_, details in zip(layer_names, layer_types, layer_details) - ], - hoverinfo="text", - marker=dict(color="rgba(99, 179, 237, 0.8)", line=dict(color="rgba(99, 179, 237, 1.0)", width=2)), - ) - ) - - fig.update_layout( - title=f"Model Architecture: {model_class.name}", - xaxis=dict(visible=False), - yaxis=dict(title="Layers", autorange="reversed"), - template="plotly_dark", - height=max(400, len(layers) * 40), - showlegend=False, - margin=dict(l=200, r=40, t=60, b=40), - ) - - return fig - - -def model_complexity_plot(file_path: Path) -> Figure: - """Analyze and visualize model complexity metrics.""" - tree, error = _parse_model_file(file_path) - - if error: - return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code") - - model_class = _find_model_class(tree) - if not model_class: - return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file") - - layers = _extract_layer_info(model_class) - - if not layers: - return create_styled_error_figure("No Layers Found", "Could not extract layer information from the model") - - # Count layer types - layer_type_counts = {} - for layer in layers: - layer_type = layer["type"] - layer_type_counts[layer_type] = layer_type_counts.get(layer_type, 0) + 1 - - # Create pie chart of layer types - fig = go.Figure( - data=[ - go.Pie( - labels=list(layer_type_counts.keys()), - values=list(layer_type_counts.values()), - hole=0.3, - marker=dict(colors=["#5c79ff", "#63b3ed", "#48bb78", "#f6ad55", "#fc8181"]), - ) - ] - ) - - fig.update_layout( - title="Layer Type Distribution", - template="plotly_dark", - height=400, - ) - - return fig - - -def model_metadata_plot(file_path: Path) -> Figure: - """Display model metadata and information extracted from the Python file.""" - tree, error = _parse_model_file(file_path) - - if error: - return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code") - - model_class = _find_model_class(tree) - if not model_class: - return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file") - - layers = _extract_layer_info(model_class) - - # Extract imports - imports = [] - for node in tree.body: - if isinstance(node, ast.Import): - for alias in node.names: - imports.append(alias.name) - elif isinstance(node, ast.ImportFrom): - if node.module: - imports.append(node.module) - - # Get docstring - docstring = ast.get_docstring(model_class) or "No docstring available" - if len(docstring) > 200: - docstring = docstring[:200] + "..." - - # Build metadata display - metadata_text = f"""Model: {model_class.name}

""" - metadata_text += f"📝 Description:
{docstring}

" - metadata_text += f"🔢 Number of Layers: {len(layers)}
" - metadata_text += f"📦 Estimated Parameters: ~{_count_parameters(layers):,}

" - metadata_text += f"📚 Key Imports:
" - - relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:5] - for imp in relevant_imports: - metadata_text += f" • {imp}
" - - fig = go.Figure() - - fig.add_annotation( - text=metadata_text, - xref="paper", - yref="paper", - x=0.05, - y=0.95, - xanchor="left", - yanchor="top", - showarrow=False, - align="left", - borderwidth=2, - bordercolor="#4a5568", - bgcolor="#2d3748", - font=dict(family="Arial, sans-serif", size=13, color="#e2e8f0"), - ) - - fig.update_layout( - title="Model Metadata", - template="plotly_dark", - height=450, - margin=dict(l=40, r=40, t=60, b=40), - plot_bgcolor="#1a202c", - paper_bgcolor="#1a202c", - ) - - fig.update_xaxes(visible=False) - fig.update_yaxes(visible=False) - - return fig - - -def code_structure_plot(file_path: Path) -> Figure: - """Visualize the code structure and method definitions in the model.""" - tree, error = _parse_model_file(file_path) - - if error: - return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code") - - model_class = _find_model_class(tree) - if not model_class: - return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file") - - # Extract methods - methods = [] - for node in model_class.body: - if isinstance(node, ast.FunctionDef): - # Count lines in method - if hasattr(node, "end_lineno") and hasattr(node, "lineno"): - lines = node.end_lineno - node.lineno + 1 - else: - lines = 1 - - methods.append({"name": node.name, "lines": lines, "args": len(node.args.args) - 1}) # Exclude self - - if not methods: - return create_styled_error_figure( - "No Methods Found", "Could not extract method information from the model class" - ) - - # Create visualization of methods - method_names = [m["name"] for m in methods] - method_lines = [m["lines"] for m in methods] - method_args = [m["args"] for m in methods] - - fig = go.Figure() - - # Bar chart for method complexity (lines of code) - fig.add_trace( - go.Bar( - x=method_names, - y=method_lines, - name="Lines of Code", - marker=dict(color="rgba(99, 179, 237, 0.8)"), - hovertext=[ - f"{name}
Lines: {lines}
Arguments: {args}" - for name, lines, args in zip(method_names, method_lines, method_args) - ], - hoverinfo="text", - ) - ) - - fig.update_layout( - title=f"Method Complexity - {model_class.name}", - xaxis_title="Methods", - yaxis_title="Lines of Code", - template="plotly_dark", - height=400, - showlegend=False, - ) - - return fig