from typing import Hashable, List, Optional, Sequence, TypeVar from pytorch_ie.annotations import BinaryRelation H = TypeVar("H", bound=Hashable) def get_connected_components( relations: Sequence[BinaryRelation], elements: Optional[Sequence[H]] = None, link_relation_label: Optional[str] = None, link_relation_relation_score_threshold: Optional[float] = None, add_singletons: bool = False, ) -> List[List[H]]: try: import networkx as nx except ImportError: raise ImportError( "NetworkX must be installed to use the SpansViaRelationMerger. " "You can install NetworkX with `pip install networkx`." ) # convert list of relations to a graph to easily calculate connected components to merge g = nx.Graph() link_relations = [] other_relations = [] elem2edge_relation = {} for rel in relations: if (link_relation_label is None or rel.label == link_relation_label) and ( link_relation_relation_score_threshold is None or rel.score >= link_relation_relation_score_threshold ): link_relations.append(rel) g.add_edge(rel.head, rel.tail) elem2edge_relation[rel.head] = rel elem2edge_relation[rel.tail] = rel else: other_relations.append(rel) if add_singletons: if elements is None: raise ValueError("elements must be provided if add_singletons is True") # add singletons to the graph for elem in elements: if elem not in elem2edge_relation: g.add_node(elem) return list(nx.connected_components(g))