File size: 4,158 Bytes
6d0498a 8d3a013 6d0498a 972fdf4 6d0498a 972fdf4 8d3a013 6d0498a 8d3a013 850d736 8d3a013 850d736 8d3a013 972fdf4 8d3a013 972fdf4 8d3a013 972fdf4 8d3a013 972fdf4 8d3a013 6d0498a 8d3a013 8a7e32b 8d3a013 972fdf4 8d3a013 972fdf4 8d3a013 972fdf4 8d3a013 972fdf4 8d3a013 972fdf4 8d3a013 8a7e32b 8d3a013 850d736 8d3a013 972fdf4 8d3a013 6d0498a 8d3a013 c3b0dee 8d3a013 972fdf4 8d3a013 972fdf4 8d3a013 972fdf4 8d3a013 972fdf4 c3b0dee 8d3a013 c3b0dee 8d3a013 972fdf4 8d3a013 972fdf4 8d3a013 972fdf4 8d3a013 972fdf4 8d3a013 972fdf4 8d3a013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report, precision_score, recall_score
import numpy as np
import logging
logger = logging.getLogger(__name__)
class GraphMetrics:
"""Comprehensive evaluation metrics for graph learning"""
@staticmethod
def accuracy(pred, target):
"""Classification accuracy with validation"""
try:
if pred.dim() > 1:
pred_labels = pred.argmax(dim=1)
else:
pred_labels = pred
if pred_labels.shape != target.shape:
raise ValueError("Prediction and target shapes don't match")
correct = (pred_labels == target).float()
accuracy = correct.mean().item()
if torch.isnan(torch.tensor(accuracy)) or torch.isinf(torch.tensor(accuracy)):
logger.warning("Invalid accuracy computed, returning 0.0")
return 0.0
return accuracy
except Exception as e:
logger.error(f"Accuracy computation failed: {e}")
return 0.0
@staticmethod
def f1_score_macro(pred, target):
"""Macro F1 score with robust error handling"""
try:
if pred.dim() > 1:
pred_labels = pred.argmax(dim=1)
else:
pred_labels = pred
pred_labels = pred_labels.cpu().numpy()
target_labels = target.cpu().numpy()
if len(pred_labels) == 0 or len(target_labels) == 0:
return 0.0
f1 = f1_score(target_labels, pred_labels, average='macro', zero_division=0)
if np.isnan(f1) or np.isinf(f1):
logger.warning("Invalid F1 macro score, returning 0.0")
return 0.0
return float(f1)
except Exception as e:
logger.error(f"F1 macro computation failed: {e}")
return 0.0
@staticmethod
def f1_score_micro(pred, target):
"""Micro F1 score with robust error handling"""
try:
if pred.dim() > 1:
pred_labels = pred.argmax(dim=1)
else:
pred_labels = pred
pred_labels = pred_labels.cpu().numpy()
target_labels = target.cpu().numpy()
if len(pred_labels) == 0 or len(target_labels) == 0:
return 0.0
f1 = f1_score(target_labels, pred_labels, average='micro', zero_division=0)
if np.isnan(f1) or np.isinf(f1):
logger.warning("Invalid F1 micro score, returning 0.0")
return 0.0
return float(f1)
except Exception as e:
logger.error(f"F1 micro computation failed: {e}")
return 0.0
@staticmethod
def precision_recall(pred, target, average='macro'):
"""Compute precision and recall scores"""
try:
if pred.dim() > 1:
pred_labels = pred.argmax(dim=1)
else:
pred_labels = pred
pred_labels = pred_labels.cpu().numpy()
target_labels = target.cpu().numpy()
if len(pred_labels) == 0 or len(target_labels) == 0:
return 0.0, 0.0
precision = precision_score(target_labels, pred_labels, average=average, zero_division=0)
recall = recall_score(target_labels, pred_labels, average=average, zero_division=0)
if np.isnan(precision) or np.isinf(precision):
precision = 0.0
if np.isnan(recall) or np.isinf(recall):
recall = 0.0
return float(precision), float(recall)
except Exception as e:
logger.error(f"Precision/recall computation failed: {e}")
return 0.0, 0.0 |