|
import torch |
|
import numpy as np |
|
import networkx as nx |
|
from scipy.sparse.linalg import eigsh |
|
from sklearn.cluster import SpectralClustering |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
|
|
class GraphSequencer: |
|
""" |
|
Production-ready graph ordering strategies |
|
Device-safe implementation with performance optimizations |
|
""" |
|
|
|
@staticmethod |
|
def bfs_ordering(edge_index, num_nodes, start_node=None): |
|
"""Breadth-first search ordering - optimized version""" |
|
device = edge_index.device |
|
|
|
if num_nodes <= 1: |
|
return torch.arange(num_nodes, device=device) |
|
|
|
|
|
adj_list = [[] for _ in range(num_nodes)] |
|
edge_list = edge_index.t().cpu().numpy() |
|
|
|
for src, dst in edge_list: |
|
if src < num_nodes and dst < num_nodes: |
|
adj_list[src].append(dst) |
|
adj_list[dst].append(src) |
|
|
|
|
|
adj_list = [sorted(list(set(neighbors))) for neighbors in adj_list] |
|
|
|
|
|
if start_node is None: |
|
degrees = [len(neighbors) for neighbors in adj_list] |
|
start_node = np.argmax(degrees) if degrees else 0 |
|
|
|
|
|
visited = set() |
|
order = [] |
|
queue = [start_node] |
|
|
|
while queue: |
|
node = queue.pop(0) |
|
if node in visited or node >= num_nodes: |
|
continue |
|
|
|
visited.add(node) |
|
order.append(node) |
|
|
|
|
|
neighbors = adj_list[node] |
|
neighbors.sort(key=lambda n: (len(adj_list[n]), n), reverse=True) |
|
|
|
for neighbor in neighbors: |
|
if neighbor not in visited: |
|
queue.append(neighbor) |
|
|
|
|
|
for node in range(num_nodes): |
|
if node not in visited: |
|
order.append(node) |
|
|
|
return torch.tensor(order, dtype=torch.long, device=device) |
|
|
|
@staticmethod |
|
def spectral_ordering(edge_index, num_nodes): |
|
"""Spectral ordering using graph Laplacian eigenvector - robust version""" |
|
device = edge_index.device |
|
|
|
if num_nodes <= 2: |
|
return torch.arange(num_nodes, device=device) |
|
|
|
try: |
|
|
|
edge_index_cpu = edge_index.cpu().numpy() |
|
|
|
|
|
A = np.zeros((num_nodes, num_nodes)) |
|
valid_edges = (edge_index_cpu[0] < num_nodes) & (edge_index_cpu[1] < num_nodes) |
|
valid_edge_index = edge_index_cpu[:, valid_edges] |
|
|
|
A[valid_edge_index[0], valid_edge_index[1]] = 1 |
|
A[valid_edge_index[1], valid_edge_index[0]] = 1 |
|
|
|
|
|
degrees = np.sum(A, axis=1) |
|
|
|
|
|
if np.any(degrees == 0): |
|
|
|
isolated = degrees == 0 |
|
A[isolated, isolated] = 1 |
|
degrees = np.sum(A, axis=1) |
|
|
|
D = np.diag(degrees) |
|
|
|
|
|
degrees_sqrt_inv = np.where(degrees > 0, 1.0 / np.sqrt(degrees), 0) |
|
D_sqrt_inv = np.diag(degrees_sqrt_inv) |
|
L = D_sqrt_inv @ (D - A) @ D_sqrt_inv |
|
|
|
|
|
k = min(10, num_nodes - 1) |
|
try: |
|
eigenvals, eigenvecs = eigsh(L, k=k, which='SM', sigma=0.0) |
|
|
|
|
|
if eigenvecs.shape[1] > 1: |
|
fiedler_vector = eigenvecs[:, 1] |
|
else: |
|
fiedler_vector = eigenvecs[:, 0] |
|
|
|
|
|
order = np.argsort(fiedler_vector) |
|
|
|
except Exception: |
|
|
|
order = np.argsort(-degrees) |
|
|
|
return torch.tensor(order, dtype=torch.long, device=device) |
|
|
|
except Exception as e: |
|
print(f"Spectral ordering failed: {e}, falling back to degree ordering") |
|
return GraphSequencer.degree_ordering(edge_index, num_nodes) |
|
|
|
@staticmethod |
|
def degree_ordering(edge_index, num_nodes): |
|
"""Order nodes by degree (high to low) - optimized version""" |
|
device = edge_index.device |
|
|
|
|
|
degrees = torch.zeros(num_nodes, dtype=torch.long, device=device) |
|
|
|
if edge_index.shape[1] > 0: |
|
|
|
valid_mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes) |
|
valid_edges = edge_index[:, valid_mask] |
|
|
|
if valid_edges.shape[1] > 0: |
|
degrees.index_add_(0, valid_edges[0], torch.ones(valid_edges.shape[1], dtype=torch.long, device=device)) |
|
degrees.index_add_(0, valid_edges[1], torch.ones(valid_edges.shape[1], dtype=torch.long, device=device)) |
|
|
|
|
|
node_indices = torch.arange(num_nodes, device=device) |
|
_, order = torch.sort(-degrees * num_nodes - node_indices) |
|
|
|
return order |
|
|
|
@staticmethod |
|
def community_ordering(edge_index, num_nodes, n_clusters=None): |
|
"""Community-aware ordering - robust version""" |
|
device = edge_index.device |
|
|
|
if num_nodes <= 3: |
|
return GraphSequencer.degree_ordering(edge_index, num_nodes) |
|
|
|
try: |
|
if n_clusters is None: |
|
n_clusters = max(2, min(10, int(np.sqrt(num_nodes)))) |
|
|
|
n_clusters = min(n_clusters, num_nodes) |
|
|
|
|
|
edge_index_cpu = edge_index.cpu().numpy() |
|
A = np.zeros((num_nodes, num_nodes)) |
|
|
|
valid_edges = (edge_index_cpu[0] < num_nodes) & (edge_index_cpu[1] < num_nodes) |
|
valid_edge_index = edge_index_cpu[:, valid_edges] |
|
|
|
if valid_edge_index.shape[1] > 0: |
|
A[valid_edge_index[0], valid_edge_index[1]] = 1 |
|
A[valid_edge_index[1], valid_edge_index[0]] = 1 |
|
|
|
|
|
A += np.eye(num_nodes) * 0.01 |
|
|
|
|
|
clustering = SpectralClustering( |
|
n_clusters=n_clusters, |
|
affinity='precomputed', |
|
random_state=42, |
|
assign_labels='discretize' |
|
) |
|
|
|
labels = clustering.fit_predict(A) |
|
|
|
|
|
degrees = np.sum(A, axis=1) |
|
|
|
order = [] |
|
for cluster in range(n_clusters): |
|
cluster_nodes = np.where(labels == cluster)[0] |
|
if len(cluster_nodes) > 0: |
|
cluster_degrees = degrees[cluster_nodes] |
|
cluster_order = cluster_nodes[np.argsort(-cluster_degrees)] |
|
order.extend(cluster_order) |
|
|
|
|
|
for i in range(num_nodes): |
|
if i not in order: |
|
order.append(i) |
|
|
|
return torch.tensor(order, dtype=torch.long, device=device) |
|
|
|
except Exception as e: |
|
print(f"Community ordering failed: {e}, falling back to BFS ordering") |
|
return GraphSequencer.bfs_ordering(edge_index, num_nodes) |
|
|
|
class PositionalEncoder: |
|
"""Graph-aware positional encoding - optimized version""" |
|
|
|
@staticmethod |
|
def encode_positions(x, edge_index, order, max_dist=10): |
|
""" |
|
Create positional encodings that preserve graph structure |
|
Optimized for training stability |
|
""" |
|
num_nodes = x.size(0) |
|
device = x.device |
|
|
|
|
|
seq_pos = torch.zeros(num_nodes, device=device) |
|
seq_pos[order] = torch.arange(num_nodes, device=device, dtype=torch.float) |
|
seq_pos = seq_pos / max(num_nodes, 1) |
|
|
|
|
|
distances = torch.zeros((num_nodes, max_dist), device=device) |
|
|
|
if edge_index.shape[1] > 0: |
|
|
|
adj = torch.zeros(num_nodes, num_nodes, device=device, dtype=torch.bool) |
|
|
|
|
|
valid_mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes) |
|
if valid_mask.any(): |
|
valid_edges = edge_index[:, valid_mask] |
|
adj[valid_edges[0], valid_edges[1]] = True |
|
adj[valid_edges[1], valid_edges[0]] = True |
|
|
|
|
|
adj2 = torch.matmul(adj.float(), adj.float()) > 0 |
|
|
|
|
|
for i, node in enumerate(order): |
|
node_idx = node.item() if isinstance(node, torch.Tensor) else node |
|
|
|
if node_idx < num_nodes: |
|
|
|
neighbors_1hop = torch.where(adj[node_idx])[0] |
|
neighbors_2hop = torch.where(adj2[node_idx] & ~adj[node_idx])[0] |
|
|
|
|
|
start_idx = max(0, i - max_dist) |
|
for j in range(start_idx, i): |
|
if j - start_idx < max_dist: |
|
prev_node = order[j] |
|
prev_idx = prev_node.item() if isinstance(prev_node, torch.Tensor) else prev_node |
|
|
|
if prev_idx < num_nodes: |
|
|
|
if prev_idx in neighbors_1hop: |
|
distances[node_idx, j - start_idx] = 0.9 |
|
elif prev_idx in neighbors_2hop: |
|
distances[node_idx, j - start_idx] = 0.6 |
|
else: |
|
distances[node_idx, j - start_idx] = 0.3 |
|
else: |
|
|
|
for i in range(num_nodes): |
|
for j in range(max_dist): |
|
distances[i, j] = (max_dist - j) / max_dist |
|
|
|
return seq_pos.unsqueeze(1), distances |