serpent / core /graph_mamba.py
kfoughali's picture
Update core/graph_mamba.py
0fecf94 verified
#!/usr/bin/env python3
"""
🚨 EMERGENCY OVERFITTING FIX 🚨
Tiny GraphMamba designed specifically for 140 training samples
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import to_undirected, add_self_loops
import torch.optim as optim
import time
def get_device():
if torch.cuda.is_available():
device = torch.device('cuda')
print(f"πŸš€ Using GPU: {torch.cuda.get_device_name()}")
torch.cuda.empty_cache()
else:
device = torch.device('cpu')
print("πŸ’» Using CPU")
return device
class EmergencyTinyMamba(nn.Module):
"""Emergency ultra-tiny model for 140 samples"""
def __init__(self, input_dim=1433, hidden_dim=8, num_classes=7):
super().__init__()
# TINY feature extraction
self.feature_reduce = nn.Sequential(
nn.Linear(input_dim, 32),
nn.ReLU(),
nn.Dropout(0.9), # Extreme dropout
nn.Linear(32, hidden_dim)
)
# Single GCN layer
self.gcn = GCNConv(hidden_dim, hidden_dim)
# Tiny "Mamba-inspired" temporal processing
self.temporal = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(), # Bounded activation
nn.Dropout(0.9)
)
# Direct classifier
self.classifier = nn.Sequential(
nn.Dropout(0.95), # Extreme dropout before classification
nn.Linear(hidden_dim, num_classes)
)
print(f"🦾 Emergency Model - Parameters: {sum(p.numel() for p in self.parameters()):,}")
def forward(self, x, edge_index):
# Feature reduction
h = self.feature_reduce(x)
# Graph convolution
h_gcn = F.relu(self.gcn(h, edge_index))
# Temporal processing (Mamba-inspired)
h_temporal = self.temporal(h_gcn)
# Small residual connection
h = h + h_temporal * 0.1 # Very small update
# Classification
return self.classifier(h)
class MicroMamba(nn.Module):
"""Even smaller model"""
def __init__(self, input_dim=1433, hidden_dim=4, num_classes=7):
super().__init__()
# Ultra-compressed feature extraction
self.features = nn.Sequential(
nn.Linear(input_dim, 16),
nn.ReLU(),
nn.Dropout(0.95),
nn.Linear(16, hidden_dim)
)
# Minimal processing
self.process = nn.Sequential(
GCNConv(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.9)
)
# Direct classification
self.classify = nn.Sequential(
nn.Dropout(0.95),
nn.Linear(hidden_dim, num_classes)
)
print(f"🀏 Micro Model - Parameters: {sum(p.numel() for p in self.parameters()):,}")
def forward(self, x, edge_index):
h = self.features(x)
h = self.process[0](h, edge_index) # GCN
h = self.process[1](h) # ReLU
h = self.process[2](h) # Dropout
return self.classify(h)
class NanoMamba(nn.Module):
"""Absolutely minimal model"""
def __init__(self, input_dim=1433, num_classes=7):
super().__init__()
# Direct path - no hidden layers
self.direct = nn.Sequential(
nn.Linear(input_dim, num_classes),
nn.Dropout(0.8)
)
# GCN path
self.gcn_path = nn.Sequential(
nn.Linear(input_dim, 8),
nn.Dropout(0.9)
)
self.gcn = GCNConv(8, num_classes)
print(f"βš›οΈ Nano Model - Parameters: {sum(p.numel() for p in self.parameters()):,}")
def forward(self, x, edge_index):
# Direct classification
direct_out = self.direct(x)
# GCN path
h = self.gcn_path(x)
gcn_out = self.gcn(h, edge_index)
# Minimal combination
return direct_out * 0.7 + gcn_out * 0.3
def emergency_train(model, data, device, epochs=2000):
"""Emergency training with extreme regularization"""
model = model.to(device)
data = data.to(device)
# Very conservative optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.5)
# Label smoothing cross entropy
criterion = nn.CrossEntropyLoss(label_smoothing=0.5)
print(f"🚨 Emergency Training Protocol")
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f" Per sample: {sum(p.numel() for p in model.parameters())/140:.1f}")
print(f" Epochs: {epochs}")
print(f" Learning rate: 0.001")
print(f" Weight decay: 0.5")
print(f" Label smoothing: 0.5")
best_val_acc = 0
patience = 0
for epoch in range(epochs):
# Training
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # Tiny gradients
optimizer.step()
# Evaluation
if (epoch + 1) % 100 == 0:
model.eval()
with torch.no_grad():
out = model(data.x, data.edge_index)
train_pred = out[data.train_mask].argmax(dim=1)
train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
val_pred = out[data.val_mask].argmax(dim=1)
val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
test_pred = out[data.test_mask].argmax(dim=1)
test_acc = (test_pred == data.y[data.test_mask]).float().mean().item()
gap = train_acc - val_acc
print(f" Epoch {epoch+1:4d}: Train {train_acc:.3f} | Val {val_acc:.3f} | "
f"Test {test_acc:.3f} | Gap {gap:.3f}")
if val_acc > best_val_acc:
best_val_acc = val_acc
patience = 0
else:
patience += 100
if patience >= 500: # Stop if no improvement
print(f" Early stopping at epoch {epoch+1}")
break
# Final evaluation
model.eval()
with torch.no_grad():
out = model(data.x, data.edge_index)
train_pred = out[data.train_mask].argmax(dim=1)
train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
val_pred = out[data.val_mask].argmax(dim=1)
val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
test_pred = out[data.test_mask].argmax(dim=1)
test_acc = (test_pred == data.y[data.test_mask]).float().mean().item()
gap = train_acc - val_acc
return {
'train_acc': train_acc,
'val_acc': val_acc,
'test_acc': test_acc,
'gap': gap
}
def run_emergency_fix():
"""Emergency overfitting fix"""
print("🚨🚨🚨 EMERGENCY OVERFITTING FIX 🚨🚨🚨")
print("🩹 Ultra-Tiny Models for 140 Training Samples")
print("=" * 60)
device = get_device()
# Load data
print("\nπŸ“Š Loading Cora dataset...")
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0].to(device)
data.edge_index = to_undirected(data.edge_index)
data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.x.size(0))
print(f"βœ… Dataset: {data.num_nodes} nodes, Train: {data.train_mask.sum()} samples")
print(f"🎯 Target: <50 parameters per sample = <7,000 total parameters")
# Test emergency models
models = {
'Emergency Tiny (8D)': EmergencyTinyMamba(hidden_dim=8),
'Micro (4D)': MicroMamba(hidden_dim=4),
'Nano (Direct)': NanoMamba()
}
results = {}
for name, model in models.items():
print(f"\nπŸ—οΈ Testing {name}...")
total_params = sum(p.numel() for p in model.parameters())
params_per_sample = total_params / 140
print(f" Parameters: {total_params:,} ({params_per_sample:.1f} per sample)")
if params_per_sample < 50:
print(f" βœ… EXCELLENT parameter ratio!")
elif params_per_sample < 100:
print(f" πŸ‘ Good parameter ratio!")
else:
print(f" ⚠️ Still might overfit")
# Test forward pass
with torch.no_grad():
out = model(data.x, data.edge_index)
print(f" Forward: {data.x.shape} -> {out.shape} βœ…")
try:
# Emergency training
result = emergency_train(model, data, device)
results[name] = result
print(f" 🎯 Final Results:")
print(f" Test Accuracy: {result['test_acc']:.3f} ({result['test_acc']*100:.1f}%)")
print(f" Train Accuracy: {result['train_acc']:.3f}")
print(f" Overfitting Gap: {result['gap']:.3f}")
if result['gap'] < 0.1:
print(f" πŸŽ‰ OVERFITTING SOLVED!")
elif result['gap'] < 0.2:
print(f" πŸ‘ Much better generalization!")
elif result['gap'] < 0.3:
print(f" πŸ“ˆ Improved generalization")
else:
print(f" ⚠️ Still overfitting")
except Exception as e:
print(f" ❌ Training failed: {e}")
# Emergency summary
print(f"\n{'='*60}")
print("🚨 EMERGENCY RESULTS SUMMARY")
print(f"{'='*60}")
best_gap = float('inf')
best_model = None
for name, result in results.items():
print(f"πŸ“Š {name}:")
print(f" Test: {result['test_acc']:.3f} | Gap: {result['gap']:.3f}")
if result['gap'] < best_gap:
best_gap = result['gap']
best_model = name
if best_model:
print(f"\nπŸ† Best Generalization: {best_model} (Gap: {best_gap:.3f})")
if best_gap < 0.1:
print(f"πŸŽ‰ MISSION ACCOMPLISHED! Overfitting crisis resolved!")
elif best_gap < 0.2:
print(f"πŸ‘ Significant improvement in generalization!")
else:
print(f"πŸ“ˆ Progress made, but still work to do...")
# Comparison with your current model
print(f"\nπŸ“ˆ Comparison:")
print(f" Your model: 194K params, Gap ~0.5")
if best_model and best_gap < 0.3:
improvement = 0.5 - best_gap
print(f" Best tiny model: Gap {best_gap:.3f} (Improvement: {improvement:.3f})")
print(f" 🎯 {improvement/0.5*100:.0f}% reduction in overfitting!")
print(f"\nπŸ’‘ Key Lesson: With only 140 samples, bigger β‰  better!")
print(f"🧠 Tiny models can achieve competitive performance with much better generalization.")
return results
if __name__ == "__main__":
results = run_emergency_fix()
print(f"\n🌐 Emergency fix complete. Process staying alive...")
try:
while True:
time.sleep(60)
except KeyboardInterrupt:
print("\nπŸ‘‹ Emergency protocol terminated.")