import pyrootutils root = pyrootutils.setup_root( search_from=__file__, indicator=[".project-root"], pythonpath=True, # dotenv=True, ) import argparse import logging import os from collections import defaultdict from typing import List, Optional, Sequence, Tuple, TypeVar import pandas as pd from pie_datasets import load_dataset from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans from pytorch_ie.annotations import LabeledMultiSpan from pytorch_ie.documents import ( TextBasedDocument, TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, ) from src.document.processing import replace_substrings_in_text_with_spaces logger = logging.getLogger(__name__) def multi_span_is_in_span(multi_span: LabeledMultiSpan, range_span: Tuple[int, int]) -> bool: start, end = range_span starts, ends = zip(*multi_span.slices) return start <= min(starts) and max(ends) <= end def filter_multi_spans( multi_spans: Sequence[LabeledMultiSpan], filter_span: Tuple[int, int] ) -> List[LabeledMultiSpan]: return [ span for span in multi_spans if multi_span_is_in_span(multi_span=span, range_span=filter_span) ] def shift_multi_span_slices( slices: Sequence[Tuple[int, int]], shift: int ) -> List[Tuple[int, int]]: return [(start + shift, end + shift) for start, end in slices] def construct_gold_retrievals( doc: TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, symmetric_relations: Optional[List[str]] = None, relation_label_whitelist: Optional[List[str]] = None, ) -> Optional[pd.DataFrame]: abstract_annotations = [ span for span in doc.labeled_partitions if span.label.lower().strip() == "abstract" ] if len(abstract_annotations) != 1: logger.warning( f"Expected exactly one abstract annotation, found {len(abstract_annotations)}" ) return None abstract_annotation = abstract_annotations[0] span_abstract = (abstract_annotation.start, abstract_annotation.end) span_remaining = (abstract_annotation.end, len(doc.text)) labeled_multi_spans = list(doc.labeled_multi_spans) spans_in_abstract = set( span for span in labeled_multi_spans if multi_span_is_in_span(span, span_abstract) ) spans_in_remaining = set( span for span in labeled_multi_spans if multi_span_is_in_span(span, span_remaining) ) spans_not_covered = set(labeled_multi_spans) - spans_in_abstract - spans_in_remaining if len(spans_not_covered) > 0: logger.warning( f"Found {len(spans_not_covered)} spans not covered by abstract or remaining text" ) rel_arg_and_label2other = defaultdict(list) for rel in doc.binary_relations: rel_arg_and_label2other[rel.head].append((rel.tail, rel.label)) if symmetric_relations is not None and rel.label in symmetric_relations: label_reversed = rel.label else: label_reversed = f"{rel.label}_reversed" rel_arg_and_label2other[rel.tail].append((rel.head, label_reversed)) result_rows = [] for rel in doc.binary_relations: # we check all semantically_same relations that point from (head) remaining to abstract (tail) ... if rel.label == "semantically_same": if rel.head in spans_in_abstract and rel.tail in spans_in_remaining: # ... and if the head is # candidate_query_span = rel.tail candidate_spans_with_label = rel_arg_and_label2other[rel.tail] for candidate_span, rel_label in candidate_spans_with_label: if ( relation_label_whitelist is not None and rel_label not in relation_label_whitelist ): continue result_row = { "doc_id": f"{doc.id}.remaining.{span_remaining[0]}.txt", "query_doc_id": f"{doc.id}.abstract.{span_abstract[0]}_{span_abstract[1]}.txt", "span": shift_multi_span_slices(candidate_span.slices, -span_remaining[0]), "query_span": shift_multi_span_slices(rel.head.slices, -span_abstract[0]), "ref_span": shift_multi_span_slices(rel.tail.slices, -span_remaining[0]), "type": rel_label, "label": candidate_span.label, "ref_label": rel.tail.label, } result_rows.append(result_row) if len(result_rows) > 0: return pd.DataFrame(result_rows) else: return None D_text = TypeVar("D_text", bound=TextBasedDocument) def clean_doc(doc: D_text) -> D_text: # remove xml tags. Note that we also remove the Abstract tag, in contrast to the preprocessing # pipeline (see configs/dataset/sciarg_cleaned.yaml). This is because there, the abstracts are # removed at completely. doc = replace_substrings_in_text_with_spaces( doc, substrings=[ "", "

", "", "

", "

", "

", "", "", "", ], ) return doc def main( data_dir: str, out_path: str, doc_id_whitelist: Optional[List[str]] = None, symmetric_relations: Optional[List[str]] = None, relation_label_whitelist: Optional[List[str]] = None, ) -> None: logger.info(f"Loading dataset from {data_dir}") sciarg_with_abstracts = load_dataset( "pie/sciarg", revision="171478ce3c13cc484be5d7c9bc8f66d7d2f1c210", base_dataset_kwargs={"data_dir": data_dir, "split_paths": None}, name="resolve_parts_of_same", split="train", ) if issubclass(sciarg_with_abstracts.document_type, BratDocument): ds_converted = sciarg_with_abstracts.to_document_type( TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions ) elif issubclass(sciarg_with_abstracts.document_type, BratDocumentWithMergedSpans): ds_converted = sciarg_with_abstracts.to_document_type( TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions ) else: raise ValueError(f"Unsupported document type {sciarg_with_abstracts.document_type}") ds_clean = ds_converted.map(clean_doc) if doc_id_whitelist is not None: num_before = len(ds_clean) ds_clean = [doc for doc in ds_clean if doc.id in doc_id_whitelist] logger.info( f"Filtered dataset from {num_before} to {len(ds_clean)} documents based on doc_id_whitelist" ) results_per_doc = [ construct_gold_retrievals( doc, symmetric_relations=symmetric_relations, relation_label_whitelist=relation_label_whitelist, ) for doc in ds_clean ] results_per_doc_not_empty = [doc for doc in results_per_doc if doc is not None] if len(results_per_doc_not_empty) > 0: results = pd.concat(results_per_doc_not_empty, ignore_index=True) # sort to make the output deterministic results = results.sort_values( by=results.columns.tolist(), ignore_index=True, key=lambda s: s.apply(str) ) os.makedirs(os.path.dirname(out_path), exist_ok=True) logger.info(f"Saving result ({len(results)}) to {out_path}") results.to_json(out_path) else: logger.warning("No results found") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Create gold retrievals for SciArg-abstracts-remaining in the same format as the retrieval results" ) parser.add_argument( "--data_dir", type=str, default="data/annotations/sciarg-with-abstracts-and-cross-section-rels", help="Path to the sciarg data directory", ) parser.add_argument( "--out_path", type=str, default="data/retrieval_results/sciarg-with-abstracts-and-cross-section-rels/gold.json", help="Path to save the results", ) parser.add_argument( "--symmetric_relations", type=str, nargs="+", default=None, help="Relations that are symmetric, i.e., if A is related to B, then B is related to A", ) parser.add_argument( "--relation_label_whitelist", type=str, nargs="+", default=None, help="Only consider relations with these labels", ) logging.basicConfig(level=logging.INFO) kwargs = vars(parser.parse_args()) main(**kwargs) logger.info("Done")