From 1fb55607a24bab6b2bf7e0ef249d99fe312d110d Mon Sep 17 00:00:00 2001 From: ben Date: Thu, 9 Oct 2025 16:55:23 -0400 Subject: [PATCH 1/8] Pytorch widgets --- src/ria_toolkit_oss/viz/pytorch_state_dict.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 src/ria_toolkit_oss/viz/pytorch_state_dict.py diff --git a/src/ria_toolkit_oss/viz/pytorch_state_dict.py b/src/ria_toolkit_oss/viz/pytorch_state_dict.py new file mode 100644 index 0000000..05c96f4 --- /dev/null +++ b/src/ria_toolkit_oss/viz/pytorch_state_dict.py @@ -0,0 +1,89 @@ +import torch +import plotly.graph_objects as go +from plotly.graph_objects import Figure +import numpy as np + +def model_summary_plot(state_dict: dict) -> Figure: + """Generate a summary plot of the PyTorch model state dict.""" + # Count parameters by layer type + layer_info = [] + for key, tensor in state_dict.items(): + if 'weight' in key: + layer_name = key.replace('.weight', '') + param_count = tensor.numel() + layer_info.append({ + 'layer': layer_name, + 'parameters': param_count, + 'shape': list(tensor.shape) + }) + + # Create bar chart of parameter counts + fig = go.Figure(data=[ + go.Bar( + x=[info['layer'] for info in layer_info], + y=[info['parameters'] for info in layer_info], + text=[f"Shape: {info['shape']}" for info in layer_info], + textposition='auto', + ) + ]) + + fig.update_layout( + title="Model Layer Parameter Counts", + xaxis_title="Layer", + yaxis_title="Number of Parameters" + ) + + return fig + +def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure: + """Visualize weights for a specific layer.""" + if layer_name is None: + # Get first weight tensor + weight_keys = [k for k in state_dict.keys() if 'weight' in k] + if not weight_keys: + raise ValueError("No weight tensors found in state dict") + layer_name = weight_keys[0] + + weights = state_dict[layer_name] + + # For 2D weights, create heatmap + if len(weights.shape) == 2: + fig = go.Figure(data=go.Heatmap( + z=weights.numpy(), + colorscale='RdBu', + zmid=0 + )) + fig.update_layout(title=f"Weights Heatmap: {layer_name}") + else: + # For other shapes, flatten and show histogram + flat_weights = weights.flatten().numpy() + fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)]) + fig.update_layout(title=f"Weight Distribution: {layer_name}") + + return fig + +def weight_distribution_plot(state_dict: dict) -> Figure: + """Show distribution of weights across all layers.""" + all_weights = [] + layer_names = [] + + for key, tensor in state_dict.items(): + if 'weight' in key: + all_weights.extend(tensor.flatten().numpy()) + layer_names.extend([key] * tensor.numel()) + + fig = go.Figure(data=[ + go.Histogram( + x=all_weights, + nbinsx=100, + name="All Weights" + ) + ]) + + fig.update_layout( + title="Overall Weight Distribution", + xaxis_title="Weight Value", + yaxis_title="Frequency" + ) + + return fig \ No newline at end of file -- 2.34.1 From f430e626a61163743298f218f79b1735a8869155 Mon Sep 17 00:00:00 2001 From: ben Date: Tue, 14 Oct 2025 14:22:37 -0400 Subject: [PATCH 2/8] Pytorch state dict widget --- src/ria_toolkit_oss/viz/pytorch_state_dict.py | 198 +++++++++++++++--- 1 file changed, 169 insertions(+), 29 deletions(-) diff --git a/src/ria_toolkit_oss/viz/pytorch_state_dict.py b/src/ria_toolkit_oss/viz/pytorch_state_dict.py index 05c96f4..7db7528 100644 --- a/src/ria_toolkit_oss/viz/pytorch_state_dict.py +++ b/src/ria_toolkit_oss/viz/pytorch_state_dict.py @@ -5,17 +5,56 @@ import numpy as np def model_summary_plot(state_dict: dict) -> Figure: """Generate a summary plot of the PyTorch model state dict.""" + if not state_dict: + # Handle empty state dict + fig = go.Figure() + fig.add_annotation( + text="No parameters found in state dict", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=16) + ) + fig.update_layout( + title="Model Layer Parameter Counts", + xaxis_title="Layer", + yaxis_title="Number of Parameters", + template="plotly_dark" + ) + return fig + # Count parameters by layer type layer_info = [] for key, tensor in state_dict.items(): if 'weight' in key: - layer_name = key.replace('.weight', '') - param_count = tensor.numel() - layer_info.append({ - 'layer': layer_name, - 'parameters': param_count, - 'shape': list(tensor.shape) - }) + try: + layer_name = key.replace('.weight', '') + param_count = tensor.numel() if hasattr(tensor, 'numel') else len(tensor.flatten()) if hasattr(tensor, 'flatten') else 0 + shape = list(tensor.shape) if hasattr(tensor, 'shape') else [len(tensor)] if hasattr(tensor, '__len__') else [] + layer_info.append({ + 'layer': layer_name, + 'parameters': param_count, + 'shape': shape + }) + except Exception as e: + print(f"Warning: Could not process layer {key}: {e}") + continue + + if not layer_info: + # Handle case where no weight layers found + fig = go.Figure() + fig.add_annotation( + text="No weight layers found in state dict", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=16) + ) + fig.update_layout( + title="Model Layer Parameter Counts", + xaxis_title="Layer", + yaxis_title="Number of Parameters", + template="plotly_dark" + ) + return fig # Create bar chart of parameter counts fig = go.Figure(data=[ @@ -30,47 +69,147 @@ def model_summary_plot(state_dict: dict) -> Figure: fig.update_layout( title="Model Layer Parameter Counts", xaxis_title="Layer", - yaxis_title="Number of Parameters" + yaxis_title="Number of Parameters", + template="plotly_dark" ) return fig def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure: """Visualize weights for a specific layer.""" + if not state_dict: + fig = go.Figure() + fig.add_annotation( + text="No data in state dict", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=16) + ) + fig.update_layout( + title="Layer Weights", + template="plotly_dark" + ) + return fig + if layer_name is None: # Get first weight tensor weight_keys = [k for k in state_dict.keys() if 'weight' in k] if not weight_keys: - raise ValueError("No weight tensors found in state dict") + fig = go.Figure() + fig.add_annotation( + text="No weight tensors found in state dict", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=16) + ) + fig.update_layout( + title="Layer Weights", + template="plotly_dark" + ) + return fig layer_name = weight_keys[0] - weights = state_dict[layer_name] - - # For 2D weights, create heatmap - if len(weights.shape) == 2: - fig = go.Figure(data=go.Heatmap( - z=weights.numpy(), - colorscale='RdBu', - zmid=0 - )) - fig.update_layout(title=f"Weights Heatmap: {layer_name}") - else: - # For other shapes, flatten and show histogram - flat_weights = weights.flatten().numpy() - fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)]) - fig.update_layout(title=f"Weight Distribution: {layer_name}") - - return fig + try: + weights = state_dict[layer_name] + + # Convert to numpy if it's a torch tensor + if hasattr(weights, 'numpy'): + weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy() + elif hasattr(weights, 'cpu'): + weights_np = weights.cpu().detach().numpy() + else: + weights_np = np.array(weights) + + # For 2D weights, create heatmap + if len(weights_np.shape) == 2: + fig = go.Figure(data=go.Heatmap( + z=weights_np, + colorscale='RdBu', + zmid=0 + )) + fig.update_layout( + title=f"Weights Heatmap: {layer_name}", + template="plotly_dark" + ) + else: + # For other shapes, flatten and show histogram + flat_weights = weights_np.flatten() + fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)]) + fig.update_layout( + title=f"Weight Distribution: {layer_name}", + template="plotly_dark" + ) + + return fig + + except Exception as e: + fig = go.Figure() + fig.add_annotation( + text=f"Error processing layer {layer_name}: {str(e)}", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=14) + ) + fig.update_layout( + title="Layer Weights - Error", + template="plotly_dark" + ) + return fig def weight_distribution_plot(state_dict: dict) -> Figure: """Show distribution of weights across all layers.""" + if not state_dict: + fig = go.Figure() + fig.add_annotation( + text="No data in state dict", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=16) + ) + fig.update_layout( + title="Overall Weight Distribution", + xaxis_title="Weight Value", + yaxis_title="Frequency", + template="plotly_dark" + ) + return fig + all_weights = [] layer_names = [] for key, tensor in state_dict.items(): if 'weight' in key: - all_weights.extend(tensor.flatten().numpy()) - layer_names.extend([key] * tensor.numel()) + try: + # Convert to numpy if it's a torch tensor + if hasattr(tensor, 'numpy'): + weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy() + elif hasattr(tensor, 'cpu'): + weights_np = tensor.cpu().detach().numpy() + else: + weights_np = np.array(tensor) + + flat_weights = weights_np.flatten() + all_weights.extend(flat_weights) + layer_names.extend([key] * len(flat_weights)) + except Exception as e: + print(f"Warning: Could not process weights for layer {key}: {e}") + continue + + if not all_weights: + fig = go.Figure() + fig.add_annotation( + text="No weight data found in state dict", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=16) + ) + fig.update_layout( + title="Overall Weight Distribution", + xaxis_title="Weight Value", + yaxis_title="Frequency", + template="plotly_dark" + ) + return fig fig = go.Figure(data=[ go.Histogram( @@ -83,7 +222,8 @@ def weight_distribution_plot(state_dict: dict) -> Figure: fig.update_layout( title="Overall Weight Distribution", xaxis_title="Weight Value", - yaxis_title="Frequency" + yaxis_title="Frequency", + template="plotly_dark" ) return fig \ No newline at end of file -- 2.34.1 From 2721ed866ce5a6c6a651f4603af366e7e90b6984 Mon Sep 17 00:00:00 2001 From: ben Date: Fri, 17 Oct 2025 09:35:27 -0400 Subject: [PATCH 3/8] Radio-dataset widgets --- src/ria_toolkit_oss/viz/radio_dataset.py | 430 +++++++++++++++++++++++ 1 file changed, 430 insertions(+) create mode 100644 src/ria_toolkit_oss/viz/radio_dataset.py diff --git a/src/ria_toolkit_oss/viz/radio_dataset.py b/src/ria_toolkit_oss/viz/radio_dataset.py new file mode 100644 index 0000000..edc5004 --- /dev/null +++ b/src/ria_toolkit_oss/viz/radio_dataset.py @@ -0,0 +1,430 @@ +""" +Simple, clean visualization utilities for RadioDataset analysis. +""" + +import random +from typing import Optional + +import numpy as np +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +from plotly.graph_objects import Figure +from plotly.subplots import make_subplots + + +def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> Figure: + """Create a professional error figure with Qoherent dark theme styling.""" + fig = go.Figure() + + # Create a clean, centered text display using Plotly's text formatting + main_text = f"⚠️ {title}

" + main_text += f"{message}" + + if suggestion: + main_text += f"

