import pyrootutils root = pyrootutils.setup_root( search_from=__file__, indicator=[".project-root"], pythonpath=True, # dotenv=True, ) import argparse import logging import os from collections import defaultdict from typing import List, Optional, Sequence, Tuple, TypeVar import pandas as pd from pie_datasets import load_dataset from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans from pytorch_ie.annotations import LabeledMultiSpan from pytorch_ie.documents import ( TextBasedDocument, TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, ) from src.document.processing import replace_substrings_in_text_with_spaces logger = logging.getLogger(__name__) def multi_span_is_in_span(multi_span: LabeledMultiSpan, range_span: Tuple[int, int]) -> bool: start, end = range_span starts, ends = zip(*multi_span.slices) return start <= min(starts) and max(ends) <= end def filter_multi_spans( multi_spans: Sequence[LabeledMultiSpan], filter_span: Tuple[int, int] ) -> List[LabeledMultiSpan]: return [ span for span in multi_spans if multi_span_is_in_span(multi_span=span, range_span=filter_span) ] def shift_multi_span_slices( slices: Sequence[Tuple[int, int]], shift: int ) -> List[Tuple[int, int]]: return [(start + shift, end + shift) for start, end in slices] def construct_gold_retrievals( doc: TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, symmetric_relations: Optional[List[str]] = None, relation_label_whitelist: Optional[List[str]] = None, ) -> Optional[pd.DataFrame]: abstract_annotations = [ span for span in doc.labeled_partitions if span.label.lower().strip() == "abstract" ] if len(abstract_annotations) != 1: logger.warning( f"Expected exactly one abstract annotation, found {len(abstract_annotations)}" ) return None abstract_annotation = abstract_annotations[0] span_abstract = (abstract_annotation.start, abstract_annotation.end) span_remaining = (abstract_annotation.end, len(doc.text)) labeled_multi_spans = list(doc.labeled_multi_spans) spans_in_abstract = set( span for span in labeled_multi_spans if multi_span_is_in_span(span, span_abstract) ) spans_in_remaining = set( span for span in labeled_multi_spans if multi_span_is_in_span(span, span_remaining) ) spans_not_covered = set(labeled_multi_spans) - spans_in_abstract - spans_in_remaining if len(spans_not_covered) > 0: logger.warning( f"Found {len(spans_not_covered)} spans not covered by abstract or remaining text" ) rel_arg_and_label2other = defaultdict(list) for rel in doc.binary_relations: rel_arg_and_label2other[rel.head].append((rel.tail, rel.label)) if symmetric_relations is not None and rel.label in symmetric_relations: label_reversed = rel.label else: label_reversed = f"{rel.label}_reversed" rel_arg_and_label2other[rel.tail].append((rel.head, label_reversed)) result_rows = [] for rel in doc.binary_relations: # we check all semantically_same relations that point from (head) remaining to abstract (tail) ... if rel.label == "semantically_same": if rel.head in spans_in_abstract and rel.tail in spans_in_remaining: # ... and if the head is # candidate_query_span = rel.tail candidate_spans_with_label = rel_arg_and_label2other[rel.tail] for candidate_span, rel_label in candidate_spans_with_label: if ( relation_label_whitelist is not None and rel_label not in relation_label_whitelist ): continue result_row = { "doc_id": f"{doc.id}.remaining.{span_remaining[0]}.txt", "query_doc_id": f"{doc.id}.abstract.{span_abstract[0]}_{span_abstract[1]}.txt", "span": shift_multi_span_slices(candidate_span.slices, -span_remaining[0]), "query_span": shift_multi_span_slices(rel.head.slices, -span_abstract[0]), "ref_span": shift_multi_span_slices(rel.tail.slices, -span_remaining[0]), "type": rel_label, "label": candidate_span.label, "ref_label": rel.tail.label, } result_rows.append(result_row) if len(result_rows) > 0: return pd.DataFrame(result_rows) else: return None D_text = TypeVar("D_text", bound=TextBasedDocument) def clean_doc(doc: D_text) -> D_text: # remove xml tags. Note that we also remove the Abstract tag, in contrast to the preprocessing # pipeline (see configs/dataset/sciarg_cleaned.yaml). This is because there, the abstracts are # removed at completely. doc = replace_substrings_in_text_with_spaces( doc, substrings=[ "", "