update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
from typing import Hashable, List, Optional, Sequence, TypeVar | |
from pytorch_ie.annotations import BinaryRelation | |
H = TypeVar("H", bound=Hashable) | |
def get_connected_components( | |
relations: Sequence[BinaryRelation], | |
elements: Optional[Sequence[H]] = None, | |
link_relation_label: Optional[str] = None, | |
link_relation_relation_score_threshold: Optional[float] = None, | |
add_singletons: bool = False, | |
) -> List[List[H]]: | |
try: | |
import networkx as nx | |
except ImportError: | |
raise ImportError( | |
"NetworkX must be installed to use the SpansViaRelationMerger. " | |
"You can install NetworkX with `pip install networkx`." | |
) | |
# convert list of relations to a graph to easily calculate connected components to merge | |
g = nx.Graph() | |
link_relations = [] | |
other_relations = [] | |
elem2edge_relation = {} | |
for rel in relations: | |
if (link_relation_label is None or rel.label == link_relation_label) and ( | |
link_relation_relation_score_threshold is None | |
or rel.score >= link_relation_relation_score_threshold | |
): | |
link_relations.append(rel) | |
g.add_edge(rel.head, rel.tail) | |
elem2edge_relation[rel.head] = rel | |
elem2edge_relation[rel.tail] = rel | |
else: | |
other_relations.append(rel) | |
if add_singletons: | |
if elements is None: | |
raise ValueError("elements must be provided if add_singletons is True") | |
# add singletons to the graph | |
for elem in elements: | |
if elem not in elem2edge_relation: | |
g.add_node(elem) | |
return list(nx.connected_components(g)) | |