File size: 1,674 Bytes
d868d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from typing import Hashable, List, Optional, Sequence, TypeVar

from pytorch_ie.annotations import BinaryRelation

H = TypeVar("H", bound=Hashable)


def get_connected_components(
    relations: Sequence[BinaryRelation],
    elements: Optional[Sequence[H]] = None,
    link_relation_label: Optional[str] = None,
    link_relation_relation_score_threshold: Optional[float] = None,
    add_singletons: bool = False,
) -> List[List[H]]:
    try:
        import networkx as nx
    except ImportError:
        raise ImportError(
            "NetworkX must be installed to use the SpansViaRelationMerger. "
            "You can install NetworkX with `pip install networkx`."
        )

    # convert list of relations to a graph to easily calculate connected components to merge
    g = nx.Graph()
    link_relations = []
    other_relations = []
    elem2edge_relation = {}
    for rel in relations:
        if (link_relation_label is None or rel.label == link_relation_label) and (
            link_relation_relation_score_threshold is None
            or rel.score >= link_relation_relation_score_threshold
        ):
            link_relations.append(rel)
            g.add_edge(rel.head, rel.tail)
            elem2edge_relation[rel.head] = rel
            elem2edge_relation[rel.tail] = rel
        else:
            other_relations.append(rel)

    if add_singletons:
        if elements is None:
            raise ValueError("elements must be provided if add_singletons is True")
        # add singletons to the graph
        for elem in elements:
            if elem not in elem2edge_relation:
                g.add_node(elem)
    return list(nx.connected_components(g))