import torch import plotly.graph_objects as go from plotly.graph_objects import Figure import numpy as np def model_summary_plot(state_dict: dict) -> Figure: """Generate a summary plot of the PyTorch model state dict.""" # Count parameters by layer type layer_info = [] for key, tensor in state_dict.items(): if 'weight' in key: layer_name = key.replace('.weight', '') param_count = tensor.numel() layer_info.append({ 'layer': layer_name, 'parameters': param_count, 'shape': list(tensor.shape) }) # Create bar chart of parameter counts fig = go.Figure(data=[ go.Bar( x=[info['layer'] for info in layer_info], y=[info['parameters'] for info in layer_info], text=[f"Shape: {info['shape']}" for info in layer_info], textposition='auto', ) ]) fig.update_layout( title="Model Layer Parameter Counts", xaxis_title="Layer", yaxis_title="Number of Parameters" ) return fig def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure: """Visualize weights for a specific layer.""" if layer_name is None: # Get first weight tensor weight_keys = [k for k in state_dict.keys() if 'weight' in k] if not weight_keys: raise ValueError("No weight tensors found in state dict") layer_name = weight_keys[0] weights = state_dict[layer_name] # For 2D weights, create heatmap if len(weights.shape) == 2: fig = go.Figure(data=go.Heatmap( z=weights.numpy(), colorscale='RdBu', zmid=0 )) fig.update_layout(title=f"Weights Heatmap: {layer_name}") else: # For other shapes, flatten and show histogram flat_weights = weights.flatten().numpy() fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)]) fig.update_layout(title=f"Weight Distribution: {layer_name}") return fig def weight_distribution_plot(state_dict: dict) -> Figure: """Show distribution of weights across all layers.""" all_weights = [] layer_names = [] for key, tensor in state_dict.items(): if 'weight' in key: all_weights.extend(tensor.flatten().numpy()) layer_names.extend([key] * tensor.numel()) fig = go.Figure(data=[ go.Histogram( x=all_weights, nbinsx=100, name="All Weights" ) ]) fig.update_layout( title="Overall Weight Distribution", xaxis_title="Weight Value", yaxis_title="Frequency" ) return fig