ArneBinder's picture
update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e verified
from typing import Hashable, List, Optional, Sequence, TypeVar
from pytorch_ie.annotations import BinaryRelation
H = TypeVar("H", bound=Hashable)
def get_connected_components(
relations: Sequence[BinaryRelation],
elements: Optional[Sequence[H]] = None,
link_relation_label: Optional[str] = None,
link_relation_relation_score_threshold: Optional[float] = None,
add_singletons: bool = False,
) -> List[List[H]]:
try:
import networkx as nx
except ImportError:
raise ImportError(
"NetworkX must be installed to use the SpansViaRelationMerger. "
"You can install NetworkX with `pip install networkx`."
)
# convert list of relations to a graph to easily calculate connected components to merge
g = nx.Graph()
link_relations = []
other_relations = []
elem2edge_relation = {}
for rel in relations:
if (link_relation_label is None or rel.label == link_relation_label) and (
link_relation_relation_score_threshold is None
or rel.score >= link_relation_relation_score_threshold
):
link_relations.append(rel)
g.add_edge(rel.head, rel.tail)
elem2edge_relation[rel.head] = rel
elem2edge_relation[rel.tail] = rel
else:
other_relations.append(rel)
if add_singletons:
if elements is None:
raise ValueError("elements must be provided if add_singletons is True")
# add singletons to the graph
for elem in elements:
if elem not in elem2edge_relation:
g.add_node(elem)
return list(nx.connected_components(g))