Pytorch state dict widget
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 17s
Test with tox / Test with tox (3.11) (pull_request) Successful in 34s
Test with tox / Test with tox (3.12) (pull_request) Successful in 32s
Test with tox / Test with tox (3.10) (pull_request) Failing after 42s
Build Project / Build Project (3.10) (pull_request) Successful in 52s
Build Project / Build Project (3.11) (pull_request) Successful in 51s
Build Project / Build Project (3.12) (pull_request) Successful in 51s
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 17s
Test with tox / Test with tox (3.11) (pull_request) Successful in 34s
Test with tox / Test with tox (3.12) (pull_request) Successful in 32s
Test with tox / Test with tox (3.10) (pull_request) Failing after 42s
Build Project / Build Project (3.10) (pull_request) Successful in 52s
Build Project / Build Project (3.11) (pull_request) Successful in 51s
Build Project / Build Project (3.12) (pull_request) Successful in 51s
This commit is contained in:
parent
1fb55607a2
commit
f430e626a6
|
|
@ -5,17 +5,56 @@ import numpy as np
|
|||
|
||||
def model_summary_plot(state_dict: dict) -> Figure:
|
||||
"""Generate a summary plot of the PyTorch model state dict."""
|
||||
if not state_dict:
|
||||
# Handle empty state dict
|
||||
fig = go.Figure()
|
||||
fig.add_annotation(
|
||||
text="No parameters found in state dict",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, showarrow=False,
|
||||
font=dict(size=16)
|
||||
)
|
||||
fig.update_layout(
|
||||
title="Model Layer Parameter Counts",
|
||||
xaxis_title="Layer",
|
||||
yaxis_title="Number of Parameters",
|
||||
template="plotly_dark"
|
||||
)
|
||||
return fig
|
||||
|
||||
# Count parameters by layer type
|
||||
layer_info = []
|
||||
for key, tensor in state_dict.items():
|
||||
if 'weight' in key:
|
||||
layer_name = key.replace('.weight', '')
|
||||
param_count = tensor.numel()
|
||||
layer_info.append({
|
||||
'layer': layer_name,
|
||||
'parameters': param_count,
|
||||
'shape': list(tensor.shape)
|
||||
})
|
||||
try:
|
||||
layer_name = key.replace('.weight', '')
|
||||
param_count = tensor.numel() if hasattr(tensor, 'numel') else len(tensor.flatten()) if hasattr(tensor, 'flatten') else 0
|
||||
shape = list(tensor.shape) if hasattr(tensor, 'shape') else [len(tensor)] if hasattr(tensor, '__len__') else []
|
||||
layer_info.append({
|
||||
'layer': layer_name,
|
||||
'parameters': param_count,
|
||||
'shape': shape
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process layer {key}: {e}")
|
||||
continue
|
||||
|
||||
if not layer_info:
|
||||
# Handle case where no weight layers found
|
||||
fig = go.Figure()
|
||||
fig.add_annotation(
|
||||
text="No weight layers found in state dict",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, showarrow=False,
|
||||
font=dict(size=16)
|
||||
)
|
||||
fig.update_layout(
|
||||
title="Model Layer Parameter Counts",
|
||||
xaxis_title="Layer",
|
||||
yaxis_title="Number of Parameters",
|
||||
template="plotly_dark"
|
||||
)
|
||||
return fig
|
||||
|
||||
# Create bar chart of parameter counts
|
||||
fig = go.Figure(data=[
|
||||
|
|
@ -30,47 +69,147 @@ def model_summary_plot(state_dict: dict) -> Figure:
|
|||
fig.update_layout(
|
||||
title="Model Layer Parameter Counts",
|
||||
xaxis_title="Layer",
|
||||
yaxis_title="Number of Parameters"
|
||||
yaxis_title="Number of Parameters",
|
||||
template="plotly_dark"
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
|
||||
"""Visualize weights for a specific layer."""
|
||||
if not state_dict:
|
||||
fig = go.Figure()
|
||||
fig.add_annotation(
|
||||
text="No data in state dict",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, showarrow=False,
|
||||
font=dict(size=16)
|
||||
)
|
||||
fig.update_layout(
|
||||
title="Layer Weights",
|
||||
template="plotly_dark"
|
||||
)
|
||||
return fig
|
||||
|
||||
if layer_name is None:
|
||||
# Get first weight tensor
|
||||
weight_keys = [k for k in state_dict.keys() if 'weight' in k]
|
||||
if not weight_keys:
|
||||
raise ValueError("No weight tensors found in state dict")
|
||||
fig = go.Figure()
|
||||
fig.add_annotation(
|
||||
text="No weight tensors found in state dict",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, showarrow=False,
|
||||
font=dict(size=16)
|
||||
)
|
||||
fig.update_layout(
|
||||
title="Layer Weights",
|
||||
template="plotly_dark"
|
||||
)
|
||||
return fig
|
||||
layer_name = weight_keys[0]
|
||||
|
||||
weights = state_dict[layer_name]
|
||||
try:
|
||||
weights = state_dict[layer_name]
|
||||
|
||||
# For 2D weights, create heatmap
|
||||
if len(weights.shape) == 2:
|
||||
fig = go.Figure(data=go.Heatmap(
|
||||
z=weights.numpy(),
|
||||
colorscale='RdBu',
|
||||
zmid=0
|
||||
))
|
||||
fig.update_layout(title=f"Weights Heatmap: {layer_name}")
|
||||
else:
|
||||
# For other shapes, flatten and show histogram
|
||||
flat_weights = weights.flatten().numpy()
|
||||
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
|
||||
fig.update_layout(title=f"Weight Distribution: {layer_name}")
|
||||
# Convert to numpy if it's a torch tensor
|
||||
if hasattr(weights, 'numpy'):
|
||||
weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy()
|
||||
elif hasattr(weights, 'cpu'):
|
||||
weights_np = weights.cpu().detach().numpy()
|
||||
else:
|
||||
weights_np = np.array(weights)
|
||||
|
||||
return fig
|
||||
# For 2D weights, create heatmap
|
||||
if len(weights_np.shape) == 2:
|
||||
fig = go.Figure(data=go.Heatmap(
|
||||
z=weights_np,
|
||||
colorscale='RdBu',
|
||||
zmid=0
|
||||
))
|
||||
fig.update_layout(
|
||||
title=f"Weights Heatmap: {layer_name}",
|
||||
template="plotly_dark"
|
||||
)
|
||||
else:
|
||||
# For other shapes, flatten and show histogram
|
||||
flat_weights = weights_np.flatten()
|
||||
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
|
||||
fig.update_layout(
|
||||
title=f"Weight Distribution: {layer_name}",
|
||||
template="plotly_dark"
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
fig = go.Figure()
|
||||
fig.add_annotation(
|
||||
text=f"Error processing layer {layer_name}: {str(e)}",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, showarrow=False,
|
||||
font=dict(size=14)
|
||||
)
|
||||
fig.update_layout(
|
||||
title="Layer Weights - Error",
|
||||
template="plotly_dark"
|
||||
)
|
||||
return fig
|
||||
|
||||
def weight_distribution_plot(state_dict: dict) -> Figure:
|
||||
"""Show distribution of weights across all layers."""
|
||||
if not state_dict:
|
||||
fig = go.Figure()
|
||||
fig.add_annotation(
|
||||
text="No data in state dict",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, showarrow=False,
|
||||
font=dict(size=16)
|
||||
)
|
||||
fig.update_layout(
|
||||
title="Overall Weight Distribution",
|
||||
xaxis_title="Weight Value",
|
||||
yaxis_title="Frequency",
|
||||
template="plotly_dark"
|
||||
)
|
||||
return fig
|
||||
|
||||
all_weights = []
|
||||
layer_names = []
|
||||
|
||||
for key, tensor in state_dict.items():
|
||||
if 'weight' in key:
|
||||
all_weights.extend(tensor.flatten().numpy())
|
||||
layer_names.extend([key] * tensor.numel())
|
||||
try:
|
||||
# Convert to numpy if it's a torch tensor
|
||||
if hasattr(tensor, 'numpy'):
|
||||
weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy()
|
||||
elif hasattr(tensor, 'cpu'):
|
||||
weights_np = tensor.cpu().detach().numpy()
|
||||
else:
|
||||
weights_np = np.array(tensor)
|
||||
|
||||
flat_weights = weights_np.flatten()
|
||||
all_weights.extend(flat_weights)
|
||||
layer_names.extend([key] * len(flat_weights))
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process weights for layer {key}: {e}")
|
||||
continue
|
||||
|
||||
if not all_weights:
|
||||
fig = go.Figure()
|
||||
fig.add_annotation(
|
||||
text="No weight data found in state dict",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, showarrow=False,
|
||||
font=dict(size=16)
|
||||
)
|
||||
fig.update_layout(
|
||||
title="Overall Weight Distribution",
|
||||
xaxis_title="Weight Value",
|
||||
yaxis_title="Frequency",
|
||||
template="plotly_dark"
|
||||
)
|
||||
return fig
|
||||
|
||||
fig = go.Figure(data=[
|
||||
go.Histogram(
|
||||
|
|
@ -83,7 +222,8 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
|
|||
fig.update_layout(
|
||||
title="Overall Weight Distribution",
|
||||
xaxis_title="Weight Value",
|
||||
yaxis_title="Frequency"
|
||||
yaxis_title="Frequency",
|
||||
template="plotly_dark"
|
||||
)
|
||||
|
||||
return fig
|
||||
Loading…
Reference in New Issue
Block a user