format fix 3?
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 13s
Test with tox / Test with tox (3.11) (pull_request) Successful in 32s
Test with tox / Test with tox (3.12) (pull_request) Successful in 31s
Test with tox / Test with tox (3.10) (pull_request) Successful in 42s
Build Project / Build Project (3.10) (pull_request) Successful in 50s
Build Project / Build Project (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.12) (pull_request) Successful in 48s
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 13s
Test with tox / Test with tox (3.11) (pull_request) Successful in 32s
Test with tox / Test with tox (3.12) (pull_request) Successful in 31s
Test with tox / Test with tox (3.10) (pull_request) Successful in 42s
Build Project / Build Project (3.10) (pull_request) Successful in 50s
Build Project / Build Project (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.12) (pull_request) Successful in 48s
This commit is contained in:
parent
4872eea116
commit
a0b46a35e2
|
|
@ -75,9 +75,7 @@ def graph_structure(file_path: Path) -> go.Figure:
|
||||||
"""
|
"""
|
||||||
if not ONNX_AVAILABLE:
|
if not ONNX_AVAILABLE:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"ONNX Not Available",
|
"ONNX Not Available", "ONNX library is required for model analysis.", "Install with: pip install onnx"
|
||||||
"ONNX library is required for model analysis.",
|
|
||||||
"Install with: pip install onnx"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -88,9 +86,7 @@ def graph_structure(file_path: Path) -> go.Figure:
|
||||||
|
|
||||||
if len(nodes) == 0:
|
if len(nodes) == 0:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Empty Model",
|
"Empty Model", "This ONNX model contains no operators.", "Please check if the model file is valid."
|
||||||
"This ONNX model contains no operators.",
|
|
||||||
"Please check if the model file is valid."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create network diagram data
|
# Create network diagram data
|
||||||
|
|
@ -153,8 +149,10 @@ def graph_structure(file_path: Path) -> go.Figure:
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title={
|
title={
|
||||||
"text": ("ONNX Graph Structure<br>"
|
"text": (
|
||||||
f"<span style='font-size:14px; color:#a0a0a0;'>{len(nodes)} Operators</span>"),
|
"ONNX Graph Structure<br>"
|
||||||
|
f"<span style='font-size:14px; color:#a0a0a0;'>{len(nodes)} Operators</span>"
|
||||||
|
),
|
||||||
"x": 0.5,
|
"x": 0.5,
|
||||||
"xanchor": "center",
|
"xanchor": "center",
|
||||||
"font": {"size": 22},
|
"font": {"size": 22},
|
||||||
|
|
@ -173,9 +171,7 @@ def graph_structure(file_path: Path) -> go.Figure:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Graph Analysis Error",
|
"Graph Analysis Error", "Could not analyze ONNX model structure.", f"Error: {str(e)}"
|
||||||
"Could not analyze ONNX model structure.",
|
|
||||||
f"Error: {str(e)}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -186,9 +182,7 @@ def operator_analysis(file_path: Path) -> go.Figure:
|
||||||
"""
|
"""
|
||||||
if not ONNX_AVAILABLE:
|
if not ONNX_AVAILABLE:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"ONNX Not Available",
|
"ONNX Not Available", "ONNX library is required for operator analysis.", "Install with: pip install onnx"
|
||||||
"ONNX library is required for operator analysis.",
|
|
||||||
"Install with: pip install onnx"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -248,8 +242,10 @@ def operator_analysis(file_path: Path) -> go.Figure:
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title={
|
title={
|
||||||
"text": ("ONNX Operator Analysis<br>"
|
"text": (
|
||||||
f"<span style='font-size:14px; color:#a0a0a0;'>{len(op_counts)} Unique Types</span>"),
|
"ONNX Operator Analysis<br>"
|
||||||
|
f"<span style='font-size:14px; color:#a0a0a0;'>{len(op_counts)} Unique Types</span>"
|
||||||
|
),
|
||||||
"x": 0.5,
|
"x": 0.5,
|
||||||
"xanchor": "center",
|
"xanchor": "center",
|
||||||
"font": {"size": 22},
|
"font": {"size": 22},
|
||||||
|
|
@ -262,9 +258,7 @@ def operator_analysis(file_path: Path) -> go.Figure:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Operator Analysis Error",
|
"Operator Analysis Error", "Could not analyze ONNX operators.", f"Error: {str(e)}"
|
||||||
"Could not analyze ONNX operators.",
|
|
||||||
f"Error: {str(e)}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -275,9 +269,7 @@ def model_metadata(file_path: Path) -> go.Figure:
|
||||||
"""
|
"""
|
||||||
if not ONNX_AVAILABLE:
|
if not ONNX_AVAILABLE:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"ONNX Not Available",
|
"ONNX Not Available", "ONNX library is required for metadata analysis.", "Install with: pip install onnx"
|
||||||
"ONNX library is required for metadata analysis.",
|
|
||||||
"Install with: pip install onnx"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -336,12 +328,7 @@ def model_metadata(file_path: Path) -> go.Figure:
|
||||||
arch_values = [total_nodes, total_inputs, total_outputs, total_initializers]
|
arch_values = [total_nodes, total_inputs, total_outputs, total_initializers]
|
||||||
|
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Bar(
|
go.Bar(x=arch_data, y=arch_values, marker_color=["blue", "green", "orange", "red"], showlegend=False),
|
||||||
x=arch_data,
|
|
||||||
y=arch_values,
|
|
||||||
marker_color=["blue", "green", "orange", "red"],
|
|
||||||
showlegend=False
|
|
||||||
),
|
|
||||||
row=1,
|
row=1,
|
||||||
col=2,
|
col=2,
|
||||||
)
|
)
|
||||||
|
|
@ -402,16 +389,8 @@ def model_metadata(file_path: Path) -> go.Figure:
|
||||||
if io_data:
|
if io_data:
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Table(
|
go.Table(
|
||||||
header=dict(
|
header=dict(values=["Type", "Name", "Shape", "Data Type"], fill_color="lightblue", align="left"),
|
||||||
values=["Type", "Name", "Shape", "Data Type"],
|
cells=dict(values=list(zip(*io_data)), fill_color="white", align="left"),
|
||||||
fill_color="lightblue",
|
|
||||||
align="left"
|
|
||||||
),
|
|
||||||
cells=dict(
|
|
||||||
values=list(zip(*io_data)),
|
|
||||||
fill_color="white",
|
|
||||||
align="left"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
row=2,
|
row=2,
|
||||||
col=1,
|
col=1,
|
||||||
|
|
@ -432,8 +411,10 @@ def model_metadata(file_path: Path) -> go.Figure:
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title={
|
title={
|
||||||
"text": ("ONNX Model Metadata<br>"
|
"text": (
|
||||||
f"<span style='font-size:14px; color:#a0a0a0;'>{total_params/1e6:.2f}M Parameters</span>"),
|
"ONNX Model Metadata<br>"
|
||||||
|
f"<span style='font-size:14px; color:#a0a0a0;'>{total_params/1e6:.2f}M Parameters</span>"
|
||||||
|
),
|
||||||
"x": 0.5,
|
"x": 0.5,
|
||||||
"xanchor": "center",
|
"xanchor": "center",
|
||||||
"font": {"size": 22},
|
"font": {"size": 22},
|
||||||
|
|
@ -447,9 +428,7 @@ def model_metadata(file_path: Path) -> go.Figure:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Metadata Analysis Error",
|
"Metadata Analysis Error", "Could not extract ONNX model metadata.", f"Error: {str(e)}"
|
||||||
"Could not extract ONNX model metadata.",
|
|
||||||
f"Error: {str(e)}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -508,10 +487,7 @@ def performance_metrics(file_path: Path) -> go.Figure:
|
||||||
|
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Bar(
|
go.Bar(
|
||||||
x=efficiency_metrics,
|
x=efficiency_metrics, y=efficiency_values, marker_color=["blue", "green", "orange"], showlegend=False
|
||||||
y=efficiency_values,
|
|
||||||
marker_color=["blue", "green", "orange"],
|
|
||||||
showlegend=False
|
|
||||||
),
|
),
|
||||||
row=1,
|
row=1,
|
||||||
col=1,
|
col=1,
|
||||||
|
|
@ -522,12 +498,7 @@ def performance_metrics(file_path: Path) -> go.Figure:
|
||||||
memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate
|
memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate
|
||||||
|
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Bar(
|
go.Bar(x=memory_types, y=memory_values, marker_color=["purple", "red"], showlegend=False),
|
||||||
x=memory_types,
|
|
||||||
y=memory_values,
|
|
||||||
marker_color=["purple", "red"],
|
|
||||||
showlegend=False
|
|
||||||
),
|
|
||||||
row=1,
|
row=1,
|
||||||
col=2,
|
col=2,
|
||||||
)
|
)
|
||||||
|
|
@ -569,9 +540,11 @@ def performance_metrics(file_path: Path) -> go.Figure:
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title={
|
title={
|
||||||
"text": ("ONNX Performance Metrics<br>"
|
"text": (
|
||||||
|
"ONNX Performance Metrics<br>"
|
||||||
f"<span style='font-size:14px; color:#a0a0a0;'>"
|
f"<span style='font-size:14px; color:#a0a0a0;'>"
|
||||||
f"Complexity Score: {complexity_score:.0f}/100</span>"),
|
f"Complexity Score: {complexity_score:.0f}/100</span>"
|
||||||
|
),
|
||||||
"x": 0.5,
|
"x": 0.5,
|
||||||
"xanchor": "center",
|
"xanchor": "center",
|
||||||
"font": {"size": 22},
|
"font": {"size": 22},
|
||||||
|
|
@ -585,7 +558,5 @@ def performance_metrics(file_path: Path) -> go.Figure:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Performance Analysis Error",
|
"Performance Analysis Error", "Could not analyze ONNX model performance.", f"Error: {str(e)}"
|
||||||
"Could not analyze ONNX model performance.",
|
|
||||||
f"Error: {str(e)}"
|
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -56,29 +56,25 @@ def model_summary_plot(state_dict: dict) -> Figure:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Empty State Dict",
|
"Empty State Dict",
|
||||||
"No parameters found in state dict",
|
"No parameters found in state dict",
|
||||||
"Ensure the model state dictionary contains weight parameters"
|
"Ensure the model state dictionary contains weight parameters",
|
||||||
)
|
)
|
||||||
# Count parameters by layer type
|
# Count parameters by layer type
|
||||||
layer_info = []
|
layer_info = []
|
||||||
for key, tensor in state_dict.items():
|
for key, tensor in state_dict.items():
|
||||||
if 'weight' in key:
|
if "weight" in key:
|
||||||
try:
|
try:
|
||||||
layer_name = key.replace('.weight', '')
|
layer_name = key.replace(".weight", "")
|
||||||
param_count = (
|
param_count = (
|
||||||
tensor.numel() if hasattr(tensor, 'numel')
|
tensor.numel()
|
||||||
else len(tensor.flatten()) if hasattr(tensor, 'flatten')
|
if hasattr(tensor, "numel")
|
||||||
else 0
|
else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0
|
||||||
)
|
)
|
||||||
shape = (
|
shape = (
|
||||||
list(tensor.shape) if hasattr(tensor, 'shape')
|
list(tensor.shape)
|
||||||
else [len(tensor)] if hasattr(tensor, '__len__')
|
if hasattr(tensor, "shape")
|
||||||
else []
|
else [len(tensor)] if hasattr(tensor, "__len__") else []
|
||||||
)
|
)
|
||||||
layer_info.append({
|
layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape})
|
||||||
'layer': layer_name,
|
|
||||||
'parameters': param_count,
|
|
||||||
'shape': shape
|
|
||||||
})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Could not process layer {key}: {e}")
|
print(f"Warning: Could not process layer {key}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
@ -86,22 +82,24 @@ def model_summary_plot(state_dict: dict) -> Figure:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"No Weight Layers Found",
|
"No Weight Layers Found",
|
||||||
"No weight layers found in state dict",
|
"No weight layers found in state dict",
|
||||||
"Ensure the state dictionary contains layers with '.weight' parameters"
|
"Ensure the state dictionary contains layers with '.weight' parameters",
|
||||||
)
|
)
|
||||||
# Create bar chart of parameter counts
|
# Create bar chart of parameter counts
|
||||||
fig = go.Figure(data=[
|
fig = go.Figure(
|
||||||
|
data=[
|
||||||
go.Bar(
|
go.Bar(
|
||||||
x=[info['layer'] for info in layer_info],
|
x=[info["layer"] for info in layer_info],
|
||||||
y=[info['parameters'] for info in layer_info],
|
y=[info["parameters"] for info in layer_info],
|
||||||
text=[f"Shape: {info['shape']}" for info in layer_info],
|
text=[f"Shape: {info['shape']}" for info in layer_info],
|
||||||
textposition='auto',
|
textposition="auto",
|
||||||
|
)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
])
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title="Model Layer Parameter Counts",
|
title="Model Layer Parameter Counts",
|
||||||
xaxis_title="Layer",
|
xaxis_title="Layer",
|
||||||
yaxis_title="Number of Parameters",
|
yaxis_title="Number of Parameters",
|
||||||
template="plotly_dark"
|
template="plotly_dark",
|
||||||
)
|
)
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
@ -110,48 +108,36 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
|
||||||
"""Visualize weights for a specific layer."""
|
"""Visualize weights for a specific layer."""
|
||||||
if not state_dict:
|
if not state_dict:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Empty State Dict",
|
"Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data"
|
||||||
"No data in state dict",
|
|
||||||
"Ensure the model state dictionary contains data"
|
|
||||||
)
|
)
|
||||||
if layer_name is None:
|
if layer_name is None:
|
||||||
# Get first weight tensor
|
# Get first weight tensor
|
||||||
weight_keys = [k for k in state_dict.keys() if 'weight' in k]
|
weight_keys = [k for k in state_dict.keys() if "weight" in k]
|
||||||
if not weight_keys:
|
if not weight_keys:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"No Weight Tensors Found",
|
"No Weight Tensors Found",
|
||||||
"No weight tensors found in state dict",
|
"No weight tensors found in state dict",
|
||||||
"Ensure the state dictionary contains layers with '.weight' parameters"
|
"Ensure the state dictionary contains layers with '.weight' parameters",
|
||||||
)
|
)
|
||||||
layer_name = weight_keys[0]
|
layer_name = weight_keys[0]
|
||||||
try:
|
try:
|
||||||
weights = state_dict[layer_name]
|
weights = state_dict[layer_name]
|
||||||
# Convert to numpy if it's a torch tensor
|
# Convert to numpy if it's a torch tensor
|
||||||
if hasattr(weights, 'numpy'):
|
if hasattr(weights, "numpy"):
|
||||||
weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy()
|
weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy()
|
||||||
elif hasattr(weights, 'cpu'):
|
elif hasattr(weights, "cpu"):
|
||||||
weights_np = weights.cpu().detach().numpy()
|
weights_np = weights.cpu().detach().numpy()
|
||||||
else:
|
else:
|
||||||
weights_np = np.array(weights)
|
weights_np = np.array(weights)
|
||||||
# For 2D weights, create heatmap
|
# For 2D weights, create heatmap
|
||||||
if len(weights_np.shape) == 2:
|
if len(weights_np.shape) == 2:
|
||||||
fig = go.Figure(data=go.Heatmap(
|
fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0))
|
||||||
z=weights_np,
|
fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark")
|
||||||
colorscale='RdBu',
|
|
||||||
zmid=0
|
|
||||||
))
|
|
||||||
fig.update_layout(
|
|
||||||
title=f"Weights Heatmap: {layer_name}",
|
|
||||||
template="plotly_dark"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# For other shapes, flatten and show histogram
|
# For other shapes, flatten and show histogram
|
||||||
flat_weights = weights_np.flatten()
|
flat_weights = weights_np.flatten()
|
||||||
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
|
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
|
||||||
fig.update_layout(
|
fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark")
|
||||||
title=f"Weight Distribution: {layer_name}",
|
|
||||||
template="plotly_dark"
|
|
||||||
)
|
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
@ -159,7 +145,7 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Layer Processing Error",
|
"Layer Processing Error",
|
||||||
f"Error processing layer {layer_name}: {str(e)}",
|
f"Error processing layer {layer_name}: {str(e)}",
|
||||||
"Check that the layer name exists and contains valid tensor data"
|
"Check that the layer name exists and contains valid tensor data",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -167,21 +153,19 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
|
||||||
"""Show distribution of weights across all layers."""
|
"""Show distribution of weights across all layers."""
|
||||||
if not state_dict:
|
if not state_dict:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Empty State Dict",
|
"Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data"
|
||||||
"No data in state dict",
|
|
||||||
"Ensure the model state dictionary contains data"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
all_weights = []
|
all_weights = []
|
||||||
layer_names = []
|
layer_names = []
|
||||||
|
|
||||||
for key, tensor in state_dict.items():
|
for key, tensor in state_dict.items():
|
||||||
if 'weight' in key:
|
if "weight" in key:
|
||||||
try:
|
try:
|
||||||
# Convert to numpy if it's a torch tensor
|
# Convert to numpy if it's a torch tensor
|
||||||
if hasattr(tensor, 'numpy'):
|
if hasattr(tensor, "numpy"):
|
||||||
weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy()
|
weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy()
|
||||||
elif hasattr(tensor, 'cpu'):
|
elif hasattr(tensor, "cpu"):
|
||||||
weights_np = tensor.cpu().detach().numpy()
|
weights_np = tensor.cpu().detach().numpy()
|
||||||
else:
|
else:
|
||||||
weights_np = np.array(tensor)
|
weights_np = np.array(tensor)
|
||||||
|
|
@ -196,21 +180,15 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"No Weight Data Found",
|
"No Weight Data Found",
|
||||||
"No weight data found in state dict",
|
"No weight data found in state dict",
|
||||||
"Ensure the state dictionary contains layers with '.weight' parameters"
|
"Ensure the state dictionary contains layers with '.weight' parameters",
|
||||||
)
|
)
|
||||||
|
|
||||||
fig = go.Figure(data=[
|
fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")])
|
||||||
go.Histogram(
|
|
||||||
x=all_weights,
|
|
||||||
nbinsx=100,
|
|
||||||
name="All Weights"
|
|
||||||
)
|
|
||||||
])
|
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title="Overall Weight Distribution",
|
title="Overall Weight Distribution",
|
||||||
xaxis_title="Weight Value",
|
xaxis_title="Weight Value",
|
||||||
yaxis_title="Frequency",
|
yaxis_title="Frequency",
|
||||||
template="plotly_dark"
|
template="plotly_dark",
|
||||||
)
|
)
|
||||||
return fig
|
return fig
|
||||||
|
|
|
||||||
|
|
@ -133,8 +133,10 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"No Class Labels Found",
|
"No Class Labels Found",
|
||||||
"This dataset contains numerical data without categorical labels.",
|
"This dataset contains numerical data without categorical labels.",
|
||||||
("Try using the Dataset Overview widget for data analysis, "
|
(
|
||||||
"or check if your dataset has hidden categorical columns."),
|
"Try using the Dataset Overview widget for data analysis, "
|
||||||
|
"or check if your dataset has hidden categorical columns."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Count examples per class (limit to top 20 for performance)
|
# Count examples per class (limit to top 20 for performance)
|
||||||
|
|
@ -355,8 +357,16 @@ def _compute_spectrogram(sample_data, nperseg: int, hop_length: int, n_frames: i
|
||||||
return Sxx
|
return Sxx
|
||||||
|
|
||||||
|
|
||||||
def _create_spectrogram_figure(Sxx, n_frames: int, hop_length: int, n_samples: int, freq_bins: int,
|
def _create_spectrogram_figure(
|
||||||
sample_idx: int, class_key: str, sample_metadata) -> Figure:
|
Sxx,
|
||||||
|
n_frames: int,
|
||||||
|
hop_length: int,
|
||||||
|
n_samples: int,
|
||||||
|
freq_bins: int,
|
||||||
|
sample_idx: int,
|
||||||
|
class_key: str,
|
||||||
|
sample_metadata,
|
||||||
|
) -> Figure:
|
||||||
"""Create the plotly figure for the spectrogram."""
|
"""Create the plotly figure for the spectrogram."""
|
||||||
# Convert to dB
|
# Convert to dB
|
||||||
Sxx_db = 10 * np.log10(Sxx + 1e-10)
|
Sxx_db = 10 * np.log10(Sxx + 1e-10)
|
||||||
|
|
@ -410,8 +420,9 @@ def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx:
|
||||||
Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins)
|
Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins)
|
||||||
|
|
||||||
# Create and return the figure
|
# Create and return the figure
|
||||||
return _create_spectrogram_figure(Sxx, n_frames, hop_length, n_samples, freq_bins,
|
return _create_spectrogram_figure(
|
||||||
sample_idx, class_key, sample_metadata)
|
Sxx, n_frames, hop_length, n_samples, freq_bins, sample_idx, class_key, sample_metadata
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user