datause-detector / relation_extraction.py
rafmacalaba's picture
change markdown
593f17e
from gliner.multitask.base import GLiNERBasePipeline
from typing import Optional, List, Union
from datasets import load_dataset, Dataset
from gliner import GLiNER
class CustomGLiNERRelationExtractor(GLiNERBasePipeline):
"""
A class to use GLiNER for relation extraction inference and evaluation.
Attributes:
device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'.
model (GLiNER): Loaded GLiNER model instance.
prompt (str): Template prompt for relation extraction.
Methods:
process_predictions(predictions):
Processes model predictions to extract the most likely labels.
prepare_texts(texts, labels):
Creates relation extraction prompts for each input text.
__call__(texts, labels, threshold=0.5):
Runs the model on the given texts and returns predicted labels.
evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1):
Evaluates the model on a dataset and computes F1 scores.
"""
prompt = "Extract relationships between entities from the text: "
def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', ner_threshold: float = 0.5, rel_threshold: float = 0.5, return_index: bool = False, prompt: Optional[str] = None):
"""
Initializes the GLiNERRelationExtractor.
Args:
model_id (str, optional): Identifier for the model to be loaded. Defaults to None.
model (GLiNER, optional): Preloaded GLiNER model. Defaults to None.
device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'.
ner_threshold (float, optional): Named Entity Recognition threshold to use. Defaults to 0.5.
rel_threshold (float, optional): Relation Extraction threshold to use. Defaults to 0.5.
prompt (str, optional): Template prompt for question-answering.
"""
# Use the provided prompt or default to the class-level prompt
prompt = prompt if prompt is not None else self.prompt
self.return_index = return_index
super().__init__(model_id=model_id, model=model, prompt=prompt, device=device)
def prepare_texts(self, texts: List[str], **kwargs):
"""
Prepares prompts for relation extraction to texts.
Args:
texts (list): List of input texts.
Returns:
list: List of formatted prompts.
"""
prompts = []
for id, text in enumerate(texts):
prompt = f"{self.prompt} \n {text}"
prompts.append(prompt)
return prompts
def prepare_source_relation(self, ner_predictions: List[dict], relations: List[str]):
relation_labels = []
for prediction in ner_predictions:
curr_labels = []
unique_entities = {ent['text'] for ent in prediction}
for relation in relations:
for ent in unique_entities:
curr_labels.append(f"{ent} <> {relation}")
relation_labels.append(curr_labels)
return relation_labels
def process_predictions(self, predictions, **kwargs):
"""
Processes predictions to extract the highest-scoring relation(s).
Args:
predictions (list): List of predictions with scores.
Returns:
list: List of predicted labels for each input.
"""
batch_predicted_relations = []
for prediction in predictions:
# Sort predictions by score in descending order
curr_relations = []
for target in prediction:
target_ent = target['text']
score = target['score']
source, relation = target['label'].split('<>')
relation = {
"source": source.strip(),
"relation": relation.strip(),
"target": target_ent.strip(),
"score": score
}
# **pull through** span info if present
if self.return_index:
relation['start'] = target.get('start', None)
relation['end'] = target.get('end', None)
curr_relations.append(relation)
batch_predicted_relations.append(curr_relations)
return batch_predicted_relations
def __call__(self, texts: Union[str, List[str]], relations: List[str]=None,
entities: List[str] = ['named entity'],
relation_labels: Optional[List[List[str]]]=None,
ner_threshold: float = 0.5,
rel_threshold: float = 0.5,
batch_size: int = 8, **kwargs):
if isinstance(texts, str):
texts = [texts]
prompts = self.prepare_texts(texts, **kwargs)
if relation_labels is None:
# ner
ner_predictions = self.model.run(texts, entities, threshold=ner_threshold, batch_size=batch_size)
#rex
relation_labels = self.prepare_source_relation(ner_predictions, relations)
predictions = self.model.run(prompts, relation_labels, threshold=rel_threshold, batch_size=batch_size)
results = self.process_predictions(predictions, **kwargs)
return results
def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None,
labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1):
"""
Evaluates the model on a specified dataset and computes evaluation metrics.
Args:
dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets).
dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored.
labels (list, optional): List of target labels to consider for relation extraction. Defaults to None (use all).
threshold (float): Confidence threshold for predictions. Defaults to 0.5.
max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples).
Returns:
dict: A dictionary containing evaluation metrics such as F1 scores.
Raises:
ValueError: If neither `dataset_id` nor `dataset` is provided.
"""
raise NotImplementedError("Currently `evaluate` method is not implemented.")