import torch import torch.nn.functional as F from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report, precision_score, recall_score import numpy as np import logging logger = logging.getLogger(__name__) class GraphMetrics: """Comprehensive evaluation metrics for graph learning""" @staticmethod def accuracy(pred, target): """Classification accuracy with validation""" try: if pred.dim() > 1: pred_labels = pred.argmax(dim=1) else: pred_labels = pred if pred_labels.shape != target.shape: raise ValueError("Prediction and target shapes don't match") correct = (pred_labels == target).float() accuracy = correct.mean().item() if torch.isnan(torch.tensor(accuracy)) or torch.isinf(torch.tensor(accuracy)): logger.warning("Invalid accuracy computed, returning 0.0") return 0.0 return accuracy except Exception as e: logger.error(f"Accuracy computation failed: {e}") return 0.0 @staticmethod def f1_score_macro(pred, target): """Macro F1 score with robust error handling""" try: if pred.dim() > 1: pred_labels = pred.argmax(dim=1) else: pred_labels = pred pred_labels = pred_labels.cpu().numpy() target_labels = target.cpu().numpy() if len(pred_labels) == 0 or len(target_labels) == 0: return 0.0 f1 = f1_score(target_labels, pred_labels, average='macro', zero_division=0) if np.isnan(f1) or np.isinf(f1): logger.warning("Invalid F1 macro score, returning 0.0") return 0.0 return float(f1) except Exception as e: logger.error(f"F1 macro computation failed: {e}") return 0.0 @staticmethod def f1_score_micro(pred, target): """Micro F1 score with robust error handling""" try: if pred.dim() > 1: pred_labels = pred.argmax(dim=1) else: pred_labels = pred pred_labels = pred_labels.cpu().numpy() target_labels = target.cpu().numpy() if len(pred_labels) == 0 or len(target_labels) == 0: return 0.0 f1 = f1_score(target_labels, pred_labels, average='micro', zero_division=0) if np.isnan(f1) or np.isinf(f1): logger.warning("Invalid F1 micro score, returning 0.0") return 0.0 return float(f1) except Exception as e: logger.error(f"F1 micro computation failed: {e}") return 0.0 @staticmethod def precision_recall(pred, target, average='macro'): """Compute precision and recall scores""" try: if pred.dim() > 1: pred_labels = pred.argmax(dim=1) else: pred_labels = pred pred_labels = pred_labels.cpu().numpy() target_labels = target.cpu().numpy() if len(pred_labels) == 0 or len(target_labels) == 0: return 0.0, 0.0 precision = precision_score(target_labels, pred_labels, average=average, zero_division=0) recall = recall_score(target_labels, pred_labels, average=average, zero_division=0) if np.isnan(precision) or np.isinf(precision): precision = 0.0 if np.isnan(recall) or np.isinf(recall): recall = 0.0 return float(precision), float(recall) except Exception as e: logger.error(f"Precision/recall computation failed: {e}") return 0.0, 0.0