💡 Suggestion:
" + main_text += f"{suggestion}" + + # Add the main text annotation + fig.add_annotation( + text=main_text, + xref="paper", yref="paper", + x=0.5, y=0.5, + xanchor='center', yanchor='middle', + showarrow=False, + align="center", + borderwidth=2, + bordercolor="#4a5568", + bgcolor="#2d3748", + font=dict( + family="Arial, sans-serif", + size=14, + color="#e2e8f0" + ) + ) + + # Update layout with dark theme + fig.update_layout( + title="", + height=400, + template="plotly_dark", + margin=dict(l=40, r=40, t=40, b=40), + plot_bgcolor="#1a202c", + paper_bgcolor="#1a202c", + font=dict(color="#e2e8f0") + ) + + # Remove axes and grid + fig.update_xaxes(visible=False) + fig.update_yaxes(visible=False) + + return fig + + +def _check_dataset_compatibility(dataset, plot_type: str) -> tuple[bool, str]: + """Check if dataset is compatible with a specific plot type. + Returns (is_compatible, error_message) + """ + try: + metadata = dataset.metadata + + if len(metadata) == 0: + return False, "Dataset is empty" + + if plot_type == "class_distribution": + # Check if we have any categorical columns + categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object'] + alternatives = ["class", "label", "modulation", "impairment", "use_case", "category", "labels"] + + has_class_col = any(alt in metadata.columns for alt in alternatives) + has_categorical = len(categorical_cols) > 0 + + if not has_class_col and not has_categorical: + return False, "No categorical columns found for class distribution" + + elif plot_type == "sample_spectrogram": + # Check if we can generate a valid spectrogram + if len(metadata) < 1: + return False, "No samples available for spectrogram" + + # Check if we can access sample data (basic test) + try: + sample_data = dataset[0] if hasattr(dataset, '__getitem__') else None + if sample_data is None or len(sample_data) < 32: + return False, "Insufficient sample data for spectrogram (need at least 32 points)" + except Exception: + # If we can't access data, we'll rely on synthetic data generation + pass + + return True, "" + + except Exception as e: + return False, f"Dataset compatibility check failed: {str(e)}" + + +def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure: + """Generate a bar plot showing the distribution of examples across classes.""" + try: + # Check dataset compatibility first + is_compatible, error_msg = _check_dataset_compatibility(dataset, "class_distribution") + if not is_compatible: + return create_styled_error_figure( + "Dataset Not Compatible", + "This dataset doesn't have categorical labels needed for class distribution analysis.", + "Try using the Dataset Overview widget to explore the available data columns." + ) + + metadata = dataset.metadata + + # Find the class column + if class_key not in metadata.columns: + # Try common alternatives + alternatives = ["class", "label", "modulation", "impairment", "use_case", "category", "labels"] + for alt in alternatives: + if alt in metadata.columns: + class_key = alt + break + else: + # Use first categorical column + for col in metadata.columns: + if metadata[col].dtype == 'object' or metadata[col].nunique() < 50: + class_key = col + break + + if class_key not in metadata.columns: + return create_styled_error_figure( + "No Class Labels Found", + "This dataset contains numerical data without categorical labels.", + "Try using the Dataset Overview widget for data analysis, or check if your dataset has hidden categorical columns." + ) + + # Count examples per class (limit to top 20 for performance) + class_counts = metadata[class_key].value_counts() + if len(class_counts) > 20: + class_counts = class_counts.head(20) + + class_counts = class_counts.sort_index() + + # Create simple bar plot + fig = px.bar( + x=class_counts.index, + y=class_counts.values, + title=f'Class Distribution: {class_key.title()}' + ) + + fig.update_traces(texttemplate='%{y}', textposition='outside') + fig.update_layout( + xaxis_title=class_key.title(), + yaxis_title='Number of Examples', + showlegend=False, + height=400, + template="plotly_dark" + ) + + return fig + + except Exception as e: + return create_styled_error_figure( + "Class Distribution Error", + f"An error occurred while generating the class distribution plot.", + f"Technical details: {str(e)}" + ) + + +def dataset_overview_plot(dataset) -> Figure: + """Generate an overview plot with key dataset statistics.""" + try: + metadata = dataset.metadata + total_examples = len(metadata) + + # Create subplot with multiple charts + + # Determine subplot titles based on data type + categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object'] + numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ['int64', 'float64']] + + dist_title = "Value Distribution" if categorical_cols else "Data Distribution" + + fig = make_subplots( + rows=2, cols=2, + subplot_titles=("Dataset Size", "Data Types", dist_title, "Statistics Summary"), + specs=[[{"type": "indicator"}, {"type": "bar"}], + [{"type": "histogram" if not categorical_cols else "bar"}, {"type": "table"}]] + ) + + # Top left: Dataset size indicator + fig.add_trace( + go.Indicator( + mode="number", + value=total_examples, + title={"text": "Total Examples"}, + number={"font": {"size": 40}} + ), + row=1, col=1 + ) + + # Top right: Data types distribution + dtype_counts = metadata.dtypes.value_counts() + fig.add_trace( + go.Bar( + x=[str(dt) for dt in dtype_counts.index], + y=dtype_counts.values, + name="Data Types", + showlegend=False + ), + row=1, col=2 + ) + + # Bottom left: Show distribution of numeric columns or categorical if available + categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object'] + numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ['int64', 'float64']] + + if categorical_cols: + col = categorical_cols[0] # Show first categorical column + value_counts = metadata[col].value_counts().head(10) + fig.add_trace( + go.Bar( + x=value_counts.index, + y=value_counts.values, + name=f"{col} Distribution", + showlegend=False + ), + row=2, col=1 + ) + elif numeric_cols: + # Show histogram of first numeric column + col = numeric_cols[0] + fig.add_trace( + go.Histogram( + x=metadata[col], + name=f"{col} Distribution", + showlegend=False, + nbinsx=20 + ), + row=2, col=1 + ) + + # Bottom right: Basic statistics table + stats_data = [] + display_cols = (numeric_cols[:5] if len(numeric_cols) > 0 else metadata.columns[:5]) + + for col in display_cols: + if metadata[col].dtype in ['int64', 'float64']: + stats_data.append([ + col[:15] + "..." if len(col) > 15 else col, # Truncate long column names + f"{metadata[col].mean():.3f}", + f"{metadata[col].std():.3f}", + f"{metadata[col].min():.3f}", + f"{metadata[col].max():.3f}" + ]) + else: + unique_count = metadata[col].nunique() + stats_data.append([ + col[:15] + "..." if len(col) > 15 else col, + "N/A", "N/A", + f"{unique_count} unique", + "N/A" + ]) + + if stats_data: + fig.add_trace( + go.Table( + header=dict( + values=["Column", "Mean", "Std", "Min/Unique", "Max"], + fill_color="rgba(30, 30, 30, 0.8)", + align="center", + font=dict(color="white", size=12) + ), + cells=dict( + values=list(zip(*stats_data)), + fill_color="rgba(50, 50, 50, 0.6)", + align="center", + font=dict(color="white", size=11) + ) + ), + row=2, col=2 + ) + + # Create informative title + total_cols = len(metadata.columns) + title = f"Dataset Overview - {total_examples} samples, {total_cols} columns" + if total_cols > 5: + title += f" (showing first 5)" + + fig.update_layout( + title=title, + height=600, + showlegend=False, + template="plotly_dark" + ) + + return fig + + except Exception as e: + return create_styled_error_figure( + "Dataset Overview Error", + "An error occurred while generating the dataset overview.", + f"Technical details: {str(e)}" + ) + + +def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx: Optional[int] = None) -> Figure: + """Generate a spectrogram plot from a sample in the dataset.""" + try: + # Check dataset compatibility first + is_compatible, error_msg = _check_dataset_compatibility(dataset, "sample_spectrogram") + if not is_compatible: + return create_styled_error_figure( + "Spectrogram Not Available", + "This dataset doesn't have sufficient signal data for spectrogram visualization.", + "Ensure your dataset contains complex-valued signal samples with at least 32 data points per sample." + ) + + metadata = dataset.metadata + + if len(metadata) == 0: + raise ValueError("Dataset is empty") + + # Find class column + if class_key not in metadata.columns: + alternatives = ["class", "label", "modulation", "impairment", "use_case"] + for alt in alternatives: + if alt in metadata.columns: + class_key = alt + break + + # Select sample + if sample_idx is None: + sample_idx = random.randint(0, len(metadata) - 1) + + sample_metadata = metadata.iloc[sample_idx] + + # Try to get actual sample data, fall back to synthetic + try: + sample_data = dataset[sample_idx] + except: + # Generate synthetic signal based on class + n_samples = 1024 + t = np.linspace(0, 1, n_samples) + freq = 0.1 + 0.05 * sample_idx % 5 # Vary frequency by sample + sample_data = np.exp(1j * 2 * np.pi * freq * t) + # Add some noise + sample_data += 0.1 * (np.random.randn(n_samples) + 1j * np.random.randn(n_samples)) + + # Ensure complex data + if not np.iscomplexobj(sample_data): + sample_data = sample_data.astype(complex) + + # Simple FFT-based spectrogram + n_samples = len(sample_data) + + # Ensure minimum viable data size + if n_samples < 32: + raise ValueError(f"Insufficient data: need at least 32 samples, got {n_samples}") + + nperseg = min(256, max(32, n_samples // 4)) + + # Create spectrogram using numpy (no scipy dependency) + hop_length = max(1, nperseg // 2) # Prevent zero hop_length + + # Ensure we can create at least one frame + if n_samples < nperseg: + nperseg = n_samples + hop_length = 1 + + n_frames = max(1, (n_samples - nperseg) // hop_length + 1) + + freq_bins = max(1, nperseg // 2) # Prevent zero frequency bins + Sxx = np.zeros((freq_bins, n_frames)) + + for i in range(n_frames): + start_idx = i * hop_length + end_idx = min(start_idx + nperseg, n_samples) # Prevent index overflow + + if end_idx > start_idx: # Ensure we have data to process + windowed = sample_data[start_idx:end_idx] + + # Pad if necessary to maintain nperseg size + if len(windowed) < nperseg: + windowed = np.pad(windowed, (0, nperseg - len(windowed)), mode='constant') + + fft_result = np.fft.fft(windowed) + Sxx[:, i] = np.abs(fft_result[:freq_bins]) ** 2 + + # Convert to dB + Sxx_db = 10 * np.log10(Sxx + 1e-10) + + # Create time and frequency vectors + t = np.arange(n_frames) * hop_length / max(1, n_samples) # Prevent division by zero + f = np.linspace(0, 0.5, freq_bins) + + # Create plot + fig = go.Figure(data=go.Heatmap( + z=Sxx_db, + x=t, + y=f, + colorscale='viridis', + colorbar=dict(title="Power (dB)") + )) + + # Add title with metadata + title = f"Sample Spectrogram (Index: {sample_idx})" + if class_key in sample_metadata: + title += f" - {class_key}: {sample_metadata[class_key]}" + + fig.update_layout( + title=title, + xaxis_title="Time", + yaxis_title="Frequency", + height=400, + template="plotly_dark" + ) + + return fig + + except Exception as e: + return create_styled_error_figure( + "Spectrogram Error", + "An error occurred while generating the spectrogram plot.", + f"Technical details: {str(e)}" + ) \ No newline at end of file -- 2.34.1 From e863040e1948bdb992c93e96ff34544808146c3c Mon Sep 17 00:00:00 2001 From: ben Date: Mon, 20 Oct 2025 12:16:30 -0400 Subject: [PATCH 4/8] onnx visualizers --- src/ria_toolkit_oss/viz/onnx.py | 558 ++++++++++++++++++++++++++++++++ 1 file changed, 558 insertions(+) create mode 100644 src/ria_toolkit_oss/viz/onnx.py diff --git a/src/ria_toolkit_oss/viz/onnx.py b/src/ria_toolkit_oss/viz/onnx.py new file mode 100644 index 0000000..04e86b9 --- /dev/null +++ b/src/ria_toolkit_oss/viz/onnx.py @@ -0,0 +1,558 @@ +""" +ONNX model visualization utilities. + +This module provides visualization functions for ONNX models following the same pattern +as other ria-toolkit-oss visualization modules. +""" + +from pathlib import Path +from typing import Optional + +import plotly.graph_objects as go +import plotly.express as px +from plotly.subplots import make_subplots +import pandas as pd +import numpy as np + +try: + import onnx + import onnx.helper + import onnx.numpy_helper + ONNX_AVAILABLE = True +except ImportError: + ONNX_AVAILABLE = False + + +def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure: + """Create a professional error figure with Qoherent dark theme styling.""" + fig = go.Figure() + + # Create a clean, centered text display using Plotly's text formatting + main_text = f"⚠️ {title}

" + main_text += f"{message}" + + if suggestion: + main_text += f"

