update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import logging | |
from collections import Counter | |
from typing import Dict, List, TypeVar | |
from pytorch_ie import Annotation, AnnotationLayer, Document, DocumentStatistic | |
from pytorch_ie.annotations import BinaryRelation | |
from src.utils.graph_utils import get_connected_components | |
logger = logging.getLogger(__name__) | |
A = TypeVar("A") | |
# TODO: remove when "counts" aggregation function is available in DocumentStatistic | |
def count_func(values: List[int]) -> Dict[int, int]: | |
"""Counts the number of occurrences of each value in the list.""" | |
counter = Counter(values) | |
result = {k: counter[k] for k in sorted(counter)} | |
return result | |
class ConnectedComponentSizes(DocumentStatistic): | |
# TODO: use "counts" aggregation function when available in DocumentStatistic | |
DEFAULT_AGGREGATION_FUNCTIONS = ["src.metrics.connected_component_sizes.count_func"] | |
def __init__(self, relation_layer: str, link_relation_label: str, **kwargs) -> None: | |
super().__init__(**kwargs) | |
self.relation_layer = relation_layer | |
self.link_relation_label = link_relation_label | |
def _collect(self, document: Document) -> List[int]: | |
relations: AnnotationLayer[BinaryRelation] = document[self.relation_layer] | |
spans: AnnotationLayer[Annotation] = document[self.relation_layer].target_layer | |
connected_components: List[List] = get_connected_components( | |
elements=spans, | |
relations=relations, | |
link_relation_label=self.link_relation_label, | |
add_singletons=True, | |
) | |
new_component_sizes = [len(component) for component in connected_components] | |
return new_component_sizes | |