File size: 9,031 Bytes
bad8293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d61d98e
bad8293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d61d98e
 
 
 
bad8293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import json
import numpy as np
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForTokenClassification
import pandas as pd
import os

from .score import run_ner, process_embedding
from .utils import compute


DEFAULT_MATRIX_LONG = {"abnormality_abnormality": 0.4276119164393705, "abnormality_anatomy": 0.6240929990607657, "abnormality_disease": 0.0034478181112993847, "abnormality_non-abnormality": 0.5431049700217344, "abnormality_non-disease": 0.27005425386213877, "anatomy_abnormality": 0.7487824274337533, "anatomy_anatomy": 0.2856134859160784, "anatomy_disease": 0.4592143222158069, "anatomy_non-abnormality": 0.02097055139911715, "anatomy_non-disease": 0.00013736314126696204, "disease_abnormality": 0.8396510075734789, "disease_anatomy": 0.9950209388542061, "disease_disease": 0.8460555030578727, "disease_non-abnormality": 0.9820689020512646, "disease_non-disease": 0.3789136708096537, "non-abnormality_abnormality": 0.16546764653692908, "non-abnormality_anatomy": 0.018670610691852826, "non-abnormality_disease": 0.719397354576018, "non-abnormality_non-abnormality": 0.0009357166071730684, "non-abnormality_non-disease": 0.0927333564267591, "non-disease_abnormality": 0.7759420231214385, "non-disease_anatomy": 0.1839139293714062, "non-disease_disease": 0.10073046076318157, "non-disease_non-abnormality": 0.03860183811876373, "non-disease_non-disease": 0.34065681486566446, "neg_weight":0.8716553966489615}
DEFAULT_MATRIX_SHORT = {"abnormality_abnormality": 0.4070293318365468, "abnormality_anatomy": 0.6952639610605605, "abnormality_disease": 0.28342529466226446, "abnormality_non-abnormality": 0.9479148658006686, "abnormality_non-disease": 0.23875064111146294, "anatomy_abnormality": 0.5829759950441763, "anatomy_anatomy": 0.7709590751917746, "anatomy_disease": 0.0006059634829551632, "anatomy_non-abnormality": 0.794672584951181, "anatomy_non-disease": 0.27982942400798977, "disease_abnormality": 0.8840397619834857, "disease_anatomy": 0.9637659445696822, "disease_disease": 0.19018958438059513, "disease_non-abnormality": 0.6962283914800402, "disease_non-disease": 0.943727057946997, "non-abnormality_abnormality": 0.1712744286898638, "non-abnormality_anatomy": 0.4485149671497294, "non-abnormality_disease": 0.00045065329822896076, "non-abnormality_non-abnormality": 0.0007887930317199857, "non-abnormality_non-disease": 0.8555432840895761, "non-disease_abnormality": 0.9555801066212176, "non-disease_anatomy": 0.13122106162635216, "non-disease_disease": 0.6072996585919443, "non-disease_non-abnormality": 0.05650711141169969, "non-disease_non-disease": 0.3214769399791204, "neg_weight":0.3611577852354489}


