File size: 5,817 Bytes
aab84ef 4992374 aab84ef 4992374 aab84ef 4992374 aab84ef 4992374 aab84ef 4992374 aab84ef 4992374 aab84ef 4992374 aab84ef 4992374 aab84ef 4992374 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.transforms import Compose
import numpy as np
class GraphProcessor:
"""Advanced data preprocessing utilities"""
@staticmethod
def normalize_features(x, method='l2'):
"""Normalize node features"""
if method == 'l2':
return F.normalize(x, p=2, dim=1)
elif method == 'minmax':
x_min = x.min(dim=0, keepdim=True)[0]
x_max = x.max(dim=0, keepdim=True)[0]
return (x - x_min) / (x_max - x_min + 1e-8)
elif method == 'standard':
return (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-8)
else:
return x
@staticmethod
def add_self_loops(edge_index, num_nodes):
"""Add self loops to graph"""
self_loops = torch.arange(num_nodes, device=edge_index.device).unsqueeze(0).repeat(2, 1)
edge_index = torch.cat([edge_index, self_loops], dim=1)
return edge_index
@staticmethod
def remove_self_loops(edge_index):
"""Remove self loops from graph"""
mask = edge_index[0] != edge_index[1]
return edge_index[:, mask]
@staticmethod
def add_positional_features(data, encoding_dim=8):
"""Add positional encodings as features"""
num_nodes = data.num_nodes
# Random walk positional encoding
if data.edge_index.shape[1] > 0:
adj = torch.zeros(num_nodes, num_nodes)
adj[data.edge_index[0], data.edge_index[1]] = 1
adj = adj + adj.t() # Make symmetric
# Degree normalization
degree = adj.sum(dim=1)
degree[degree == 0] = 1 # Avoid division by zero
D_inv_sqrt = torch.diag(1.0 / torch.sqrt(degree))
# Normalized adjacency
A_norm = D_inv_sqrt @ adj @ D_inv_sqrt
# Random walk features
rw_features = []
A_power = torch.eye(num_nodes)
for k in range(encoding_dim):
A_power = A_power @ A_norm
rw_features.append(A_power.diag().unsqueeze(1))
pos_encoding = torch.cat(rw_features, dim=1)
else:
# No edges - use node indices
pos_encoding = torch.zeros(num_nodes, encoding_dim)
for i in range(min(encoding_dim, num_nodes)):
pos_encoding[i, i] = 1.0
# Concatenate with existing features
if data.x is not None:
data.x = torch.cat([data.x, pos_encoding], dim=1)
else:
data.x = pos_encoding
return data
@staticmethod
def augment_graph(data, aug_type='edge_drop', aug_ratio=0.1):
"""Graph augmentation for training"""
if aug_type == 'edge_drop':
# Randomly drop edges
num_edges = data.edge_index.shape[1]
mask = torch.rand(num_edges) > aug_ratio
data.edge_index = data.edge_index[:, mask]
elif aug_type == 'node_drop':
# Randomly drop nodes
num_nodes = data.num_nodes
keep_mask = torch.rand(num_nodes) > aug_ratio
keep_nodes = torch.where(keep_mask)[0]
# Update edge index
node_map = torch.full((num_nodes,), -1, dtype=torch.long)
node_map[keep_nodes] = torch.arange(len(keep_nodes))
# Filter edges
edge_mask = keep_mask[data.edge_index[0]] & keep_mask[data.edge_index[1]]
filtered_edges = data.edge_index[:, edge_mask]
data.edge_index = node_map[filtered_edges]
# Update features
data.x = data.x[keep_nodes]
if hasattr(data, 'y') and data.y.size(0) == num_nodes:
data.y = data.y[keep_nodes]
elif aug_type == 'feature_noise':
# Add Gaussian noise to features
if data.x is not None:
noise = torch.randn_like(data.x) * aug_ratio
data.x = data.x + noise
elif aug_type == 'feature_mask':
# Randomly mask features
if data.x is not None:
mask = torch.rand_like(data.x) > aug_ratio
data.x = data.x * mask
return data
@staticmethod
def to_device_safe(data, device):
"""Move data to device safely"""
if hasattr(data, 'to'):
return data.to(device)
elif isinstance(data, (list, tuple)):
return [GraphProcessor.to_device_safe(item, device) for item in data]
elif isinstance(data, dict):
return {k: GraphProcessor.to_device_safe(v, device) for k, v in data.items()}
else:
return data
@staticmethod
def validate_data(data):
"""Validate graph data integrity"""
errors = []
# Check basic structure
if not hasattr(data, 'edge_index'):
errors.append("Missing edge_index")
elif data.edge_index.shape[0] != 2:
errors.append("edge_index must have shape (2, num_edges)")
if hasattr(data, 'x') and data.x is not None:
if hasattr(data, 'num_nodes') and data.x.shape[0] != data.num_nodes:
errors.append("Feature matrix size mismatch")
# Check edge indices
if hasattr(data, 'edge_index') and data.edge_index.shape[1] > 0:
max_idx = data.edge_index.max().item()
if hasattr(data, 'num_nodes') and max_idx >= data.num_nodes:
errors.append("Edge indices exceed number of nodes")
return errors |