serpent / utils /visualization.py
kfoughali's picture
Update utils/visualization.py
ed6db0b verified
import plotly.graph_objects as go
import plotly.express as px
import plotly.figure_factory as ff
from plotly.subplots import make_subplots
import networkx as nx
import torch
import numpy as np
import pandas as pd
import logging
logger = logging.getLogger(__name__)
class GraphVisualizer:
"""Advanced graph visualization utilities"""
@staticmethod
def create_graph_plot(data, max_nodes=500, layout_algorithm='spring', node_size_factor=1.0):
"""Create interactive graph visualization"""
try:
if not hasattr(data, 'edge_index') or not hasattr(data, 'num_nodes'):
raise ValueError("Data must have edge_index and num_nodes attributes")
num_nodes = min(data.num_nodes, max_nodes)
if num_nodes <= 0:
raise ValueError("No nodes to visualize")
# Create NetworkX graph
G = nx.Graph()
if data.edge_index.size(1) > 0:
edge_list = data.edge_index.t().cpu().numpy()
edge_list = edge_list[
(edge_list[:, 0] < num_nodes) & (edge_list[:, 1] < num_nodes)
]
if len(edge_list) > 0:
G.add_edges_from(edge_list)
G.add_nodes_from(range(num_nodes))
# Layout
pos = nx.spring_layout(G, seed=42)
# Node colors
if hasattr(data, 'y') and data.y is not None:
node_colors = data.y.cpu().numpy()[:num_nodes]
else:
node_colors = [0] * num_nodes
# Create edge traces
edge_x, edge_y = [], []
for edge in G.edges():
if edge[0] in pos and edge[1] in pos:
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
# Create node traces
node_x = [pos[node][0] for node in G.nodes()]
node_y = [pos[node][1] for node in G.nodes()]
fig = go.Figure()
# Add edges
if edge_x:
fig.add_trace(go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=0.8, color='rgba(125,125,125,0.5)'),
hoverinfo='none',
mode='lines',
showlegend=False
))
# Add nodes
fig.add_trace(go.Scatter(
x=node_x, y=node_y,
mode='markers',
marker=dict(
size=8,
color=node_colors,
colorscale='Viridis',
line=dict(width=2, color='white'),
opacity=0.8
),
text=[f"Node {i}" for i in range(len(node_x))],
hoverinfo='text',
showlegend=False
))
fig.update_layout(
title=f'Graph Visualization ({num_nodes} nodes)',
showlegend=False,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='white',
width=800,
height=600
)
return fig
except Exception as e:
logger.error(f"Graph visualization error: {e}")
return GraphVisualizer._create_error_figure(f"Visualization error: {str(e)}")
@staticmethod
def create_metrics_plot(metrics):
"""Create comprehensive metrics visualization"""
try:
metric_names = []
metric_values = []
for key, value in metrics.items():
if isinstance(value, (int, float)) and key not in ['error', 'loss']:
if not (np.isnan(value) or np.isinf(value)) and 0 <= value <= 1:
metric_names.append(key.replace('_', ' ').title())
metric_values.append(value)
if not metric_names:
return GraphVisualizer._create_error_figure("No valid metrics to display")
fig = make_subplots(
rows=1, cols=2,
subplot_titles=('Performance Metrics', 'Metric Radar Chart'),
specs=[[{"type": "bar"}, {"type": "polar"}]]
)
colors = px.colors.qualitative.Set3[:len(metric_names)]
fig.add_trace(
go.Bar(
x=metric_names,
y=metric_values,
marker_color=colors,
text=[f'{v:.3f}' for v in metric_values],
textposition='auto',
showlegend=False
),
row=1, col=1
)
fig.add_trace(
go.Scatterpolar(
r=metric_values + [metric_values[0]],
theta=metric_names + [metric_names[0]],
fill='toself',
line=dict(color='blue'),
marker=dict(size=8),
showlegend=False
),
row=1, col=2
)
fig.update_layout(
title='Model Performance Dashboard',
height=400,
showlegend=False
)
fig.update_xaxes(title_text="Metrics", tickangle=45, row=1, col=1)
fig.update_yaxes(title_text="Score", range=[0, 1], row=1, col=1)
fig.update_polars(
radialaxis=dict(range=[0, 1], showticklabels=True),
row=1, col=2
)
return fig
except Exception as e:
logger.error(f"Metrics plot error: {e}")
return GraphVisualizer._create_error_figure(f"Metrics plot error: {str(e)}")
@staticmethod
def create_training_history_plot(history):
"""Create comprehensive training history visualization"""
try:
if not isinstance(history, dict) or not history:
return GraphVisualizer._create_error_figure("No training history available")
required_keys = ['train_loss', 'train_acc']
for key in required_keys:
if key not in history or not history[key]:
return GraphVisualizer._create_error_figure(f"Missing {key} in training history")
epochs = list(range(len(history['train_loss'])))
fig = make_subplots(
rows=2, cols=2,
subplot_titles=('Loss Over Time', 'Accuracy Over Time', 'Learning Rate', 'Training Progress'),
specs=[[{"secondary_y": False}, {"secondary_y": False}],
[{"secondary_y": False}, {"secondary_y": False}]]
)
# Training loss
fig.add_trace(
go.Scatter(
x=epochs, y=history['train_loss'],
mode='lines', name='Train Loss',
line=dict(color='blue', width=2),
showlegend=False
),
row=1, col=1
)
if 'val_loss' in history and history['val_loss']:
fig.add_trace(
go.Scatter(
x=epochs, y=history['val_loss'],
mode='lines', name='Val Loss',
line=dict(color='red', width=2),
showlegend=False
),
row=1, col=1
)
# Training accuracy
fig.add_trace(
go.Scatter(
x=epochs, y=history['train_acc'],
mode='lines', name='Train Acc',
line=dict(color='green', width=2),
showlegend=False
),
row=1, col=2
)
if 'val_acc' in history and history['val_acc']:
fig.add_trace(
go.Scatter(
x=epochs, y=history['val_acc'],
mode='lines', name='Val Acc',
line=dict(color='orange', width=2),
showlegend=False
),
row=1, col=2
)
# Learning rate
if 'lr' in history and history['lr']:
fig.add_trace(
go.Scatter(
x=epochs, y=history['lr'],
mode='lines', name='Learning Rate',
line=dict(color='purple', width=2),
showlegend=False
),
row=2, col=1
)
# Training progress summary
final_metrics = {
'Final Train Acc': history['train_acc'][-1] if history['train_acc'] else 0,
'Final Train Loss': history['train_loss'][-1] if history['train_loss'] else 0,
}
if 'val_acc' in history and history['val_acc']:
final_metrics['Final Val Acc'] = history['val_acc'][-1]
final_metrics['Best Val Acc'] = max(history['val_acc'])
metric_names = list(final_metrics.keys())
metric_values = list(final_metrics.values())
fig.add_trace(
go.Bar(
x=metric_names,
y=metric_values,
marker_color=['lightblue', 'lightcoral', 'lightgreen', 'gold'],
text=[f'{v:.3f}' for v in metric_values],
textposition='auto',
showlegend=False
),
row=2, col=2
)
fig.update_layout(
title='Training History Dashboard',
height=600,
showlegend=True
)
fig.update_xaxes(title_text="Epoch", row=1, col=1)
fig.update_xaxes(title_text="Epoch", row=1, col=2)
fig.update_xaxes(title_text="Epoch", row=2, col=1)
fig.update_xaxes(title_text="Metric", tickangle=45, row=2, col=2)
fig.update_yaxes(title_text="Loss", row=1, col=1)
fig.update_yaxes(title_text="Accuracy", range=[0, 1], row=1, col=2)
fig.update_yaxes(title_text="Learning Rate", type="log", row=2, col=1)
fig.update_yaxes(title_text="Value", row=2, col=2)
return fig
except Exception as e:
logger.error(f"Training history plot error: {e}")
return GraphVisualizer._create_error_figure(f"Training history plot error: {str(e)}")
@staticmethod
def _create_error_figure(error_message):
"""Create an error figure with message"""
fig = go.Figure()
fig.add_annotation(
text=error_message,
x=0.5, y=0.5,
xref="paper", yref="paper",
showarrow=False,
font=dict(size=14, color="red"),
bgcolor="rgba(255,255,255,0.8)",
bordercolor="red",
borderwidth=1
)
fig.update_layout(
title="Visualization Error",
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='white',
width=600,
height=400
)
return fig