ria-toolkit-oss/src/ria_toolkit_oss/viz/pytorch_state_dict.py

222 lines
7.2 KiB
Python
Raw Normal View History

2025-10-20 14:44:51 -04:00
import numpy as np
2025-10-09 16:55:23 -04:00
import plotly.graph_objects as go
from plotly.graph_objects import Figure
2025-10-20 14:44:51 -04:00
2025-10-09 16:55:23 -04:00
def model_summary_plot(state_dict: dict) -> Figure:
"""Generate a summary plot of the PyTorch model state dict."""
2025-10-14 14:22:37 -04:00
if not state_dict:
# Handle empty state dict
fig = go.Figure()
fig.add_annotation(
text="No parameters found in state dict",
2025-10-20 14:44:51 -04:00
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=16),
2025-10-14 14:22:37 -04:00
)
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
2025-10-20 14:44:51 -04:00
template="plotly_dark",
2025-10-14 14:22:37 -04:00
)
return fig
2025-10-20 14:44:51 -04:00
2025-10-09 16:55:23 -04:00
# Count parameters by layer type
layer_info = []
for key, tensor in state_dict.items():
2025-10-20 14:44:51 -04:00
if "weight" in key:
2025-10-14 14:22:37 -04:00
try:
2025-10-20 14:44:51 -04:00
layer_name = key.replace(".weight", "")
param_count = (
tensor.numel()
if hasattr(tensor, "numel")
else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0
)
shape = (
list(tensor.shape)
if hasattr(tensor, "shape")
else [len(tensor)] if hasattr(tensor, "__len__") else []
)
layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape})
2025-10-14 14:22:37 -04:00
except Exception as e:
print(f"Warning: Could not process layer {key}: {e}")
continue
2025-10-20 14:44:51 -04:00
2025-10-14 14:22:37 -04:00
if not layer_info:
# Handle case where no weight layers found
fig = go.Figure()
fig.add_annotation(
text="No weight layers found in state dict",
2025-10-20 14:44:51 -04:00
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=16),
2025-10-14 14:22:37 -04:00
)
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
2025-10-20 14:44:51 -04:00
template="plotly_dark",
2025-10-14 14:22:37 -04:00
)
return fig
2025-10-20 14:44:51 -04:00
2025-10-09 16:55:23 -04:00
# Create bar chart of parameter counts
2025-10-20 14:44:51 -04:00
fig = go.Figure(
data=[
go.Bar(
x=[info["layer"] for info in layer_info],
y=[info["parameters"] for info in layer_info],
text=[f"Shape: {info['shape']}" for info in layer_info],
textposition="auto",
)
]
)
2025-10-09 16:55:23 -04:00
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
2025-10-14 14:22:37 -04:00
yaxis_title="Number of Parameters",
2025-10-20 14:44:51 -04:00
template="plotly_dark",
2025-10-09 16:55:23 -04:00
)
2025-10-20 14:44:51 -04:00
2025-10-09 16:55:23 -04:00
return fig
2025-10-20 14:44:51 -04:00
2025-10-09 16:55:23 -04:00
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
"""Visualize weights for a specific layer."""
2025-10-14 14:22:37 -04:00
if not state_dict:
fig = go.Figure()
fig.add_annotation(
2025-10-20 14:44:51 -04:00
text="No data in state dict", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16)
2025-10-14 14:22:37 -04:00
)
2025-10-20 14:44:51 -04:00
fig.update_layout(title="Layer Weights", template="plotly_dark")
2025-10-14 14:22:37 -04:00
return fig
2025-10-20 14:44:51 -04:00
2025-10-09 16:55:23 -04:00
if layer_name is None:
# Get first weight tensor
2025-10-20 14:44:51 -04:00
weight_keys = [k for k in state_dict.keys() if "weight" in k]
2025-10-09 16:55:23 -04:00
if not weight_keys:
2025-10-14 14:22:37 -04:00
fig = go.Figure()
fig.add_annotation(
text="No weight tensors found in state dict",
2025-10-20 14:44:51 -04:00
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=16),
2025-10-14 14:22:37 -04:00
)
2025-10-20 14:44:51 -04:00
fig.update_layout(title="Layer Weights", template="plotly_dark")
2025-10-14 14:22:37 -04:00
return fig
2025-10-09 16:55:23 -04:00
layer_name = weight_keys[0]
2025-10-20 14:44:51 -04:00
2025-10-14 14:22:37 -04:00
try:
weights = state_dict[layer_name]
2025-10-20 14:44:51 -04:00
2025-10-14 14:22:37 -04:00
# Convert to numpy if it's a torch tensor
2025-10-20 14:44:51 -04:00
if hasattr(weights, "numpy"):
weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy()
elif hasattr(weights, "cpu"):
2025-10-14 14:22:37 -04:00
weights_np = weights.cpu().detach().numpy()
else:
weights_np = np.array(weights)
2025-10-20 14:44:51 -04:00
2025-10-14 14:22:37 -04:00
# For 2D weights, create heatmap
if len(weights_np.shape) == 2:
2025-10-20 14:44:51 -04:00
fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0))
fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark")
2025-10-14 14:22:37 -04:00
else:
# For other shapes, flatten and show histogram
flat_weights = weights_np.flatten()
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
2025-10-20 14:44:51 -04:00
fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark")
2025-10-14 14:22:37 -04:00
return fig
2025-10-20 14:44:51 -04:00
2025-10-14 14:22:37 -04:00
except Exception as e:
fig = go.Figure()
fig.add_annotation(
text=f"Error processing layer {layer_name}: {str(e)}",
2025-10-20 14:44:51 -04:00
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=14),
2025-10-14 14:22:37 -04:00
)
2025-10-20 14:44:51 -04:00
fig.update_layout(title="Layer Weights - Error", template="plotly_dark")
2025-10-14 14:22:37 -04:00
return fig
2025-10-09 16:55:23 -04:00
2025-10-20 14:44:51 -04:00
2025-10-09 16:55:23 -04:00
def weight_distribution_plot(state_dict: dict) -> Figure:
"""Show distribution of weights across all layers."""
2025-10-14 14:22:37 -04:00
if not state_dict:
fig = go.Figure()
fig.add_annotation(
2025-10-20 14:44:51 -04:00
text="No data in state dict", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16)
2025-10-14 14:22:37 -04:00
)
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
2025-10-20 14:44:51 -04:00
template="plotly_dark",
2025-10-14 14:22:37 -04:00
)
return fig
2025-10-20 14:44:51 -04:00
2025-10-09 16:55:23 -04:00
all_weights = []
layer_names = []
2025-10-20 14:44:51 -04:00
2025-10-09 16:55:23 -04:00
for key, tensor in state_dict.items():
2025-10-20 14:44:51 -04:00
if "weight" in key:
2025-10-14 14:22:37 -04:00
try:
# Convert to numpy if it's a torch tensor
2025-10-20 14:44:51 -04:00
if hasattr(tensor, "numpy"):
weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy()
elif hasattr(tensor, "cpu"):
2025-10-14 14:22:37 -04:00
weights_np = tensor.cpu().detach().numpy()
else:
weights_np = np.array(tensor)
2025-10-20 14:44:51 -04:00
2025-10-14 14:22:37 -04:00
flat_weights = weights_np.flatten()
all_weights.extend(flat_weights)
layer_names.extend([key] * len(flat_weights))
except Exception as e:
print(f"Warning: Could not process weights for layer {key}: {e}")
continue
2025-10-20 14:44:51 -04:00
2025-10-14 14:22:37 -04:00
if not all_weights:
fig = go.Figure()
fig.add_annotation(
text="No weight data found in state dict",
2025-10-20 14:44:51 -04:00
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=16),
2025-10-14 14:22:37 -04:00
)
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
2025-10-20 14:44:51 -04:00
template="plotly_dark",
2025-10-14 14:22:37 -04:00
)
return fig
2025-10-20 14:44:51 -04:00
fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")])
2025-10-09 16:55:23 -04:00
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
2025-10-14 14:22:37 -04:00
yaxis_title="Frequency",
2025-10-20 14:44:51 -04:00
template="plotly_dark",
2025-10-09 16:55:23 -04:00
)
2025-10-20 14:44:51 -04:00
return fig