import torch import torch.nn as nn import torch.optim as optim from torch.optim.lr_scheduler import ReduceLROnPlateau import numpy as np import time import logging from utils.metrics import GraphMetrics logger = logging.getLogger(__name__) class GraphMambaTrainer: """Anti-overfitting trainer with heavy regularization""" def __init__(self, model, config, device): self.model = model self.config = config self.device = device # Optimized learning parameters self.lr = config['training']['learning_rate'] self.epochs = config['training']['epochs'] self.patience = config['training'].get('patience', 20) self.min_lr = config['training'].get('min_lr', 1e-6) self.max_gap = config['training'].get('max_gap', 0.4) # Heavily regularized optimizer self.optimizer = optim.AdamW( model.parameters(), lr=self.lr, weight_decay=config['training']['weight_decay'], betas=(0.9, 0.999), eps=1e-8 ) # Proper loss function with label smoothing self.criterion = nn.CrossEntropyLoss( label_smoothing=config['training'].get('label_smoothing', 0.1) ) # Balanced scheduler self.scheduler = ReduceLROnPlateau( self.optimizer, mode='max', factor=0.5, patience=5, min_lr=self.min_lr ) # Training state self.best_val_acc = 0.0 self.best_val_loss = float('inf') self.patience_counter = 0 self.training_history = { 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': [] } # Track overfitting self.best_gap = float('inf') self.overfitting_threshold = 0.25 # Balanced threshold def train_node_classification(self, data, verbose=True): """Anti-overfitting training with gap monitoring""" if verbose: total_params = sum(p.numel() for p in self.model.parameters()) train_samples = data.train_mask.sum().item() params_per_sample = total_params / train_samples print(f"🏋️ Training GraphMamba for {self.epochs} epochs") print(f"📊 Dataset: {data.num_nodes} nodes, {data.num_edges} edges") print(f"🎯 Classes: {len(torch.unique(data.y))}") print(f"💾 Device: {self.device}") print(f"⚙️ Parameters: {total_params:,}") print(f"📚 Training samples: {train_samples}") print(f"⚠️ Params per sample: {params_per_sample:.1f}") print(f"🚨 Max allowed gap: {self.max_gap:.3f}") if params_per_sample > 500: print(f"🚨 WARNING: High params per sample ratio - overfitting risk!") # Initialize classifier num_classes = len(torch.unique(data.y)) self.model._init_classifier(num_classes, self.device) self.model.train() start_time = time.time() for epoch in range(self.epochs): # Training step train_metrics = self._train_epoch(data, epoch) # Validation step val_metrics = self._validate_epoch(data, epoch) # Calculate overfitting gap acc_gap = train_metrics['acc'] - val_metrics['acc'] # Update history self.training_history['train_loss'].append(train_metrics['loss']) self.training_history['train_acc'].append(train_metrics['acc']) self.training_history['val_loss'].append(val_metrics['loss']) self.training_history['val_acc'].append(val_metrics['acc']) self.training_history['lr'].append(self.optimizer.param_groups[0]['lr']) # Step scheduler self.scheduler.step(val_metrics['acc']) # Check for improvement if val_metrics['acc'] > self.best_val_acc: self.best_val_acc = val_metrics['acc'] self.best_val_loss = val_metrics['loss'] self.best_gap = acc_gap self.patience_counter = 0 if verbose: print(f"🎉 New best validation accuracy: {self.best_val_acc:.4f}") else: self.patience_counter += 1 # Aggressive overfitting detection if acc_gap > self.overfitting_threshold: if verbose: print(f"🚨 OVERFITTING detected: {acc_gap:.3f} gap") print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}") # Progress logging if verbose and (epoch == 0 or (epoch + 1) % 5 == 0 or epoch == self.epochs - 1): elapsed = time.time() - start_time gap_indicator = "🚨" if acc_gap > 0.25 else "⚠️" if acc_gap > 0.15 else "✅" print(f"Epoch {epoch:3d} | " f"Train: {train_metrics['loss']:.4f} ({train_metrics['acc']:.4f}) | " f"Val: {val_metrics['loss']:.4f} ({val_metrics['acc']:.4f}) | " f"Gap: {acc_gap:.3f} {gap_indicator} | " f"LR: {self.optimizer.param_groups[0]['lr']:.6f}") # Enhanced early stopping conditions if self.patience_counter >= self.patience: if verbose: print(f"🛑 Early stopping at epoch {epoch} (patience)") break # Stop if gap exceeds threshold if acc_gap > self.max_gap: if verbose: print(f"🛑 Stopping due to overfitting gap: {acc_gap:.3f} > {self.max_gap:.3f}") break # Stop if severe overfitting (backup check) if acc_gap > 0.6: if verbose: print(f"🛑 Emergency stop - severe overfitting (gap: {acc_gap:.3f})") break if verbose: total_time = time.time() - start_time print(f"✅ Training completed in {total_time:.2f}s") print(f"🏆 Best validation accuracy: {self.best_val_acc:.4f}") print(f"📊 Best train-val gap: {self.best_gap:.4f}") if self.best_gap < 0.1: print("🎉 Excellent generalization!") elif self.best_gap < 0.2: print("👍 Good generalization") else: print("⚠️ Some overfitting detected") return self.training_history def _train_epoch(self, data, epoch): """Single training epoch with regularization""" self.model.train() self.optimizer.zero_grad() # Forward pass (with data augmentation) h = self.model(data.x, data.edge_index) logits = self.model.classifier(h) # Compute loss on training nodes only train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask]) # Add stronger L2 regularization l2_reg = 0.0 for param in self.model.parameters(): l2_reg += torch.norm(param, p=2) train_loss += 5e-5 * l2_reg # Increased from 1e-5 # Backward pass with gradient clipping train_loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5) # Reduced from 1.0 self.optimizer.step() # Compute accuracy with torch.no_grad(): train_pred = logits[data.train_mask].argmax(dim=1) train_acc = (train_pred == data.y[data.train_mask]).float().mean().item() return {'loss': train_loss.item(), 'acc': train_acc} def _validate_epoch(self, data, epoch): """Validation without augmentation""" self.model.eval() with torch.no_grad(): h = self.model(data.x, data.edge_index) logits = self.model.classifier(h) # Validation loss and accuracy val_loss = self.criterion(logits[data.val_mask], data.y[data.val_mask]) val_pred = logits[data.val_mask].argmax(dim=1) val_acc = (val_pred == data.y[data.val_mask]).float().mean().item() return {'loss': val_loss.item(), 'acc': val_acc} def test(self, data): """Test evaluation""" self.model.eval() with torch.no_grad(): h = self.model(data.x, data.edge_index) if self.model.classifier is None: num_classes = len(torch.unique(data.y)) self.model._init_classifier(num_classes, self.device) logits = self.model.classifier(h) # Test metrics test_loss = self.criterion(logits[data.test_mask], data.y[data.test_mask]) test_pred = logits[data.test_mask] test_target = data.y[data.test_mask] metrics = { 'test_loss': test_loss.item(), 'test_acc': GraphMetrics.accuracy(test_pred, test_target), 'f1_macro': GraphMetrics.f1_score_macro(test_pred, test_target), 'f1_micro': GraphMetrics.f1_score_micro(test_pred, test_target), } precision, recall = GraphMetrics.precision_recall(test_pred, test_target) metrics['precision'] = precision metrics['recall'] = recall return metrics def get_embeddings(self, data): """Get node embeddings""" self.model.eval() with torch.no_grad(): return self.model(data.x, data.edge_index)