serpent / core /trainer.py
kfoughali's picture
Update core/trainer.py
7a2f94c verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import time
import logging
from utils.metrics import GraphMetrics
logger = logging.getLogger(__name__)
class GraphMambaTrainer:
"""Anti-overfitting trainer with heavy regularization"""
def __init__(self, model, config, device):
self.model = model
self.config = config
self.device = device
# Optimized learning parameters
self.lr = config['training']['learning_rate']
self.epochs = config['training']['epochs']
self.patience = config['training'].get('patience', 20)
self.min_lr = config['training'].get('min_lr', 1e-6)
self.max_gap = config['training'].get('max_gap', 0.4)
# Heavily regularized optimizer
self.optimizer = optim.AdamW(
model.parameters(),
lr=self.lr,
weight_decay=config['training']['weight_decay'],
betas=(0.9, 0.999),
eps=1e-8
)
# Proper loss function with label smoothing
self.criterion = nn.CrossEntropyLoss(
label_smoothing=config['training'].get('label_smoothing', 0.1)
)
# Balanced scheduler
self.scheduler = ReduceLROnPlateau(
self.optimizer,
mode='max',
factor=0.5,
patience=5,
min_lr=self.min_lr
)
# Training state
self.best_val_acc = 0.0
self.best_val_loss = float('inf')
self.patience_counter = 0
self.training_history = {
'train_loss': [], 'train_acc': [],
'val_loss': [], 'val_acc': [], 'lr': []
}
# Track overfitting
self.best_gap = float('inf')
self.overfitting_threshold = 0.25 # Balanced threshold
def train_node_classification(self, data, verbose=True):
"""Anti-overfitting training with gap monitoring"""
if verbose:
total_params = sum(p.numel() for p in self.model.parameters())
train_samples = data.train_mask.sum().item()
params_per_sample = total_params / train_samples
print(f"πŸ‹οΈ Training GraphMamba for {self.epochs} epochs")
print(f"πŸ“Š Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
print(f"🎯 Classes: {len(torch.unique(data.y))}")
print(f"πŸ’Ύ Device: {self.device}")
print(f"βš™οΈ Parameters: {total_params:,}")
print(f"πŸ“š Training samples: {train_samples}")
print(f"⚠️ Params per sample: {params_per_sample:.1f}")
print(f"🚨 Max allowed gap: {self.max_gap:.3f}")
if params_per_sample > 500:
print(f"🚨 WARNING: High params per sample ratio - overfitting risk!")
# Initialize classifier
num_classes = len(torch.unique(data.y))
self.model._init_classifier(num_classes, self.device)
self.model.train()
start_time = time.time()
for epoch in range(self.epochs):
# Training step
train_metrics = self._train_epoch(data, epoch)
# Validation step
val_metrics = self._validate_epoch(data, epoch)
# Calculate overfitting gap
acc_gap = train_metrics['acc'] - val_metrics['acc']
# Update history
self.training_history['train_loss'].append(train_metrics['loss'])
self.training_history['train_acc'].append(train_metrics['acc'])
self.training_history['val_loss'].append(val_metrics['loss'])
self.training_history['val_acc'].append(val_metrics['acc'])
self.training_history['lr'].append(self.optimizer.param_groups[0]['lr'])
# Step scheduler
self.scheduler.step(val_metrics['acc'])
# Check for improvement
if val_metrics['acc'] > self.best_val_acc:
self.best_val_acc = val_metrics['acc']
self.best_val_loss = val_metrics['loss']
self.best_gap = acc_gap
self.patience_counter = 0
if verbose:
print(f"πŸŽ‰ New best validation accuracy: {self.best_val_acc:.4f}")
else:
self.patience_counter += 1
# Aggressive overfitting detection
if acc_gap > self.overfitting_threshold:
if verbose:
print(f"🚨 OVERFITTING detected: {acc_gap:.3f} gap")
print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}")
# Progress logging
if verbose and (epoch == 0 or (epoch + 1) % 5 == 0 or epoch == self.epochs - 1):
elapsed = time.time() - start_time
gap_indicator = "🚨" if acc_gap > 0.25 else "⚠️" if acc_gap > 0.15 else "βœ…"
print(f"Epoch {epoch:3d} | "
f"Train: {train_metrics['loss']:.4f} ({train_metrics['acc']:.4f}) | "
f"Val: {val_metrics['loss']:.4f} ({val_metrics['acc']:.4f}) | "
f"Gap: {acc_gap:.3f} {gap_indicator} | "
f"LR: {self.optimizer.param_groups[0]['lr']:.6f}")
# Enhanced early stopping conditions
if self.patience_counter >= self.patience:
if verbose:
print(f"πŸ›‘ Early stopping at epoch {epoch} (patience)")
break
# Stop if gap exceeds threshold
if acc_gap > self.max_gap:
if verbose:
print(f"πŸ›‘ Stopping due to overfitting gap: {acc_gap:.3f} > {self.max_gap:.3f}")
break
# Stop if severe overfitting (backup check)
if acc_gap > 0.6:
if verbose:
print(f"πŸ›‘ Emergency stop - severe overfitting (gap: {acc_gap:.3f})")
break
if verbose:
total_time = time.time() - start_time
print(f"βœ… Training completed in {total_time:.2f}s")
print(f"πŸ† Best validation accuracy: {self.best_val_acc:.4f}")
print(f"πŸ“Š Best train-val gap: {self.best_gap:.4f}")
if self.best_gap < 0.1:
print("πŸŽ‰ Excellent generalization!")
elif self.best_gap < 0.2:
print("πŸ‘ Good generalization")
else:
print("⚠️ Some overfitting detected")
return self.training_history
def _train_epoch(self, data, epoch):
"""Single training epoch with regularization"""
self.model.train()
self.optimizer.zero_grad()
# Forward pass (with data augmentation)
h = self.model(data.x, data.edge_index)
logits = self.model.classifier(h)
# Compute loss on training nodes only
train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
# Add stronger L2 regularization
l2_reg = 0.0
for param in self.model.parameters():
l2_reg += torch.norm(param, p=2)
train_loss += 5e-5 * l2_reg # Increased from 1e-5
# Backward pass with gradient clipping
train_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5) # Reduced from 1.0
self.optimizer.step()
# Compute accuracy
with torch.no_grad():
train_pred = logits[data.train_mask].argmax(dim=1)
train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
return {'loss': train_loss.item(), 'acc': train_acc}
def _validate_epoch(self, data, epoch):
"""Validation without augmentation"""
self.model.eval()
with torch.no_grad():
h = self.model(data.x, data.edge_index)
logits = self.model.classifier(h)
# Validation loss and accuracy
val_loss = self.criterion(logits[data.val_mask], data.y[data.val_mask])
val_pred = logits[data.val_mask].argmax(dim=1)
val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
return {'loss': val_loss.item(), 'acc': val_acc}
def test(self, data):
"""Test evaluation"""
self.model.eval()
with torch.no_grad():
h = self.model(data.x, data.edge_index)
if self.model.classifier is None:
num_classes = len(torch.unique(data.y))
self.model._init_classifier(num_classes, self.device)
logits = self.model.classifier(h)
# Test metrics
test_loss = self.criterion(logits[data.test_mask], data.y[data.test_mask])
test_pred = logits[data.test_mask]
test_target = data.y[data.test_mask]
metrics = {
'test_loss': test_loss.item(),
'test_acc': GraphMetrics.accuracy(test_pred, test_target),
'f1_macro': GraphMetrics.f1_score_macro(test_pred, test_target),
'f1_micro': GraphMetrics.f1_score_micro(test_pred, test_target),
}
precision, recall = GraphMetrics.precision_recall(test_pred, test_target)
metrics['precision'] = precision
metrics['recall'] = recall
return metrics
def get_embeddings(self, data):
"""Get node embeddings"""
self.model.eval()
with torch.no_grad():
return self.model(data.x, data.edge_index)