X-iZhang's picture
Update factual/RaTEScore/scorer.py
d61d98e verified
import torch
import json
import numpy as np
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForTokenClassification
import pandas as pd
import os
from .score import run_ner, process_embedding
from .utils import compute
DEFAULT_MATRIX_LONG = {"abnormality_abnormality": 0.4276119164393705, "abnormality_anatomy": 0.6240929990607657, "abnormality_disease": 0.0034478181112993847, "abnormality_non-abnormality": 0.5431049700217344, "abnormality_non-disease": 0.27005425386213877, "anatomy_abnormality": 0.7487824274337533, "anatomy_anatomy": 0.2856134859160784, "anatomy_disease": 0.4592143222158069, "anatomy_non-abnormality": 0.02097055139911715, "anatomy_non-disease": 0.00013736314126696204, "disease_abnormality": 0.8396510075734789, "disease_anatomy": 0.9950209388542061, "disease_disease": 0.8460555030578727, "disease_non-abnormality": 0.9820689020512646, "disease_non-disease": 0.3789136708096537, "non-abnormality_abnormality": 0.16546764653692908, "non-abnormality_anatomy": 0.018670610691852826, "non-abnormality_disease": 0.719397354576018, "non-abnormality_non-abnormality": 0.0009357166071730684, "non-abnormality_non-disease": 0.0927333564267591, "non-disease_abnormality": 0.7759420231214385, "non-disease_anatomy": 0.1839139293714062, "non-disease_disease": 0.10073046076318157, "non-disease_non-abnormality": 0.03860183811876373, "non-disease_non-disease": 0.34065681486566446, "neg_weight":0.8716553966489615}
DEFAULT_MATRIX_SHORT = {"abnormality_abnormality": 0.4070293318365468, "abnormality_anatomy": 0.6952639610605605, "abnormality_disease": 0.28342529466226446, "abnormality_non-abnormality": 0.9479148658006686, "abnormality_non-disease": 0.23875064111146294, "anatomy_abnormality": 0.5829759950441763, "anatomy_anatomy": 0.7709590751917746, "anatomy_disease": 0.0006059634829551632, "anatomy_non-abnormality": 0.794672584951181, "anatomy_non-disease": 0.27982942400798977, "disease_abnormality": 0.8840397619834857, "disease_anatomy": 0.9637659445696822, "disease_disease": 0.19018958438059513, "disease_non-abnormality": 0.6962283914800402, "disease_non-disease": 0.943727057946997, "non-abnormality_abnormality": 0.1712744286898638, "non-abnormality_anatomy": 0.4485149671497294, "non-abnormality_disease": 0.00045065329822896076, "non-abnormality_non-abnormality": 0.0007887930317199857, "non-abnormality_non-disease": 0.8555432840895761, "non-disease_abnormality": 0.9555801066212176, "non-disease_anatomy": 0.13122106162635216, "non-disease_disease": 0.6072996585919443, "non-disease_non-abnormality": 0.05650711141169969, "non-disease_non-disease": 0.3214769399791204, "neg_weight":0.3611577852354489}
class RaTEScore:
def __init__(self,
bert_model="Angelakeke/RaTE-NER-Deberta",
eval_model='FremyCompany/BioLORD-2023-C',
batch_size=1,
use_gpu=None,
visualization_path=None,
affinity_matrix="long",
):
""" RaTEScore is a novel, entity-aware metric to assess the quality of medical reports generated by AI models.
It emphasizes crucial medical entities such as diagnostic outcomes and anatomical details, and is robust
against complex medical synonyms and sensitive to negation expressions. The evaluations demonstrate that
RaTEScore aligns more closely with human preference than existing metrics.
Args:
bert_model (str, optional): Medical entity recognition modul module. Defaults to "Angelakeke/RaTE-NER-Deberta".
eval_model (str, optional): Synonym disambuation encoding module. Defaults to 'FremyCompany/BioLORD-2023-C'.
batch_size (int, optional): Batch size to choose. Defaults to 1.
use_gpu (bool, optional): If to use gpu. Defaults to True.
visualization_path (str, optional): Output the visualized files, default to save as a json file. Defaults to None.
affinity_matrix (str, optional):pre-searched type weight and can be changed due to the human rating bias.
Defaults to 'long'.
"""
# Auto select GPU
if use_gpu is None:
use_gpu = torch.cuda.is_available()
self.device = torch.device("cuda" if use_gpu else "cpu")
# load the Medical entity recognition module
self.tokenizer = AutoTokenizer.from_pretrained(bert_model)
self.model = AutoModelForTokenClassification.from_pretrained(bert_model).eval().to(self.device)
# load the Synonym disambuation module
self.eval_tokenizer = AutoTokenizer.from_pretrained(eval_model)
self.eval_model = AutoModel.from_pretrained(eval_model).eval().to(self.device)
# load the weight matrix
if isinstance(affinity_matrix, str):
# Choose the appropriate matrix based on the argument
if affinity_matrix.lower() == "long":
self.matrix_path = DEFAULT_MATRIX_LONG
elif affinity_matrix.lower() == "short":
self.matrix_path = DEFAULT_MATRIX_SHORT
else:
# Assume it's a file path
try:
with open(affinity_matrix, 'r') as f:
self.matrix_path = json.load(f)
except Exception as e:
raise ValueError(f"Failed to load affinity matrix from {affinity_matrix}: {e}")
else:
raise ValueError("affinity_matrix must be a string")
self.affinity_matrix = {(k.split('_')[0].upper(), k.split('_')[1].upper()):v for k,v in self.matrix_path.items()}
# load the label file
self.config = AutoConfig.from_pretrained(bert_model)
self.label2idx = self.config.label2id
self.idx2label = self.config.id2label
# save the input
self.batch_size = batch_size
if visualization_path:
self.visualization_path = visualization_path
if not os.path.exists(os.path.dirname(visualization_path)):
os.makedirs(os.path.dirname(visualization_path))
else:
self.visualization_path = None
def compute_score(self, candidate_list, reference_list):
'''Compute the RaTEScore for the candidate and reference reports.
Args:
candidate_list (list): list of candidate reports
reference_list (list): list of reference reports
'''
# check if candidate and reference are list
if not isinstance(candidate_list, list):
raise ValueError("candidate must be a list")
if not isinstance(reference_list, list):
raise ValueError("reference must be a list")
assert len(candidate_list) == len(reference_list), "candidate and reference must have the same length"
# check if candidate and reference are list of strings
if not all(isinstance(x, str) for x in candidate_list):
raise ValueError("candidate must be a list of strings")
gt_pairs = run_ner(reference_list, self.idx2label, self.tokenizer, self.model, self.device, self.batch_size)
pred_pairs = run_ner(candidate_list, self.idx2label, self.tokenizer, self.model, self.device, self.batch_size)
rate_score = []
for gt_pair, pred_pair in zip(gt_pairs, pred_pairs):
# process the embedding for gt
gt_embeds_word, gt_types = process_embedding(gt_pair, self.eval_tokenizer, self.eval_model, self.device)
# process the embedding for pred
pred_embeds_word, pred_types = process_embedding(pred_pair, self.eval_tokenizer, self.eval_model, self.device)
# compute the score, if the length of gt or pred is 0, the score is 0.5
if len(gt_embeds_word) == 0 or len(pred_embeds_word) == 0:
rate_score.append(0.5)
continue
precision_score = compute(gt_embeds_word, pred_embeds_word, gt_types, pred_types, self.affinity_matrix)
recall_score = compute(pred_embeds_word, gt_embeds_word, pred_types, gt_types, self.affinity_matrix)
if precision_score + recall_score == 0:
rate_score.append(0)
else:
rate_score.append(2*precision_score*recall_score/(precision_score+recall_score))
if self.visualization_path:
save_file = pd.DataFrame({
'candidate': candidate_list,
'reference': reference_list,
'candidate_entities': pred_pairs,
'reference_entities': gt_pairs,
'rate_score': rate_score
})
save_file.to_json(os.path.join(self.visualization_path, 'rate_score.json'), lines=True, orient='records')
return rate_score, pred_pairs ,gt_pairs