class RaTEScore:
    def __init__(self, 
                    bert_model="Angelakeke/RaTE-NER-Deberta",
                    eval_model='FremyCompany/BioLORD-2023-C',
                    batch_size=1,
                    use_gpu=None,
                    visualization_path=None,
                    affinity_matrix="long",
                ):
        """ RaTEScore is a novel, entity-aware metric to assess the quality of medical reports generated by AI models. 
        It emphasizes crucial medical entities such as diagnostic outcomes and anatomical details, and is robust 
        against complex medical synonyms and sensitive to negation expressions. The evaluations demonstrate that 
        RaTEScore aligns more closely with human preference than existing metrics.

        Args:
            bert_model (str, optional): Medical entity recognition modul module. Defaults to "Angelakeke/RaTE-NER-Deberta".
            eval_model (str, optional): Synonym disambuation encoding module. Defaults to 'FremyCompany/BioLORD-2023-C'.
            batch_size (int, optional): Batch size to choose. Defaults to 1.
            use_gpu (bool, optional): If to use gpu. Defaults to True.
            visualization_path (str, optional): Output the visualized files, default to save as a json file. Defaults to None.
            affinity_matrix (str, optional):pre-searched type weight and can be changed due to the human rating bias. 
                                          Defaults to 'long'.

        """
        
        # Auto select GPU
        if use_gpu is None:
            use_gpu = torch.cuda.is_available()
        self.device = torch.device("cuda" if use_gpu else "cpu")
        
        # load the Medical entity recognition module
        self.tokenizer = AutoTokenizer.from_pretrained(bert_model)
        self.model = AutoModelForTokenClassification.from_pretrained(bert_model).eval().to(self.device)
        
        # load the Synonym disambuation module
        self.eval_tokenizer = AutoTokenizer.from_pretrained(eval_model)
        self.eval_model = AutoModel.from_pretrained(eval_model).eval().to(self.device)
       
        # load the weight matrix
        if isinstance(affinity_matrix, str):
            # Choose the appropriate matrix based on the argument
            if affinity_matrix.lower() == "long":
                self.matrix_path = DEFAULT_MATRIX_LONG
            elif affinity_matrix.lower() == "short":
                self.matrix_path = DEFAULT_MATRIX_SHORT
            else:
                # Assume it's a file path
                try:
                    with open(affinity_matrix, 'r') as f:
                        self.matrix_path = json.load(f)
                except Exception as e:
                    raise ValueError(f"Failed to load affinity matrix from {affinity_matrix}: {e}")
        else:
            raise ValueError("affinity_matrix must be a string")
            
        self.affinity_matrix = {(k.split('_')[0].upper(), k.split('_')[1].upper()):v for k,v in self.matrix_path.items()}
        
        # load the label file
        self.config = AutoConfig.from_pretrained(bert_model)
        self.label2idx = self.config.label2id
        self.idx2label = self.config.id2label
        
        # save the input
        self.batch_size = batch_size
        
        if visualization_path:
            self.visualization_path = visualization_path
            if not os.path.exists(os.path.dirname(visualization_path)):
                os.makedirs(os.path.dirname(visualization_path))
        else:
            self.visualization_path = None
        
        
    def compute_score(self, candidate_list, reference_list):
        '''Compute the RaTEScore for the candidate and reference reports.
    
        Args:
            candidate_list (list): list of candidate reports
            reference_list (list): list of reference reports
        '''
        
        # check if candidate and reference are list
        if not isinstance(candidate_list, list):
            raise ValueError("candidate must be a list")
        if not isinstance(reference_list, list):
            raise ValueError("reference must be a list")
        
        assert len(candidate_list) == len(reference_list), "candidate and reference must have the same length"
        
        # check if candidate and reference are list of strings
        if not all(isinstance(x, str) for x in candidate_list):
            raise ValueError("candidate must be a list of strings")
        
        gt_pairs = run_ner(reference_list, self.idx2label, self.tokenizer, self.model, self.device, self.batch_size)
        pred_pairs = run_ner(candidate_list, self.idx2label, self.tokenizer, self.model, self.device, self.batch_size)
        
        rate_score = []
        
        for gt_pair, pred_pair in zip(gt_pairs, pred_pairs):
            
            # process the embedding for gt
            gt_embeds_word, gt_types = process_embedding(gt_pair, self.eval_tokenizer, self.eval_model, self.device)

            # process the embedding for pred
            pred_embeds_word, pred_types = process_embedding(pred_pair, self.eval_tokenizer, self.eval_model, self.device)
            
            # compute the score, if the length of gt or pred is 0, the score is 0.5
            if len(gt_embeds_word) == 0 or len(pred_embeds_word) == 0:
                rate_score.append(0.5)
                continue 

            precision_score = compute(gt_embeds_word, pred_embeds_word, gt_types, pred_types, self.affinity_matrix)
            recall_score = compute(pred_embeds_word, gt_embeds_word, pred_types, gt_types, self.affinity_matrix)

            if precision_score + recall_score == 0:
                rate_score.append(0)
            else:  
                rate_score.append(2*precision_score*recall_score/(precision_score+recall_score))
            
        if self.visualization_path:
            save_file = pd.DataFrame({
                'candidate': candidate_list,
                'reference': reference_list,
                'candidate_entities': pred_pairs,
                'reference_entities': gt_pairs,
                'rate_score': rate_score
            })
            save_file.to_json(os.path.join(self.visualization_path, 'rate_score.json'), lines=True, orient='records')
                
        return rate_score, pred_pairs ,gt_pairs