diff --git a/src/ria_toolkit_oss/viz/pytorch_state_dict.py b/src/ria_toolkit_oss/viz/pytorch_state_dict.py new file mode 100644 index 0000000..05c96f4 --- /dev/null +++ b/src/ria_toolkit_oss/viz/pytorch_state_dict.py @@ -0,0 +1,89 @@ +import torch +import plotly.graph_objects as go +from plotly.graph_objects import Figure +import numpy as np + +def model_summary_plot(state_dict: dict) -> Figure: + """Generate a summary plot of the PyTorch model state dict.""" + # Count parameters by layer type + layer_info = [] + for key, tensor in state_dict.items(): + if 'weight' in key: + layer_name = key.replace('.weight', '') + param_count = tensor.numel() + layer_info.append({ + 'layer': layer_name, + 'parameters': param_count, + 'shape': list(tensor.shape) + }) + + # Create bar chart of parameter counts + fig = go.Figure(data=[ + go.Bar( + x=[info['layer'] for info in layer_info], + y=[info['parameters'] for info in layer_info], + text=[f"Shape: {info['shape']}" for info in layer_info], + textposition='auto', + ) + ]) + + fig.update_layout( + title="Model Layer Parameter Counts", + xaxis_title="Layer", + yaxis_title="Number of Parameters" + ) + + return fig + +def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure: + """Visualize weights for a specific layer.""" + if layer_name is None: + # Get first weight tensor + weight_keys = [k for k in state_dict.keys() if 'weight' in k] + if not weight_keys: + raise ValueError("No weight tensors found in state dict") + layer_name = weight_keys[0] + + weights = state_dict[layer_name] + + # For 2D weights, create heatmap + if len(weights.shape) == 2: + fig = go.Figure(data=go.Heatmap( + z=weights.numpy(), + colorscale='RdBu', + zmid=0 + )) + fig.update_layout(title=f"Weights Heatmap: {layer_name}") + else: + # For other shapes, flatten and show histogram + flat_weights = weights.flatten().numpy() + fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)]) + fig.update_layout(title=f"Weight Distribution: {layer_name}") + + return fig + +def weight_distribution_plot(state_dict: dict) -> Figure: + """Show distribution of weights across all layers.""" + all_weights = [] + layer_names = [] + + for key, tensor in state_dict.items(): + if 'weight' in key: + all_weights.extend(tensor.flatten().numpy()) + layer_names.extend([key] * tensor.numel()) + + fig = go.Figure(data=[ + go.Histogram( + x=all_weights, + nbinsx=100, + name="All Weights" + ) + ]) + + fig.update_layout( + title="Overall Weight Distribution", + xaxis_title="Weight Value", + yaxis_title="Frequency" + ) + + return fig \ No newline at end of file