import plotly.graph_objects as go import plotly.express as px import plotly.figure_factory as ff from plotly.subplots import make_subplots import networkx as nx import torch import numpy as np import pandas as pd import logging logger = logging.getLogger(__name__) class GraphVisualizer: """Advanced graph visualization utilities""" @staticmethod def create_graph_plot(data, max_nodes=500, layout_algorithm='spring', node_size_factor=1.0): """Create interactive graph visualization""" try: if not hasattr(data, 'edge_index') or not hasattr(data, 'num_nodes'): raise ValueError("Data must have edge_index and num_nodes attributes") num_nodes = min(data.num_nodes, max_nodes) if num_nodes <= 0: raise ValueError("No nodes to visualize") # Create NetworkX graph G = nx.Graph() if data.edge_index.size(1) > 0: edge_list = data.edge_index.t().cpu().numpy() edge_list = edge_list[ (edge_list[:, 0] < num_nodes) & (edge_list[:, 1] < num_nodes) ] if len(edge_list) > 0: G.add_edges_from(edge_list) G.add_nodes_from(range(num_nodes)) # Layout pos = nx.spring_layout(G, seed=42) # Node colors if hasattr(data, 'y') and data.y is not None: node_colors = data.y.cpu().numpy()[:num_nodes] else: node_colors = [0] * num_nodes # Create edge traces edge_x, edge_y = [], [] for edge in G.edges(): if edge[0] in pos and edge[1] in pos: x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) # Create node traces node_x = [pos[node][0] for node in G.nodes()] node_y = [pos[node][1] for node in G.nodes()] fig = go.Figure() # Add edges if edge_x: fig.add_trace(go.Scatter( x=edge_x, y=edge_y, line=dict(width=0.8, color='rgba(125,125,125,0.5)'), hoverinfo='none', mode='lines', showlegend=False )) # Add nodes fig.add_trace(go.Scatter( x=node_x, y=node_y, mode='markers', marker=dict( size=8, color=node_colors, colorscale='Viridis', line=dict(width=2, color='white'), opacity=0.8 ), text=[f"Node {i}" for i in range(len(node_x))], hoverinfo='text', showlegend=False )) fig.update_layout( title=f'Graph Visualization ({num_nodes} nodes)', showlegend=False, hovermode='closest', margin=dict(b=20, l=5, r=5, t=40), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), plot_bgcolor='white', width=800, height=600 ) return fig except Exception as e: logger.error(f"Graph visualization error: {e}") return GraphVisualizer._create_error_figure(f"Visualization error: {str(e)}") @staticmethod def create_metrics_plot(metrics): """Create comprehensive metrics visualization""" try: metric_names = [] metric_values = [] for key, value in metrics.items(): if isinstance(value, (int, float)) and key not in ['error', 'loss']: if not (np.isnan(value) or np.isinf(value)) and 0 <= value <= 1: metric_names.append(key.replace('_', ' ').title()) metric_values.append(value) if not metric_names: return GraphVisualizer._create_error_figure("No valid metrics to display") fig = make_subplots( rows=1, cols=2, subplot_titles=('Performance Metrics', 'Metric Radar Chart'), specs=[[{"type": "bar"}, {"type": "polar"}]] ) colors = px.colors.qualitative.Set3[:len(metric_names)] fig.add_trace( go.Bar( x=metric_names, y=metric_values, marker_color=colors, text=[f'{v:.3f}' for v in metric_values], textposition='auto', showlegend=False ), row=1, col=1 ) fig.add_trace( go.Scatterpolar( r=metric_values + [metric_values[0]], theta=metric_names + [metric_names[0]], fill='toself', line=dict(color='blue'), marker=dict(size=8), showlegend=False ), row=1, col=2 ) fig.update_layout( title='Model Performance Dashboard', height=400, showlegend=False ) fig.update_xaxes(title_text="Metrics", tickangle=45, row=1, col=1) fig.update_yaxes(title_text="Score", range=[0, 1], row=1, col=1) fig.update_polars( radialaxis=dict(range=[0, 1], showticklabels=True), row=1, col=2 ) return fig except Exception as e: logger.error(f"Metrics plot error: {e}") return GraphVisualizer._create_error_figure(f"Metrics plot error: {str(e)}") @staticmethod def create_training_history_plot(history): """Create comprehensive training history visualization""" try: if not isinstance(history, dict) or not history: return GraphVisualizer._create_error_figure("No training history available") required_keys = ['train_loss', 'train_acc'] for key in required_keys: if key not in history or not history[key]: return GraphVisualizer._create_error_figure(f"Missing {key} in training history") epochs = list(range(len(history['train_loss']))) fig = make_subplots( rows=2, cols=2, subplot_titles=('Loss Over Time', 'Accuracy Over Time', 'Learning Rate', 'Training Progress'), specs=[[{"secondary_y": False}, {"secondary_y": False}], [{"secondary_y": False}, {"secondary_y": False}]] ) # Training loss fig.add_trace( go.Scatter( x=epochs, y=history['train_loss'], mode='lines', name='Train Loss', line=dict(color='blue', width=2), showlegend=False ), row=1, col=1 ) if 'val_loss' in history and history['val_loss']: fig.add_trace( go.Scatter( x=epochs, y=history['val_loss'], mode='lines', name='Val Loss', line=dict(color='red', width=2), showlegend=False ), row=1, col=1 ) # Training accuracy fig.add_trace( go.Scatter( x=epochs, y=history['train_acc'], mode='lines', name='Train Acc', line=dict(color='green', width=2), showlegend=False ), row=1, col=2 ) if 'val_acc' in history and history['val_acc']: fig.add_trace( go.Scatter( x=epochs, y=history['val_acc'], mode='lines', name='Val Acc', line=dict(color='orange', width=2), showlegend=False ), row=1, col=2 ) # Learning rate if 'lr' in history and history['lr']: fig.add_trace( go.Scatter( x=epochs, y=history['lr'], mode='lines', name='Learning Rate', line=dict(color='purple', width=2), showlegend=False ), row=2, col=1 ) # Training progress summary final_metrics = { 'Final Train Acc': history['train_acc'][-1] if history['train_acc'] else 0, 'Final Train Loss': history['train_loss'][-1] if history['train_loss'] else 0, } if 'val_acc' in history and history['val_acc']: final_metrics['Final Val Acc'] = history['val_acc'][-1] final_metrics['Best Val Acc'] = max(history['val_acc']) metric_names = list(final_metrics.keys()) metric_values = list(final_metrics.values()) fig.add_trace( go.Bar( x=metric_names, y=metric_values, marker_color=['lightblue', 'lightcoral', 'lightgreen', 'gold'], text=[f'{v:.3f}' for v in metric_values], textposition='auto', showlegend=False ), row=2, col=2 ) fig.update_layout( title='Training History Dashboard', height=600, showlegend=True ) fig.update_xaxes(title_text="Epoch", row=1, col=1) fig.update_xaxes(title_text="Epoch", row=1, col=2) fig.update_xaxes(title_text="Epoch", row=2, col=1) fig.update_xaxes(title_text="Metric", tickangle=45, row=2, col=2) fig.update_yaxes(title_text="Loss", row=1, col=1) fig.update_yaxes(title_text="Accuracy", range=[0, 1], row=1, col=2) fig.update_yaxes(title_text="Learning Rate", type="log", row=2, col=1) fig.update_yaxes(title_text="Value", row=2, col=2) return fig except Exception as e: logger.error(f"Training history plot error: {e}") return GraphVisualizer._create_error_figure(f"Training history plot error: {str(e)}") @staticmethod def _create_error_figure(error_message): """Create an error figure with message""" fig = go.Figure() fig.add_annotation( text=error_message, x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False, font=dict(size=14, color="red"), bgcolor="rgba(255,255,255,0.8)", bordercolor="red", borderwidth=1 ) fig.update_layout( title="Visualization Error", xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), plot_bgcolor='white', width=600, height=400 ) return fig