89 lines
2.7 KiB
Python
89 lines
2.7 KiB
Python
|
|
import torch
|
||
|
|
import plotly.graph_objects as go
|
||
|
|
from plotly.graph_objects import Figure
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
def model_summary_plot(state_dict: dict) -> Figure:
|
||
|
|
"""Generate a summary plot of the PyTorch model state dict."""
|
||
|
|
# Count parameters by layer type
|
||
|
|
layer_info = []
|
||
|
|
for key, tensor in state_dict.items():
|
||
|
|
if 'weight' in key:
|
||
|
|
layer_name = key.replace('.weight', '')
|
||
|
|
param_count = tensor.numel()
|
||
|
|
layer_info.append({
|
||
|
|
'layer': layer_name,
|
||
|
|
'parameters': param_count,
|
||
|
|
'shape': list(tensor.shape)
|
||
|
|
})
|
||
|
|
|
||
|
|
# Create bar chart of parameter counts
|
||
|
|
fig = go.Figure(data=[
|
||
|
|
go.Bar(
|
||
|
|
x=[info['layer'] for info in layer_info],
|
||
|
|
y=[info['parameters'] for info in layer_info],
|
||
|
|
text=[f"Shape: {info['shape']}" for info in layer_info],
|
||
|
|
textposition='auto',
|
||
|
|
)
|
||
|
|
])
|
||
|
|
|
||
|
|
fig.update_layout(
|
||
|
|
title="Model Layer Parameter Counts",
|
||
|
|
xaxis_title="Layer",
|
||
|
|
yaxis_title="Number of Parameters"
|
||
|
|
)
|
||
|
|
|
||
|
|
return fig
|
||
|
|
|
||
|
|
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
|
||
|
|
"""Visualize weights for a specific layer."""
|
||
|
|
if layer_name is None:
|
||
|
|
# Get first weight tensor
|
||
|
|
weight_keys = [k for k in state_dict.keys() if 'weight' in k]
|
||
|
|
if not weight_keys:
|
||
|
|
raise ValueError("No weight tensors found in state dict")
|
||
|
|
layer_name = weight_keys[0]
|
||
|
|
|
||
|
|
weights = state_dict[layer_name]
|
||
|
|
|
||
|
|
# For 2D weights, create heatmap
|
||
|
|
if len(weights.shape) == 2:
|
||
|
|
fig = go.Figure(data=go.Heatmap(
|
||
|
|
z=weights.numpy(),
|
||
|
|
colorscale='RdBu',
|
||
|
|
zmid=0
|
||
|
|
))
|
||
|
|
fig.update_layout(title=f"Weights Heatmap: {layer_name}")
|
||
|
|
else:
|
||
|
|
# For other shapes, flatten and show histogram
|
||
|
|
flat_weights = weights.flatten().numpy()
|
||
|
|
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
|
||
|
|
fig.update_layout(title=f"Weight Distribution: {layer_name}")
|
||
|
|
|
||
|
|
return fig
|
||
|
|
|
||
|
|
def weight_distribution_plot(state_dict: dict) -> Figure:
|
||
|
|
"""Show distribution of weights across all layers."""
|
||
|
|
all_weights = []
|
||
|
|
layer_names = []
|
||
|
|
|
||
|
|
for key, tensor in state_dict.items():
|
||
|
|
if 'weight' in key:
|
||
|
|
all_weights.extend(tensor.flatten().numpy())
|
||
|
|
layer_names.extend([key] * tensor.numel())
|
||
|
|
|
||
|
|
fig = go.Figure(data=[
|
||
|
|
go.Histogram(
|
||
|
|
x=all_weights,
|
||
|
|
nbinsx=100,
|
||
|
|
name="All Weights"
|
||
|
|
)
|
||
|
|
])
|
||
|
|
|
||
|
|
fig.update_layout(
|
||
|
|
title="Overall Weight Distribution",
|
||
|
|
xaxis_title="Weight Value",
|
||
|
|
yaxis_title="Frequency"
|
||
|
|
)
|
||
|
|
|
||
|
|
return fig
|