serpent / data /processor.py
kfoughali's picture
Update data/processor.py
4992374 verified
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.transforms import Compose
import numpy as np
class GraphProcessor:
"""Advanced data preprocessing utilities"""
@staticmethod
def normalize_features(x, method='l2'):
"""Normalize node features"""
if method == 'l2':
return F.normalize(x, p=2, dim=1)
elif method == 'minmax':
x_min = x.min(dim=0, keepdim=True)[0]
x_max = x.max(dim=0, keepdim=True)[0]
return (x - x_min) / (x_max - x_min + 1e-8)
elif method == 'standard':
return (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-8)
else:
return x
@staticmethod
def add_self_loops(edge_index, num_nodes):
"""Add self loops to graph"""
self_loops = torch.arange(num_nodes, device=edge_index.device).unsqueeze(0).repeat(2, 1)
edge_index = torch.cat([edge_index, self_loops], dim=1)
return edge_index
@staticmethod
def remove_self_loops(edge_index):
"""Remove self loops from graph"""
mask = edge_index[0] != edge_index[1]
return edge_index[:, mask]
@staticmethod
def add_positional_features(data, encoding_dim=8):
"""Add positional encodings as features"""
num_nodes = data.num_nodes
# Random walk positional encoding
if data.edge_index.shape[1] > 0:
adj = torch.zeros(num_nodes, num_nodes)
adj[data.edge_index[0], data.edge_index[1]] = 1
adj = adj + adj.t() # Make symmetric
# Degree normalization
degree = adj.sum(dim=1)
degree[degree == 0] = 1 # Avoid division by zero
D_inv_sqrt = torch.diag(1.0 / torch.sqrt(degree))
# Normalized adjacency
A_norm = D_inv_sqrt @ adj @ D_inv_sqrt
# Random walk features
rw_features = []
A_power = torch.eye(num_nodes)
for k in range(encoding_dim):
A_power = A_power @ A_norm
rw_features.append(A_power.diag().unsqueeze(1))
pos_encoding = torch.cat(rw_features, dim=1)
else:
# No edges - use node indices
pos_encoding = torch.zeros(num_nodes, encoding_dim)
for i in range(min(encoding_dim, num_nodes)):
pos_encoding[i, i] = 1.0
# Concatenate with existing features
if data.x is not None:
data.x = torch.cat([data.x, pos_encoding], dim=1)
else:
data.x = pos_encoding
return data
@staticmethod
def augment_graph(data, aug_type='edge_drop', aug_ratio=0.1):
"""Graph augmentation for training"""
if aug_type == 'edge_drop':
# Randomly drop edges
num_edges = data.edge_index.shape[1]
mask = torch.rand(num_edges) > aug_ratio
data.edge_index = data.edge_index[:, mask]
elif aug_type == 'node_drop':
# Randomly drop nodes
num_nodes = data.num_nodes
keep_mask = torch.rand(num_nodes) > aug_ratio
keep_nodes = torch.where(keep_mask)[0]
# Update edge index
node_map = torch.full((num_nodes,), -1, dtype=torch.long)
node_map[keep_nodes] = torch.arange(len(keep_nodes))
# Filter edges
edge_mask = keep_mask[data.edge_index[0]] & keep_mask[data.edge_index[1]]
filtered_edges = data.edge_index[:, edge_mask]
data.edge_index = node_map[filtered_edges]
# Update features
data.x = data.x[keep_nodes]
if hasattr(data, 'y') and data.y.size(0) == num_nodes:
data.y = data.y[keep_nodes]
elif aug_type == 'feature_noise':
# Add Gaussian noise to features
if data.x is not None:
noise = torch.randn_like(data.x) * aug_ratio
data.x = data.x + noise
elif aug_type == 'feature_mask':
# Randomly mask features
if data.x is not None:
mask = torch.rand_like(data.x) > aug_ratio
data.x = data.x * mask
return data
@staticmethod
def to_device_safe(data, device):
"""Move data to device safely"""
if hasattr(data, 'to'):
return data.to(device)
elif isinstance(data, (list, tuple)):
return [GraphProcessor.to_device_safe(item, device) for item in data]
elif isinstance(data, dict):
return {k: GraphProcessor.to_device_safe(v, device) for k, v in data.items()}
else:
return data
@staticmethod
def validate_data(data):
"""Validate graph data integrity"""
errors = []
# Check basic structure
if not hasattr(data, 'edge_index'):
errors.append("Missing edge_index")
elif data.edge_index.shape[0] != 2:
errors.append("edge_index must have shape (2, num_edges)")
if hasattr(data, 'x') and data.x is not None:
if hasattr(data, 'num_nodes') and data.x.shape[0] != data.num_nodes:
errors.append("Feature matrix size mismatch")
# Check edge indices
if hasattr(data, 'edge_index') and data.edge_index.shape[1] > 0:
max_idx = data.edge_index.max().item()
if hasattr(data, 'num_nodes') and max_idx >= data.num_nodes:
errors.append("Edge indices exceed number of nodes")
return errors