import torch import torch.nn.functional as F from torch_geometric.data import Data from torch_geometric.transforms import Compose import numpy as np class GraphProcessor: """Advanced data preprocessing utilities""" @staticmethod def normalize_features(x, method='l2'): """Normalize node features""" if method == 'l2': return F.normalize(x, p=2, dim=1) elif method == 'minmax': x_min = x.min(dim=0, keepdim=True)[0] x_max = x.max(dim=0, keepdim=True)[0] return (x - x_min) / (x_max - x_min + 1e-8) elif method == 'standard': return (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-8) else: return x @staticmethod def add_self_loops(edge_index, num_nodes): """Add self loops to graph""" self_loops = torch.arange(num_nodes, device=edge_index.device).unsqueeze(0).repeat(2, 1) edge_index = torch.cat([edge_index, self_loops], dim=1) return edge_index @staticmethod def remove_self_loops(edge_index): """Remove self loops from graph""" mask = edge_index[0] != edge_index[1] return edge_index[:, mask] @staticmethod def add_positional_features(data, encoding_dim=8): """Add positional encodings as features""" num_nodes = data.num_nodes # Random walk positional encoding if data.edge_index.shape[1] > 0: adj = torch.zeros(num_nodes, num_nodes) adj[data.edge_index[0], data.edge_index[1]] = 1 adj = adj + adj.t() # Make symmetric # Degree normalization degree = adj.sum(dim=1) degree[degree == 0] = 1 # Avoid division by zero D_inv_sqrt = torch.diag(1.0 / torch.sqrt(degree)) # Normalized adjacency A_norm = D_inv_sqrt @ adj @ D_inv_sqrt # Random walk features rw_features = [] A_power = torch.eye(num_nodes) for k in range(encoding_dim): A_power = A_power @ A_norm rw_features.append(A_power.diag().unsqueeze(1)) pos_encoding = torch.cat(rw_features, dim=1) else: # No edges - use node indices pos_encoding = torch.zeros(num_nodes, encoding_dim) for i in range(min(encoding_dim, num_nodes)): pos_encoding[i, i] = 1.0 # Concatenate with existing features if data.x is not None: data.x = torch.cat([data.x, pos_encoding], dim=1) else: data.x = pos_encoding return data @staticmethod def augment_graph(data, aug_type='edge_drop', aug_ratio=0.1): """Graph augmentation for training""" if aug_type == 'edge_drop': # Randomly drop edges num_edges = data.edge_index.shape[1] mask = torch.rand(num_edges) > aug_ratio data.edge_index = data.edge_index[:, mask] elif aug_type == 'node_drop': # Randomly drop nodes num_nodes = data.num_nodes keep_mask = torch.rand(num_nodes) > aug_ratio keep_nodes = torch.where(keep_mask)[0] # Update edge index node_map = torch.full((num_nodes,), -1, dtype=torch.long) node_map[keep_nodes] = torch.arange(len(keep_nodes)) # Filter edges edge_mask = keep_mask[data.edge_index[0]] & keep_mask[data.edge_index[1]] filtered_edges = data.edge_index[:, edge_mask] data.edge_index = node_map[filtered_edges] # Update features data.x = data.x[keep_nodes] if hasattr(data, 'y') and data.y.size(0) == num_nodes: data.y = data.y[keep_nodes] elif aug_type == 'feature_noise': # Add Gaussian noise to features if data.x is not None: noise = torch.randn_like(data.x) * aug_ratio data.x = data.x + noise elif aug_type == 'feature_mask': # Randomly mask features if data.x is not None: mask = torch.rand_like(data.x) > aug_ratio data.x = data.x * mask return data @staticmethod def to_device_safe(data, device): """Move data to device safely""" if hasattr(data, 'to'): return data.to(device) elif isinstance(data, (list, tuple)): return [GraphProcessor.to_device_safe(item, device) for item in data] elif isinstance(data, dict): return {k: GraphProcessor.to_device_safe(v, device) for k, v in data.items()} else: return data @staticmethod def validate_data(data): """Validate graph data integrity""" errors = [] # Check basic structure if not hasattr(data, 'edge_index'): errors.append("Missing edge_index") elif data.edge_index.shape[0] != 2: errors.append("edge_index must have shape (2, num_edges)") if hasattr(data, 'x') and data.x is not None: if hasattr(data, 'num_nodes') and data.x.shape[0] != data.num_nodes: errors.append("Feature matrix size mismatch") # Check edge indices if hasattr(data, 'edge_index') and data.edge_index.shape[1] > 0: max_idx = data.edge_index.max().item() if hasattr(data, 'num_nodes') and max_idx >= data.num_nodes: errors.append("Edge indices exceed number of nodes") return errors