ScientificArgumentRecommender / src /metrics /connected_component_sizes.py
ArneBinder's picture
update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e verified
import logging
from collections import Counter
from typing import Dict, List, TypeVar
from pytorch_ie import Annotation, AnnotationLayer, Document, DocumentStatistic
from pytorch_ie.annotations import BinaryRelation
from src.utils.graph_utils import get_connected_components
logger = logging.getLogger(__name__)
A = TypeVar("A")
# TODO: remove when "counts" aggregation function is available in DocumentStatistic
def count_func(values: List[int]) -> Dict[int, int]:
"""Counts the number of occurrences of each value in the list."""
counter = Counter(values)
result = {k: counter[k] for k in sorted(counter)}
return result
class ConnectedComponentSizes(DocumentStatistic):
# TODO: use "counts" aggregation function when available in DocumentStatistic
DEFAULT_AGGREGATION_FUNCTIONS = ["src.metrics.connected_component_sizes.count_func"]
def __init__(self, relation_layer: str, link_relation_label: str, **kwargs) -> None:
super().__init__(**kwargs)
self.relation_layer = relation_layer
self.link_relation_label = link_relation_label
def _collect(self, document: Document) -> List[int]:
relations: AnnotationLayer[BinaryRelation] = document[self.relation_layer]
spans: AnnotationLayer[Annotation] = document[self.relation_layer].target_layer
connected_components: List[List] = get_connected_components(
elements=spans,
relations=relations,
link_relation_label=self.link_relation_label,
add_singletons=True,
)
new_component_sizes = [len(component) for component in connected_components]
return new_component_sizes