import logging from collections import Counter from typing import Dict, List, TypeVar from pytorch_ie import Annotation, AnnotationLayer, Document, DocumentStatistic from pytorch_ie.annotations import BinaryRelation from src.utils.graph_utils import get_connected_components logger = logging.getLogger(__name__) A = TypeVar("A") # TODO: remove when "counts" aggregation function is available in DocumentStatistic def count_func(values: List[int]) -> Dict[int, int]: """Counts the number of occurrences of each value in the list.""" counter = Counter(values) result = {k: counter[k] for k in sorted(counter)} return result class ConnectedComponentSizes(DocumentStatistic): # TODO: use "counts" aggregation function when available in DocumentStatistic DEFAULT_AGGREGATION_FUNCTIONS = ["src.metrics.connected_component_sizes.count_func"] def __init__(self, relation_layer: str, link_relation_label: str, **kwargs) -> None: super().__init__(**kwargs) self.relation_layer = relation_layer self.link_relation_label = link_relation_label def _collect(self, document: Document) -> List[int]: relations: AnnotationLayer[BinaryRelation] = document[self.relation_layer] spans: AnnotationLayer[Annotation] = document[self.relation_layer].target_layer connected_components: List[List] = get_connected_components( elements=spans, relations=relations, link_relation_label=self.link_relation_label, add_singletons=True, ) new_component_sizes = [len(component) for component in connected_components] return new_component_sizes