File size: 2,580 Bytes
bad8293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import json
from radgraph import RadGraph


def compute_f1(test, retrieved):
    """Computes F1 between test/retrieved report's entities or relations."""
    tp = len(test & retrieved)
    fp = len(retrieved) - tp
    fn = len(test) - tp
    precision = tp / (tp + fp) if (tp + fp) else 0
    recall = tp / (tp + fn) if (tp + fn) else 0
    return 2 * precision * recall / (precision + recall) if (precision + recall) else 0


def extract_entities(output):
    """Extracts set of (tokens, label) from a RadGraph output dict."""
    return {(tuple(ent["tokens"]), ent["label"]) for ent in output.get("entities", {}).values()}


def extract_relations(output):
    """Extracts set of (src, tgt, relation) from a RadGraph output dict."""
    rels = set()
    entities = output.get("entities", {})
    for ent in entities.values():
        src = (tuple(ent["tokens"]), ent["label"])
        for rel_type, tgt_idx in ent.get("relations", []):
            tgt_ent = entities.get(tgt_idx)
            if tgt_ent:
                tgt = (tuple(tgt_ent["tokens"]), tgt_ent["label"])
                rels.add((src, tgt, rel_type))
    return rels


def compute_radgraph_scores(refs, hyps, model_name='radgraph'):
    """
    Computes combined RadGraph F1 scores for each pair of reference and hypothesis reports.
    Returns:
      List of floats: (entity_f1 + relation_f1)/2 per report.
    """
    # Initialize RadGraph model
    rad = RadGraph(model_type=model_name)

    # Perform inference
    gt_outputs = rad(refs)
    pred_outputs = rad(hyps)

    scores = []
    for i in range(len(gt_outputs)):
        gt_out = gt_outputs[str(i)]
        pred_out = pred_outputs[str(i)]
        
        gt_ents = extract_entities(gt_out)
        pred_ents = extract_entities(pred_out)
        gt_rels = extract_relations(gt_out)
        pred_rels = extract_relations(pred_out)

        ent_f1 = compute_f1(gt_ents, pred_ents)
        rel_f1 = compute_f1(gt_rels, pred_rels)
        scores.append((ent_f1 + rel_f1) / 2)

    return scores


if __name__ == '__main__':
    # Example usage
    refs = [
        "No evidence of pneumothorax following chest tube removal.",
        "There is a left pleural effusion."
    ]
    hyps = [
        "No pneumothorax detected.",
        "Left pleural effusion is present."
    ]

    combined_scores = compute_radgraph_scores(refs, hyps)
    print(combined_scores)  # e.g., [1.0, 1.0]
    from radgraph import F1RadGraph
    f1_radgraph = F1RadGraph(model_type="radgraph", reward_level="simple")
    f1_scores = f1_radgraph(refs, hyps,)
    print(f1_scores)