|
import torch |
|
import torch.nn.functional as F |
|
from torch_geometric.data import Data |
|
from torch_geometric.transforms import Compose |
|
import numpy as np |
|
|
|
class GraphProcessor: |
|
"""Advanced data preprocessing utilities""" |
|
|
|
@staticmethod |
|
def normalize_features(x, method='l2'): |
|
"""Normalize node features""" |
|
if method == 'l2': |
|
return F.normalize(x, p=2, dim=1) |
|
elif method == 'minmax': |
|
x_min = x.min(dim=0, keepdim=True)[0] |
|
x_max = x.max(dim=0, keepdim=True)[0] |
|
return (x - x_min) / (x_max - x_min + 1e-8) |
|
elif method == 'standard': |
|
return (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-8) |
|
else: |
|
return x |
|
|
|
@staticmethod |
|
def add_self_loops(edge_index, num_nodes): |
|
"""Add self loops to graph""" |
|
self_loops = torch.arange(num_nodes, device=edge_index.device).unsqueeze(0).repeat(2, 1) |
|
edge_index = torch.cat([edge_index, self_loops], dim=1) |
|
return edge_index |
|
|
|
@staticmethod |
|
def remove_self_loops(edge_index): |
|
"""Remove self loops from graph""" |
|
mask = edge_index[0] != edge_index[1] |
|
return edge_index[:, mask] |
|
|
|
@staticmethod |
|
def add_positional_features(data, encoding_dim=8): |
|
"""Add positional encodings as features""" |
|
num_nodes = data.num_nodes |
|
|
|
|
|
if data.edge_index.shape[1] > 0: |
|
adj = torch.zeros(num_nodes, num_nodes) |
|
adj[data.edge_index[0], data.edge_index[1]] = 1 |
|
adj = adj + adj.t() |
|
|
|
|
|
degree = adj.sum(dim=1) |
|
degree[degree == 0] = 1 |
|
D_inv_sqrt = torch.diag(1.0 / torch.sqrt(degree)) |
|
|
|
|
|
A_norm = D_inv_sqrt @ adj @ D_inv_sqrt |
|
|
|
|
|
rw_features = [] |
|
A_power = torch.eye(num_nodes) |
|
|
|
for k in range(encoding_dim): |
|
A_power = A_power @ A_norm |
|
rw_features.append(A_power.diag().unsqueeze(1)) |
|
|
|
pos_encoding = torch.cat(rw_features, dim=1) |
|
else: |
|
|
|
pos_encoding = torch.zeros(num_nodes, encoding_dim) |
|
for i in range(min(encoding_dim, num_nodes)): |
|
pos_encoding[i, i] = 1.0 |
|
|
|
|
|
if data.x is not None: |
|
data.x = torch.cat([data.x, pos_encoding], dim=1) |
|
else: |
|
data.x = pos_encoding |
|
|
|
return data |
|
|
|
@staticmethod |
|
def augment_graph(data, aug_type='edge_drop', aug_ratio=0.1): |
|
"""Graph augmentation for training""" |
|
if aug_type == 'edge_drop': |
|
|
|
num_edges = data.edge_index.shape[1] |
|
mask = torch.rand(num_edges) > aug_ratio |
|
data.edge_index = data.edge_index[:, mask] |
|
|
|
elif aug_type == 'node_drop': |
|
|
|
num_nodes = data.num_nodes |
|
keep_mask = torch.rand(num_nodes) > aug_ratio |
|
keep_nodes = torch.where(keep_mask)[0] |
|
|
|
|
|
node_map = torch.full((num_nodes,), -1, dtype=torch.long) |
|
node_map[keep_nodes] = torch.arange(len(keep_nodes)) |
|
|
|
|
|
edge_mask = keep_mask[data.edge_index[0]] & keep_mask[data.edge_index[1]] |
|
filtered_edges = data.edge_index[:, edge_mask] |
|
data.edge_index = node_map[filtered_edges] |
|
|
|
|
|
data.x = data.x[keep_nodes] |
|
if hasattr(data, 'y') and data.y.size(0) == num_nodes: |
|
data.y = data.y[keep_nodes] |
|
|
|
elif aug_type == 'feature_noise': |
|
|
|
if data.x is not None: |
|
noise = torch.randn_like(data.x) * aug_ratio |
|
data.x = data.x + noise |
|
|
|
elif aug_type == 'feature_mask': |
|
|
|
if data.x is not None: |
|
mask = torch.rand_like(data.x) > aug_ratio |
|
data.x = data.x * mask |
|
|
|
return data |
|
|
|
@staticmethod |
|
def to_device_safe(data, device): |
|
"""Move data to device safely""" |
|
if hasattr(data, 'to'): |
|
return data.to(device) |
|
elif isinstance(data, (list, tuple)): |
|
return [GraphProcessor.to_device_safe(item, device) for item in data] |
|
elif isinstance(data, dict): |
|
return {k: GraphProcessor.to_device_safe(v, device) for k, v in data.items()} |
|
else: |
|
return data |
|
|
|
@staticmethod |
|
def validate_data(data): |
|
"""Validate graph data integrity""" |
|
errors = [] |
|
|
|
|
|
if not hasattr(data, 'edge_index'): |
|
errors.append("Missing edge_index") |
|
elif data.edge_index.shape[0] != 2: |
|
errors.append("edge_index must have shape (2, num_edges)") |
|
|
|
if hasattr(data, 'x') and data.x is not None: |
|
if hasattr(data, 'num_nodes') and data.x.shape[0] != data.num_nodes: |
|
errors.append("Feature matrix size mismatch") |
|
|
|
|
|
if hasattr(data, 'edge_index') and data.edge_index.shape[1] > 0: |
|
max_idx = data.edge_index.max().item() |
|
if hasattr(data, 'num_nodes') and max_idx >= data.num_nodes: |
|
errors.append("Edge indices exceed number of nodes") |
|
|
|
return errors |