diff --git a/src/ria_toolkit_oss/viz/onnx.py b/src/ria_toolkit_oss/viz/onnx.py
index b92c3e4..e260eeb 100644
--- a/src/ria_toolkit_oss/viz/onnx.py
+++ b/src/ria_toolkit_oss/viz/onnx.py
@@ -75,9 +75,7 @@ def graph_structure(file_path: Path) -> go.Figure:
"""
if not ONNX_AVAILABLE:
return create_styled_error_figure(
- "ONNX Not Available",
- "ONNX library is required for model analysis.",
- "Install with: pip install onnx"
+ "ONNX Not Available", "ONNX library is required for model analysis.", "Install with: pip install onnx"
)
try:
@@ -88,9 +86,7 @@ def graph_structure(file_path: Path) -> go.Figure:
if len(nodes) == 0:
return create_styled_error_figure(
- "Empty Model",
- "This ONNX model contains no operators.",
- "Please check if the model file is valid."
+ "Empty Model", "This ONNX model contains no operators.", "Please check if the model file is valid."
)
# Create network diagram data
@@ -153,8 +149,10 @@ def graph_structure(file_path: Path) -> go.Figure:
fig.update_layout(
title={
- "text": ("ONNX Graph Structure
"
- f"{len(nodes)} Operators"),
+ "text": (
+ "ONNX Graph Structure
"
+ f"{len(nodes)} Operators"
+ ),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
@@ -173,9 +171,7 @@ def graph_structure(file_path: Path) -> go.Figure:
except Exception as e:
return create_styled_error_figure(
- "Graph Analysis Error",
- "Could not analyze ONNX model structure.",
- f"Error: {str(e)}"
+ "Graph Analysis Error", "Could not analyze ONNX model structure.", f"Error: {str(e)}"
)
@@ -186,9 +182,7 @@ def operator_analysis(file_path: Path) -> go.Figure:
"""
if not ONNX_AVAILABLE:
return create_styled_error_figure(
- "ONNX Not Available",
- "ONNX library is required for operator analysis.",
- "Install with: pip install onnx"
+ "ONNX Not Available", "ONNX library is required for operator analysis.", "Install with: pip install onnx"
)
try:
@@ -248,8 +242,10 @@ def operator_analysis(file_path: Path) -> go.Figure:
fig.update_layout(
title={
- "text": ("ONNX Operator Analysis
"
- f"{len(op_counts)} Unique Types"),
+ "text": (
+ "ONNX Operator Analysis
"
+ f"{len(op_counts)} Unique Types"
+ ),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
@@ -262,9 +258,7 @@ def operator_analysis(file_path: Path) -> go.Figure:
except Exception as e:
return create_styled_error_figure(
- "Operator Analysis Error",
- "Could not analyze ONNX operators.",
- f"Error: {str(e)}"
+ "Operator Analysis Error", "Could not analyze ONNX operators.", f"Error: {str(e)}"
)
@@ -275,9 +269,7 @@ def model_metadata(file_path: Path) -> go.Figure:
"""
if not ONNX_AVAILABLE:
return create_styled_error_figure(
- "ONNX Not Available",
- "ONNX library is required for metadata analysis.",
- "Install with: pip install onnx"
+ "ONNX Not Available", "ONNX library is required for metadata analysis.", "Install with: pip install onnx"
)
try:
@@ -336,12 +328,7 @@ def model_metadata(file_path: Path) -> go.Figure:
arch_values = [total_nodes, total_inputs, total_outputs, total_initializers]
fig.add_trace(
- go.Bar(
- x=arch_data,
- y=arch_values,
- marker_color=["blue", "green", "orange", "red"],
- showlegend=False
- ),
+ go.Bar(x=arch_data, y=arch_values, marker_color=["blue", "green", "orange", "red"], showlegend=False),
row=1,
col=2,
)
@@ -402,16 +389,8 @@ def model_metadata(file_path: Path) -> go.Figure:
if io_data:
fig.add_trace(
go.Table(
- header=dict(
- values=["Type", "Name", "Shape", "Data Type"],
- fill_color="lightblue",
- align="left"
- ),
- cells=dict(
- values=list(zip(*io_data)),
- fill_color="white",
- align="left"
- ),
+ header=dict(values=["Type", "Name", "Shape", "Data Type"], fill_color="lightblue", align="left"),
+ cells=dict(values=list(zip(*io_data)), fill_color="white", align="left"),
),
row=2,
col=1,
@@ -432,8 +411,10 @@ def model_metadata(file_path: Path) -> go.Figure:
fig.update_layout(
title={
- "text": ("ONNX Model Metadata
"
- f"{total_params/1e6:.2f}M Parameters"),
+ "text": (
+ "ONNX Model Metadata
"
+ f"{total_params/1e6:.2f}M Parameters"
+ ),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
@@ -447,9 +428,7 @@ def model_metadata(file_path: Path) -> go.Figure:
except Exception as e:
return create_styled_error_figure(
- "Metadata Analysis Error",
- "Could not extract ONNX model metadata.",
- f"Error: {str(e)}"
+ "Metadata Analysis Error", "Could not extract ONNX model metadata.", f"Error: {str(e)}"
)
@@ -508,10 +487,7 @@ def performance_metrics(file_path: Path) -> go.Figure:
fig.add_trace(
go.Bar(
- x=efficiency_metrics,
- y=efficiency_values,
- marker_color=["blue", "green", "orange"],
- showlegend=False
+ x=efficiency_metrics, y=efficiency_values, marker_color=["blue", "green", "orange"], showlegend=False
),
row=1,
col=1,
@@ -522,12 +498,7 @@ def performance_metrics(file_path: Path) -> go.Figure:
memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate
fig.add_trace(
- go.Bar(
- x=memory_types,
- y=memory_values,
- marker_color=["purple", "red"],
- showlegend=False
- ),
+ go.Bar(x=memory_types, y=memory_values, marker_color=["purple", "red"], showlegend=False),
row=1,
col=2,
)
@@ -569,9 +540,11 @@ def performance_metrics(file_path: Path) -> go.Figure:
fig.update_layout(
title={
- "text": ("ONNX Performance Metrics
"
- f""
- f"Complexity Score: {complexity_score:.0f}/100"),
+ "text": (
+ "ONNX Performance Metrics
"
+ f""
+ f"Complexity Score: {complexity_score:.0f}/100"
+ ),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
@@ -585,7 +558,5 @@ def performance_metrics(file_path: Path) -> go.Figure:
except Exception as e:
return create_styled_error_figure(
- "Performance Analysis Error",
- "Could not analyze ONNX model performance.",
- f"Error: {str(e)}"
+ "Performance Analysis Error", "Could not analyze ONNX model performance.", f"Error: {str(e)}"
)
diff --git a/src/ria_toolkit_oss/viz/pytorch_state_dict.py b/src/ria_toolkit_oss/viz/pytorch_state_dict.py
index 578ebd0..6c625bc 100644
--- a/src/ria_toolkit_oss/viz/pytorch_state_dict.py
+++ b/src/ria_toolkit_oss/viz/pytorch_state_dict.py
@@ -56,29 +56,25 @@ def model_summary_plot(state_dict: dict) -> Figure:
return create_styled_error_figure(
"Empty State Dict",
"No parameters found in state dict",
- "Ensure the model state dictionary contains weight parameters"
+ "Ensure the model state dictionary contains weight parameters",
)
# Count parameters by layer type
layer_info = []
for key, tensor in state_dict.items():
- if 'weight' in key:
+ if "weight" in key:
try:
- layer_name = key.replace('.weight', '')
+ layer_name = key.replace(".weight", "")
param_count = (
- tensor.numel() if hasattr(tensor, 'numel')
- else len(tensor.flatten()) if hasattr(tensor, 'flatten')
- else 0
+ tensor.numel()
+ if hasattr(tensor, "numel")
+ else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0
)
shape = (
- list(tensor.shape) if hasattr(tensor, 'shape')
- else [len(tensor)] if hasattr(tensor, '__len__')
- else []
+ list(tensor.shape)
+ if hasattr(tensor, "shape")
+ else [len(tensor)] if hasattr(tensor, "__len__") else []
)
- layer_info.append({
- 'layer': layer_name,
- 'parameters': param_count,
- 'shape': shape
- })
+ layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape})
except Exception as e:
print(f"Warning: Could not process layer {key}: {e}")
continue
@@ -86,22 +82,24 @@ def model_summary_plot(state_dict: dict) -> Figure:
return create_styled_error_figure(
"No Weight Layers Found",
"No weight layers found in state dict",
- "Ensure the state dictionary contains layers with '.weight' parameters"
+ "Ensure the state dictionary contains layers with '.weight' parameters",
)
# Create bar chart of parameter counts
- fig = go.Figure(data=[
- go.Bar(
- x=[info['layer'] for info in layer_info],
- y=[info['parameters'] for info in layer_info],
- text=[f"Shape: {info['shape']}" for info in layer_info],
- textposition='auto',
- )
- ])
+ fig = go.Figure(
+ data=[
+ go.Bar(
+ x=[info["layer"] for info in layer_info],
+ y=[info["parameters"] for info in layer_info],
+ text=[f"Shape: {info['shape']}" for info in layer_info],
+ textposition="auto",
+ )
+ ]
+ )
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
- template="plotly_dark"
+ template="plotly_dark",
)
return fig
@@ -110,48 +108,36 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
"""Visualize weights for a specific layer."""
if not state_dict:
return create_styled_error_figure(
- "Empty State Dict",
- "No data in state dict",
- "Ensure the model state dictionary contains data"
+ "Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data"
)
if layer_name is None:
# Get first weight tensor
- weight_keys = [k for k in state_dict.keys() if 'weight' in k]
+ weight_keys = [k for k in state_dict.keys() if "weight" in k]
if not weight_keys:
return create_styled_error_figure(
"No Weight Tensors Found",
"No weight tensors found in state dict",
- "Ensure the state dictionary contains layers with '.weight' parameters"
+ "Ensure the state dictionary contains layers with '.weight' parameters",
)
layer_name = weight_keys[0]
try:
weights = state_dict[layer_name]
# Convert to numpy if it's a torch tensor
- if hasattr(weights, 'numpy'):
- weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy()
- elif hasattr(weights, 'cpu'):
+ if hasattr(weights, "numpy"):
+ weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy()
+ elif hasattr(weights, "cpu"):
weights_np = weights.cpu().detach().numpy()
else:
weights_np = np.array(weights)
# For 2D weights, create heatmap
if len(weights_np.shape) == 2:
- fig = go.Figure(data=go.Heatmap(
- z=weights_np,
- colorscale='RdBu',
- zmid=0
- ))
- fig.update_layout(
- title=f"Weights Heatmap: {layer_name}",
- template="plotly_dark"
- )
+ fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0))
+ fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark")
else:
# For other shapes, flatten and show histogram
flat_weights = weights_np.flatten()
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
- fig.update_layout(
- title=f"Weight Distribution: {layer_name}",
- template="plotly_dark"
- )
+ fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark")
return fig
@@ -159,7 +145,7 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
return create_styled_error_figure(
"Layer Processing Error",
f"Error processing layer {layer_name}: {str(e)}",
- "Check that the layer name exists and contains valid tensor data"
+ "Check that the layer name exists and contains valid tensor data",
)
@@ -167,21 +153,19 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
"""Show distribution of weights across all layers."""
if not state_dict:
return create_styled_error_figure(
- "Empty State Dict",
- "No data in state dict",
- "Ensure the model state dictionary contains data"
+ "Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data"
)
all_weights = []
layer_names = []
for key, tensor in state_dict.items():
- if 'weight' in key:
+ if "weight" in key:
try:
# Convert to numpy if it's a torch tensor
- if hasattr(tensor, 'numpy'):
- weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy()
- elif hasattr(tensor, 'cpu'):
+ if hasattr(tensor, "numpy"):
+ weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy()
+ elif hasattr(tensor, "cpu"):
weights_np = tensor.cpu().detach().numpy()
else:
weights_np = np.array(tensor)
@@ -196,21 +180,15 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
return create_styled_error_figure(
"No Weight Data Found",
"No weight data found in state dict",
- "Ensure the state dictionary contains layers with '.weight' parameters"
+ "Ensure the state dictionary contains layers with '.weight' parameters",
)
- fig = go.Figure(data=[
- go.Histogram(
- x=all_weights,
- nbinsx=100,
- name="All Weights"
- )
- ])
+ fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")])
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
- template="plotly_dark"
+ template="plotly_dark",
)
return fig
diff --git a/src/ria_toolkit_oss/viz/radio_dataset.py b/src/ria_toolkit_oss/viz/radio_dataset.py
index cae4084..a96b4d2 100644
--- a/src/ria_toolkit_oss/viz/radio_dataset.py
+++ b/src/ria_toolkit_oss/viz/radio_dataset.py
@@ -133,8 +133,10 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
return create_styled_error_figure(
"No Class Labels Found",
"This dataset contains numerical data without categorical labels.",
- ("Try using the Dataset Overview widget for data analysis, "
- "or check if your dataset has hidden categorical columns."),
+ (
+ "Try using the Dataset Overview widget for data analysis, "
+ "or check if your dataset has hidden categorical columns."
+ ),
)
# Count examples per class (limit to top 20 for performance)
@@ -355,8 +357,16 @@ def _compute_spectrogram(sample_data, nperseg: int, hop_length: int, n_frames: i
return Sxx
-def _create_spectrogram_figure(Sxx, n_frames: int, hop_length: int, n_samples: int, freq_bins: int,
- sample_idx: int, class_key: str, sample_metadata) -> Figure:
+def _create_spectrogram_figure(
+ Sxx,
+ n_frames: int,
+ hop_length: int,
+ n_samples: int,
+ freq_bins: int,
+ sample_idx: int,
+ class_key: str,
+ sample_metadata,
+) -> Figure:
"""Create the plotly figure for the spectrogram."""
# Convert to dB
Sxx_db = 10 * np.log10(Sxx + 1e-10)
@@ -410,8 +420,9 @@ def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx:
Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins)
# Create and return the figure
- return _create_spectrogram_figure(Sxx, n_frames, hop_length, n_samples, freq_bins,
- sample_idx, class_key, sample_metadata)
+ return _create_spectrogram_figure(
+ Sxx, n_frames, hop_length, n_samples, freq_bins, sample_idx, class_key, sample_metadata
+ )
except Exception as e:
return create_styled_error_figure(