|
import torch |
|
import torch.nn.functional as F |
|
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report, precision_score, recall_score |
|
import numpy as np |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class GraphMetrics: |
|
"""Comprehensive evaluation metrics for graph learning""" |
|
|
|
@staticmethod |
|
def accuracy(pred, target): |
|
"""Classification accuracy with validation""" |
|
try: |
|
if pred.dim() > 1: |
|
pred_labels = pred.argmax(dim=1) |
|
else: |
|
pred_labels = pred |
|
|
|
if pred_labels.shape != target.shape: |
|
raise ValueError("Prediction and target shapes don't match") |
|
|
|
correct = (pred_labels == target).float() |
|
accuracy = correct.mean().item() |
|
|
|
if torch.isnan(torch.tensor(accuracy)) or torch.isinf(torch.tensor(accuracy)): |
|
logger.warning("Invalid accuracy computed, returning 0.0") |
|
return 0.0 |
|
|
|
return accuracy |
|
|
|
except Exception as e: |
|
logger.error(f"Accuracy computation failed: {e}") |
|
return 0.0 |
|
|
|
@staticmethod |
|
def f1_score_macro(pred, target): |
|
"""Macro F1 score with robust error handling""" |
|
try: |
|
if pred.dim() > 1: |
|
pred_labels = pred.argmax(dim=1) |
|
else: |
|
pred_labels = pred |
|
|
|
pred_labels = pred_labels.cpu().numpy() |
|
target_labels = target.cpu().numpy() |
|
|
|
if len(pred_labels) == 0 or len(target_labels) == 0: |
|
return 0.0 |
|
|
|
f1 = f1_score(target_labels, pred_labels, average='macro', zero_division=0) |
|
|
|
if np.isnan(f1) or np.isinf(f1): |
|
logger.warning("Invalid F1 macro score, returning 0.0") |
|
return 0.0 |
|
|
|
return float(f1) |
|
|
|
except Exception as e: |
|
logger.error(f"F1 macro computation failed: {e}") |
|
return 0.0 |
|
|
|
@staticmethod |
|
def f1_score_micro(pred, target): |
|
"""Micro F1 score with robust error handling""" |
|
try: |
|
if pred.dim() > 1: |
|
pred_labels = pred.argmax(dim=1) |
|
else: |
|
pred_labels = pred |
|
|
|
pred_labels = pred_labels.cpu().numpy() |
|
target_labels = target.cpu().numpy() |
|
|
|
if len(pred_labels) == 0 or len(target_labels) == 0: |
|
return 0.0 |
|
|
|
f1 = f1_score(target_labels, pred_labels, average='micro', zero_division=0) |
|
|
|
if np.isnan(f1) or np.isinf(f1): |
|
logger.warning("Invalid F1 micro score, returning 0.0") |
|
return 0.0 |
|
|
|
return float(f1) |
|
|
|
except Exception as e: |
|
logger.error(f"F1 micro computation failed: {e}") |
|
return 0.0 |
|
|
|
@staticmethod |
|
def precision_recall(pred, target, average='macro'): |
|
"""Compute precision and recall scores""" |
|
try: |
|
if pred.dim() > 1: |
|
pred_labels = pred.argmax(dim=1) |
|
else: |
|
pred_labels = pred |
|
|
|
pred_labels = pred_labels.cpu().numpy() |
|
target_labels = target.cpu().numpy() |
|
|
|
if len(pred_labels) == 0 or len(target_labels) == 0: |
|
return 0.0, 0.0 |
|
|
|
precision = precision_score(target_labels, pred_labels, average=average, zero_division=0) |
|
recall = recall_score(target_labels, pred_labels, average=average, zero_division=0) |
|
|
|
if np.isnan(precision) or np.isinf(precision): |
|
precision = 0.0 |
|
if np.isnan(recall) or np.isinf(recall): |
|
recall = 0.0 |
|
|
|
return float(precision), float(recall) |
|
|
|
except Exception as e: |
|
logger.error(f"Precision/recall computation failed: {e}") |
|
return 0.0, 0.0 |