serpent / utils /metrics.py
kfoughali's picture
Update utils/metrics.py
8d3a013 verified
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report, precision_score, recall_score
import numpy as np
import logging
logger = logging.getLogger(__name__)
class GraphMetrics:
"""Comprehensive evaluation metrics for graph learning"""
@staticmethod
def accuracy(pred, target):
"""Classification accuracy with validation"""
try:
if pred.dim() > 1:
pred_labels = pred.argmax(dim=1)
else:
pred_labels = pred
if pred_labels.shape != target.shape:
raise ValueError("Prediction and target shapes don't match")
correct = (pred_labels == target).float()
accuracy = correct.mean().item()
if torch.isnan(torch.tensor(accuracy)) or torch.isinf(torch.tensor(accuracy)):
logger.warning("Invalid accuracy computed, returning 0.0")
return 0.0
return accuracy
except Exception as e:
logger.error(f"Accuracy computation failed: {e}")
return 0.0
@staticmethod
def f1_score_macro(pred, target):
"""Macro F1 score with robust error handling"""
try:
if pred.dim() > 1:
pred_labels = pred.argmax(dim=1)
else:
pred_labels = pred
pred_labels = pred_labels.cpu().numpy()
target_labels = target.cpu().numpy()
if len(pred_labels) == 0 or len(target_labels) == 0:
return 0.0
f1 = f1_score(target_labels, pred_labels, average='macro', zero_division=0)
if np.isnan(f1) or np.isinf(f1):
logger.warning("Invalid F1 macro score, returning 0.0")
return 0.0
return float(f1)
except Exception as e:
logger.error(f"F1 macro computation failed: {e}")
return 0.0
@staticmethod
def f1_score_micro(pred, target):
"""Micro F1 score with robust error handling"""
try:
if pred.dim() > 1:
pred_labels = pred.argmax(dim=1)
else:
pred_labels = pred
pred_labels = pred_labels.cpu().numpy()
target_labels = target.cpu().numpy()
if len(pred_labels) == 0 or len(target_labels) == 0:
return 0.0
f1 = f1_score(target_labels, pred_labels, average='micro', zero_division=0)
if np.isnan(f1) or np.isinf(f1):
logger.warning("Invalid F1 micro score, returning 0.0")
return 0.0
return float(f1)
except Exception as e:
logger.error(f"F1 micro computation failed: {e}")
return 0.0
@staticmethod
def precision_recall(pred, target, average='macro'):
"""Compute precision and recall scores"""
try:
if pred.dim() > 1:
pred_labels = pred.argmax(dim=1)
else:
pred_labels = pred
pred_labels = pred_labels.cpu().numpy()
target_labels = target.cpu().numpy()
if len(pred_labels) == 0 or len(target_labels) == 0:
return 0.0, 0.0
precision = precision_score(target_labels, pred_labels, average=average, zero_division=0)
recall = recall_score(target_labels, pred_labels, average=average, zero_division=0)
if np.isnan(precision) or np.isinf(precision):
precision = 0.0
if np.isnan(recall) or np.isinf(recall):
recall = 0.0
return float(precision), float(recall)
except Exception as e:
logger.error(f"Precision/recall computation failed: {e}")
return 0.0, 0.0