update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
from __future__ import annotations | |
import itertools | |
import logging | |
from collections import defaultdict | |
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union | |
from pie_modules.utils.span import have_overlap | |
from pytorch_ie import AnnotationLayer | |
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan, MultiSpan, Span | |
from pytorch_ie.core import Document | |
from pytorch_ie.core.document import Annotation, _enumerate_dependencies | |
from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations | |
from src.document.types import ( | |
RelatedRelation, | |
TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations, | |
) | |
from src.utils import distance, distance_slices | |
from src.utils.graph_utils import get_connected_components | |
from src.utils.span_utils import get_overlap_len | |
logger = logging.getLogger(__name__) | |
D = TypeVar("D", bound=Document) | |
def _remove_overlapping_entities( | |
entities: Iterable[Dict[str, Any]], relations: Iterable[Dict[str, Any]] | |
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: | |
sorted_entities = sorted(entities, key=lambda span: span["start"]) | |
entities_wo_overlap = [] | |
skipped_entities = [] | |
last_end = 0 | |
for entity_dict in sorted_entities: | |
if entity_dict["start"] < last_end: | |
skipped_entities.append(entity_dict) | |
else: | |
entities_wo_overlap.append(entity_dict) | |
last_end = entity_dict["end"] | |
if len(skipped_entities) > 0: | |
logger.warning(f"skipped overlapping entities: {skipped_entities}") | |
valid_entity_ids = set(entity_dict["_id"] for entity_dict in entities_wo_overlap) | |
valid_relations = [ | |
relation_dict | |
for relation_dict in relations | |
if relation_dict["head"] in valid_entity_ids and relation_dict["tail"] in valid_entity_ids | |
] | |
return entities_wo_overlap, valid_relations | |
def remove_overlapping_entities( | |
doc: D, | |
entity_layer_name: str = "entities", | |
relation_layer_name: str = "relations", | |
) -> D: | |
# TODO: use document.add_all_annotations_from_other() | |
document_dict = doc.asdict() | |
entities_wo_overlap, valid_relations = _remove_overlapping_entities( | |
entities=document_dict[entity_layer_name]["annotations"], | |
relations=document_dict[relation_layer_name]["annotations"], | |
) | |
document_dict[entity_layer_name] = { | |
"annotations": entities_wo_overlap, | |
"predictions": [], | |
} | |
document_dict[relation_layer_name] = { | |
"annotations": valid_relations, | |
"predictions": [], | |
} | |
new_doc = type(doc).fromdict(document_dict) | |
return new_doc | |
def remove_partitions_by_labels( | |
document: D, partition_layer: str, label_blacklist: List[str], span_layer: Optional[str] = None | |
) -> D: | |
"""Remove partitions with labels in the blacklist from a document. | |
Args: | |
document: The document to process. | |
partition_layer: The name of the partition layer. | |
label_blacklist: The list of labels to remove. | |
span_layer: The name of the span layer to remove spans from if they are not fully | |
contained in any remaining partition. Any dependent annotations will be removed as well. | |
Returns: | |
The processed document. | |
""" | |
document = document.copy() | |
p_layer: AnnotationLayer = document[partition_layer] | |
new_partitions = [] | |
for partition in p_layer.clear(): | |
if partition.label not in label_blacklist: | |
new_partitions.append(partition) | |
p_layer.extend(new_partitions) | |
if span_layer is not None: | |
result = document.copy(with_annotations=False) | |
removed_span_ids = set() | |
for span in document[span_layer]: | |
# keep spans fully contained in any partition | |
if any( | |
partition.start <= span.start and span.end <= partition.end | |
for partition in new_partitions | |
): | |
result[span_layer].append(span.copy()) | |
else: | |
removed_span_ids.add(span._id) | |
result.add_all_annotations_from_other( | |
document, | |
removed_annotations={span_layer: removed_span_ids}, | |
strict=False, | |
verbose=False, | |
) | |
document = result | |
return document | |
D_text = TypeVar("D_text", bound=Document) | |
def remove_annotations_by_label( | |
document: D, layer2label_blacklist: Dict[str, List[str]], verbose: bool = False | |
) -> D: | |
"""Remove annotations with labels in the blacklist from a document. | |
Args: | |
document: The document to process. | |
layer2label_blacklist: A mapping from layer names to lists of labels to remove. | |
verbose: Whether to print number of removed annotations. | |
Returns: | |
The processed document. | |
""" | |
result = document.copy(with_annotations=False) | |
override_annotations: Dict[str, Dict[int, Annotation]] = defaultdict(dict) | |
removed_annotations: Dict[str, Set[int]] = defaultdict(set) | |
for layer_name, labels in layer2label_blacklist.items(): | |
# process gold annotations and predictions | |
for src_layer, tgt_layer in [ | |
(document[layer_name], result[layer_name]), | |
(document[layer_name].predictions, result[layer_name].predictions), | |
]: | |
current_override_annotations = dict() | |
current_removed_annotations = set() | |
for annotation in src_layer: | |
label = getattr(annotation, "label") | |
if label is None: | |
raise ValueError( | |
f"Annotation {annotation} has no label. Please check the annotation type." | |
) | |
if label not in labels: | |
current_override_annotations[annotation._id] = annotation.copy() | |
else: | |
current_removed_annotations.add(annotation._id) | |
tgt_layer.extend(current_override_annotations.values()) | |
override_annotations[layer_name].update(current_override_annotations) | |
removed_annotations[layer_name].update(current_removed_annotations) | |
if verbose: | |
num_removed = { | |
layer_name: len(removed_ids) for layer_name, removed_ids in removed_annotations.items() | |
} | |
if len(num_removed) > 0: | |
num_total = { | |
layer_name: len(kept_ids) + num_removed[layer_name] | |
for layer_name, kept_ids in override_annotations.items() | |
} | |
logger.warning( | |
f"doc.id={document.id}: Removed {num_removed} (total: {num_total}) " | |
f"annotations with label blacklists {layer2label_blacklist}" | |
) | |
result.add_all_annotations_from_other( | |
other=document, | |
removed_annotations=removed_annotations, | |
override_annotations=override_annotations, | |
strict=False, | |
verbose=False, | |
) | |
return result | |
def replace_substrings_in_text( | |
document: D_text, replacements: Dict[str, str], enforce_same_length: bool = True | |
) -> D_text: | |
new_text = document.text | |
for old_str, new_str in replacements.items(): | |
if enforce_same_length and len(old_str) != len(new_str): | |
raise ValueError( | |
f'Replacement strings must have the same length, but got "{old_str}" -> "{new_str}"' | |
) | |
new_text = new_text.replace(old_str, new_str) | |
result_dict = document.asdict() | |
result_dict["text"] = new_text | |
result = type(document).fromdict(result_dict) | |
result.text = new_text | |
return result | |
def replace_substrings_in_text_with_spaces(document: D_text, substrings: Iterable[str]) -> D_text: | |
replacements = {substring: " " * len(substring) for substring in substrings} | |
return replace_substrings_in_text(document, replacements=replacements) | |
def relabel_annotations( | |
document: D, | |
label_mapping: Dict[str, Dict[str, str]], | |
) -> D: | |
""" | |
Replace annotation labels in a document. | |
Args: | |
document: The document to process. | |
label_mapping: A mapping from layer names to mappings from old labels to new labels. | |
Returns: | |
The processed document. | |
""" | |
dependency_ordered_fields: List[str] = [] | |
_enumerate_dependencies( | |
dependency_ordered_fields, | |
dependency_graph=document._annotation_graph, | |
nodes=document._annotation_graph["_artificial_root"], | |
) | |
result = document.copy(with_annotations=False) | |
store: Dict[int, Annotation] = {} | |
# not yet used | |
invalid_annotation_ids: Set[int] = set() | |
for field_name in dependency_ordered_fields: | |
if field_name in document._annotation_fields: | |
layer = document[field_name] | |
for is_prediction, anns in [(False, layer), (True, layer.predictions)]: | |
for ann in anns: | |
new_ann = ann.copy_with_store( | |
override_annotation_store=store, | |
invalid_annotation_ids=invalid_annotation_ids, | |
) | |
if field_name in label_mapping: | |
if ann.label in label_mapping[field_name]: | |
new_label = label_mapping[field_name][ann.label] | |
new_ann = new_ann.copy(label=new_label) | |
else: | |
raise ValueError( | |
f"Label {ann.label} not found in label mapping for {field_name}" | |
) | |
store[ann._id] = new_ann | |
target_layer = result[field_name] | |
if is_prediction: | |
target_layer.predictions.append(new_ann) | |
else: | |
target_layer.append(new_ann) | |
return result | |
DWithSpans = TypeVar("DWithSpans", bound=Document) | |
def get_start_end(span: Union[Span, MultiSpan]) -> Tuple[int, int]: | |
if isinstance(span, Span): | |
return span.start, span.end | |
elif isinstance(span, MultiSpan): | |
starts, ends = zip(*span.slices) | |
return min(starts), max(ends) | |
else: | |
raise ValueError(f"Unsupported span type: {type(span)}") | |
def _get_aligned_span_mappings( | |
gold_spans: Iterable[Span], pred_spans: Iterable[Span], distance_type: str | |
) -> Tuple[Dict[int, Span], Dict[int, Span]]: | |
old2new_pred_span = {} | |
span_id2gold_span = {} | |
for pred_span in pred_spans: | |
gold_spans_with_distance = [ | |
( | |
gold_span, | |
distance( | |
start_end=get_start_end(pred_span), | |
other_start_end=get_start_end(gold_span), | |
distance_type=distance_type, | |
), | |
) | |
for gold_span in gold_spans | |
] | |
if len(gold_spans_with_distance) == 0: | |
continue | |
closest_gold_span, min_distance = min(gold_spans_with_distance, key=lambda x: x[1]) | |
# if the closest gold span is the same as the predicted span, we don't need to align | |
if min_distance == 0.0: | |
continue | |
pred_start_end = get_start_end(pred_span) | |
closest_gold_start_end = get_start_end(closest_gold_span) | |
if have_overlap( | |
start_end=pred_start_end, | |
other_start_end=closest_gold_start_end, | |
): | |
overlap_len = get_overlap_len(pred_start_end, closest_gold_start_end) | |
l_max = max( | |
pred_start_end[1] - pred_start_end[0], | |
closest_gold_start_end[1] - closest_gold_start_end[0], | |
) | |
# if the overlap is at least half of the maximum length, we consider it a valid match for alignment | |
valid_match = overlap_len >= (l_max / 2) | |
else: | |
valid_match = False | |
if valid_match: | |
if isinstance(pred_span, Span): | |
aligned_pred_span = pred_span.copy( | |
start=closest_gold_span.start, end=closest_gold_span.end | |
) | |
elif isinstance(pred_span, MultiSpan): | |
aligned_pred_span = pred_span.copy(slices=closest_gold_span.slices) | |
else: | |
raise ValueError(f"Unsupported span type: {type(pred_span)}") | |
old2new_pred_span[pred_span._id] = aligned_pred_span | |
span_id2gold_span[pred_span._id] = closest_gold_span | |
return old2new_pred_span, span_id2gold_span | |
def get_spans2multi_spans_mapping(multi_spans: Iterable[MultiSpan]) -> Dict[Span, MultiSpan]: | |
result = {} | |
for multi_span in multi_spans: | |
for start, end in multi_span.slices: | |
span_kwargs = dict(start=start, end=end, score=multi_span.score) | |
if isinstance(multi_span, LabeledMultiSpan): | |
result[LabeledSpan(label=multi_span.label, **span_kwargs)] = multi_span | |
else: | |
result[Span(**span_kwargs)] = multi_span | |
return result | |
def align_predicted_span_annotations( | |
document: DWithSpans, | |
span_layer: str, | |
distance_type: str = "center", | |
simple_multi_span: bool = False, | |
verbose: bool = False, | |
) -> DWithSpans: | |
""" | |
Aligns predicted span annotations with the closest gold spans in a document. | |
First, calculates the distance between each predicted span and each gold span. Then, | |
for each predicted span, the gold span with the smallest distance is selected. If the | |
predicted span and the gold span have an overlap of at least half of the maximum length | |
of the two spans, the predicted span is aligned with the gold span. | |
This also works for MultiSpan annotations, where the slices of the MultiSpan are used | |
to align the predicted spans. If any of the slices is aligned with a gold slice, | |
the MultiSpan is aligned with the respective gold MultiSpan. However, this may result in | |
the predicted MultiSpan being aligned with multiple gold MultiSpans, in which case the | |
closest gold MultiSpan is selected. A simplified version of this alignment can be achieved | |
by setting `simple_multi_span=True`, which treats MultiSpan annotations as simple Spans | |
by using their maximum and minimum start and end indices. | |
Args: | |
document: The document to process. | |
span_layer: The name of the span layer. | |
distance_type: The type of distance to calculate. One of: center, inner, outer | |
simple_multi_span: Whether to treat MultiSpan annotations as simple Spans by using their | |
maximum and minimum start and end indices. | |
verbose: Whether to print debug information. | |
Returns: | |
The processed document. | |
""" | |
gold_spans = document[span_layer] | |
if len(gold_spans) == 0: | |
return document.copy() | |
pred_spans = document[span_layer].predictions | |
span_annotation_type = document.annotation_types()[span_layer] | |
if issubclass(span_annotation_type, Span) or simple_multi_span: | |
old2new_pred_span, span_id2gold_span = _get_aligned_span_mappings( | |
gold_spans=gold_spans, pred_spans=pred_spans, distance_type=distance_type | |
) | |
elif issubclass(span_annotation_type, MultiSpan): | |
# create Span objects from MultiSpan slices | |
gold_single_spans2multi_spans = get_spans2multi_spans_mapping(gold_spans) | |
pred_single_spans2multi_spans = get_spans2multi_spans_mapping(pred_spans) | |
# create the alignment mappings for the single spans | |
single_old2new_pred_span, single_span_id2gold_span = _get_aligned_span_mappings( | |
gold_spans=gold_single_spans2multi_spans.keys(), | |
pred_spans=pred_single_spans2multi_spans.keys(), | |
distance_type=distance_type, | |
) | |
# collect all Spans that are part of the same MultiSpan | |
pred_multi_span2single_spans: Dict[MultiSpan, List[Span]] = defaultdict(list) | |
for pred_span, multi_span in pred_single_spans2multi_spans.items(): | |
pred_multi_span2single_spans[multi_span].append(pred_span) | |
# create the new mappings for the MultiSpans | |
old2new_pred_span = {} | |
span_id2gold_span = {} | |
for pred_multi_span, pred_single_spans in pred_multi_span2single_spans.items(): | |
# if any of the single spans is aligned with a gold span, align the multi span | |
if any( | |
pred_single_span._id in single_old2new_pred_span | |
for pred_single_span in pred_single_spans | |
): | |
# get aligned gold multi spans | |
aligned_gold_multi_spans = set() | |
for pred_single_span in pred_single_spans: | |
if pred_single_span._id in single_old2new_pred_span: | |
aligned_gold_single_span = single_span_id2gold_span[pred_single_span._id] | |
aligned_gold_multi_span = gold_single_spans2multi_spans[ | |
aligned_gold_single_span | |
] | |
aligned_gold_multi_spans.add(aligned_gold_multi_span) | |
# calculate distances between the predicted multi span and the aligned gold multi spans | |
gold_multi_spans_with_distance = [ | |
( | |
gold_multi_span, | |
distance_slices( | |
slices=pred_multi_span.slices, | |
other_slices=gold_multi_span.slices, | |
distance_type=distance_type, | |
), | |
) | |
for gold_multi_span in aligned_gold_multi_spans | |
] | |
if len(aligned_gold_multi_spans) > 1: | |
logger.warning( | |
f"Multiple gold multi spans aligned with predicted multi span ({pred_multi_span}): " | |
f"{aligned_gold_multi_spans}" | |
) | |
# get the closest gold multi span | |
closest_gold_multi_span, min_distance = min( | |
gold_multi_spans_with_distance, key=lambda x: x[1] | |
) | |
old2new_pred_span[pred_multi_span._id] = pred_multi_span.copy( | |
slices=closest_gold_multi_span.slices | |
) | |
span_id2gold_span[pred_multi_span._id] = closest_gold_multi_span | |
else: | |
raise ValueError(f"Unsupported span annotation type: {span_annotation_type}") | |
result = document.copy(with_annotations=False) | |
# multiple predicted spans can be aligned with the same gold span, | |
# so we need to keep track of the added spans | |
added_pred_span_ids = dict() | |
for pred_span in pred_spans: | |
# just add the predicted span if it was not aligned with a gold span | |
if pred_span._id not in old2new_pred_span: | |
# if this was not added before (e.g. as aligned span), add it | |
if pred_span._id not in added_pred_span_ids: | |
keep_pred_span = pred_span.copy() | |
result[span_layer].predictions.append(keep_pred_span) | |
added_pred_span_ids[pred_span._id] = keep_pred_span | |
elif verbose: | |
print(f"Skipping duplicate predicted span. pred_span='{str(pred_span)}'") | |
else: | |
aligned_pred_span = old2new_pred_span[pred_span._id] | |
# if this was not added before (e.g. as aligned or original pred span), add it | |
if aligned_pred_span._id not in added_pred_span_ids: | |
result[span_layer].predictions.append(aligned_pred_span) | |
added_pred_span_ids[aligned_pred_span._id] = aligned_pred_span | |
elif verbose: | |
prev_pred_span = added_pred_span_ids[aligned_pred_span._id] | |
gold_span = span_id2gold_span[pred_span._id] | |
print( | |
f"Skipping duplicate aligned predicted span. aligned gold_span='{str(gold_span)}', " | |
f"prev_pred_span='{str(prev_pred_span)}', current_pred_span='{str(pred_span)}'" | |
) | |
# print("bbb") | |
result[span_layer].extend([span.copy() for span in gold_spans]) | |
# add remaining gold and predicted spans (the result, _aligned_spans, is just for debugging) | |
_aligned_spans = result.add_all_annotations_from_other( | |
document, override_annotations={span_layer: old2new_pred_span} | |
) | |
return result | |
def add_related_relations_from_binary_relations( | |
document: TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations, | |
link_relation_label: str, | |
link_partition_whitelist: Optional[List[List[str]]] = None, | |
relation_label_whitelist: Optional[List[str]] = None, | |
reversed_relation_suffix: str = "_reversed", | |
symmetric_relations: Optional[List[str]] = None, | |
) -> TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations: | |
span2partition = {} | |
for multi_span in document.labeled_multi_spans: | |
found_partition = False | |
for partition in document.labeled_partitions or [ | |
LabeledSpan(start=0, end=len(document.text), label="ALL") | |
]: | |
starts, ends = zip(*multi_span.slices) | |
if partition.start <= min(starts) and max(ends) <= partition.end: | |
span2partition[multi_span] = partition | |
found_partition = True | |
break | |
if not found_partition: | |
raise ValueError(f"No partition found for multi_span {multi_span}") | |
rel_head2rels = defaultdict(list) | |
rel_tail2rels = defaultdict(list) | |
for rel in document.binary_relations: | |
rel_head2rels[rel.head].append(rel) | |
rel_tail2rels[rel.tail].append(rel) | |
link_partition_whitelist_tuples = None | |
if link_partition_whitelist is not None: | |
link_partition_whitelist_tuples = {tuple(pair) for pair in link_partition_whitelist} | |
skipped_labels = [] | |
for link_rel in document.binary_relations: | |
if link_rel.label == link_relation_label: | |
head_partition = span2partition[link_rel.head] | |
tail_partition = span2partition[link_rel.tail] | |
if link_partition_whitelist_tuples is None or ( | |
(head_partition.label, tail_partition.label) in link_partition_whitelist_tuples | |
): | |
# link_head -> link_tail == rel_head -> rel_tail | |
for rel in rel_head2rels.get(link_rel.tail, []): | |
label = rel.label | |
if relation_label_whitelist is None or label in relation_label_whitelist: | |
new_rel = RelatedRelation( | |
head=link_rel.head, | |
tail=rel.tail, | |
link_relation=link_rel, | |
relation=rel, | |
label=label, | |
) | |
document.related_relations.append(new_rel) | |
else: | |
skipped_labels.append(label) | |
# link_head -> link_tail == rel_tail -> rel_head | |
if reversed_relation_suffix is not None: | |
for reversed_rel in rel_tail2rels.get(link_rel.tail, []): | |
label = reversed_rel.label | |
if not (symmetric_relations is not None and label in symmetric_relations): | |
label = f"{label}{reversed_relation_suffix}" | |
if relation_label_whitelist is None or label in relation_label_whitelist: | |
new_rel = RelatedRelation( | |
head=link_rel.head, | |
tail=reversed_rel.head, | |
link_relation=link_rel, | |
relation=reversed_rel, | |
label=label, | |
) | |
document.related_relations.append(new_rel) | |
else: | |
skipped_labels.append(label) | |
else: | |
logger.warning( | |
f"Skipping related relation because of partition whitelist ({[head_partition.label, tail_partition.label]}): {link_rel.resolve()}" | |
) | |
if len(skipped_labels) > 0: | |
logger.warning( | |
f"Skipped relations with labels not in whitelist: {sorted(set(skipped_labels))}" | |
) | |
return document | |
T = TypeVar("T", bound=TextDocumentWithLabeledSpansAndBinaryRelations) | |
def remove_discontinuous_spans( | |
document: T, | |
parts_of_same_relation: str, | |
verbose: bool = False, | |
) -> T: | |
""" | |
Remove discontinuous spans from a document. | |
Args: | |
document: The document to process. | |
parts_of_same_relation: The name of the relation that indicates linked spans. | |
verbose: Whether to print debug information. | |
Returns: | |
The processed document. | |
""" | |
result = document.copy() | |
spans = result.labeled_spans.clear() | |
rels = result.binary_relations.clear() | |
segment_spans = set() | |
segment_rels = set() | |
# collect all spans that are linked | |
for rel in rels: | |
if rel.label == parts_of_same_relation: | |
segment_spans.add(rel.head) | |
segment_spans.add(rel.tail) | |
segment_rels.add(rel) | |
for span in spans: | |
if span not in segment_spans: | |
result.labeled_spans.append(span) | |
other_rels_dropped = set() | |
for rel in rels: | |
if rel not in segment_rels: | |
if rel.head not in segment_spans and rel.tail not in segment_spans: | |
result.binary_relations.append(rel) | |
else: | |
other_rels_dropped.add(rel) | |
if verbose: | |
if len(segment_rels) > 0: | |
logger.warning( | |
f"doc={document.id}: Dropped {len(segment_rels)} segment rels " | |
f"and {len(other_rels_dropped)} other rels " | |
f"({round((len(document.binary_relations) - len(result.binary_relations)) * 100 / len(document.binary_relations), 1)}% " | |
f"of all relations dropped)" | |
) | |
return result | |
def close_clusters_transitively( | |
document: D, relation_layer: str, link_relation_label: str, verbose: bool = False | |
) -> D: | |
""" | |
Close clusters transitively by adding relations between all pairs of spans in the same cluster. | |
Args: | |
document: The document to process. | |
relation_layer: The name of the relation layer. | |
link_relation_label: The label of the link relation. | |
verbose: Whether to print debug information. | |
Returns: | |
The processed document. | |
""" | |
result = document.copy() | |
connected_components: List[List[Annotation]] = get_connected_components( | |
relations=result[relation_layer], | |
link_relation_label=link_relation_label, | |
add_singletons=False, | |
) | |
# detach from document | |
relations = result[relation_layer].clear() | |
# use set to speed up membership checks | |
relations_set = set(relations) | |
n_before = len(relations) | |
for cluster in connected_components: | |
for head, tail in itertools.combinations(sorted(cluster), 2): | |
rel = BinaryRelation( | |
head=head, | |
tail=tail, | |
label=link_relation_label, | |
) | |
rel_reversed = BinaryRelation( | |
head=tail, | |
tail=head, | |
label=link_relation_label, | |
) | |
if rel not in relations_set and rel_reversed not in relations_set: | |
# append to relations to keep the order | |
relations.append(rel) | |
relations_set.add(rel) | |
result[relation_layer].extend(relations) | |
if verbose: | |
num_added = len(relations) - n_before | |
if num_added > 0: | |
logger.warning( | |
f"doc.id={document.id}: added {num_added} relations to {relation_layer} layer" | |
) | |
return result | |
def get_ancestor_layers(children: Dict[str, Set[str]], layer: str) -> Set[str]: | |
""" | |
Get all ancestor layers of a given layer in the dependency graph. | |
Args: | |
children: A mapping from layers to their children layers. | |
layer: The layer for which to find ancestors. | |
Returns: | |
A set of ancestor layers. | |
""" | |
ancestors = set() | |
def _get_ancestors(current_layer: str): | |
for parent_layer, child_layers in children.items(): | |
if current_layer in child_layers: | |
ancestors.add(parent_layer) | |
_get_ancestors(parent_layer) | |
_get_ancestors(layer) | |
# drop the _artificial_root | |
ancestors.discard("_artificial_root") | |
return ancestors | |
def remove_binary_relations_by_partition_labels( | |
document: D, | |
partition_layer: str, | |
relation_layer: str, | |
partition_label_whitelist: Optional[List[List[str]]] = None, | |
partition_label_blacklist: Optional[List[List[str]]] = None, | |
verbose: bool = False, | |
) -> D: | |
""" | |
Remove binary relations that are not between partitions with labels in the whitelist or | |
that are in the blacklist. | |
Args: | |
document: The document to process. | |
partition_layer: The name of the partition layer. | |
relation_layer: The name of the relation layer. | |
partition_label_whitelist: The list of head-tail label pairs to keep. | |
partition_label_blacklist: The list of head-tail label pairs to remove. | |
verbose: Whether to print the removed relations to console. | |
Returns: | |
The processed document. | |
""" | |
result = document.copy() | |
relation_annotation_layer = result[relation_layer] | |
# get all layers that target the relation layer | |
relation_dependent_layers = get_ancestor_layers( | |
children=result._annotation_graph, layer=relation_layer | |
) | |
# clear all layers that depend on the relation layer | |
for layer_name in relation_dependent_layers: | |
dependent_layer = result[layer_name] | |
gold_anns_cleared = dependent_layer.clear() | |
pred_anns_cleared = dependent_layer.predictions.clear() | |
if len(gold_anns_cleared) > 0 or len(pred_anns_cleared) > 0: | |
if verbose: | |
logger.warning( | |
f"doc.id={document.id}: Cleared {len(gold_anns_cleared)} gold and " | |
f"{len(pred_anns_cleared)} predicted annotations from layer {layer_name} " | |
f"because it depends on the relation layer {relation_layer}." | |
) | |
span2partition = {} | |
span_layer: AnnotationLayer | |
for span_layer in relation_annotation_layer.target_layers.values(): | |
for span in list(span_layer) + list(span_layer.predictions): | |
if isinstance(span, Span): | |
span_start, span_end = span.start, span.end | |
elif isinstance(span, MultiSpan): | |
span_start, span_end = min(start for start, _ in span.slices), max( | |
end for _, end in span.slices | |
) | |
else: | |
raise ValueError(f"Unsupported span type: {type(span)}") | |
found_partition = False | |
for partition in result[partition_layer]: | |
if partition.start <= span_start and span_end <= partition.end: | |
span2partition[span] = partition | |
found_partition = True | |
break | |
if not found_partition: | |
raise ValueError(f"No partition found for span {span}") | |
if partition_label_whitelist is not None: | |
partition_label_whitelist_tuples = [tuple(pair) for pair in partition_label_whitelist] | |
else: | |
partition_label_whitelist_tuples = None | |
if partition_label_blacklist is not None: | |
partition_label_blacklist_tuples = [tuple(pair) for pair in partition_label_blacklist] | |
else: | |
partition_label_blacklist_tuples = None | |
for relation_base_layer in [relation_annotation_layer, relation_annotation_layer.predictions]: | |
# get all relations and clear the layer | |
relations = relation_base_layer.clear() | |
for relation in relations: | |
head_partition = span2partition[relation.head] | |
tail_partition = span2partition[relation.tail] | |
pair = (head_partition.label, tail_partition.label) | |
if ( | |
partition_label_whitelist_tuples is None | |
or pair in partition_label_whitelist_tuples | |
) and ( | |
partition_label_blacklist_tuples is None | |
or pair not in partition_label_blacklist_tuples | |
): | |
relation_base_layer.append(relation) | |
else: | |
if verbose: | |
logger.info( | |
f"Removing relation {relation} because its partitions " | |
f"({pair}) are not in the whitelist or are in the blacklist." | |
) | |
return result | |