File size: 4,158 Bytes
6d0498a
 
8d3a013
6d0498a
972fdf4
6d0498a
972fdf4
 
8d3a013
 
6d0498a
 
8d3a013
 
850d736
8d3a013
 
850d736
8d3a013
972fdf4
8d3a013
 
972fdf4
8d3a013
 
972fdf4
8d3a013
 
 
 
 
972fdf4
 
8d3a013
 
6d0498a
8d3a013
 
 
8a7e32b
8d3a013
 
 
 
 
 
 
972fdf4
8d3a013
 
972fdf4
8d3a013
972fdf4
8d3a013
 
 
972fdf4
8d3a013
972fdf4
 
8d3a013
 
8a7e32b
 
8d3a013
 
850d736
8d3a013
 
972fdf4
8d3a013
6d0498a
8d3a013
 
c3b0dee
8d3a013
 
972fdf4
8d3a013
972fdf4
8d3a013
 
 
972fdf4
8d3a013
972fdf4
c3b0dee
8d3a013
 
 
c3b0dee
8d3a013
 
972fdf4
8d3a013
 
 
 
972fdf4
8d3a013
 
 
 
 
 
 
 
972fdf4
8d3a013
 
 
 
972fdf4
8d3a013
972fdf4
 
8d3a013
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report, precision_score, recall_score
import numpy as np
import logging

logger = logging.getLogger(__name__)

class GraphMetrics:
    """Comprehensive evaluation metrics for graph learning"""
    
    @staticmethod
    def accuracy(pred, target):
        """Classification accuracy with validation"""
        try:
            if pred.dim() > 1:
                pred_labels = pred.argmax(dim=1)
            else:
                pred_labels = pred
                
            if pred_labels.shape != target.shape:
                raise ValueError("Prediction and target shapes don't match")
            
            correct = (pred_labels == target).float()
            accuracy = correct.mean().item()
            
            if torch.isnan(torch.tensor(accuracy)) or torch.isinf(torch.tensor(accuracy)):
                logger.warning("Invalid accuracy computed, returning 0.0")
                return 0.0
                
            return accuracy
            
        except Exception as e:
            logger.error(f"Accuracy computation failed: {e}")
            return 0.0
    
    @staticmethod 
    def f1_score_macro(pred, target):
        """Macro F1 score with robust error handling"""
        try:
            if pred.dim() > 1:
                pred_labels = pred.argmax(dim=1)
            else:
                pred_labels = pred
                
            pred_labels = pred_labels.cpu().numpy()
            target_labels = target.cpu().numpy()
            
            if len(pred_labels) == 0 or len(target_labels) == 0:
                return 0.0
            
            f1 = f1_score(target_labels, pred_labels, average='macro', zero_division=0)
            
            if np.isnan(f1) or np.isinf(f1):
                logger.warning("Invalid F1 macro score, returning 0.0")
                return 0.0
                
            return float(f1)
            
        except Exception as e:
            logger.error(f"F1 macro computation failed: {e}")
            return 0.0
    
    @staticmethod
    def f1_score_micro(pred, target):
        """Micro F1 score with robust error handling"""
        try:
            if pred.dim() > 1:
                pred_labels = pred.argmax(dim=1)
            else:
                pred_labels = pred
                
            pred_labels = pred_labels.cpu().numpy()
            target_labels = target.cpu().numpy()
            
            if len(pred_labels) == 0 or len(target_labels) == 0:
                return 0.0
            
            f1 = f1_score(target_labels, pred_labels, average='micro', zero_division=0)
            
            if np.isnan(f1) or np.isinf(f1):
                logger.warning("Invalid F1 micro score, returning 0.0")
                return 0.0
                
            return float(f1)
            
        except Exception as e:
            logger.error(f"F1 micro computation failed: {e}")
            return 0.0

    @staticmethod
    def precision_recall(pred, target, average='macro'):
        """Compute precision and recall scores"""
        try:
            if pred.dim() > 1:
                pred_labels = pred.argmax(dim=1)
            else:
                pred_labels = pred
                
            pred_labels = pred_labels.cpu().numpy()
            target_labels = target.cpu().numpy()
            
            if len(pred_labels) == 0 or len(target_labels) == 0:
                return 0.0, 0.0
            
            precision = precision_score(target_labels, pred_labels, average=average, zero_division=0)
            recall = recall_score(target_labels, pred_labels, average=average, zero_division=0)
            
            if np.isnan(precision) or np.isinf(precision):
                precision = 0.0
            if np.isnan(recall) or np.isinf(recall):
                recall = 0.0
                
            return float(precision), float(recall)
            
        except Exception as e:
            logger.error(f"Precision/recall computation failed: {e}")
            return 0.0, 0.0