serpent / core /graph_sequencer.py
kfoughali's picture
Update core/graph_sequencer.py
8e24e05 verified
import torch
import numpy as np
import networkx as nx
from scipy.sparse.linalg import eigsh
from sklearn.cluster import SpectralClustering
import warnings
warnings.filterwarnings('ignore')
class GraphSequencer:
"""
Production-ready graph ordering strategies
Device-safe implementation with performance optimizations
"""
@staticmethod
def bfs_ordering(edge_index, num_nodes, start_node=None):
"""Breadth-first search ordering - optimized version"""
device = edge_index.device
if num_nodes <= 1:
return torch.arange(num_nodes, device=device)
# Convert to adjacency list efficiently
adj_list = [[] for _ in range(num_nodes)]
edge_list = edge_index.t().cpu().numpy()
for src, dst in edge_list:
if src < num_nodes and dst < num_nodes:
adj_list[src].append(dst)
adj_list[dst].append(src)
# Remove duplicates and sort for determinism
adj_list = [sorted(list(set(neighbors))) for neighbors in adj_list]
# Start from highest degree node if not specified
if start_node is None:
degrees = [len(neighbors) for neighbors in adj_list]
start_node = np.argmax(degrees) if degrees else 0
# BFS traversal
visited = set()
order = []
queue = [start_node]
while queue:
node = queue.pop(0)
if node in visited or node >= num_nodes:
continue
visited.add(node)
order.append(node)
# Add neighbors by degree (deterministic)
neighbors = adj_list[node]
neighbors.sort(key=lambda n: (len(adj_list[n]), n), reverse=True)
for neighbor in neighbors:
if neighbor not in visited:
queue.append(neighbor)
# Add any disconnected nodes
for node in range(num_nodes):
if node not in visited:
order.append(node)
return torch.tensor(order, dtype=torch.long, device=device)
@staticmethod
def spectral_ordering(edge_index, num_nodes):
"""Spectral ordering using graph Laplacian eigenvector - robust version"""
device = edge_index.device
if num_nodes <= 2:
return torch.arange(num_nodes, device=device)
try:
# Move to CPU for scipy operations
edge_index_cpu = edge_index.cpu().numpy()
# Create adjacency matrix
A = np.zeros((num_nodes, num_nodes))
valid_edges = (edge_index_cpu[0] < num_nodes) & (edge_index_cpu[1] < num_nodes)
valid_edge_index = edge_index_cpu[:, valid_edges]
A[valid_edge_index[0], valid_edge_index[1]] = 1
A[valid_edge_index[1], valid_edge_index[0]] = 1 # Undirected
# Degree matrix
degrees = np.sum(A, axis=1)
# Handle disconnected components
if np.any(degrees == 0):
# Add self-loops to isolated nodes
isolated = degrees == 0
A[isolated, isolated] = 1
degrees = np.sum(A, axis=1)
D = np.diag(degrees)
# Normalized Laplacian: L = D^(-1/2) * (D - A) * D^(-1/2)
degrees_sqrt_inv = np.where(degrees > 0, 1.0 / np.sqrt(degrees), 0)
D_sqrt_inv = np.diag(degrees_sqrt_inv)
L = D_sqrt_inv @ (D - A) @ D_sqrt_inv
# Compute eigenvectors
k = min(10, num_nodes - 1)
try:
eigenvals, eigenvecs = eigsh(L, k=k, which='SM', sigma=0.0)
# Use second smallest eigenvector (Fiedler vector)
if eigenvecs.shape[1] > 1:
fiedler_vector = eigenvecs[:, 1]
else:
fiedler_vector = eigenvecs[:, 0]
# Order by Fiedler vector values
order = np.argsort(fiedler_vector)
except Exception:
# Fallback to degree ordering
order = np.argsort(-degrees)
return torch.tensor(order, dtype=torch.long, device=device)
except Exception as e:
print(f"Spectral ordering failed: {e}, falling back to degree ordering")
return GraphSequencer.degree_ordering(edge_index, num_nodes)
@staticmethod
def degree_ordering(edge_index, num_nodes):
"""Order nodes by degree (high to low) - optimized version"""
device = edge_index.device
# Count degrees efficiently
degrees = torch.zeros(num_nodes, dtype=torch.long, device=device)
if edge_index.shape[1] > 0:
# Ensure valid indices
valid_mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes)
valid_edges = edge_index[:, valid_mask]
if valid_edges.shape[1] > 0:
degrees.index_add_(0, valid_edges[0], torch.ones(valid_edges.shape[1], dtype=torch.long, device=device))
degrees.index_add_(0, valid_edges[1], torch.ones(valid_edges.shape[1], dtype=torch.long, device=device))
# Sort by degree (descending), then by node index for determinism
node_indices = torch.arange(num_nodes, device=device)
_, order = torch.sort(-degrees * num_nodes - node_indices)
return order
@staticmethod
def community_ordering(edge_index, num_nodes, n_clusters=None):
"""Community-aware ordering - robust version"""
device = edge_index.device
if num_nodes <= 3:
return GraphSequencer.degree_ordering(edge_index, num_nodes)
try:
if n_clusters is None:
n_clusters = max(2, min(10, int(np.sqrt(num_nodes))))
n_clusters = min(n_clusters, num_nodes)
# Convert to adjacency matrix on CPU
edge_index_cpu = edge_index.cpu().numpy()
A = np.zeros((num_nodes, num_nodes))
valid_edges = (edge_index_cpu[0] < num_nodes) & (edge_index_cpu[1] < num_nodes)
valid_edge_index = edge_index_cpu[:, valid_edges]
if valid_edge_index.shape[1] > 0:
A[valid_edge_index[0], valid_edge_index[1]] = 1
A[valid_edge_index[1], valid_edge_index[0]] = 1
# Add small diagonal for stability
A += np.eye(num_nodes) * 0.01
# Spectral clustering
clustering = SpectralClustering(
n_clusters=n_clusters,
affinity='precomputed',
random_state=42,
assign_labels='discretize'
)
labels = clustering.fit_predict(A)
# Order by cluster, then by degree within cluster
degrees = np.sum(A, axis=1)
order = []
for cluster in range(n_clusters):
cluster_nodes = np.where(labels == cluster)[0]
if len(cluster_nodes) > 0:
cluster_degrees = degrees[cluster_nodes]
cluster_order = cluster_nodes[np.argsort(-cluster_degrees)]
order.extend(cluster_order)
# Add any missed nodes
for i in range(num_nodes):
if i not in order:
order.append(i)
return torch.tensor(order, dtype=torch.long, device=device)
except Exception as e:
print(f"Community ordering failed: {e}, falling back to BFS ordering")
return GraphSequencer.bfs_ordering(edge_index, num_nodes)
class PositionalEncoder:
"""Graph-aware positional encoding - optimized version"""
@staticmethod
def encode_positions(x, edge_index, order, max_dist=10):
"""
Create positional encodings that preserve graph structure
Optimized for training stability
"""
num_nodes = x.size(0)
device = x.device
# Sequential positions
seq_pos = torch.zeros(num_nodes, device=device)
seq_pos[order] = torch.arange(num_nodes, device=device, dtype=torch.float)
seq_pos = seq_pos / max(num_nodes, 1)
# Enhanced distance encoding
distances = torch.zeros((num_nodes, max_dist), device=device)
if edge_index.shape[1] > 0:
# Create adjacency matrix efficiently
adj = torch.zeros(num_nodes, num_nodes, device=device, dtype=torch.bool)
# Filter valid edges
valid_mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes)
if valid_mask.any():
valid_edges = edge_index[:, valid_mask]
adj[valid_edges[0], valid_edges[1]] = True
adj[valid_edges[1], valid_edges[0]] = True # Undirected
# Compute 2-hop neighbors for richer encoding
adj2 = torch.matmul(adj.float(), adj.float()) > 0
# Fill distance features
for i, node in enumerate(order):
node_idx = node.item() if isinstance(node, torch.Tensor) else node
if node_idx < num_nodes:
# Get 1-hop and 2-hop neighbors
neighbors_1hop = torch.where(adj[node_idx])[0]
neighbors_2hop = torch.where(adj2[node_idx] & ~adj[node_idx])[0]
# Fill distance features based on order position
start_idx = max(0, i - max_dist)
for j in range(start_idx, i):
if j - start_idx < max_dist:
prev_node = order[j]
prev_idx = prev_node.item() if isinstance(prev_node, torch.Tensor) else prev_node
if prev_idx < num_nodes:
# Multi-scale distance encoding
if prev_idx in neighbors_1hop:
distances[node_idx, j - start_idx] = 0.9 # Direct neighbor
elif prev_idx in neighbors_2hop:
distances[node_idx, j - start_idx] = 0.6 # 2-hop neighbor
else:
distances[node_idx, j - start_idx] = 0.3 # Distant
else:
# No edges - use position-based encoding
for i in range(num_nodes):
for j in range(max_dist):
distances[i, j] = (max_dist - j) / max_dist
return seq_pos.unsqueeze(1), distances