💡 Suggestion:
" + main_text += f"{suggestion}" + + # Add the main text annotation + fig.add_annotation( + text=main_text, + xref="paper", yref="paper", + x=0.5, y=0.5, + xanchor='center', yanchor='middle', + showarrow=False, + align="center", + borderwidth=2, + bordercolor="#4a5568", + bgcolor="#2d3748", + font=dict( + family="Arial, sans-serif", + size=14, + color="#e2e8f0" + ) + ) + + # Update layout with dark theme + fig.update_layout( + title="", + height=400, + template="plotly_dark", + margin=dict(l=40, r=40, t=40, b=40), + plot_bgcolor="#1a202c", + paper_bgcolor="#1a202c", + font=dict(color="#e2e8f0") + ) + + # Remove axes and grid + fig.update_xaxes(visible=False) + fig.update_yaxes(visible=False) + + return fig + + +def graph_structure(file_path: Path) -> go.Figure: + """ + Visualize the ONNX model graph structure showing nodes and connections. + Matches layout ID: graph_structure + """ + if not ONNX_AVAILABLE: + return create_styled_error_figure( + "ONNX Not Available", + "ONNX library is required for model analysis.", + "Install with: pip install onnx" + ) + + try: + # Load ONNX model + model = onnx.load(str(file_path)) + graph = model.graph + nodes = graph.node + + if len(nodes) == 0: + return create_styled_error_figure( + "Empty Model", + "This ONNX model contains no operators.", + "Please check if the model file is valid." + ) + + # Create network diagram data + node_info = [] + for i, node in enumerate(nodes): + node_info.append({ + 'id': i, + 'name': node.name or f"{node.op_type}_{i}", + 'op_type': node.op_type, + 'inputs': len(node.input), + 'outputs': len(node.output) + }) + + # Create visualization + fig = go.Figure() + + # Simple linear layout for now + x_positions = list(range(len(node_info))) + y_positions = [0] * len(node_info) + + # Add nodes as scatter points + fig.add_trace(go.Scatter( + x=x_positions, + y=y_positions, + mode='markers+text', + marker=dict( + size=[min(max(info['inputs'] + info['outputs'] + 15, 20), 50) for info in node_info], + color=px.colors.qualitative.Set3[:len(node_info)], + opacity=0.8, + line=dict(width=2, color='white') + ), + text=[f"{info['op_type']}" for info in node_info], + textposition="middle center", + textfont=dict(size=10, color="white"), + hovertemplate="%{text}
" + + "Name: %{customdata[0]}
" + + "Inputs: %{customdata[1]}
" + + "Outputs: %{customdata[2]}
" + + "", + customdata=[[info['name'], info['inputs'], info['outputs']] for info in node_info], + name="Operators" + )) + + # Add connecting lines + for i in range(len(node_info) - 1): + fig.add_trace(go.Scatter( + x=[x_positions[i], x_positions[i+1]], + y=[y_positions[i], y_positions[i+1]], + mode='lines', + line=dict(color='gray', width=1, dash='dot'), + showlegend=False, + hoverinfo='skip' + )) + + fig.update_layout( + title={ + 'text': f"ONNX Graph Structure
{len(nodes)} Operators", + 'x': 0.5, + 'xanchor': 'center', + 'font': {'size': 22} + }, + xaxis_title="Execution Order", + yaxis_title="", + showlegend=False, + height=500, + template="plotly_dark", + yaxis=dict(showticklabels=False, showgrid=False), + xaxis=dict(showgrid=False), + margin=dict(l=50, r=50, t=80, b=50) + ) + + return fig + + except Exception as e: + return create_styled_error_figure( + "Graph Analysis Error", + f"Could not analyze ONNX model structure.", + f"Error: {str(e)}" + ) + + +def operator_analysis(file_path: Path) -> go.Figure: + """ + Analyze the distribution and types of operators in the ONNX model. + Matches layout ID: operator_analysis + """ + if not ONNX_AVAILABLE: + return create_styled_error_figure( + "ONNX Not Available", + "ONNX library is required for operator analysis.", + "Install with: pip install onnx" + ) + + try: + model = onnx.load(str(file_path)) + graph = model.graph + + # Count operators + op_counts = {} + for node in graph.node: + op_type = node.op_type + op_counts[op_type] = op_counts.get(op_type, 0) + 1 + + if not op_counts: + return create_styled_error_figure( + "No Operators", + "This ONNX model contains no operators to analyze.", + "Please verify the model file is valid." + ) + + # Sort by frequency + sorted_ops = sorted(op_counts.items(), key=lambda x: x[1], reverse=True) + + # Create pie chart and bar chart + fig = make_subplots( + rows=2, cols=1, + subplot_titles=("Operator Distribution", "Operator Frequency"), + specs=[[{"type": "pie"}], [{"type": "bar"}]] + ) + + # Pie chart for operator distribution + op_names, op_values = zip(*sorted_ops) if sorted_ops else ([], []) + + fig.add_trace( + go.Pie( + labels=list(op_names), + values=list(op_values), + textinfo="label+percent", + textposition="auto", + showlegend=False + ), + row=1, col=1 + ) + + # Bar chart for frequency + fig.add_trace( + go.Bar( + x=list(op_names), + y=list(op_values), + marker_color=px.colors.qualitative.Set3[:len(op_names)], + showlegend=False + ), + row=2, col=1 + ) + + fig.update_layout( + title={ + 'text': f"ONNX Operator Analysis
{len(op_counts)} Unique Types", + 'x': 0.5, + 'xanchor': 'center', + 'font': {'size': 22} + }, + height=700, + template="plotly_dark" + ) + + return fig + + except Exception as e: + return create_styled_error_figure( + "Operator Analysis Error", + f"Could not analyze ONNX operators.", + f"Error: {str(e)}" + ) + + +def model_metadata(file_path: Path) -> go.Figure: + """ + Display comprehensive metadata about the ONNX model. + Matches layout ID: model_metadata + """ + if not ONNX_AVAILABLE: + return create_styled_error_figure( + "ONNX Not Available", + "ONNX library is required for metadata analysis.", + "Install with: pip install onnx" + ) + + try: + model = onnx.load(str(file_path)) + graph = model.graph + + # Calculate basic statistics + total_nodes = len(graph.node) + total_inputs = len(graph.input) + total_outputs = len(graph.output) + total_initializers = len(graph.initializer) + + # Calculate parameter count + total_params = 0 + for initializer in graph.initializer: + try: + tensor = onnx.numpy_helper.to_array(initializer) + total_params += tensor.size + except: + pass # Skip if tensor can't be loaded + + # Get model file size + file_size_mb = file_path.stat().st_size / (1024 * 1024) + + # Create metadata display + fig = make_subplots( + rows=2, cols=2, + subplot_titles=("Model Size", "Architecture", "Inputs/Outputs", "Parameters"), + specs=[[{"type": "indicator"}, {"type": "bar"}], + [{"type": "table"}, {"type": "indicator"}]] + ) + + # Model size indicator + fig.add_trace( + go.Indicator( + mode="number+gauge", + value=file_size_mb, + title={'text': "Model Size (MB)"}, + number={'suffix': ' MB', 'valueformat': '.2f'}, + gauge={ + 'axis': {'range': [0, max(100, file_size_mb * 1.5)]}, + 'bar': {'color': "darkblue"}, + 'steps': [ + {'range': [0, 10], 'color': "lightgreen"}, + {'range': [10, 50], 'color': "yellow"}, + {'range': [50, 100], 'color': "orange"} + ] + } + ), + row=1, col=1 + ) + + # Architecture components + arch_data = ["Nodes", "Inputs", "Outputs", "Initializers"] + arch_values = [total_nodes, total_inputs, total_outputs, total_initializers] + + fig.add_trace( + go.Bar( + x=arch_data, + y=arch_values, + marker_color=['blue', 'green', 'orange', 'red'], + showlegend=False + ), + row=1, col=2 + ) + + # I/O Table + io_data = [] + + # Add input info + for inp in graph.input[:5]: # Limit to first 5 + shape = "Unknown" + dtype = "Unknown" + if inp.type and inp.type.tensor_type: + # Get shape + if inp.type.tensor_type.shape: + dims = [str(d.dim_value) if d.dim_value > 0 else "?" + for d in inp.type.tensor_type.shape.dim] + shape = f"[{', '.join(dims)}]" + + # Get data type + elem_type = inp.type.tensor_type.elem_type + type_map = {1: 'float32', 2: 'uint8', 3: 'int8', 6: 'int32', + 7: 'int64', 9: 'bool', 10: 'float16', 11: 'double'} + dtype = type_map.get(elem_type, f'type_{elem_type}') + + io_data.append(['Input', inp.name[:20], shape, dtype]) + + # Add output info + for out in graph.output[:5]: # Limit to first 5 + shape = "Unknown" + dtype = "Unknown" + if out.type and out.type.tensor_type: + if out.type.tensor_type.shape: + dims = [str(d.dim_value) if d.dim_value > 0 else "?" + for d in out.type.tensor_type.shape.dim] + shape = f"[{', '.join(dims)}]" + + elem_type = out.type.tensor_type.elem_type + type_map = {1: 'float32', 2: 'uint8', 3: 'int8', 6: 'int32', + 7: 'int64', 9: 'bool', 10: 'float16', 11: 'double'} + dtype = type_map.get(elem_type, f'type_{elem_type}') + + io_data.append(['Output', out.name[:20], shape, dtype]) + + if io_data: + fig.add_trace( + go.Table( + header=dict( + values=['Type', 'Name', 'Shape', 'Data Type'], + fill_color='lightblue', + align='left' + ), + cells=dict( + values=list(zip(*io_data)), + fill_color='white', + align='left' + ) + ), + row=2, col=1 + ) + + # Parameters indicator + fig.add_trace( + go.Indicator( + mode="number", + value=total_params, + title={'text': "Total Parameters"}, + number={'suffix': 'M', 'valueformat': '.2f'}, + number_font_size=30 + ), + row=2, col=2 + ) + + fig.update_layout( + title={ + 'text': f"ONNX Model Metadata
{total_params/1e6:.2f}M Parameters", + 'x': 0.5, + 'xanchor': 'center', + 'font': {'size': 22} + }, + height=600, + template="plotly_dark", + showlegend=False + ) + + return fig + + except Exception as e: + return create_styled_error_figure( + "Metadata Analysis Error", + f"Could not extract ONNX model metadata.", + f"Error: {str(e)}" + ) + + +def performance_metrics(file_path: Path) -> go.Figure: + """ + Display performance and computational metrics for the ONNX model. + Matches layout ID: performance_metrics + """ + if not ONNX_AVAILABLE: + return create_styled_error_figure( + "ONNX Not Available", + "ONNX library is required for performance analysis.", + "Install with: pip install onnx" + ) + + try: + model = onnx.load(str(file_path)) + graph = model.graph + + # Calculate metrics + model_size_bytes = file_path.stat().st_size + model_size_mb = model_size_bytes / (1024 * 1024) + + # Count parameters + total_params = 0 + for initializer in graph.initializer: + try: + tensor = onnx.numpy_helper.to_array(initializer) + total_params += tensor.size + except: + pass + + # Estimate memory usage (rough approximation) + param_memory_mb = (total_params * 4) / (1024 * 1024) # Assume float32 + + # Count operations by complexity + compute_ops = ['Conv', 'MatMul', 'Gemm', 'LSTM', 'GRU'] + efficient_ops = ['Relu', 'Add', 'Mul', 'BatchNormalization', 'Dropout'] + + compute_count = sum(1 for node in graph.node + if any(op in node.op_type for op in compute_ops)) + efficient_count = sum(1 for node in graph.node + if any(op in node.op_type for op in efficient_ops)) + total_ops = len(graph.node) + other_count = total_ops - compute_count - efficient_count + + # Create performance dashboard + fig = make_subplots( + rows=2, cols=2, + subplot_titles=("Model Efficiency", "Memory Usage", "Operation Types", "Complexity Score"), + specs=[[{"type": "bar"}, {"type": "bar"}], + [{"type": "pie"}, {"type": "indicator"}]] + ) + + # Model efficiency metrics + efficiency_metrics = ["Model Size (MB)", "Parameters (M)", "Total Ops"] + efficiency_values = [model_size_mb, total_params/1e6, total_ops] + + fig.add_trace( + go.Bar( + x=efficiency_metrics, + y=efficiency_values, + marker_color=['blue', 'green', 'orange'], + showlegend=False + ), + row=1, col=1 + ) + + # Memory usage + memory_types = ["Parameters", "Est. Inference"] + memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate + + fig.add_trace( + go.Bar( + x=memory_types, + y=memory_values, + marker_color=['purple', 'red'], + showlegend=False + ), + row=1, col=2 + ) + + # Operation types pie chart + fig.add_trace( + go.Pie( + labels=['Compute Ops', 'Efficient Ops', 'Other Ops'], + values=[compute_count, efficient_count, other_count], + marker_colors=['red', 'green', 'gray'] + ), + row=2, col=1 + ) + + # Complexity score (simple heuristic) + complexity_score = min(100, (model_size_mb * 10 + total_params / 1e6 * 20 + compute_count)) + + fig.add_trace( + go.Indicator( + mode="gauge+number", + value=complexity_score, + title={'text': "Complexity Score"}, + gauge={ + 'axis': {'range': [0, 100]}, + 'bar': {'color': "darkred" if complexity_score > 70 else "orange" if complexity_score > 40 else "green"}, + 'steps': [ + {'range': [0, 40], 'color': "lightgreen"}, + {'range': [40, 70], 'color': "yellow"}, + {'range': [70, 100], 'color': "lightcoral"} + ] + } + ), + row=2, col=2 + ) + + fig.update_layout( + title={ + 'text': f"ONNX Performance Metrics
Complexity Score: {complexity_score:.0f}/100", + 'x': 0.5, + 'xanchor': 'center', + 'font': {'size': 22} + }, + height=600, + template="plotly_dark", + showlegend=False + ) + + return fig + + except Exception as e: + return create_styled_error_figure( + "Performance Analysis Error", + f"Could not analyze ONNX model performance.", + f"Error: {str(e)}" + ) \ No newline at end of file -- 2.34.1 From c7c7100d465866d30c08b4a8d4a12bd80b28b026 Mon Sep 17 00:00:00 2001 From: ben Date: Mon, 20 Oct 2025 14:44:51 -0400 Subject: [PATCH 5/8] Formatting fixes --- src/ria_toolkit_oss/viz/onnx.py | 467 ++++++++++-------- src/ria_toolkit_oss/viz/pytorch_state_dict.py | 206 ++++---- src/ria_toolkit_oss/viz/radio_dataset.py | 429 ++++++++-------- 3 files changed, 559 insertions(+), 543 deletions(-) diff --git a/src/ria_toolkit_oss/viz/onnx.py b/src/ria_toolkit_oss/viz/onnx.py index 04e86b9..b92c3e4 100644 --- a/src/ria_toolkit_oss/viz/onnx.py +++ b/src/ria_toolkit_oss/viz/onnx.py @@ -6,18 +6,16 @@ as other ria-toolkit-oss visualization modules. """ from pathlib import Path -from typing import Optional -import plotly.graph_objects as go import plotly.express as px +import plotly.graph_objects as go from plotly.subplots import make_subplots -import pandas as pd -import numpy as np try: import onnx import onnx.helper import onnx.numpy_helper + ONNX_AVAILABLE = True except ImportError: ONNX_AVAILABLE = False @@ -26,33 +24,32 @@ except ImportError: def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure: """Create a professional error figure with Qoherent dark theme styling.""" fig = go.Figure() - + # Create a clean, centered text display using Plotly's text formatting main_text = f"⚠️ {title}

" main_text += f"{message}" - + if suggestion: - main_text += f"

💡 Suggestion:
" + main_text += "

