File size: 1,901 Bytes
bad8293 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch
from bert_score import score
def _get_default_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
class RadEvalBERTScorer:
"""
Wrapper around bert_score for radiology reports using a custom BERT model.
"""
def __init__(self,
model_type: str = "IAMJB/RadEvalModernBERT",
num_layers: int = None,
use_fast_tokenizer: bool = True,
rescale_with_baseline: bool = False,
device: torch.device = None):
self.model_type = model_type
self.num_layers = num_layers
self.use_fast_tokenizer = use_fast_tokenizer
self.rescale_with_baseline = rescale_with_baseline
self.device = device or _get_default_device()
def score(self, refs: list[str], hyps: list[str]) -> float:
"""
Compute BERTScore F1 between reference and hypothesis texts.
Args:
refs: list of reference sentences.
hyps: list of hypothesis sentences (predictions).
Returns:
Mean F1 score as a float.
"""
# bert_score expects cands (hypotheses) first, then refs
P, R, F1 = score(
cands=hyps,
refs=refs,
model_type=self.model_type,
num_layers=self.num_layers,
use_fast_tokenizer=self.use_fast_tokenizer,
rescale_with_baseline=self.rescale_with_baseline,
device=self.device
)
# Return the mean F1 over all pairs
return F1.mean().item(), F1
if __name__ == "__main__":
# Example usage
refs = ["Chronic mild to moderate cardiomegaly and pulmonary venous hypertension."]
hyps = ["Mild left basal atelectasis; no pneumonia."]
scorer = RadiologyBERTScorer(num_layers=23)
f1_score = scorer.score(refs, hyps)
print(f"Mean F1 score: {f1_score:.4f}")
|