File size: 5,817 Bytes
aab84ef
 
 
4992374
 
aab84ef
 
4992374
aab84ef
 
4992374
aab84ef
4992374
 
 
 
 
 
 
 
 
 
aab84ef
 
 
 
4992374
aab84ef
 
 
 
4992374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aab84ef
 
 
 
4992374
aab84ef
4992374
aab84ef
4992374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.transforms import Compose
import numpy as np

class GraphProcessor:
    """Advanced data preprocessing utilities"""
    
    @staticmethod
    def normalize_features(x, method='l2'):
        """Normalize node features"""
        if method == 'l2':
            return F.normalize(x, p=2, dim=1)
        elif method == 'minmax':
            x_min = x.min(dim=0, keepdim=True)[0]
            x_max = x.max(dim=0, keepdim=True)[0]
            return (x - x_min) / (x_max - x_min + 1e-8)
        elif method == 'standard':
            return (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-8)
        else:
            return x
    
    @staticmethod
    def add_self_loops(edge_index, num_nodes):
        """Add self loops to graph"""
        self_loops = torch.arange(num_nodes, device=edge_index.device).unsqueeze(0).repeat(2, 1)
        edge_index = torch.cat([edge_index, self_loops], dim=1)
        return edge_index
    
    @staticmethod
    def remove_self_loops(edge_index):
        """Remove self loops from graph"""
        mask = edge_index[0] != edge_index[1]
        return edge_index[:, mask]
    
    @staticmethod
    def add_positional_features(data, encoding_dim=8):
        """Add positional encodings as features"""
        num_nodes = data.num_nodes
        
        # Random walk positional encoding
        if data.edge_index.shape[1] > 0:
            adj = torch.zeros(num_nodes, num_nodes)
            adj[data.edge_index[0], data.edge_index[1]] = 1
            adj = adj + adj.t()  # Make symmetric
            
            # Degree normalization
            degree = adj.sum(dim=1)
            degree[degree == 0] = 1  # Avoid division by zero
            D_inv_sqrt = torch.diag(1.0 / torch.sqrt(degree))
            
            # Normalized adjacency
            A_norm = D_inv_sqrt @ adj @ D_inv_sqrt
            
            # Random walk features
            rw_features = []
            A_power = torch.eye(num_nodes)
            
            for k in range(encoding_dim):
                A_power = A_power @ A_norm
                rw_features.append(A_power.diag().unsqueeze(1))
            
            pos_encoding = torch.cat(rw_features, dim=1)
        else:
            # No edges - use node indices
            pos_encoding = torch.zeros(num_nodes, encoding_dim)
            for i in range(min(encoding_dim, num_nodes)):
                pos_encoding[i, i] = 1.0
        
        # Concatenate with existing features
        if data.x is not None:
            data.x = torch.cat([data.x, pos_encoding], dim=1)
        else:
            data.x = pos_encoding
            
        return data
    
    @staticmethod
    def augment_graph(data, aug_type='edge_drop', aug_ratio=0.1):
        """Graph augmentation for training"""
        if aug_type == 'edge_drop':
            # Randomly drop edges
            num_edges = data.edge_index.shape[1]
            mask = torch.rand(num_edges) > aug_ratio
            data.edge_index = data.edge_index[:, mask]
            
        elif aug_type == 'node_drop':
            # Randomly drop nodes
            num_nodes = data.num_nodes
            keep_mask = torch.rand(num_nodes) > aug_ratio
            keep_nodes = torch.where(keep_mask)[0]
            
            # Update edge index
            node_map = torch.full((num_nodes,), -1, dtype=torch.long)
            node_map[keep_nodes] = torch.arange(len(keep_nodes))
            
            # Filter edges
            edge_mask = keep_mask[data.edge_index[0]] & keep_mask[data.edge_index[1]]
            filtered_edges = data.edge_index[:, edge_mask]
            data.edge_index = node_map[filtered_edges]
            
            # Update features
            data.x = data.x[keep_nodes]
            if hasattr(data, 'y') and data.y.size(0) == num_nodes:
                data.y = data.y[keep_nodes]
                
        elif aug_type == 'feature_noise':
            # Add Gaussian noise to features
            if data.x is not None:
                noise = torch.randn_like(data.x) * aug_ratio
                data.x = data.x + noise
                
        elif aug_type == 'feature_mask':
            # Randomly mask features
            if data.x is not None:
                mask = torch.rand_like(data.x) > aug_ratio
                data.x = data.x * mask
        
        return data
    
    @staticmethod
    def to_device_safe(data, device):
        """Move data to device safely"""
        if hasattr(data, 'to'):
            return data.to(device)
        elif isinstance(data, (list, tuple)):
            return [GraphProcessor.to_device_safe(item, device) for item in data]
        elif isinstance(data, dict):
            return {k: GraphProcessor.to_device_safe(v, device) for k, v in data.items()}
        else:
            return data
    
    @staticmethod
    def validate_data(data):
        """Validate graph data integrity"""
        errors = []
        
        # Check basic structure
        if not hasattr(data, 'edge_index'):
            errors.append("Missing edge_index")
        elif data.edge_index.shape[0] != 2:
            errors.append("edge_index must have shape (2, num_edges)")
            
        if hasattr(data, 'x') and data.x is not None:
            if hasattr(data, 'num_nodes') and data.x.shape[0] != data.num_nodes:
                errors.append("Feature matrix size mismatch")
                
        # Check edge indices
        if hasattr(data, 'edge_index') and data.edge_index.shape[1] > 0:
            max_idx = data.edge_index.max().item()
            if hasattr(data, 'num_nodes') and max_idx >= data.num_nodes:
                errors.append("Edge indices exceed number of nodes")
                
        return errors