💡 Suggestion:
" main_text += f"{suggestion}" - + # Add the main text annotation fig.add_annotation( text=main_text, - xref="paper", yref="paper", - x=0.5, y=0.5, - xanchor='center', yanchor='middle', + xref="paper", + yref="paper", + x=0.5, + y=0.5, + xanchor="center", + yanchor="middle", showarrow=False, align="center", borderwidth=2, bordercolor="#4a5568", bgcolor="#2d3748", - font=dict( - family="Arial, sans-serif", - size=14, - color="#e2e8f0" - ) + font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"), ) - + # Update layout with dark theme fig.update_layout( title="", @@ -61,13 +58,13 @@ def create_styled_error_figure(title: str, message: str, suggestion: str = None) margin=dict(l=40, r=40, t=40, b=40), plot_bgcolor="#1a202c", paper_bgcolor="#1a202c", - font=dict(color="#e2e8f0") + font=dict(color="#e2e8f0"), ) - + # Remove axes and grid fig.update_xaxes(visible=False) fig.update_yaxes(visible=False) - + return fig @@ -82,78 +79,85 @@ def graph_structure(file_path: Path) -> go.Figure: "ONNX library is required for model analysis.", "Install with: pip install onnx" ) - + try: # Load ONNX model model = onnx.load(str(file_path)) graph = model.graph nodes = graph.node - + if len(nodes) == 0: return create_styled_error_figure( "Empty Model", "This ONNX model contains no operators.", "Please check if the model file is valid." ) - + # Create network diagram data node_info = [] for i, node in enumerate(nodes): - node_info.append({ - 'id': i, - 'name': node.name or f"{node.op_type}_{i}", - 'op_type': node.op_type, - 'inputs': len(node.input), - 'outputs': len(node.output) - }) - + node_info.append( + { + "id": i, + "name": node.name or f"{node.op_type}_{i}", + "op_type": node.op_type, + "inputs": len(node.input), + "outputs": len(node.output), + } + ) + # Create visualization fig = go.Figure() - + # Simple linear layout for now x_positions = list(range(len(node_info))) y_positions = [0] * len(node_info) - + # Add nodes as scatter points - fig.add_trace(go.Scatter( - x=x_positions, - y=y_positions, - mode='markers+text', - marker=dict( - size=[min(max(info['inputs'] + info['outputs'] + 15, 20), 50) for info in node_info], - color=px.colors.qualitative.Set3[:len(node_info)], - opacity=0.8, - line=dict(width=2, color='white') - ), - text=[f"{info['op_type']}" for info in node_info], - textposition="middle center", - textfont=dict(size=10, color="white"), - hovertemplate="%{text}
" + - "Name: %{customdata[0]}
" + - "Inputs: %{customdata[1]}
" + - "Outputs: %{customdata[2]}
" + - "", - customdata=[[info['name'], info['inputs'], info['outputs']] for info in node_info], - name="Operators" - )) - + fig.add_trace( + go.Scatter( + x=x_positions, + y=y_positions, + mode="markers+text", + marker=dict( + size=[min(max(info["inputs"] + info["outputs"] + 15, 20), 50) for info in node_info], + color=px.colors.qualitative.Set3[: len(node_info)], + opacity=0.8, + line=dict(width=2, color="white"), + ), + text=[f"{info['op_type']}" for info in node_info], + textposition="middle center", + textfont=dict(size=10, color="white"), + hovertemplate="%{text}
" + + "Name: %{customdata[0]}
" + + "Inputs: %{customdata[1]}
" + + "Outputs: %{customdata[2]}
" + + "", + customdata=[[info["name"], info["inputs"], info["outputs"]] for info in node_info], + name="Operators", + ) + ) + # Add connecting lines for i in range(len(node_info) - 1): - fig.add_trace(go.Scatter( - x=[x_positions[i], x_positions[i+1]], - y=[y_positions[i], y_positions[i+1]], - mode='lines', - line=dict(color='gray', width=1, dash='dot'), - showlegend=False, - hoverinfo='skip' - )) - + fig.add_trace( + go.Scatter( + x=[x_positions[i], x_positions[i + 1]], + y=[y_positions[i], y_positions[i + 1]], + mode="lines", + line=dict(color="gray", width=1, dash="dot"), + showlegend=False, + hoverinfo="skip", + ) + ) + fig.update_layout( title={ - 'text': f"ONNX Graph Structure
{len(nodes)} Operators", - 'x': 0.5, - 'xanchor': 'center', - 'font': {'size': 22} + "text": ("ONNX Graph Structure
" + f"{len(nodes)} Operators"), + "x": 0.5, + "xanchor": "center", + "font": {"size": 22}, }, xaxis_title="Execution Order", yaxis_title="", @@ -162,15 +166,15 @@ def graph_structure(file_path: Path) -> go.Figure: template="plotly_dark", yaxis=dict(showticklabels=False, showgrid=False), xaxis=dict(showgrid=False), - margin=dict(l=50, r=50, t=80, b=50) + margin=dict(l=50, r=50, t=80, b=50), ) - + return fig - + except Exception as e: return create_styled_error_figure( - "Graph Analysis Error", - f"Could not analyze ONNX model structure.", + "Graph Analysis Error", + "Could not analyze ONNX model structure.", f"Error: {str(e)}" ) @@ -186,76 +190,80 @@ def operator_analysis(file_path: Path) -> go.Figure: "ONNX library is required for operator analysis.", "Install with: pip install onnx" ) - + try: model = onnx.load(str(file_path)) graph = model.graph - + # Count operators op_counts = {} for node in graph.node: op_type = node.op_type op_counts[op_type] = op_counts.get(op_type, 0) + 1 - + if not op_counts: return create_styled_error_figure( "No Operators", "This ONNX model contains no operators to analyze.", - "Please verify the model file is valid." + "Please verify the model file is valid.", ) - + # Sort by frequency sorted_ops = sorted(op_counts.items(), key=lambda x: x[1], reverse=True) - + # Create pie chart and bar chart fig = make_subplots( - rows=2, cols=1, + rows=2, + cols=1, subplot_titles=("Operator Distribution", "Operator Frequency"), - specs=[[{"type": "pie"}], [{"type": "bar"}]] + specs=[[{"type": "pie"}], [{"type": "bar"}]], ) - + # Pie chart for operator distribution op_names, op_values = zip(*sorted_ops) if sorted_ops else ([], []) - + fig.add_trace( go.Pie( labels=list(op_names), values=list(op_values), textinfo="label+percent", textposition="auto", - showlegend=False + showlegend=False, ), - row=1, col=1 + row=1, + col=1, ) - + # Bar chart for frequency fig.add_trace( go.Bar( x=list(op_names), y=list(op_values), - marker_color=px.colors.qualitative.Set3[:len(op_names)], - showlegend=False + marker_color=px.colors.qualitative.Set3[: len(op_names)], + showlegend=False, ), - row=2, col=1 + row=2, + col=1, ) - + fig.update_layout( title={ - 'text': f"ONNX Operator Analysis
{len(op_counts)} Unique Types", - 'x': 0.5, - 'xanchor': 'center', - 'font': {'size': 22} + "text": ("ONNX Operator Analysis
" + f"{len(op_counts)} Unique Types"), + "x": 0.5, + "xanchor": "center", + "font": {"size": 22}, }, height=700, - template="plotly_dark" + template="plotly_dark", ) - + return fig - + except Exception as e: return create_styled_error_figure( "Operator Analysis Error", - f"Could not analyze ONNX operators.", + "Could not analyze ONNX operators.", f"Error: {str(e)}" ) @@ -271,74 +279,76 @@ def model_metadata(file_path: Path) -> go.Figure: "ONNX library is required for metadata analysis.", "Install with: pip install onnx" ) - + try: model = onnx.load(str(file_path)) graph = model.graph - + # Calculate basic statistics total_nodes = len(graph.node) total_inputs = len(graph.input) total_outputs = len(graph.output) total_initializers = len(graph.initializer) - + # Calculate parameter count total_params = 0 for initializer in graph.initializer: try: tensor = onnx.numpy_helper.to_array(initializer) total_params += tensor.size - except: + except Exception: pass # Skip if tensor can't be loaded - + # Get model file size file_size_mb = file_path.stat().st_size / (1024 * 1024) - + # Create metadata display fig = make_subplots( - rows=2, cols=2, + rows=2, + cols=2, subplot_titles=("Model Size", "Architecture", "Inputs/Outputs", "Parameters"), - specs=[[{"type": "indicator"}, {"type": "bar"}], - [{"type": "table"}, {"type": "indicator"}]] + specs=[[{"type": "indicator"}, {"type": "bar"}], [{"type": "table"}, {"type": "indicator"}]], ) - + # Model size indicator fig.add_trace( go.Indicator( mode="number+gauge", value=file_size_mb, - title={'text': "Model Size (MB)"}, - number={'suffix': ' MB', 'valueformat': '.2f'}, + title={"text": "Model Size (MB)"}, + number={"suffix": " MB", "valueformat": ".2f"}, gauge={ - 'axis': {'range': [0, max(100, file_size_mb * 1.5)]}, - 'bar': {'color': "darkblue"}, - 'steps': [ - {'range': [0, 10], 'color': "lightgreen"}, - {'range': [10, 50], 'color': "yellow"}, - {'range': [50, 100], 'color': "orange"} - ] - } + "axis": {"range": [0, max(100, file_size_mb * 1.5)]}, + "bar": {"color": "darkblue"}, + "steps": [ + {"range": [0, 10], "color": "lightgreen"}, + {"range": [10, 50], "color": "yellow"}, + {"range": [50, 100], "color": "orange"}, + ], + }, ), - row=1, col=1 + row=1, + col=1, ) - + # Architecture components arch_data = ["Nodes", "Inputs", "Outputs", "Initializers"] arch_values = [total_nodes, total_inputs, total_outputs, total_initializers] - + fig.add_trace( go.Bar( x=arch_data, y=arch_values, - marker_color=['blue', 'green', 'orange', 'red'], + marker_color=["blue", "green", "orange", "red"], showlegend=False ), - row=1, col=2 + row=1, + col=2, ) - + # I/O Table io_data = [] - + # Add input info for inp in graph.input[:5]: # Limit to first 5 shape = "Unknown" @@ -346,82 +356,99 @@ def model_metadata(file_path: Path) -> go.Figure: if inp.type and inp.type.tensor_type: # Get shape if inp.type.tensor_type.shape: - dims = [str(d.dim_value) if d.dim_value > 0 else "?" - for d in inp.type.tensor_type.shape.dim] + dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in inp.type.tensor_type.shape.dim] shape = f"[{', '.join(dims)}]" - + # Get data type elem_type = inp.type.tensor_type.elem_type - type_map = {1: 'float32', 2: 'uint8', 3: 'int8', 6: 'int32', - 7: 'int64', 9: 'bool', 10: 'float16', 11: 'double'} - dtype = type_map.get(elem_type, f'type_{elem_type}') - - io_data.append(['Input', inp.name[:20], shape, dtype]) - + type_map = { + 1: "float32", + 2: "uint8", + 3: "int8", + 6: "int32", + 7: "int64", + 9: "bool", + 10: "float16", + 11: "double", + } + dtype = type_map.get(elem_type, f"type_{elem_type}") + + io_data.append(["Input", inp.name[:20], shape, dtype]) + # Add output info for out in graph.output[:5]: # Limit to first 5 shape = "Unknown" dtype = "Unknown" if out.type and out.type.tensor_type: if out.type.tensor_type.shape: - dims = [str(d.dim_value) if d.dim_value > 0 else "?" - for d in out.type.tensor_type.shape.dim] + dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in out.type.tensor_type.shape.dim] shape = f"[{', '.join(dims)}]" - + elem_type = out.type.tensor_type.elem_type - type_map = {1: 'float32', 2: 'uint8', 3: 'int8', 6: 'int32', - 7: 'int64', 9: 'bool', 10: 'float16', 11: 'double'} - dtype = type_map.get(elem_type, f'type_{elem_type}') - - io_data.append(['Output', out.name[:20], shape, dtype]) - + type_map = { + 1: "float32", + 2: "uint8", + 3: "int8", + 6: "int32", + 7: "int64", + 9: "bool", + 10: "float16", + 11: "double", + } + dtype = type_map.get(elem_type, f"type_{elem_type}") + + io_data.append(["Output", out.name[:20], shape, dtype]) + if io_data: fig.add_trace( go.Table( header=dict( - values=['Type', 'Name', 'Shape', 'Data Type'], - fill_color='lightblue', - align='left' + values=["Type", "Name", "Shape", "Data Type"], + fill_color="lightblue", + align="left" ), cells=dict( values=list(zip(*io_data)), - fill_color='white', - align='left' - ) + fill_color="white", + align="left" + ), ), - row=2, col=1 + row=2, + col=1, ) - + # Parameters indicator fig.add_trace( go.Indicator( mode="number", value=total_params, - title={'text': "Total Parameters"}, - number={'suffix': 'M', 'valueformat': '.2f'}, - number_font_size=30 + title={"text": "Total Parameters"}, + number={"suffix": "M", "valueformat": ".2f"}, + number_font_size=30, ), - row=2, col=2 + row=2, + col=2, ) - + fig.update_layout( title={ - 'text': f"ONNX Model Metadata
{total_params/1e6:.2f}M Parameters", - 'x': 0.5, - 'xanchor': 'center', - 'font': {'size': 22} + "text": ("ONNX Model Metadata
" + f"{total_params/1e6:.2f}M Parameters"), + "x": 0.5, + "xanchor": "center", + "font": {"size": 22}, }, height=600, template="plotly_dark", - showlegend=False + showlegend=False, ) - + return fig - + except Exception as e: return create_styled_error_figure( "Metadata Analysis Error", - f"Could not extract ONNX model metadata.", + "Could not extract ONNX model metadata.", f"Error: {str(e)}" ) @@ -435,124 +462,130 @@ def performance_metrics(file_path: Path) -> go.Figure: return create_styled_error_figure( "ONNX Not Available", "ONNX library is required for performance analysis.", - "Install with: pip install onnx" + "Install with: pip install onnx", ) - + try: model = onnx.load(str(file_path)) graph = model.graph - + # Calculate metrics model_size_bytes = file_path.stat().st_size model_size_mb = model_size_bytes / (1024 * 1024) - + # Count parameters total_params = 0 for initializer in graph.initializer: try: tensor = onnx.numpy_helper.to_array(initializer) total_params += tensor.size - except: + except Exception: pass - + # Estimate memory usage (rough approximation) param_memory_mb = (total_params * 4) / (1024 * 1024) # Assume float32 - + # Count operations by complexity - compute_ops = ['Conv', 'MatMul', 'Gemm', 'LSTM', 'GRU'] - efficient_ops = ['Relu', 'Add', 'Mul', 'BatchNormalization', 'Dropout'] - - compute_count = sum(1 for node in graph.node - if any(op in node.op_type for op in compute_ops)) - efficient_count = sum(1 for node in graph.node - if any(op in node.op_type for op in efficient_ops)) + compute_ops = ["Conv", "MatMul", "Gemm", "LSTM", "GRU"] + efficient_ops = ["Relu", "Add", "Mul", "BatchNormalization", "Dropout"] + + compute_count = sum(1 for node in graph.node if any(op in node.op_type for op in compute_ops)) + efficient_count = sum(1 for node in graph.node if any(op in node.op_type for op in efficient_ops)) total_ops = len(graph.node) other_count = total_ops - compute_count - efficient_count - + # Create performance dashboard fig = make_subplots( - rows=2, cols=2, + rows=2, + cols=2, subplot_titles=("Model Efficiency", "Memory Usage", "Operation Types", "Complexity Score"), - specs=[[{"type": "bar"}, {"type": "bar"}], - [{"type": "pie"}, {"type": "indicator"}]] + specs=[[{"type": "bar"}, {"type": "bar"}], [{"type": "pie"}, {"type": "indicator"}]], ) - + # Model efficiency metrics efficiency_metrics = ["Model Size (MB)", "Parameters (M)", "Total Ops"] - efficiency_values = [model_size_mb, total_params/1e6, total_ops] - + efficiency_values = [model_size_mb, total_params / 1e6, total_ops] + fig.add_trace( go.Bar( x=efficiency_metrics, y=efficiency_values, - marker_color=['blue', 'green', 'orange'], + marker_color=["blue", "green", "orange"], showlegend=False ), - row=1, col=1 + row=1, + col=1, ) - + # Memory usage memory_types = ["Parameters", "Est. Inference"] memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate - + fig.add_trace( go.Bar( x=memory_types, y=memory_values, - marker_color=['purple', 'red'], + marker_color=["purple", "red"], showlegend=False ), - row=1, col=2 + row=1, + col=2, ) - + # Operation types pie chart fig.add_trace( go.Pie( - labels=['Compute Ops', 'Efficient Ops', 'Other Ops'], + labels=["Compute Ops", "Efficient Ops", "Other Ops"], values=[compute_count, efficient_count, other_count], - marker_colors=['red', 'green', 'gray'] + marker_colors=["red", "green", "gray"], ), - row=2, col=1 + row=2, + col=1, ) - + # Complexity score (simple heuristic) complexity_score = min(100, (model_size_mb * 10 + total_params / 1e6 * 20 + compute_count)) - + fig.add_trace( go.Indicator( mode="gauge+number", value=complexity_score, - title={'text': "Complexity Score"}, + title={"text": "Complexity Score"}, gauge={ - 'axis': {'range': [0, 100]}, - 'bar': {'color': "darkred" if complexity_score > 70 else "orange" if complexity_score > 40 else "green"}, - 'steps': [ - {'range': [0, 40], 'color': "lightgreen"}, - {'range': [40, 70], 'color': "yellow"}, - {'range': [70, 100], 'color': "lightcoral"} - ] - } + "axis": {"range": [0, 100]}, + "bar": { + "color": "darkred" if complexity_score > 70 else "orange" if complexity_score > 40 else "green" + }, + "steps": [ + {"range": [0, 40], "color": "lightgreen"}, + {"range": [40, 70], "color": "yellow"}, + {"range": [70, 100], "color": "lightcoral"}, + ], + }, ), - row=2, col=2 + row=2, + col=2, ) - + fig.update_layout( title={ - 'text': f"ONNX Performance Metrics
Complexity Score: {complexity_score:.0f}/100", - 'x': 0.5, - 'xanchor': 'center', - 'font': {'size': 22} + "text": ("ONNX Performance Metrics
" + f"" + f"Complexity Score: {complexity_score:.0f}/100"), + "x": 0.5, + "xanchor": "center", + "font": {"size": 22}, }, height=600, template="plotly_dark", - showlegend=False + showlegend=False, ) - + return fig - + except Exception as e: return create_styled_error_figure( "Performance Analysis Error", - f"Could not analyze ONNX model performance.", + "Could not analyze ONNX model performance.", f"Error: {str(e)}" - ) \ No newline at end of file + ) diff --git a/src/ria_toolkit_oss/viz/pytorch_state_dict.py b/src/ria_toolkit_oss/viz/pytorch_state_dict.py index 7db7528..b549662 100644 --- a/src/ria_toolkit_oss/viz/pytorch_state_dict.py +++ b/src/ria_toolkit_oss/viz/pytorch_state_dict.py @@ -1,7 +1,7 @@ -import torch +import numpy as np import plotly.graph_objects as go from plotly.graph_objects import Figure -import numpy as np + def model_summary_plot(state_dict: dict) -> Figure: """Generate a summary plot of the PyTorch model state dict.""" @@ -10,220 +10,212 @@ def model_summary_plot(state_dict: dict) -> Figure: fig = go.Figure() fig.add_annotation( text="No parameters found in state dict", - xref="paper", yref="paper", - x=0.5, y=0.5, showarrow=False, - font=dict(size=16) + xref="paper", + yref="paper", + x=0.5, + y=0.5, + showarrow=False, + font=dict(size=16), ) fig.update_layout( title="Model Layer Parameter Counts", xaxis_title="Layer", yaxis_title="Number of Parameters", - template="plotly_dark" + template="plotly_dark", ) return fig - + # Count parameters by layer type layer_info = [] for key, tensor in state_dict.items(): - if 'weight' in key: + if "weight" in key: try: - layer_name = key.replace('.weight', '') - param_count = tensor.numel() if hasattr(tensor, 'numel') else len(tensor.flatten()) if hasattr(tensor, 'flatten') else 0 - shape = list(tensor.shape) if hasattr(tensor, 'shape') else [len(tensor)] if hasattr(tensor, '__len__') else [] - layer_info.append({ - 'layer': layer_name, - 'parameters': param_count, - 'shape': shape - }) + layer_name = key.replace(".weight", "") + param_count = ( + tensor.numel() + if hasattr(tensor, "numel") + else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0 + ) + shape = ( + list(tensor.shape) + if hasattr(tensor, "shape") + else [len(tensor)] if hasattr(tensor, "__len__") else [] + ) + layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape}) except Exception as e: print(f"Warning: Could not process layer {key}: {e}") continue - + if not layer_info: # Handle case where no weight layers found fig = go.Figure() fig.add_annotation( text="No weight layers found in state dict", - xref="paper", yref="paper", - x=0.5, y=0.5, showarrow=False, - font=dict(size=16) + xref="paper", + yref="paper", + x=0.5, + y=0.5, + showarrow=False, + font=dict(size=16), ) fig.update_layout( title="Model Layer Parameter Counts", xaxis_title="Layer", yaxis_title="Number of Parameters", - template="plotly_dark" + template="plotly_dark", ) return fig - + # Create bar chart of parameter counts - fig = go.Figure(data=[ - go.Bar( - x=[info['layer'] for info in layer_info], - y=[info['parameters'] for info in layer_info], - text=[f"Shape: {info['shape']}" for info in layer_info], - textposition='auto', - ) - ]) - + fig = go.Figure( + data=[ + go.Bar( + x=[info["layer"] for info in layer_info], + y=[info["parameters"] for info in layer_info], + text=[f"Shape: {info['shape']}" for info in layer_info], + textposition="auto", + ) + ] + ) + fig.update_layout( title="Model Layer Parameter Counts", xaxis_title="Layer", yaxis_title="Number of Parameters", - template="plotly_dark" + template="plotly_dark", ) - + return fig + def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure: """Visualize weights for a specific layer.""" if not state_dict: fig = go.Figure() fig.add_annotation( - text="No data in state dict", - xref="paper", yref="paper", - x=0.5, y=0.5, showarrow=False, - font=dict(size=16) - ) - fig.update_layout( - title="Layer Weights", - template="plotly_dark" + text="No data in state dict", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16) ) + fig.update_layout(title="Layer Weights", template="plotly_dark") return fig - + if layer_name is None: # Get first weight tensor - weight_keys = [k for k in state_dict.keys() if 'weight' in k] + weight_keys = [k for k in state_dict.keys() if "weight" in k] if not weight_keys: fig = go.Figure() fig.add_annotation( text="No weight tensors found in state dict", - xref="paper", yref="paper", - x=0.5, y=0.5, showarrow=False, - font=dict(size=16) - ) - fig.update_layout( - title="Layer Weights", - template="plotly_dark" + xref="paper", + yref="paper", + x=0.5, + y=0.5, + showarrow=False, + font=dict(size=16), ) + fig.update_layout(title="Layer Weights", template="plotly_dark") return fig layer_name = weight_keys[0] - + try: weights = state_dict[layer_name] - + # Convert to numpy if it's a torch tensor - if hasattr(weights, 'numpy'): - weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy() - elif hasattr(weights, 'cpu'): + if hasattr(weights, "numpy"): + weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy() + elif hasattr(weights, "cpu"): weights_np = weights.cpu().detach().numpy() else: weights_np = np.array(weights) - + # For 2D weights, create heatmap if len(weights_np.shape) == 2: - fig = go.Figure(data=go.Heatmap( - z=weights_np, - colorscale='RdBu', - zmid=0 - )) - fig.update_layout( - title=f"Weights Heatmap: {layer_name}", - template="plotly_dark" - ) + fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0)) + fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark") else: # For other shapes, flatten and show histogram flat_weights = weights_np.flatten() fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)]) - fig.update_layout( - title=f"Weight Distribution: {layer_name}", - template="plotly_dark" - ) - + fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark") + return fig - + except Exception as e: fig = go.Figure() fig.add_annotation( text=f"Error processing layer {layer_name}: {str(e)}", - xref="paper", yref="paper", - x=0.5, y=0.5, showarrow=False, - font=dict(size=14) - ) - fig.update_layout( - title="Layer Weights - Error", - template="plotly_dark" + xref="paper", + yref="paper", + x=0.5, + y=0.5, + showarrow=False, + font=dict(size=14), ) + fig.update_layout(title="Layer Weights - Error", template="plotly_dark") return fig + def weight_distribution_plot(state_dict: dict) -> Figure: """Show distribution of weights across all layers.""" if not state_dict: fig = go.Figure() fig.add_annotation( - text="No data in state dict", - xref="paper", yref="paper", - x=0.5, y=0.5, showarrow=False, - font=dict(size=16) + text="No data in state dict", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16) ) fig.update_layout( title="Overall Weight Distribution", xaxis_title="Weight Value", yaxis_title="Frequency", - template="plotly_dark" + template="plotly_dark", ) return fig - + all_weights = [] layer_names = [] - + for key, tensor in state_dict.items(): - if 'weight' in key: + if "weight" in key: try: # Convert to numpy if it's a torch tensor - if hasattr(tensor, 'numpy'): - weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy() - elif hasattr(tensor, 'cpu'): + if hasattr(tensor, "numpy"): + weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy() + elif hasattr(tensor, "cpu"): weights_np = tensor.cpu().detach().numpy() else: weights_np = np.array(tensor) - + flat_weights = weights_np.flatten() all_weights.extend(flat_weights) layer_names.extend([key] * len(flat_weights)) except Exception as e: print(f"Warning: Could not process weights for layer {key}: {e}") continue - + if not all_weights: fig = go.Figure() fig.add_annotation( text="No weight data found in state dict", - xref="paper", yref="paper", - x=0.5, y=0.5, showarrow=False, - font=dict(size=16) + xref="paper", + yref="paper", + x=0.5, + y=0.5, + showarrow=False, + font=dict(size=16), ) fig.update_layout( title="Overall Weight Distribution", xaxis_title="Weight Value", yaxis_title="Frequency", - template="plotly_dark" + template="plotly_dark", ) return fig - - fig = go.Figure(data=[ - go.Histogram( - x=all_weights, - nbinsx=100, - name="All Weights" - ) - ]) - + + fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")]) + fig.update_layout( title="Overall Weight Distribution", xaxis_title="Weight Value", yaxis_title="Frequency", - template="plotly_dark" + template="plotly_dark", ) - - return fig \ No newline at end of file + + return fig diff --git a/src/ria_toolkit_oss/viz/radio_dataset.py b/src/ria_toolkit_oss/viz/radio_dataset.py index edc5004..cae4084 100644 --- a/src/ria_toolkit_oss/viz/radio_dataset.py +++ b/src/ria_toolkit_oss/viz/radio_dataset.py @@ -6,7 +6,6 @@ import random from typing import Optional import numpy as np -import pandas as pd import plotly.express as px import plotly.graph_objects as go from plotly.graph_objects import Figure @@ -16,33 +15,32 @@ from plotly.subplots import make_subplots def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> Figure: """Create a professional error figure with Qoherent dark theme styling.""" fig = go.Figure() - + # Create a clean, centered text display using Plotly's text formatting main_text = f"⚠️ {title}

" main_text += f"{message}" - + if suggestion: - main_text += f"

💡 Suggestion:
" + main_text += "

💡 Suggestion:
" main_text += f"{suggestion}" - + # Add the main text annotation fig.add_annotation( text=main_text, - xref="paper", yref="paper", - x=0.5, y=0.5, - xanchor='center', yanchor='middle', + xref="paper", + yref="paper", + x=0.5, + y=0.5, + xanchor="center", + yanchor="middle", showarrow=False, align="center", borderwidth=2, bordercolor="#4a5568", bgcolor="#2d3748", - font=dict( - family="Arial, sans-serif", - size=14, - color="#e2e8f0" - ) + font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"), ) - + # Update layout with dark theme fig.update_layout( title="", @@ -51,13 +49,13 @@ def create_styled_error_figure(title: str, message: str, suggestion: str = None) margin=dict(l=40, r=40, t=40, b=40), plot_bgcolor="#1a202c", paper_bgcolor="#1a202c", - font=dict(color="#e2e8f0") + font=dict(color="#e2e8f0"), ) - + # Remove axes and grid fig.update_xaxes(visible=False) fig.update_yaxes(visible=False) - + return fig @@ -67,37 +65,37 @@ def _check_dataset_compatibility(dataset, plot_type: str) -> tuple[bool, str]: """ try: metadata = dataset.metadata - + if len(metadata) == 0: return False, "Dataset is empty" - + if plot_type == "class_distribution": # Check if we have any categorical columns - categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object'] + categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"] alternatives = ["class", "label", "modulation", "impairment", "use_case", "category", "labels"] - + has_class_col = any(alt in metadata.columns for alt in alternatives) has_categorical = len(categorical_cols) > 0 - + if not has_class_col and not has_categorical: return False, "No categorical columns found for class distribution" - + elif plot_type == "sample_spectrogram": # Check if we can generate a valid spectrogram if len(metadata) < 1: return False, "No samples available for spectrogram" - + # Check if we can access sample data (basic test) try: - sample_data = dataset[0] if hasattr(dataset, '__getitem__') else None + sample_data = dataset[0] if hasattr(dataset, "__getitem__") else None if sample_data is None or len(sample_data) < 32: return False, "Insufficient sample data for spectrogram (need at least 32 points)" except Exception: # If we can't access data, we'll rely on synthetic data generation pass - + return True, "" - + except Exception as e: return False, f"Dataset compatibility check failed: {str(e)}" @@ -111,11 +109,11 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure: return create_styled_error_figure( "Dataset Not Compatible", "This dataset doesn't have categorical labels needed for class distribution analysis.", - "Try using the Dataset Overview widget to explore the available data columns." + "Try using the Dataset Overview widget to explore the available data columns.", ) - + metadata = dataset.metadata - + # Find the class column if class_key not in metadata.columns: # Try common alternatives @@ -127,47 +125,44 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure: else: # Use first categorical column for col in metadata.columns: - if metadata[col].dtype == 'object' or metadata[col].nunique() < 50: + if metadata[col].dtype == "object" or metadata[col].nunique() < 50: class_key = col break - + if class_key not in metadata.columns: return create_styled_error_figure( "No Class Labels Found", "This dataset contains numerical data without categorical labels.", - "Try using the Dataset Overview widget for data analysis, or check if your dataset has hidden categorical columns." + ("Try using the Dataset Overview widget for data analysis, " + "or check if your dataset has hidden categorical columns."), ) - + # Count examples per class (limit to top 20 for performance) class_counts = metadata[class_key].value_counts() if len(class_counts) > 20: class_counts = class_counts.head(20) - + class_counts = class_counts.sort_index() - + # Create simple bar plot - fig = px.bar( - x=class_counts.index, - y=class_counts.values, - title=f'Class Distribution: {class_key.title()}' - ) - - fig.update_traces(texttemplate='%{y}', textposition='outside') + fig = px.bar(x=class_counts.index, y=class_counts.values, title=f"Class Distribution: {class_key.title()}") + + fig.update_traces(texttemplate="%{y}", textposition="outside") fig.update_layout( xaxis_title=class_key.title(), - yaxis_title='Number of Examples', + yaxis_title="Number of Examples", showlegend=False, height=400, - template="plotly_dark" + template="plotly_dark", ) - + return fig - + except Exception as e: return create_styled_error_figure( "Class Distribution Error", - f"An error occurred while generating the class distribution plot.", - f"Technical details: {str(e)}" + "An error occurred while generating the class distribution plot.", + f"Technical details: {str(e)}", ) @@ -176,96 +171,84 @@ def dataset_overview_plot(dataset) -> Figure: try: metadata = dataset.metadata total_examples = len(metadata) - + # Create subplot with multiple charts - + # Determine subplot titles based on data type - categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object'] - numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ['int64', 'float64']] - + categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"] + numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ["int64", "float64"]] + dist_title = "Value Distribution" if categorical_cols else "Data Distribution" - + fig = make_subplots( - rows=2, cols=2, + rows=2, + cols=2, subplot_titles=("Dataset Size", "Data Types", dist_title, "Statistics Summary"), - specs=[[{"type": "indicator"}, {"type": "bar"}], - [{"type": "histogram" if not categorical_cols else "bar"}, {"type": "table"}]] + specs=[ + [{"type": "indicator"}, {"type": "bar"}], + [{"type": "histogram" if not categorical_cols else "bar"}, {"type": "table"}], + ], ) - + # Top left: Dataset size indicator fig.add_trace( go.Indicator( - mode="number", - value=total_examples, - title={"text": "Total Examples"}, - number={"font": {"size": 40}} + mode="number", value=total_examples, title={"text": "Total Examples"}, number={"font": {"size": 40}} ), - row=1, col=1 + row=1, + col=1, ) - + # Top right: Data types distribution dtype_counts = metadata.dtypes.value_counts() fig.add_trace( go.Bar( - x=[str(dt) for dt in dtype_counts.index], - y=dtype_counts.values, - name="Data Types", - showlegend=False + x=[str(dt) for dt in dtype_counts.index], y=dtype_counts.values, name="Data Types", showlegend=False ), - row=1, col=2 + row=1, + col=2, ) - + # Bottom left: Show distribution of numeric columns or categorical if available - categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object'] - numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ['int64', 'float64']] - + categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"] + numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ["int64", "float64"]] + if categorical_cols: col = categorical_cols[0] # Show first categorical column value_counts = metadata[col].value_counts().head(10) fig.add_trace( - go.Bar( - x=value_counts.index, - y=value_counts.values, - name=f"{col} Distribution", - showlegend=False - ), - row=2, col=1 + go.Bar(x=value_counts.index, y=value_counts.values, name=f"{col} Distribution", showlegend=False), + row=2, + col=1, ) elif numeric_cols: # Show histogram of first numeric column col = numeric_cols[0] fig.add_trace( - go.Histogram( - x=metadata[col], - name=f"{col} Distribution", - showlegend=False, - nbinsx=20 - ), - row=2, col=1 + go.Histogram(x=metadata[col], name=f"{col} Distribution", showlegend=False, nbinsx=20), row=2, col=1 ) - + # Bottom right: Basic statistics table stats_data = [] - display_cols = (numeric_cols[:5] if len(numeric_cols) > 0 else metadata.columns[:5]) - + display_cols = numeric_cols[:5] if len(numeric_cols) > 0 else metadata.columns[:5] + for col in display_cols: - if metadata[col].dtype in ['int64', 'float64']: - stats_data.append([ - col[:15] + "..." if len(col) > 15 else col, # Truncate long column names - f"{metadata[col].mean():.3f}", - f"{metadata[col].std():.3f}", - f"{metadata[col].min():.3f}", - f"{metadata[col].max():.3f}" - ]) + if metadata[col].dtype in ["int64", "float64"]: + stats_data.append( + [ + col[:15] + "..." if len(col) > 15 else col, # Truncate long column names + f"{metadata[col].mean():.3f}", + f"{metadata[col].std():.3f}", + f"{metadata[col].min():.3f}", + f"{metadata[col].max():.3f}", + ] + ) else: unique_count = metadata[col].nunique() - stats_data.append([ - col[:15] + "..." if len(col) > 15 else col, - "N/A", "N/A", - f"{unique_count} unique", - "N/A" - ]) - + stats_data.append( + [col[:15] + "..." if len(col) > 15 else col, "N/A", "N/A", f"{unique_count} unique", "N/A"] + ) + if stats_data: fig.add_trace( go.Table( @@ -273,41 +256,127 @@ def dataset_overview_plot(dataset) -> Figure: values=["Column", "Mean", "Std", "Min/Unique", "Max"], fill_color="rgba(30, 30, 30, 0.8)", align="center", - font=dict(color="white", size=12) + font=dict(color="white", size=12), ), cells=dict( values=list(zip(*stats_data)), fill_color="rgba(50, 50, 50, 0.6)", align="center", - font=dict(color="white", size=11) - ) + font=dict(color="white", size=11), + ), ), - row=2, col=2 + row=2, + col=2, ) - + # Create informative title total_cols = len(metadata.columns) title = f"Dataset Overview - {total_examples} samples, {total_cols} columns" if total_cols > 5: - title += f" (showing first 5)" - - fig.update_layout( - title=title, - height=600, - showlegend=False, - template="plotly_dark" - ) - + title += " (showing first 5)" + + fig.update_layout(title=title, height=600, showlegend=False, template="plotly_dark") + return fig - + except Exception as e: return create_styled_error_figure( "Dataset Overview Error", "An error occurred while generating the dataset overview.", - f"Technical details: {str(e)}" + f"Technical details: {str(e)}", ) +def _find_class_column(metadata, class_key: str) -> str: + """Find the appropriate class column in metadata.""" + if class_key in metadata.columns: + return class_key + + alternatives = ["class", "label", "modulation", "impairment", "use_case"] + for alt in alternatives: + if alt in metadata.columns: + return alt + return class_key + + +def _get_sample_data(dataset, sample_idx: int): + """Get sample data from dataset, with synthetic fallback.""" + try: + return dataset[sample_idx] + except Exception: + # Generate synthetic signal based on class + n_samples = 1024 + t = np.linspace(0, 1, n_samples) + freq = 0.1 + 0.05 * sample_idx % 5 # Vary frequency by sample + sample_data = np.exp(1j * 2 * np.pi * freq * t) + # Add some noise + sample_data += 0.1 * (np.random.randn(n_samples) + 1j * np.random.randn(n_samples)) + return sample_data + + +def _calculate_spectrogram_params(n_samples: int) -> tuple[int, int, int, int]: + """Calculate spectrogram parameters based on sample length.""" + if n_samples < 32: + raise ValueError(f"Insufficient data: need at least 32 samples, got {n_samples}") + + nperseg = min(256, max(32, n_samples // 4)) + hop_length = max(1, nperseg // 2) + + # Adjust for very short signals + if n_samples < nperseg: + nperseg = n_samples + hop_length = 1 + + n_frames = max(1, (n_samples - nperseg) // hop_length + 1) + freq_bins = max(1, nperseg // 2) + + return nperseg, hop_length, n_frames, freq_bins + + +def _compute_spectrogram(sample_data, nperseg: int, hop_length: int, n_frames: int, freq_bins: int): + """Compute spectrogram using FFT.""" + n_samples = len(sample_data) + Sxx = np.zeros((freq_bins, n_frames)) + + for i in range(n_frames): + start_idx = i * hop_length + end_idx = min(start_idx + nperseg, n_samples) + + if end_idx > start_idx: + windowed = sample_data[start_idx:end_idx] + + # Pad if necessary to maintain nperseg size + if len(windowed) < nperseg: + windowed = np.pad(windowed, (0, nperseg - len(windowed)), mode="constant") + + fft_result = np.fft.fft(windowed) + Sxx[:, i] = np.abs(fft_result[:freq_bins]) ** 2 + + return Sxx + + +def _create_spectrogram_figure(Sxx, n_frames: int, hop_length: int, n_samples: int, freq_bins: int, + sample_idx: int, class_key: str, sample_metadata) -> Figure: + """Create the plotly figure for the spectrogram.""" + # Convert to dB + Sxx_db = 10 * np.log10(Sxx + 1e-10) + + # Create time and frequency vectors + t = np.arange(n_frames) * hop_length / max(1, n_samples) + f = np.linspace(0, 0.5, freq_bins) + + # Create plot + fig = go.Figure(data=go.Heatmap(z=Sxx_db, x=t, y=f, colorscale="viridis", colorbar=dict(title="Power (dB)"))) + + # Add title with metadata + title = f"Sample Spectrogram (Index: {sample_idx})" + if class_key in sample_metadata: + title += f" - {class_key}: {sample_metadata[class_key]}" + + fig.update_layout(title=title, xaxis_title="Time", yaxis_title="Frequency", height=400, template="plotly_dark") + return fig + + def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx: Optional[int] = None) -> Figure: """Generate a spectrogram plot from a sample in the dataset.""" try: @@ -317,114 +386,36 @@ def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx: return create_styled_error_figure( "Spectrogram Not Available", "This dataset doesn't have sufficient signal data for spectrogram visualization.", - "Ensure your dataset contains complex-valued signal samples with at least 32 data points per sample." + "Ensure your dataset contains complex-valued signal samples with at least 32 data points per sample.", ) - + metadata = dataset.metadata - if len(metadata) == 0: raise ValueError("Dataset is empty") - - # Find class column - if class_key not in metadata.columns: - alternatives = ["class", "label", "modulation", "impairment", "use_case"] - for alt in alternatives: - if alt in metadata.columns: - class_key = alt - break - - # Select sample + + # Find class column and select sample + class_key = _find_class_column(metadata, class_key) if sample_idx is None: sample_idx = random.randint(0, len(metadata) - 1) - sample_metadata = metadata.iloc[sample_idx] - - # Try to get actual sample data, fall back to synthetic - try: - sample_data = dataset[sample_idx] - except: - # Generate synthetic signal based on class - n_samples = 1024 - t = np.linspace(0, 1, n_samples) - freq = 0.1 + 0.05 * sample_idx % 5 # Vary frequency by sample - sample_data = np.exp(1j * 2 * np.pi * freq * t) - # Add some noise - sample_data += 0.1 * (np.random.randn(n_samples) + 1j * np.random.randn(n_samples)) - - # Ensure complex data + + # Get sample data and ensure it's complex + sample_data = _get_sample_data(dataset, sample_idx) if not np.iscomplexobj(sample_data): sample_data = sample_data.astype(complex) - - # Simple FFT-based spectrogram + + # Calculate spectrogram parameters and compute spectrogram n_samples = len(sample_data) - - # Ensure minimum viable data size - if n_samples < 32: - raise ValueError(f"Insufficient data: need at least 32 samples, got {n_samples}") - - nperseg = min(256, max(32, n_samples // 4)) - - # Create spectrogram using numpy (no scipy dependency) - hop_length = max(1, nperseg // 2) # Prevent zero hop_length - - # Ensure we can create at least one frame - if n_samples < nperseg: - nperseg = n_samples - hop_length = 1 - - n_frames = max(1, (n_samples - nperseg) // hop_length + 1) - - freq_bins = max(1, nperseg // 2) # Prevent zero frequency bins - Sxx = np.zeros((freq_bins, n_frames)) - - for i in range(n_frames): - start_idx = i * hop_length - end_idx = min(start_idx + nperseg, n_samples) # Prevent index overflow - - if end_idx > start_idx: # Ensure we have data to process - windowed = sample_data[start_idx:end_idx] - - # Pad if necessary to maintain nperseg size - if len(windowed) < nperseg: - windowed = np.pad(windowed, (0, nperseg - len(windowed)), mode='constant') - - fft_result = np.fft.fft(windowed) - Sxx[:, i] = np.abs(fft_result[:freq_bins]) ** 2 - - # Convert to dB - Sxx_db = 10 * np.log10(Sxx + 1e-10) - - # Create time and frequency vectors - t = np.arange(n_frames) * hop_length / max(1, n_samples) # Prevent division by zero - f = np.linspace(0, 0.5, freq_bins) - - # Create plot - fig = go.Figure(data=go.Heatmap( - z=Sxx_db, - x=t, - y=f, - colorscale='viridis', - colorbar=dict(title="Power (dB)") - )) - - # Add title with metadata - title = f"Sample Spectrogram (Index: {sample_idx})" - if class_key in sample_metadata: - title += f" - {class_key}: {sample_metadata[class_key]}" - - fig.update_layout( - title=title, - xaxis_title="Time", - yaxis_title="Frequency", - height=400, - template="plotly_dark" - ) - - return fig - + nperseg, hop_length, n_frames, freq_bins = _calculate_spectrogram_params(n_samples) + Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins) + + # Create and return the figure + return _create_spectrogram_figure(Sxx, n_frames, hop_length, n_samples, freq_bins, + sample_idx, class_key, sample_metadata) + except Exception as e: return create_styled_error_figure( "Spectrogram Error", "An error occurred while generating the spectrogram plot.", - f"Technical details: {str(e)}" - ) \ No newline at end of file + f"Technical details: {str(e)}", + ) -- 2.34.1 From c06e58f5d6576e7f299017478cf6a7a6aa1a8cb1 Mon Sep 17 00:00:00 2001 From: ben Date: Tue, 21 Oct 2025 11:48:40 -0400 Subject: [PATCH 6/8] Formatting Fixes --- src/ria_toolkit_oss/viz/pytorch_state_dict.py | 259 +++++++++--------- 1 file changed, 127 insertions(+), 132 deletions(-) diff --git a/src/ria_toolkit_oss/viz/pytorch_state_dict.py b/src/ria_toolkit_oss/viz/pytorch_state_dict.py index b549662..9bceb1c 100644 --- a/src/ria_toolkit_oss/viz/pytorch_state_dict.py +++ b/src/ria_toolkit_oss/viz/pytorch_state_dict.py @@ -1,188 +1,190 @@ -import numpy as np import plotly.graph_objects as go from plotly.graph_objects import Figure +import numpy as np + + +def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure: + """Create a professional error figure with Qoherent dark theme styling.""" + fig = go.Figure() + + # Create a clean, centered text display using Plotly's text formatting + main_text = f"⚠️ {title}

