update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import pyrootutils | |
root = pyrootutils.setup_root( | |
search_from=__file__, | |
indicator=[".project-root"], | |
pythonpath=True, | |
dotenv=True, | |
) | |
import argparse | |
import logging | |
import os | |
from typing import Callable, Dict, List, Optional, Tuple | |
import pandas as pd | |
from pie_datasets import Dataset, DatasetDict | |
from pytorch_ie import Annotation | |
from pytorch_ie.annotations import BinaryRelation, MultiSpan, Span | |
from src.demo.retriever_utils import ( | |
retrieve_all_relevant_spans, | |
retrieve_all_relevant_spans_for_all_documents, | |
retrieve_all_similar_spans, | |
retrieve_all_similar_spans_for_all_documents, | |
retrieve_relevant_spans, | |
retrieve_similar_spans, | |
) | |
from src.document.types import ( | |
RelatedRelation, | |
TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations, | |
) | |
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations | |
logger = logging.getLogger(__name__) | |
def get_original_doc_id_and_offsets(doc_id: str) -> Tuple[str, int, Optional[int]]: | |
original_doc_id, middle, start_end, ext = doc_id.split(".") | |
if middle == "remaining": | |
return original_doc_id, int(start_end), None | |
elif middle == "abstract": | |
start, end = start_end.split("_") | |
return original_doc_id, int(start), int(end) | |
else: | |
raise ValueError(f"unexpected doc_id format: {doc_id}") | |
def add_base_annotations( | |
documents: Dict[ | |
str, TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations | |
], | |
retrieved_doc_ids: List[str], | |
retriever: DocumentAwareSpanRetrieverWithRelations, | |
) -> Dict[Tuple[str, Annotation], Tuple[str, Annotation]]: | |
# (retrieved_doc_id, retrieved_annotation) -> (original_doc_id, original_annotation) | |
annotation_mapping = {} | |
for retrieved_doc_id in retrieved_doc_ids: | |
pie_doc = retriever.get_document(retrieved_doc_id).metadata["pie_document"].copy() | |
original_doc_id, offset, _ = get_original_doc_id_and_offsets(retrieved_doc_id) | |
document = documents[original_doc_id] | |
span_mapping = {} | |
for span in pie_doc.labeled_multi_spans.predictions: | |
if isinstance(span, MultiSpan): | |
new_span = span.copy( | |
slices=[(start + offset, end + offset) for start, end in span.slices] | |
) | |
elif isinstance(span, Span): | |
new_span = span.copy(start=span.start + offset, end=span.end + offset) | |
else: | |
raise ValueError(f"unexpected span type: {span}") | |
span_mapping[span] = new_span | |
document.labeled_multi_spans.predictions.extend(span_mapping.values()) | |
for relation in pie_doc.binary_relations.predictions: | |
new_relation = relation.copy( | |
head=span_mapping[relation.head], tail=span_mapping[relation.tail] | |
) | |
document.binary_relations.predictions.append(new_relation) | |
for old_ann, new_ann in span_mapping.items(): | |
annotation_mapping[(retrieved_doc_id, old_ann)] = (original_doc_id, new_ann) | |
return annotation_mapping | |
def get_doc_and_span_id2annotation_mapping( | |
span_ids: pd.Series, | |
doc_ids: pd.Series, | |
retriever: DocumentAwareSpanRetrieverWithRelations, | |
base_annotation_mapping: Dict[Tuple[str, Annotation], Tuple[str, Annotation]], | |
) -> Dict[Tuple[str, str], Tuple[str, Annotation]]: | |
if len(doc_ids) != len(span_ids): | |
raise ValueError("doc_ids and span_ids must have the same length") | |
doc_and_span_ids = zip(doc_ids.tolist(), span_ids.tolist()) | |
return { | |
(doc_id, span_id): base_annotation_mapping[(doc_id, retriever.get_span_by_id(span_id))] | |
for doc_id, span_id in set(doc_and_span_ids) | |
} | |
def add_result_to_gold_data( | |
result: pd.DataFrame, | |
gold_dataset_dir: str, | |
dataset_out_dir: str, | |
retriever: DocumentAwareSpanRetrieverWithRelations, | |
split: Optional[str] = None, | |
link_relation_label: str = "semantically_same", | |
reversed_relation_suffix: str = "_reversed", | |
): | |
if not os.path.exists(gold_dataset_dir): | |
raise ValueError(f"gold dataset directory does not exist: {gold_dataset_dir}") | |
dataset_dict = DatasetDict.from_json(data_dir=gold_dataset_dir) | |
if split is None and len(dataset_dict) == 1: | |
split = list(dataset_dict.keys())[0] | |
if split is None: | |
raise ValueError("need to provide split name to add results to gold dataset") | |
dataset = dataset_dict[split] | |
doc_id2doc = {doc.id: doc for doc in dataset} | |
retriever_doc_ids = ( | |
result["doc_id"].unique().tolist() + result["query_doc_id"].unique().tolist() | |
) | |
base_annotation_mapping = add_base_annotations( | |
documents=doc_id2doc, retrieved_doc_ids=retriever_doc_ids, retriever=retriever | |
) | |
# (retriever_doc_id, retriever_span_id) -> (original_doc_id, original_span) | |
doc_and_span_id2annotation = {} | |
doc_and_span_id2annotation.update( | |
get_doc_and_span_id2annotation_mapping( | |
span_ids=result["span_id"], | |
doc_ids=result["doc_id"], | |
retriever=retriever, | |
base_annotation_mapping=base_annotation_mapping, | |
) | |
) | |
# only when we process relevant span retriever results, we have a ref_span_id | |
# (for similar span retriever results, we only have query_span_id) | |
if "ref_span_id" in result.columns: | |
doc_and_span_id2annotation.update( | |
get_doc_and_span_id2annotation_mapping( | |
span_ids=result["ref_span_id"], | |
doc_ids=result["doc_id"], | |
retriever=retriever, | |
base_annotation_mapping=base_annotation_mapping, | |
) | |
) | |
doc_and_span_id2annotation.update( | |
get_doc_and_span_id2annotation_mapping( | |
span_ids=result["query_span_id"], | |
doc_ids=result["query_doc_id"], | |
retriever=retriever, | |
base_annotation_mapping=base_annotation_mapping, | |
) | |
) | |
doc_id2head_tail2relation = {} | |
for doc_id, doc in doc_id2doc.items(): | |
head_and_tail2relation = {} | |
for relation in doc.binary_relations.predictions: | |
head_and_tail2relation[(relation.head, relation.tail)] = relation | |
doc_id2head_tail2relation[doc_id] = head_and_tail2relation | |
for row in result.itertuples(): | |
query_doc_id, query_span = doc_and_span_id2annotation[ | |
(row.query_doc_id, row.query_span_id) | |
] | |
doc_id, span = doc_and_span_id2annotation[(row.doc_id, row.span_id)] | |
if doc_id != query_doc_id: | |
raise ValueError("doc_id and query_doc_id must be the same") | |
doc = doc_id2doc[doc_id] | |
# if we have a reference span, we need to construct the related relation | |
if hasattr(row, "ref_span_id"): | |
doc_id2, ref_span = doc_and_span_id2annotation[(row.doc_id, row.ref_span_id)] | |
if doc_id != doc_id2: | |
raise ValueError("doc_id and ref_doc_id must be the same") | |
# create a link relation between the query span and the reference span | |
link_rel = BinaryRelation( | |
head=query_span, tail=ref_span, label=link_relation_label, score=row.sim_score | |
) | |
doc.binary_relations.predictions.append(link_rel) | |
head_and_tail2relation = doc_id2head_tail2relation[doc_id] | |
related_rel_label = row.type | |
if related_rel_label.endswith(reversed_relation_suffix): | |
base_rel = head_and_tail2relation[(span, ref_span)] | |
else: | |
base_rel = head_and_tail2relation[(ref_span, span)] | |
related_rel = RelatedRelation( | |
head=query_span, | |
tail=span, | |
link_relation=link_rel, | |
relation=base_rel, | |
label=related_rel_label, | |
score=link_rel.score * base_rel.score, | |
) | |
doc.related_relations.predictions.append(related_rel) | |
# otherwise, we just ... | |
else: | |
# ... create a link relation between the query span and returned span | |
link_rel = BinaryRelation( | |
head=query_span, tail=span, label=link_relation_label, score=row.sim_score | |
) | |
doc.binary_relations.predictions.append(link_rel) | |
dataset = Dataset.from_documents(list(doc_id2doc.values())) | |
dataset_dict = DatasetDict({split: dataset}) | |
if not os.path.exists(dataset_out_dir): | |
os.makedirs(dataset_out_dir, exist_ok=True) | |
dataset_dict.to_json(dataset_out_dir, mode="w") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-c", | |
"--config_path", | |
type=str, | |
default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml", | |
) | |
parser.add_argument( | |
"--data_path", | |
type=str, | |
required=True, | |
help="Path to a zip or directory containing a retriever dump.", | |
) | |
parser.add_argument("-k", "--top_k", type=int, default=10) | |
parser.add_argument("-t", "--threshold", type=float, default=0.95) | |
parser.add_argument( | |
"-o", | |
"--output_path", | |
type=str, | |
required=True, | |
) | |
parser.add_argument( | |
"-v", | |
"--variant", | |
choices=["relevant", "similar"], | |
default="relevant", | |
help="Variant of the retriever to use: 'relevant' for relevant spans, 'similar' for similar spans.", | |
) | |
parser.add_argument( | |
"--query_doc_id", | |
type=str, | |
default=None, | |
help="If provided, retrieve all spans for only this query document.", | |
) | |
parser.add_argument( | |
"--query_span_id", | |
type=str, | |
default=None, | |
help="If provided, retrieve all spans for only this query span.", | |
) | |
parser.add_argument( | |
"--doc_id_whitelist", | |
type=str, | |
nargs="+", | |
default=None, | |
help="If provided, only consider documents with these IDs.", | |
) | |
parser.add_argument( | |
"--doc_id_blacklist", | |
type=str, | |
nargs="+", | |
default=None, | |
help="If provided, ignore documents with these IDs.", | |
) | |
parser.add_argument( | |
"--query_target_doc_id_pairs", | |
type=str, | |
nargs="+", | |
default=None, | |
help="One or more pairs of query and target document IDs " | |
'(each separated by ":") to retrieve spans for. If provided, ' | |
"--query_doc_id and --query_span_id are ignored.", | |
) | |
parser.add_argument( | |
"--gold_dataset_dir", | |
type=str, | |
default=None, | |
help="If provided, add the spans and base relations from the retriever data as well " | |
"as the related relations to the gold dataset.", | |
) | |
parser.add_argument( | |
"--dataset_out_dir", | |
type=str, | |
default=None, | |
help="If provided, save the enriched gold dataset to this directory.", | |
) | |
args = parser.parse_args() | |
logging.basicConfig( | |
format="%(asctime)s %(levelname)-8s %(message)s", | |
level=logging.INFO, | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
if not args.output_path.endswith(".json"): | |
raise ValueError("only support json output") | |
logger.info(f"instantiating retriever from {args.config_path}...") | |
retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file( | |
args.config_path | |
) | |
logger.info(f"loading data from {args.data_path}...") | |
retriever.load_from_disc(args.data_path) | |
methods: Dict[str, Callable] | |
if args.variant == "relevant": | |
logger.info("using *relevant* span retriever methods") | |
methods = { | |
"retrieve_all_spans": retrieve_all_relevant_spans, | |
"retrieve_spans": retrieve_relevant_spans, | |
"retrieve_all_spans_for_all_documents": retrieve_all_relevant_spans_for_all_documents, | |
} | |
elif args.variant == "similar": | |
logger.info("using *similar* span retriever methods") | |
methods = { | |
"retrieve_all_spans": retrieve_all_similar_spans, | |
"retrieve_spans": retrieve_similar_spans, | |
"retrieve_all_spans_for_all_documents": retrieve_all_similar_spans_for_all_documents, | |
} | |
else: | |
raise ValueError(f"unknown method: {args.variant}") | |
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold} | |
if args.doc_id_whitelist is not None: | |
search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist | |
if args.doc_id_blacklist is not None: | |
search_kwargs["doc_id_blacklist"] = args.doc_id_blacklist | |
logger.info(f"use search_kwargs: {search_kwargs}") | |
if args.query_target_doc_id_pairs is not None: | |
all_spans_for_all_documents = None | |
for doc_id_pair in args.query_target_doc_id_pairs: | |
query_doc_id, target_doc_id = doc_id_pair.split(":") | |
current_result = methods["retrieve_all_spans"]( | |
retriever=retriever, | |
query_doc_id=query_doc_id, | |
doc_id_whitelist=[target_doc_id], | |
**search_kwargs, | |
) | |
if current_result is None: | |
logger.warning( | |
f"no relevant spans found for query_doc_id={query_doc_id} and " | |
f"target_doc_id={target_doc_id}" | |
) | |
continue | |
logger.info( | |
f"retrieved {len(current_result)} spans for query_doc_id={query_doc_id} " | |
f"and target_doc_id={target_doc_id}" | |
) | |
current_result["query_doc_id"] = query_doc_id | |
if all_spans_for_all_documents is None: | |
all_spans_for_all_documents = current_result | |
else: | |
all_spans_for_all_documents = pd.concat( | |
[all_spans_for_all_documents, current_result], ignore_index=True | |
) | |
elif args.query_span_id is not None: | |
logger.warning(f"retrieving results for single span: {args.query_span_id}") | |
all_spans_for_all_documents = methods["retrieve_spans"]( | |
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs | |
) | |
elif args.query_doc_id is not None: | |
logger.warning(f"retrieving results for single document: {args.query_doc_id}") | |
all_spans_for_all_documents = methods["retrieve_all_spans"]( | |
retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs | |
) | |
else: | |
all_spans_for_all_documents = methods["retrieve_all_spans_for_all_documents"]( | |
retriever=retriever, **search_kwargs | |
) | |
if all_spans_for_all_documents is None: | |
logger.warning("no relevant spans found in any document") | |
exit(0) | |
logger.info(f"dumping results ({len(all_spans_for_all_documents)}) to {args.output_path}...") | |
os.makedirs(os.path.dirname(args.output_path), exist_ok=True) | |
all_spans_for_all_documents.to_json(args.output_path) | |
if args.gold_dataset_dir is not None: | |
logger.info( | |
f"reading gold data from {args.gold_dataset_dir} and adding results as predictions ..." | |
) | |
if args.dataset_out_dir is None: | |
raise ValueError("need to provide --dataset_out_dir to save the enriched dataset") | |
add_result_to_gold_data( | |
all_spans_for_all_documents, | |
gold_dataset_dir=args.gold_dataset_dir, | |
dataset_out_dir=args.dataset_out_dir, | |
retriever=retriever, | |
) | |
logger.info("done") | |