|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
import numpy as np |
|
import time |
|
import logging |
|
from utils.metrics import GraphMetrics |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class GraphMambaTrainer: |
|
"""Anti-overfitting trainer with heavy regularization""" |
|
|
|
def __init__(self, model, config, device): |
|
self.model = model |
|
self.config = config |
|
self.device = device |
|
|
|
|
|
self.lr = config['training']['learning_rate'] |
|
self.epochs = config['training']['epochs'] |
|
self.patience = config['training'].get('patience', 20) |
|
self.min_lr = config['training'].get('min_lr', 1e-6) |
|
self.max_gap = config['training'].get('max_gap', 0.4) |
|
|
|
|
|
self.optimizer = optim.AdamW( |
|
model.parameters(), |
|
lr=self.lr, |
|
weight_decay=config['training']['weight_decay'], |
|
betas=(0.9, 0.999), |
|
eps=1e-8 |
|
) |
|
|
|
|
|
self.criterion = nn.CrossEntropyLoss( |
|
label_smoothing=config['training'].get('label_smoothing', 0.1) |
|
) |
|
|
|
|
|
self.scheduler = ReduceLROnPlateau( |
|
self.optimizer, |
|
mode='max', |
|
factor=0.5, |
|
patience=5, |
|
min_lr=self.min_lr |
|
) |
|
|
|
|
|
self.best_val_acc = 0.0 |
|
self.best_val_loss = float('inf') |
|
self.patience_counter = 0 |
|
self.training_history = { |
|
'train_loss': [], 'train_acc': [], |
|
'val_loss': [], 'val_acc': [], 'lr': [] |
|
} |
|
|
|
|
|
self.best_gap = float('inf') |
|
self.overfitting_threshold = 0.25 |
|
|
|
def train_node_classification(self, data, verbose=True): |
|
"""Anti-overfitting training with gap monitoring""" |
|
|
|
if verbose: |
|
total_params = sum(p.numel() for p in self.model.parameters()) |
|
train_samples = data.train_mask.sum().item() |
|
params_per_sample = total_params / train_samples |
|
|
|
print(f"ποΈ Training GraphMamba for {self.epochs} epochs") |
|
print(f"π Dataset: {data.num_nodes} nodes, {data.num_edges} edges") |
|
print(f"π― Classes: {len(torch.unique(data.y))}") |
|
print(f"πΎ Device: {self.device}") |
|
print(f"βοΈ Parameters: {total_params:,}") |
|
print(f"π Training samples: {train_samples}") |
|
print(f"β οΈ Params per sample: {params_per_sample:.1f}") |
|
print(f"π¨ Max allowed gap: {self.max_gap:.3f}") |
|
|
|
if params_per_sample > 500: |
|
print(f"π¨ WARNING: High params per sample ratio - overfitting risk!") |
|
|
|
|
|
num_classes = len(torch.unique(data.y)) |
|
self.model._init_classifier(num_classes, self.device) |
|
|
|
self.model.train() |
|
start_time = time.time() |
|
|
|
for epoch in range(self.epochs): |
|
|
|
train_metrics = self._train_epoch(data, epoch) |
|
|
|
|
|
val_metrics = self._validate_epoch(data, epoch) |
|
|
|
|
|
acc_gap = train_metrics['acc'] - val_metrics['acc'] |
|
|
|
|
|
self.training_history['train_loss'].append(train_metrics['loss']) |
|
self.training_history['train_acc'].append(train_metrics['acc']) |
|
self.training_history['val_loss'].append(val_metrics['loss']) |
|
self.training_history['val_acc'].append(val_metrics['acc']) |
|
self.training_history['lr'].append(self.optimizer.param_groups[0]['lr']) |
|
|
|
|
|
self.scheduler.step(val_metrics['acc']) |
|
|
|
|
|
if val_metrics['acc'] > self.best_val_acc: |
|
self.best_val_acc = val_metrics['acc'] |
|
self.best_val_loss = val_metrics['loss'] |
|
self.best_gap = acc_gap |
|
self.patience_counter = 0 |
|
if verbose: |
|
print(f"π New best validation accuracy: {self.best_val_acc:.4f}") |
|
else: |
|
self.patience_counter += 1 |
|
|
|
|
|
if acc_gap > self.overfitting_threshold: |
|
if verbose: |
|
print(f"π¨ OVERFITTING detected: {acc_gap:.3f} gap") |
|
print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}") |
|
|
|
|
|
if verbose and (epoch == 0 or (epoch + 1) % 5 == 0 or epoch == self.epochs - 1): |
|
elapsed = time.time() - start_time |
|
gap_indicator = "π¨" if acc_gap > 0.25 else "β οΈ" if acc_gap > 0.15 else "β
" |
|
|
|
print(f"Epoch {epoch:3d} | " |
|
f"Train: {train_metrics['loss']:.4f} ({train_metrics['acc']:.4f}) | " |
|
f"Val: {val_metrics['loss']:.4f} ({val_metrics['acc']:.4f}) | " |
|
f"Gap: {acc_gap:.3f} {gap_indicator} | " |
|
f"LR: {self.optimizer.param_groups[0]['lr']:.6f}") |
|
|
|
|
|
if self.patience_counter >= self.patience: |
|
if verbose: |
|
print(f"π Early stopping at epoch {epoch} (patience)") |
|
break |
|
|
|
|
|
if acc_gap > self.max_gap: |
|
if verbose: |
|
print(f"π Stopping due to overfitting gap: {acc_gap:.3f} > {self.max_gap:.3f}") |
|
break |
|
|
|
|
|
if acc_gap > 0.6: |
|
if verbose: |
|
print(f"π Emergency stop - severe overfitting (gap: {acc_gap:.3f})") |
|
break |
|
|
|
if verbose: |
|
total_time = time.time() - start_time |
|
print(f"β
Training completed in {total_time:.2f}s") |
|
print(f"π Best validation accuracy: {self.best_val_acc:.4f}") |
|
print(f"π Best train-val gap: {self.best_gap:.4f}") |
|
|
|
if self.best_gap < 0.1: |
|
print("π Excellent generalization!") |
|
elif self.best_gap < 0.2: |
|
print("π Good generalization") |
|
else: |
|
print("β οΈ Some overfitting detected") |
|
|
|
return self.training_history |
|
|
|
def _train_epoch(self, data, epoch): |
|
"""Single training epoch with regularization""" |
|
self.model.train() |
|
self.optimizer.zero_grad() |
|
|
|
|
|
h = self.model(data.x, data.edge_index) |
|
logits = self.model.classifier(h) |
|
|
|
|
|
train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask]) |
|
|
|
|
|
l2_reg = 0.0 |
|
for param in self.model.parameters(): |
|
l2_reg += torch.norm(param, p=2) |
|
train_loss += 5e-5 * l2_reg |
|
|
|
|
|
train_loss.backward() |
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5) |
|
self.optimizer.step() |
|
|
|
|
|
with torch.no_grad(): |
|
train_pred = logits[data.train_mask].argmax(dim=1) |
|
train_acc = (train_pred == data.y[data.train_mask]).float().mean().item() |
|
|
|
return {'loss': train_loss.item(), 'acc': train_acc} |
|
|
|
def _validate_epoch(self, data, epoch): |
|
"""Validation without augmentation""" |
|
self.model.eval() |
|
|
|
with torch.no_grad(): |
|
h = self.model(data.x, data.edge_index) |
|
logits = self.model.classifier(h) |
|
|
|
|
|
val_loss = self.criterion(logits[data.val_mask], data.y[data.val_mask]) |
|
val_pred = logits[data.val_mask].argmax(dim=1) |
|
val_acc = (val_pred == data.y[data.val_mask]).float().mean().item() |
|
|
|
return {'loss': val_loss.item(), 'acc': val_acc} |
|
|
|
def test(self, data): |
|
"""Test evaluation""" |
|
self.model.eval() |
|
|
|
with torch.no_grad(): |
|
h = self.model(data.x, data.edge_index) |
|
|
|
if self.model.classifier is None: |
|
num_classes = len(torch.unique(data.y)) |
|
self.model._init_classifier(num_classes, self.device) |
|
|
|
logits = self.model.classifier(h) |
|
|
|
|
|
test_loss = self.criterion(logits[data.test_mask], data.y[data.test_mask]) |
|
test_pred = logits[data.test_mask] |
|
test_target = data.y[data.test_mask] |
|
|
|
metrics = { |
|
'test_loss': test_loss.item(), |
|
'test_acc': GraphMetrics.accuracy(test_pred, test_target), |
|
'f1_macro': GraphMetrics.f1_score_macro(test_pred, test_target), |
|
'f1_micro': GraphMetrics.f1_score_micro(test_pred, test_target), |
|
} |
|
|
|
precision, recall = GraphMetrics.precision_recall(test_pred, test_target) |
|
metrics['precision'] = precision |
|
metrics['recall'] = recall |
|
|
|
return metrics |
|
|
|
def get_embeddings(self, data): |
|
"""Get node embeddings""" |
|
self.model.eval() |
|
with torch.no_grad(): |
|
return self.model(data.x, data.edge_index) |