" + main_text += f"{message}" + + if suggestion: + main_text += "

💡 Suggestion:
" + main_text += f"{suggestion}" + + # Add the main text annotation + fig.add_annotation( + text=main_text, + xref="paper", + yref="paper", + x=0.5, + y=0.5, + xanchor="center", + yanchor="middle", + showarrow=False, + align="center", + borderwidth=2, + bordercolor="#4a5568", + bgcolor="#2d3748", + font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"), + ) + + # Update layout with dark theme + fig.update_layout( + title="", + height=400, + template="plotly_dark", + margin=dict(l=40, r=40, t=40, b=40), + plot_bgcolor="#1a202c", + paper_bgcolor="#1a202c", + font=dict(color="#e2e8f0"), + ) + + # Remove axes and grid + fig.update_xaxes(visible=False) + fig.update_yaxes(visible=False) + + return fig def model_summary_plot(state_dict: dict) -> Figure: """Generate a summary plot of the PyTorch model state dict.""" if not state_dict: - # Handle empty state dict - fig = go.Figure() - fig.add_annotation( - text="No parameters found in state dict", - xref="paper", - yref="paper", - x=0.5, - y=0.5, - showarrow=False, - font=dict(size=16), + return create_styled_error_figure( + "Empty State Dict", + "No parameters found in state dict", + "Ensure the model state dictionary contains weight parameters" ) - fig.update_layout( - title="Model Layer Parameter Counts", - xaxis_title="Layer", - yaxis_title="Number of Parameters", - template="plotly_dark", - ) - return fig - # Count parameters by layer type layer_info = [] for key, tensor in state_dict.items(): - if "weight" in key: + if 'weight' in key: try: - layer_name = key.replace(".weight", "") + layer_name = key.replace('.weight', '') param_count = ( - tensor.numel() - if hasattr(tensor, "numel") - else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0 + tensor.numel() if hasattr(tensor, 'numel') + else len(tensor.flatten()) if hasattr(tensor, 'flatten') + else 0 ) shape = ( - list(tensor.shape) - if hasattr(tensor, "shape") - else [len(tensor)] if hasattr(tensor, "__len__") else [] + list(tensor.shape) if hasattr(tensor, 'shape') + else [len(tensor)] if hasattr(tensor, '__len__') + else [] ) - layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape}) + layer_info.append({ + 'layer': layer_name, + 'parameters': param_count, + 'shape': shape + }) except Exception as e: print(f"Warning: Could not process layer {key}: {e}") continue - if not layer_info: - # Handle case where no weight layers found - fig = go.Figure() - fig.add_annotation( - text="No weight layers found in state dict", - xref="paper", - yref="paper", - x=0.5, - y=0.5, - showarrow=False, - font=dict(size=16), + return create_styled_error_figure( + "No Weight Layers Found", + "No weight layers found in state dict", + "Ensure the state dictionary contains layers with '.weight' parameters" ) - fig.update_layout( - title="Model Layer Parameter Counts", - xaxis_title="Layer", - yaxis_title="Number of Parameters", - template="plotly_dark", - ) - return fig - # Create bar chart of parameter counts - fig = go.Figure( - data=[ - go.Bar( - x=[info["layer"] for info in layer_info], - y=[info["parameters"] for info in layer_info], - text=[f"Shape: {info['shape']}" for info in layer_info], - textposition="auto", - ) - ] - ) - + fig = go.Figure(data=[ + go.Bar( + x=[info['layer'] for info in layer_info], + y=[info['parameters'] for info in layer_info], + text=[f"Shape: {info['shape']}" for info in layer_info], + textposition='auto', + ) + ]) fig.update_layout( title="Model Layer Parameter Counts", xaxis_title="Layer", yaxis_title="Number of Parameters", - template="plotly_dark", + template="plotly_dark" ) - return fig def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure: """Visualize weights for a specific layer.""" if not state_dict: - fig = go.Figure() - fig.add_annotation( - text="No data in state dict", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16) + return create_styled_error_figure( + "Empty State Dict", + "No data in state dict", + "Ensure the model state dictionary contains data" ) - fig.update_layout(title="Layer Weights", template="plotly_dark") - return fig - if layer_name is None: # Get first weight tensor - weight_keys = [k for k in state_dict.keys() if "weight" in k] + weight_keys = [k for k in state_dict.keys() if 'weight' in k] if not weight_keys: - fig = go.Figure() - fig.add_annotation( - text="No weight tensors found in state dict", - xref="paper", - yref="paper", - x=0.5, - y=0.5, - showarrow=False, - font=dict(size=16), + return create_styled_error_figure( + "No Weight Tensors Found", + "No weight tensors found in state dict", + "Ensure the state dictionary contains layers with '.weight' parameters" ) - fig.update_layout(title="Layer Weights", template="plotly_dark") - return fig layer_name = weight_keys[0] - try: weights = state_dict[layer_name] - # Convert to numpy if it's a torch tensor - if hasattr(weights, "numpy"): - weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy() - elif hasattr(weights, "cpu"): + if hasattr(weights, 'numpy'): + weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy() + elif hasattr(weights, 'cpu'): weights_np = weights.cpu().detach().numpy() else: weights_np = np.array(weights) - # For 2D weights, create heatmap if len(weights_np.shape) == 2: - fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0)) - fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark") + fig = go.Figure(data=go.Heatmap( + z=weights_np, + colorscale='RdBu', + zmid=0 + )) + fig.update_layout( + title=f"Weights Heatmap: {layer_name}", + template="plotly_dark" + ) else: # For other shapes, flatten and show histogram flat_weights = weights_np.flatten() fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)]) - fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark") + fig.update_layout( + title=f"Weight Distribution: {layer_name}", + template="plotly_dark" + ) return fig except Exception as e: - fig = go.Figure() - fig.add_annotation( - text=f"Error processing layer {layer_name}: {str(e)}", - xref="paper", - yref="paper", - x=0.5, - y=0.5, - showarrow=False, - font=dict(size=14), + return create_styled_error_figure( + "Layer Processing Error", + f"Error processing layer {layer_name}: {str(e)}", + "Check that the layer name exists and contains valid tensor data" ) - fig.update_layout(title="Layer Weights - Error", template="plotly_dark") - return fig def weight_distribution_plot(state_dict: dict) -> Figure: """Show distribution of weights across all layers.""" if not state_dict: - fig = go.Figure() - fig.add_annotation( - text="No data in state dict", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16) + return create_styled_error_figure( + "Empty State Dict", + "No data in state dict", + "Ensure the model state dictionary contains data" ) - fig.update_layout( - title="Overall Weight Distribution", - xaxis_title="Weight Value", - yaxis_title="Frequency", - template="plotly_dark", - ) - return fig all_weights = [] layer_names = [] for key, tensor in state_dict.items(): - if "weight" in key: + if 'weight' in key: try: # Convert to numpy if it's a torch tensor - if hasattr(tensor, "numpy"): - weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy() - elif hasattr(tensor, "cpu"): + if hasattr(tensor, 'numpy'): + weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy() + elif hasattr(tensor, 'cpu'): weights_np = tensor.cpu().detach().numpy() else: weights_np = np.array(tensor) - flat_weights = weights_np.flatten() all_weights.extend(flat_weights) layer_names.extend([key] * len(flat_weights)) @@ -191,31 +193,24 @@ def weight_distribution_plot(state_dict: dict) -> Figure: continue if not all_weights: - fig = go.Figure() - fig.add_annotation( - text="No weight data found in state dict", - xref="paper", - yref="paper", - x=0.5, - y=0.5, - showarrow=False, - font=dict(size=16), + return create_styled_error_figure( + "No Weight Data Found", + "No weight data found in state dict", + "Ensure the state dictionary contains layers with '.weight' parameters" ) - fig.update_layout( - title="Overall Weight Distribution", - xaxis_title="Weight Value", - yaxis_title="Frequency", - template="plotly_dark", - ) - return fig - fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")]) + fig = go.Figure(data=[ + go.Histogram( + x=all_weights, + nbinsx=100, + name="All Weights" + ) + ]) fig.update_layout( title="Overall Weight Distribution", xaxis_title="Weight Value", yaxis_title="Frequency", - template="plotly_dark", + template="plotly_dark" ) - return fig -- 2.34.1 From 4872eea1161e63d2cfa6bf35518614eeadb9027c Mon Sep 17 00:00:00 2001 From: ben Date: Wed, 22 Oct 2025 11:41:37 -0400 Subject: [PATCH 7/8] more formatting fixes --- src/ria_toolkit_oss/viz/pytorch_state_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ria_toolkit_oss/viz/pytorch_state_dict.py b/src/ria_toolkit_oss/viz/pytorch_state_dict.py index 9bceb1c..578ebd0 100644 --- a/src/ria_toolkit_oss/viz/pytorch_state_dict.py +++ b/src/ria_toolkit_oss/viz/pytorch_state_dict.py @@ -1,6 +1,6 @@ +import numpy as np import plotly.graph_objects as go from plotly.graph_objects import Figure -import numpy as np def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure: -- 2.34.1 From a0b46a35e28b13a8f289db41dea822fe775e16cb Mon Sep 17 00:00:00 2001 From: ben Date: Wed, 22 Oct 2025 12:02:05 -0400 Subject: [PATCH 8/8] format fix 3? --- src/ria_toolkit_oss/viz/onnx.py | 89 ++++++--------- src/ria_toolkit_oss/viz/pytorch_state_dict.py | 102 +++++++----------- src/ria_toolkit_oss/viz/radio_dataset.py | 23 ++-- 3 files changed, 87 insertions(+), 127 deletions(-) diff --git a/src/ria_toolkit_oss/viz/onnx.py b/src/ria_toolkit_oss/viz/onnx.py index b92c3e4..e260eeb 100644 --- a/src/ria_toolkit_oss/viz/onnx.py +++ b/src/ria_toolkit_oss/viz/onnx.py @@ -75,9 +75,7 @@ def graph_structure(file_path: Path) -> go.Figure: """ if not ONNX_AVAILABLE: return create_styled_error_figure( - "ONNX Not Available", - "ONNX library is required for model analysis.", - "Install with: pip install onnx" + "ONNX Not Available", "ONNX library is required for model analysis.", "Install with: pip install onnx" ) try: @@ -88,9 +86,7 @@ def graph_structure(file_path: Path) -> go.Figure: if len(nodes) == 0: return create_styled_error_figure( - "Empty Model", - "This ONNX model contains no operators.", - "Please check if the model file is valid." + "Empty Model", "This ONNX model contains no operators.", "Please check if the model file is valid." ) # Create network diagram data @@ -153,8 +149,10 @@ def graph_structure(file_path: Path) -> go.Figure: fig.update_layout( title={ - "text": ("ONNX Graph Structure
" - f"{len(nodes)} Operators"), + "text": ( + "ONNX Graph Structure
" + f"{len(nodes)} Operators" + ), "x": 0.5, "xanchor": "center", "font": {"size": 22}, @@ -173,9 +171,7 @@ def graph_structure(file_path: Path) -> go.Figure: except Exception as e: return create_styled_error_figure( - "Graph Analysis Error", - "Could not analyze ONNX model structure.", - f"Error: {str(e)}" + "Graph Analysis Error", "Could not analyze ONNX model structure.", f"Error: {str(e)}" ) @@ -186,9 +182,7 @@ def operator_analysis(file_path: Path) -> go.Figure: """ if not ONNX_AVAILABLE: return create_styled_error_figure( - "ONNX Not Available", - "ONNX library is required for operator analysis.", - "Install with: pip install onnx" + "ONNX Not Available", "ONNX library is required for operator analysis.", "Install with: pip install onnx" ) try: @@ -248,8 +242,10 @@ def operator_analysis(file_path: Path) -> go.Figure: fig.update_layout( title={ - "text": ("ONNX Operator Analysis
" - f"{len(op_counts)} Unique Types"), + "text": ( + "ONNX Operator Analysis
" + f"{len(op_counts)} Unique Types" + ), "x": 0.5, "xanchor": "center", "font": {"size": 22}, @@ -262,9 +258,7 @@ def operator_analysis(file_path: Path) -> go.Figure: except Exception as e: return create_styled_error_figure( - "Operator Analysis Error", - "Could not analyze ONNX operators.", - f"Error: {str(e)}" + "Operator Analysis Error", "Could not analyze ONNX operators.", f"Error: {str(e)}" ) @@ -275,9 +269,7 @@ def model_metadata(file_path: Path) -> go.Figure: """ if not ONNX_AVAILABLE: return create_styled_error_figure( - "ONNX Not Available", - "ONNX library is required for metadata analysis.", - "Install with: pip install onnx" + "ONNX Not Available", "ONNX library is required for metadata analysis.", "Install with: pip install onnx" ) try: @@ -336,12 +328,7 @@ def model_metadata(file_path: Path) -> go.Figure: arch_values = [total_nodes, total_inputs, total_outputs, total_initializers] fig.add_trace( - go.Bar( - x=arch_data, - y=arch_values, - marker_color=["blue", "green", "orange", "red"], - showlegend=False - ), + go.Bar(x=arch_data, y=arch_values, marker_color=["blue", "green", "orange", "red"], showlegend=False), row=1, col=2, ) @@ -402,16 +389,8 @@ def model_metadata(file_path: Path) -> go.Figure: if io_data: fig.add_trace( go.Table( - header=dict( - values=["Type", "Name", "Shape", "Data Type"], - fill_color="lightblue", - align="left" - ), - cells=dict( - values=list(zip(*io_data)), - fill_color="white", - align="left" - ), + header=dict(values=["Type", "Name", "Shape", "Data Type"], fill_color="lightblue", align="left"), + cells=dict(values=list(zip(*io_data)), fill_color="white", align="left"), ), row=2, col=1, @@ -432,8 +411,10 @@ def model_metadata(file_path: Path) -> go.Figure: fig.update_layout( title={ - "text": ("ONNX Model Metadata
" - f"{total_params/1e6:.2f}M Parameters"), + "text": ( + "ONNX Model Metadata
" + f"{total_params/1e6:.2f}M Parameters" + ), "x": 0.5, "xanchor": "center", "font": {"size": 22}, @@ -447,9 +428,7 @@ def model_metadata(file_path: Path) -> go.Figure: except Exception as e: return create_styled_error_figure( - "Metadata Analysis Error", - "Could not extract ONNX model metadata.", - f"Error: {str(e)}" + "Metadata Analysis Error", "Could not extract ONNX model metadata.", f"Error: {str(e)}" ) @@ -508,10 +487,7 @@ def performance_metrics(file_path: Path) -> go.Figure: fig.add_trace( go.Bar( - x=efficiency_metrics, - y=efficiency_values, - marker_color=["blue", "green", "orange"], - showlegend=False + x=efficiency_metrics, y=efficiency_values, marker_color=["blue", "green", "orange"], showlegend=False ), row=1, col=1, @@ -522,12 +498,7 @@ def performance_metrics(file_path: Path) -> go.Figure: memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate fig.add_trace( - go.Bar( - x=memory_types, - y=memory_values, - marker_color=["purple", "red"], - showlegend=False - ), + go.Bar(x=memory_types, y=memory_values, marker_color=["purple", "red"], showlegend=False), row=1, col=2, ) @@ -569,9 +540,11 @@ def performance_metrics(file_path: Path) -> go.Figure: fig.update_layout( title={ - "text": ("ONNX Performance Metrics
" - f"" - f"Complexity Score: {complexity_score:.0f}/100"), + "text": ( + "ONNX Performance Metrics
" + f"" + f"Complexity Score: {complexity_score:.0f}/100" + ), "x": 0.5, "xanchor": "center", "font": {"size": 22}, @@ -585,7 +558,5 @@ def performance_metrics(file_path: Path) -> go.Figure: except Exception as e: return create_styled_error_figure( - "Performance Analysis Error", - "Could not analyze ONNX model performance.", - f"Error: {str(e)}" + "Performance Analysis Error", "Could not analyze ONNX model performance.", f"Error: {str(e)}" ) diff --git a/src/ria_toolkit_oss/viz/pytorch_state_dict.py b/src/ria_toolkit_oss/viz/pytorch_state_dict.py index 578ebd0..6c625bc 100644 --- a/src/ria_toolkit_oss/viz/pytorch_state_dict.py +++ b/src/ria_toolkit_oss/viz/pytorch_state_dict.py @@ -56,29 +56,25 @@ def model_summary_plot(state_dict: dict) -> Figure: return create_styled_error_figure( "Empty State Dict", "No parameters found in state dict", - "Ensure the model state dictionary contains weight parameters" + "Ensure the model state dictionary contains weight parameters", ) # Count parameters by layer type layer_info = [] for key, tensor in state_dict.items(): - if 'weight' in key: + if "weight" in key: try: - layer_name = key.replace('.weight', '') + layer_name = key.replace(".weight", "") param_count = ( - tensor.numel() if hasattr(tensor, 'numel') - else len(tensor.flatten()) if hasattr(tensor, 'flatten') - else 0 + tensor.numel() + if hasattr(tensor, "numel") + else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0 ) shape = ( - list(tensor.shape) if hasattr(tensor, 'shape') - else [len(tensor)] if hasattr(tensor, '__len__') - else [] + list(tensor.shape) + if hasattr(tensor, "shape") + else [len(tensor)] if hasattr(tensor, "__len__") else [] ) - layer_info.append({ - 'layer': layer_name, - 'parameters': param_count, - 'shape': shape - }) + layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape}) except Exception as e: print(f"Warning: Could not process layer {key}: {e}") continue @@ -86,22 +82,24 @@ def model_summary_plot(state_dict: dict) -> Figure: return create_styled_error_figure( "No Weight Layers Found", "No weight layers found in state dict", - "Ensure the state dictionary contains layers with '.weight' parameters" + "Ensure the state dictionary contains layers with '.weight' parameters", ) # Create bar chart of parameter counts - fig = go.Figure(data=[ - go.Bar( - x=[info['layer'] for info in layer_info], - y=[info['parameters'] for info in layer_info], - text=[f"Shape: {info['shape']}" for info in layer_info], - textposition='auto', - ) - ]) + fig = go.Figure( + data=[ + go.Bar( + x=[info["layer"] for info in layer_info], + y=[info["parameters"] for info in layer_info], + text=[f"Shape: {info['shape']}" for info in layer_info], + textposition="auto", + ) + ] + ) fig.update_layout( title="Model Layer Parameter Counts", xaxis_title="Layer", yaxis_title="Number of Parameters", - template="plotly_dark" + template="plotly_dark", ) return fig @@ -110,48 +108,36 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure: """Visualize weights for a specific layer.""" if not state_dict: return create_styled_error_figure( - "Empty State Dict", - "No data in state dict", - "Ensure the model state dictionary contains data" + "Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data" ) if layer_name is None: # Get first weight tensor - weight_keys = [k for k in state_dict.keys() if 'weight' in k] + weight_keys = [k for k in state_dict.keys() if "weight" in k] if not weight_keys: return create_styled_error_figure( "No Weight Tensors Found", "No weight tensors found in state dict", - "Ensure the state dictionary contains layers with '.weight' parameters" + "Ensure the state dictionary contains layers with '.weight' parameters", ) layer_name = weight_keys[0] try: weights = state_dict[layer_name] # Convert to numpy if it's a torch tensor - if hasattr(weights, 'numpy'): - weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy() - elif hasattr(weights, 'cpu'): + if hasattr(weights, "numpy"): + weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy() + elif hasattr(weights, "cpu"): weights_np = weights.cpu().detach().numpy() else: weights_np = np.array(weights) # For 2D weights, create heatmap if len(weights_np.shape) == 2: - fig = go.Figure(data=go.Heatmap( - z=weights_np, - colorscale='RdBu', - zmid=0 - )) - fig.update_layout( - title=f"Weights Heatmap: {layer_name}", - template="plotly_dark" - ) + fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0)) + fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark") else: # For other shapes, flatten and show histogram flat_weights = weights_np.flatten() fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)]) - fig.update_layout( - title=f"Weight Distribution: {layer_name}", - template="plotly_dark" - ) + fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark") return fig @@ -159,7 +145,7 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure: return create_styled_error_figure( "Layer Processing Error", f"Error processing layer {layer_name}: {str(e)}", - "Check that the layer name exists and contains valid tensor data" + "Check that the layer name exists and contains valid tensor data", ) @@ -167,21 +153,19 @@ def weight_distribution_plot(state_dict: dict) -> Figure: """Show distribution of weights across all layers.""" if not state_dict: return create_styled_error_figure( - "Empty State Dict", - "No data in state dict", - "Ensure the model state dictionary contains data" + "Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data" ) all_weights = [] layer_names = [] for key, tensor in state_dict.items(): - if 'weight' in key: + if "weight" in key: try: # Convert to numpy if it's a torch tensor - if hasattr(tensor, 'numpy'): - weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy() - elif hasattr(tensor, 'cpu'): + if hasattr(tensor, "numpy"): + weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy() + elif hasattr(tensor, "cpu"): weights_np = tensor.cpu().detach().numpy() else: weights_np = np.array(tensor) @@ -196,21 +180,15 @@ def weight_distribution_plot(state_dict: dict) -> Figure: return create_styled_error_figure( "No Weight Data Found", "No weight data found in state dict", - "Ensure the state dictionary contains layers with '.weight' parameters" + "Ensure the state dictionary contains layers with '.weight' parameters", ) - fig = go.Figure(data=[ - go.Histogram( - x=all_weights, - nbinsx=100, - name="All Weights" - ) - ]) + fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")]) fig.update_layout( title="Overall Weight Distribution", xaxis_title="Weight Value", yaxis_title="Frequency", - template="plotly_dark" + template="plotly_dark", ) return fig diff --git a/src/ria_toolkit_oss/viz/radio_dataset.py b/src/ria_toolkit_oss/viz/radio_dataset.py index cae4084..a96b4d2 100644 --- a/src/ria_toolkit_oss/viz/radio_dataset.py +++ b/src/ria_toolkit_oss/viz/radio_dataset.py @@ -133,8 +133,10 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure: return create_styled_error_figure( "No Class Labels Found", "This dataset contains numerical data without categorical labels.", - ("Try using the Dataset Overview widget for data analysis, " - "or check if your dataset has hidden categorical columns."), + ( + "Try using the Dataset Overview widget for data analysis, " + "or check if your dataset has hidden categorical columns." + ), ) # Count examples per class (limit to top 20 for performance) @@ -355,8 +357,16 @@ def _compute_spectrogram(sample_data, nperseg: int, hop_length: int, n_frames: i return Sxx -def _create_spectrogram_figure(Sxx, n_frames: int, hop_length: int, n_samples: int, freq_bins: int, - sample_idx: int, class_key: str, sample_metadata) -> Figure: +def _create_spectrogram_figure( + Sxx, + n_frames: int, + hop_length: int, + n_samples: int, + freq_bins: int, + sample_idx: int, + class_key: str, + sample_metadata, +) -> Figure: """Create the plotly figure for the spectrogram.""" # Convert to dB Sxx_db = 10 * np.log10(Sxx + 1e-10) @@ -410,8 +420,9 @@ def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx: Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins) # Create and return the figure - return _create_spectrogram_figure(Sxx, n_frames, hop_length, n_samples, freq_bins, - sample_idx, class_key, sample_metadata) + return _create_spectrogram_figure( + Sxx, n_frames, hop_length, n_samples, freq_bins, sample_idx, class_key, sample_metadata + ) except Exception as e: return create_styled_error_figure( -- 2.34.1