File size: 9,375 Bytes
abceea1 021bc4e abceea1 021bc4e abceea1 beb8b0c abceea1 021bc4e abceea1 beb8b0c 021bc4e beb8b0c abceea1 021bc4e abceea1 021bc4e abceea1 beb8b0c 021bc4e beb8b0c 021bc4e beb8b0c 021bc4e beb8b0c 021bc4e beb8b0c 021bc4e beb8b0c abceea1 beb8b0c 021bc4e abceea1 021bc4e abceea1 021bc4e abceea1 021bc4e abceea1 021bc4e abceea1 beb8b0c 021bc4e beb8b0c 021bc4e beb8b0c 021bc4e beb8b0c 021bc4e beb8b0c 021bc4e abceea1 021bc4e abceea1 021bc4e abceea1 021bc4e abceea1 021bc4e abceea1 beb8b0c abceea1 021bc4e beb8b0c 021bc4e beb8b0c 021bc4e beb8b0c 021bc4e beb8b0c 021bc4e beb8b0c 021bc4e abceea1 beb8b0c 021bc4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
import torch
from torch_geometric.datasets import Planetoid, TUDataset, Amazon, Coauthor
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import NormalizeFeatures, Compose
import yaml
import os
class GraphDataLoader:
"""
Production data loading with comprehensive dataset support
"""
def __init__(self, config_path='config.yaml'):
if os.path.exists(config_path):
with open(config_path, 'r') as f:
self.config = yaml.safe_load(f)
else:
# Default config
self.config = {
'data': {
'batch_size': 32,
'test_split': 0.2
}
}
self.batch_size = self.config['data']['batch_size']
self.test_split = self.config['data']['test_split']
# Standard transform
self.transform = Compose([
NormalizeFeatures()
])
def load_node_classification_data(self, dataset_name='Cora'):
"""Load node classification datasets with proper splits"""
try:
if dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
dataset = Planetoid(
root=f'./data/{dataset_name}',
name=dataset_name,
transform=self.transform
)
elif dataset_name in ['Computers', 'Photo']:
dataset = Amazon(
root=f'./data/Amazon{dataset_name}',
name=dataset_name,
transform=self.transform
)
elif dataset_name in ['CS', 'Physics']:
dataset = Coauthor(
root=f'./data/Coauthor{dataset_name}',
name=dataset_name,
transform=self.transform
)
else:
print(f"Unknown dataset {dataset_name}, falling back to Cora")
dataset = Planetoid(
root='./data/Cora',
name='Cora',
transform=self.transform
)
except Exception as e:
print(f"Error loading {dataset_name}: {e}")
# Fallback to Cora
dataset = Planetoid(
root='./data/Cora',
name='Cora',
transform=self.transform
)
# Ensure proper masks exist
data = dataset[0]
self._ensure_masks(data)
return dataset
def _ensure_masks(self, data):
"""Ensure train/val/test masks exist"""
num_nodes = data.num_nodes
if not hasattr(data, 'train_mask') or data.train_mask is None:
# Create random splits
indices = torch.randperm(num_nodes)
train_size = int(0.6 * num_nodes)
val_size = int(0.2 * num_nodes)
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[indices[:train_size]] = True
val_mask[indices[train_size:train_size + val_size]] = True
test_mask[indices[train_size + val_size:]] = True
data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask
def load_graph_classification_data(self, dataset_name='MUTAG'):
"""Load graph classification datasets"""
valid_datasets = ['MUTAG', 'ENZYMES', 'PROTEINS', 'COLLAB', 'IMDB-BINARY', 'DD']
try:
if dataset_name not in valid_datasets:
dataset_name = 'MUTAG'
dataset = TUDataset(
root=f'./data/{dataset_name}',
name=dataset_name,
transform=self.transform
)
# Handle missing features
if dataset[0].x is None:
# Use degree as features
max_degree = 0
for data in dataset:
if data.edge_index.shape[1] > 0:
degree = torch.zeros(data.num_nodes)
degree.index_add_(0, data.edge_index[0], torch.ones(data.edge_index.shape[1]))
max_degree = max(max_degree, degree.max().item())
for data in dataset:
if data.edge_index.shape[1] > 0:
degree = torch.zeros(data.num_nodes)
degree.index_add_(0, data.edge_index[0], torch.ones(data.edge_index.shape[1]))
data.x = degree.unsqueeze(1) / max(max_degree, 1)
else:
data.x = torch.zeros(data.num_nodes, 1)
except Exception as e:
print(f"Error loading {dataset_name}: {e}")
# Create minimal synthetic dataset
from torch_geometric.data import Data
dataset = [
Data(
x=torch.randn(10, 5),
edge_index=torch.randint(0, 10, (2, 20)),
y=torch.randint(0, 2, (1,))
) for _ in range(100)
]
return dataset
def create_dataloaders(self, dataset, task_type='node_classification'):
"""Create train/val/test splits with dataloaders"""
if task_type == 'node_classification':
# Single graph with masks
data = dataset[0]
return data, None, None
elif task_type == 'graph_classification':
# Split dataset
num_graphs = len(dataset)
indices = torch.randperm(num_graphs)
train_size = int(0.8 * num_graphs)
val_size = int(0.1 * num_graphs)
train_dataset = [dataset[i] for i in indices[:train_size]]
val_dataset = [dataset[i] for i in indices[train_size:train_size+val_size]]
test_dataset = [dataset[i] for i in indices[train_size+val_size:]]
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
return train_loader, val_loader, test_loader
def get_dataset_info(self, dataset):
"""Get comprehensive dataset information"""
try:
if hasattr(dataset, 'num_features'):
num_features = dataset.num_features
else:
num_features = dataset[0].x.size(1) if dataset[0].x is not None else 1
if hasattr(dataset, 'num_classes'):
num_classes = dataset.num_classes
else:
if hasattr(dataset[0], 'y') and dataset[0].y is not None:
if len(dataset) > 1:
all_labels = []
for data in dataset:
if data.y is not None:
all_labels.extend(data.y.flatten().tolist())
num_classes = len(set(all_labels)) if all_labels else 2
else:
num_classes = len(torch.unique(dataset[0].y))
else:
num_classes = 2
num_graphs = len(dataset)
# Calculate statistics
total_nodes = sum([data.num_nodes for data in dataset])
total_edges = sum([data.num_edges for data in dataset])
avg_nodes = total_nodes / num_graphs
avg_edges = total_edges / num_graphs
# Additional statistics
node_counts = [data.num_nodes for data in dataset]
edge_counts = [data.num_edges for data in dataset]
stats = {
'num_features': num_features,
'num_classes': num_classes,
'num_graphs': num_graphs,
'avg_nodes': avg_nodes,
'avg_edges': avg_edges,
'min_nodes': min(node_counts),
'max_nodes': max(node_counts),
'min_edges': min(edge_counts),
'max_edges': max(edge_counts),
'total_nodes': total_nodes,
'total_edges': total_edges
}
except Exception as e:
print(f"Error getting dataset info: {e}")
# Return safe defaults
stats = {
'num_features': 1433,
'num_classes': 7,
'num_graphs': 1,
'avg_nodes': 2708.0,
'avg_edges': 10556.0,
'min_nodes': 2708,
'max_nodes': 2708,
'min_edges': 10556,
'max_edges': 10556,
'total_nodes': 2708,
'total_edges': 10556
}
return stats |