diff --git a/src/ria_toolkit_oss/viz/onnx.py b/src/ria_toolkit_oss/viz/onnx.py index b92c3e4..e260eeb 100644 --- a/src/ria_toolkit_oss/viz/onnx.py +++ b/src/ria_toolkit_oss/viz/onnx.py @@ -75,9 +75,7 @@ def graph_structure(file_path: Path) -> go.Figure: """ if not ONNX_AVAILABLE: return create_styled_error_figure( - "ONNX Not Available", - "ONNX library is required for model analysis.", - "Install with: pip install onnx" + "ONNX Not Available", "ONNX library is required for model analysis.", "Install with: pip install onnx" ) try: @@ -88,9 +86,7 @@ def graph_structure(file_path: Path) -> go.Figure: if len(nodes) == 0: return create_styled_error_figure( - "Empty Model", - "This ONNX model contains no operators.", - "Please check if the model file is valid." + "Empty Model", "This ONNX model contains no operators.", "Please check if the model file is valid." ) # Create network diagram data @@ -153,8 +149,10 @@ def graph_structure(file_path: Path) -> go.Figure: fig.update_layout( title={ - "text": ("ONNX Graph Structure
" - f"{len(nodes)} Operators"), + "text": ( + "ONNX Graph Structure
" + f"{len(nodes)} Operators" + ), "x": 0.5, "xanchor": "center", "font": {"size": 22}, @@ -173,9 +171,7 @@ def graph_structure(file_path: Path) -> go.Figure: except Exception as e: return create_styled_error_figure( - "Graph Analysis Error", - "Could not analyze ONNX model structure.", - f"Error: {str(e)}" + "Graph Analysis Error", "Could not analyze ONNX model structure.", f"Error: {str(e)}" ) @@ -186,9 +182,7 @@ def operator_analysis(file_path: Path) -> go.Figure: """ if not ONNX_AVAILABLE: return create_styled_error_figure( - "ONNX Not Available", - "ONNX library is required for operator analysis.", - "Install with: pip install onnx" + "ONNX Not Available", "ONNX library is required for operator analysis.", "Install with: pip install onnx" ) try: @@ -248,8 +242,10 @@ def operator_analysis(file_path: Path) -> go.Figure: fig.update_layout( title={ - "text": ("ONNX Operator Analysis
" - f"{len(op_counts)} Unique Types"), + "text": ( + "ONNX Operator Analysis
" + f"{len(op_counts)} Unique Types" + ), "x": 0.5, "xanchor": "center", "font": {"size": 22}, @@ -262,9 +258,7 @@ def operator_analysis(file_path: Path) -> go.Figure: except Exception as e: return create_styled_error_figure( - "Operator Analysis Error", - "Could not analyze ONNX operators.", - f"Error: {str(e)}" + "Operator Analysis Error", "Could not analyze ONNX operators.", f"Error: {str(e)}" ) @@ -275,9 +269,7 @@ def model_metadata(file_path: Path) -> go.Figure: """ if not ONNX_AVAILABLE: return create_styled_error_figure( - "ONNX Not Available", - "ONNX library is required for metadata analysis.", - "Install with: pip install onnx" + "ONNX Not Available", "ONNX library is required for metadata analysis.", "Install with: pip install onnx" ) try: @@ -336,12 +328,7 @@ def model_metadata(file_path: Path) -> go.Figure: arch_values = [total_nodes, total_inputs, total_outputs, total_initializers] fig.add_trace( - go.Bar( - x=arch_data, - y=arch_values, - marker_color=["blue", "green", "orange", "red"], - showlegend=False - ), + go.Bar(x=arch_data, y=arch_values, marker_color=["blue", "green", "orange", "red"], showlegend=False), row=1, col=2, ) @@ -402,16 +389,8 @@ def model_metadata(file_path: Path) -> go.Figure: if io_data: fig.add_trace( go.Table( - header=dict( - values=["Type", "Name", "Shape", "Data Type"], - fill_color="lightblue", - align="left" - ), - cells=dict( - values=list(zip(*io_data)), - fill_color="white", - align="left" - ), + header=dict(values=["Type", "Name", "Shape", "Data Type"], fill_color="lightblue", align="left"), + cells=dict(values=list(zip(*io_data)), fill_color="white", align="left"), ), row=2, col=1, @@ -432,8 +411,10 @@ def model_metadata(file_path: Path) -> go.Figure: fig.update_layout( title={ - "text": ("ONNX Model Metadata
" - f"{total_params/1e6:.2f}M Parameters"), + "text": ( + "ONNX Model Metadata
" + f"{total_params/1e6:.2f}M Parameters" + ), "x": 0.5, "xanchor": "center", "font": {"size": 22}, @@ -447,9 +428,7 @@ def model_metadata(file_path: Path) -> go.Figure: except Exception as e: return create_styled_error_figure( - "Metadata Analysis Error", - "Could not extract ONNX model metadata.", - f"Error: {str(e)}" + "Metadata Analysis Error", "Could not extract ONNX model metadata.", f"Error: {str(e)}" ) @@ -508,10 +487,7 @@ def performance_metrics(file_path: Path) -> go.Figure: fig.add_trace( go.Bar( - x=efficiency_metrics, - y=efficiency_values, - marker_color=["blue", "green", "orange"], - showlegend=False + x=efficiency_metrics, y=efficiency_values, marker_color=["blue", "green", "orange"], showlegend=False ), row=1, col=1, @@ -522,12 +498,7 @@ def performance_metrics(file_path: Path) -> go.Figure: memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate fig.add_trace( - go.Bar( - x=memory_types, - y=memory_values, - marker_color=["purple", "red"], - showlegend=False - ), + go.Bar(x=memory_types, y=memory_values, marker_color=["purple", "red"], showlegend=False), row=1, col=2, ) @@ -569,9 +540,11 @@ def performance_metrics(file_path: Path) -> go.Figure: fig.update_layout( title={ - "text": ("ONNX Performance Metrics
" - f"" - f"Complexity Score: {complexity_score:.0f}/100"), + "text": ( + "ONNX Performance Metrics
" + f"" + f"Complexity Score: {complexity_score:.0f}/100" + ), "x": 0.5, "xanchor": "center", "font": {"size": 22}, @@ -585,7 +558,5 @@ def performance_metrics(file_path: Path) -> go.Figure: except Exception as e: return create_styled_error_figure( - "Performance Analysis Error", - "Could not analyze ONNX model performance.", - f"Error: {str(e)}" + "Performance Analysis Error", "Could not analyze ONNX model performance.", f"Error: {str(e)}" ) diff --git a/src/ria_toolkit_oss/viz/pytorch_state_dict.py b/src/ria_toolkit_oss/viz/pytorch_state_dict.py index 578ebd0..6c625bc 100644 --- a/src/ria_toolkit_oss/viz/pytorch_state_dict.py +++ b/src/ria_toolkit_oss/viz/pytorch_state_dict.py @@ -56,29 +56,25 @@ def model_summary_plot(state_dict: dict) -> Figure: return create_styled_error_figure( "Empty State Dict", "No parameters found in state dict", - "Ensure the model state dictionary contains weight parameters" + "Ensure the model state dictionary contains weight parameters", ) # Count parameters by layer type layer_info = [] for key, tensor in state_dict.items(): - if 'weight' in key: + if "weight" in key: try: - layer_name = key.replace('.weight', '') + layer_name = key.replace(".weight", "") param_count = ( - tensor.numel() if hasattr(tensor, 'numel') - else len(tensor.flatten()) if hasattr(tensor, 'flatten') - else 0 + tensor.numel() + if hasattr(tensor, "numel") + else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0 ) shape = ( - list(tensor.shape) if hasattr(tensor, 'shape') - else [len(tensor)] if hasattr(tensor, '__len__') - else [] + list(tensor.shape) + if hasattr(tensor, "shape") + else [len(tensor)] if hasattr(tensor, "__len__") else [] ) - layer_info.append({ - 'layer': layer_name, - 'parameters': param_count, - 'shape': shape - }) + layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape}) except Exception as e: print(f"Warning: Could not process layer {key}: {e}") continue @@ -86,22 +82,24 @@ def model_summary_plot(state_dict: dict) -> Figure: return create_styled_error_figure( "No Weight Layers Found", "No weight layers found in state dict", - "Ensure the state dictionary contains layers with '.weight' parameters" + "Ensure the state dictionary contains layers with '.weight' parameters", ) # Create bar chart of parameter counts - fig = go.Figure(data=[ - go.Bar( - x=[info['layer'] for info in layer_info], - y=[info['parameters'] for info in layer_info], - text=[f"Shape: {info['shape']}" for info in layer_info], - textposition='auto', - ) - ]) + fig = go.Figure( + data=[ + go.Bar( + x=[info["layer"] for info in layer_info], + y=[info["parameters"] for info in layer_info], + text=[f"Shape: {info['shape']}" for info in layer_info], + textposition="auto", + ) + ] + ) fig.update_layout( title="Model Layer Parameter Counts", xaxis_title="Layer", yaxis_title="Number of Parameters", - template="plotly_dark" + template="plotly_dark", ) return fig @@ -110,48 +108,36 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure: """Visualize weights for a specific layer.""" if not state_dict: return create_styled_error_figure( - "Empty State Dict", - "No data in state dict", - "Ensure the model state dictionary contains data" + "Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data" ) if layer_name is None: # Get first weight tensor - weight_keys = [k for k in state_dict.keys() if 'weight' in k] + weight_keys = [k for k in state_dict.keys() if "weight" in k] if not weight_keys: return create_styled_error_figure( "No Weight Tensors Found", "No weight tensors found in state dict", - "Ensure the state dictionary contains layers with '.weight' parameters" + "Ensure the state dictionary contains layers with '.weight' parameters", ) layer_name = weight_keys[0] try: weights = state_dict[layer_name] # Convert to numpy if it's a torch tensor - if hasattr(weights, 'numpy'): - weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy() - elif hasattr(weights, 'cpu'): + if hasattr(weights, "numpy"): + weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy() + elif hasattr(weights, "cpu"): weights_np = weights.cpu().detach().numpy() else: weights_np = np.array(weights) # For 2D weights, create heatmap if len(weights_np.shape) == 2: - fig = go.Figure(data=go.Heatmap( - z=weights_np, - colorscale='RdBu', - zmid=0 - )) - fig.update_layout( - title=f"Weights Heatmap: {layer_name}", - template="plotly_dark" - ) + fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0)) + fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark") else: # For other shapes, flatten and show histogram flat_weights = weights_np.flatten() fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)]) - fig.update_layout( - title=f"Weight Distribution: {layer_name}", - template="plotly_dark" - ) + fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark") return fig @@ -159,7 +145,7 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure: return create_styled_error_figure( "Layer Processing Error", f"Error processing layer {layer_name}: {str(e)}", - "Check that the layer name exists and contains valid tensor data" + "Check that the layer name exists and contains valid tensor data", ) @@ -167,21 +153,19 @@ def weight_distribution_plot(state_dict: dict) -> Figure: """Show distribution of weights across all layers.""" if not state_dict: return create_styled_error_figure( - "Empty State Dict", - "No data in state dict", - "Ensure the model state dictionary contains data" + "Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data" ) all_weights = [] layer_names = [] for key, tensor in state_dict.items(): - if 'weight' in key: + if "weight" in key: try: # Convert to numpy if it's a torch tensor - if hasattr(tensor, 'numpy'): - weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy() - elif hasattr(tensor, 'cpu'): + if hasattr(tensor, "numpy"): + weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy() + elif hasattr(tensor, "cpu"): weights_np = tensor.cpu().detach().numpy() else: weights_np = np.array(tensor) @@ -196,21 +180,15 @@ def weight_distribution_plot(state_dict: dict) -> Figure: return create_styled_error_figure( "No Weight Data Found", "No weight data found in state dict", - "Ensure the state dictionary contains layers with '.weight' parameters" + "Ensure the state dictionary contains layers with '.weight' parameters", ) - fig = go.Figure(data=[ - go.Histogram( - x=all_weights, - nbinsx=100, - name="All Weights" - ) - ]) + fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")]) fig.update_layout( title="Overall Weight Distribution", xaxis_title="Weight Value", yaxis_title="Frequency", - template="plotly_dark" + template="plotly_dark", ) return fig diff --git a/src/ria_toolkit_oss/viz/radio_dataset.py b/src/ria_toolkit_oss/viz/radio_dataset.py index cae4084..a96b4d2 100644 --- a/src/ria_toolkit_oss/viz/radio_dataset.py +++ b/src/ria_toolkit_oss/viz/radio_dataset.py @@ -133,8 +133,10 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure: return create_styled_error_figure( "No Class Labels Found", "This dataset contains numerical data without categorical labels.", - ("Try using the Dataset Overview widget for data analysis, " - "or check if your dataset has hidden categorical columns."), + ( + "Try using the Dataset Overview widget for data analysis, " + "or check if your dataset has hidden categorical columns." + ), ) # Count examples per class (limit to top 20 for performance) @@ -355,8 +357,16 @@ def _compute_spectrogram(sample_data, nperseg: int, hop_length: int, n_frames: i return Sxx -def _create_spectrogram_figure(Sxx, n_frames: int, hop_length: int, n_samples: int, freq_bins: int, - sample_idx: int, class_key: str, sample_metadata) -> Figure: +def _create_spectrogram_figure( + Sxx, + n_frames: int, + hop_length: int, + n_samples: int, + freq_bins: int, + sample_idx: int, + class_key: str, + sample_metadata, +) -> Figure: """Create the plotly figure for the spectrogram.""" # Convert to dB Sxx_db = 10 * np.log10(Sxx + 1e-10) @@ -410,8 +420,9 @@ def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx: Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins) # Create and return the figure - return _create_spectrogram_figure(Sxx, n_frames, hop_length, n_samples, freq_bins, - sample_idx, class_key, sample_metadata) + return _create_spectrogram_figure( + Sxx, n_frames, hop_length, n_samples, freq_bins, sample_idx, class_key, sample_metadata + ) except Exception as e: return create_styled_error_figure(