|
import json |
|
from radgraph import RadGraph |
|
|
|
|
|
def compute_f1(test, retrieved): |
|
"""Computes F1 between test/retrieved report's entities or relations.""" |
|
tp = len(test & retrieved) |
|
fp = len(retrieved) - tp |
|
fn = len(test) - tp |
|
precision = tp / (tp + fp) if (tp + fp) else 0 |
|
recall = tp / (tp + fn) if (tp + fn) else 0 |
|
return 2 * precision * recall / (precision + recall) if (precision + recall) else 0 |
|
|
|
|
|
def extract_entities(output): |
|
"""Extracts set of (tokens, label) from a RadGraph output dict.""" |
|
return {(tuple(ent["tokens"]), ent["label"]) for ent in output.get("entities", {}).values()} |
|
|
|
|
|
def extract_relations(output): |
|
"""Extracts set of (src, tgt, relation) from a RadGraph output dict.""" |
|
rels = set() |
|
entities = output.get("entities", {}) |
|
for ent in entities.values(): |
|
src = (tuple(ent["tokens"]), ent["label"]) |
|
for rel_type, tgt_idx in ent.get("relations", []): |
|
tgt_ent = entities.get(tgt_idx) |
|
if tgt_ent: |
|
tgt = (tuple(tgt_ent["tokens"]), tgt_ent["label"]) |
|
rels.add((src, tgt, rel_type)) |
|
return rels |
|
|
|
|
|
def compute_radgraph_scores(refs, hyps, model_name='radgraph'): |
|
""" |
|
Computes combined RadGraph F1 scores for each pair of reference and hypothesis reports. |
|
Returns: |
|
List of floats: (entity_f1 + relation_f1)/2 per report. |
|
""" |
|
|
|
rad = RadGraph(model_type=model_name) |
|
|
|
|
|
gt_outputs = rad(refs) |
|
pred_outputs = rad(hyps) |
|
|
|
scores = [] |
|
for i in range(len(gt_outputs)): |
|
gt_out = gt_outputs[str(i)] |
|
pred_out = pred_outputs[str(i)] |
|
|
|
gt_ents = extract_entities(gt_out) |
|
pred_ents = extract_entities(pred_out) |
|
gt_rels = extract_relations(gt_out) |
|
pred_rels = extract_relations(pred_out) |
|
|
|
ent_f1 = compute_f1(gt_ents, pred_ents) |
|
rel_f1 = compute_f1(gt_rels, pred_rels) |
|
scores.append((ent_f1 + rel_f1) / 2) |
|
|
|
return scores |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
refs = [ |
|
"No evidence of pneumothorax following chest tube removal.", |
|
"There is a left pleural effusion." |
|
] |
|
hyps = [ |
|
"No pneumothorax detected.", |
|
"Left pleural effusion is present." |
|
] |
|
|
|
combined_scores = compute_radgraph_scores(refs, hyps) |
|
print(combined_scores) |
|
from radgraph import F1RadGraph |
|
f1_radgraph = F1RadGraph(model_type="radgraph", reward_level="simple") |
|
f1_scores = f1_radgraph(refs, hyps,) |
|
print(f1_scores) |