RadEval / factual /RadCliQv1 /radcliq_radgraph.py
X-iZhang's picture
initial
bad8293 verified
import json
from radgraph import RadGraph
def compute_f1(test, retrieved):
"""Computes F1 between test/retrieved report's entities or relations."""
tp = len(test & retrieved)
fp = len(retrieved) - tp
fn = len(test) - tp
precision = tp / (tp + fp) if (tp + fp) else 0
recall = tp / (tp + fn) if (tp + fn) else 0
return 2 * precision * recall / (precision + recall) if (precision + recall) else 0
def extract_entities(output):
"""Extracts set of (tokens, label) from a RadGraph output dict."""
return {(tuple(ent["tokens"]), ent["label"]) for ent in output.get("entities", {}).values()}
def extract_relations(output):
"""Extracts set of (src, tgt, relation) from a RadGraph output dict."""
rels = set()
entities = output.get("entities", {})
for ent in entities.values():
src = (tuple(ent["tokens"]), ent["label"])
for rel_type, tgt_idx in ent.get("relations", []):
tgt_ent = entities.get(tgt_idx)
if tgt_ent:
tgt = (tuple(tgt_ent["tokens"]), tgt_ent["label"])
rels.add((src, tgt, rel_type))
return rels
def compute_radgraph_scores(refs, hyps, model_name='radgraph'):
"""
Computes combined RadGraph F1 scores for each pair of reference and hypothesis reports.
Returns:
List of floats: (entity_f1 + relation_f1)/2 per report.
"""
# Initialize RadGraph model
rad = RadGraph(model_type=model_name)
# Perform inference
gt_outputs = rad(refs)
pred_outputs = rad(hyps)
scores = []
for i in range(len(gt_outputs)):
gt_out = gt_outputs[str(i)]
pred_out = pred_outputs[str(i)]
gt_ents = extract_entities(gt_out)
pred_ents = extract_entities(pred_out)
gt_rels = extract_relations(gt_out)
pred_rels = extract_relations(pred_out)
ent_f1 = compute_f1(gt_ents, pred_ents)
rel_f1 = compute_f1(gt_rels, pred_rels)
scores.append((ent_f1 + rel_f1) / 2)
return scores
if __name__ == '__main__':
# Example usage
refs = [
"No evidence of pneumothorax following chest tube removal.",
"There is a left pleural effusion."
]
hyps = [
"No pneumothorax detected.",
"Left pleural effusion is present."
]
combined_scores = compute_radgraph_scores(refs, hyps)
print(combined_scores) # e.g., [1.0, 1.0]
from radgraph import F1RadGraph
f1_radgraph = F1RadGraph(model_type="radgraph", reward_level="simple")
f1_scores = f1_radgraph(refs, hyps,)
print(f1_scores)