Spaces:
Running
Running
from gliner.multitask.base import GLiNERBasePipeline | |
from typing import Optional, List, Union | |
from datasets import load_dataset, Dataset | |
from gliner import GLiNER | |
class CustomGLiNERRelationExtractor(GLiNERBasePipeline): | |
""" | |
A class to use GLiNER for relation extraction inference and evaluation. | |
Attributes: | |
device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'. | |
model (GLiNER): Loaded GLiNER model instance. | |
prompt (str): Template prompt for relation extraction. | |
Methods: | |
process_predictions(predictions): | |
Processes model predictions to extract the most likely labels. | |
prepare_texts(texts, labels): | |
Creates relation extraction prompts for each input text. | |
__call__(texts, labels, threshold=0.5): | |
Runs the model on the given texts and returns predicted labels. | |
evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1): | |
Evaluates the model on a dataset and computes F1 scores. | |
""" | |
prompt = "Extract relationships between entities from the text: " | |
def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', ner_threshold: float = 0.5, rel_threshold: float = 0.5, return_index: bool = False, prompt: Optional[str] = None): | |
""" | |
Initializes the GLiNERRelationExtractor. | |
Args: | |
model_id (str, optional): Identifier for the model to be loaded. Defaults to None. | |
model (GLiNER, optional): Preloaded GLiNER model. Defaults to None. | |
device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'. | |
ner_threshold (float, optional): Named Entity Recognition threshold to use. Defaults to 0.5. | |
rel_threshold (float, optional): Relation Extraction threshold to use. Defaults to 0.5. | |
prompt (str, optional): Template prompt for question-answering. | |
""" | |
# Use the provided prompt or default to the class-level prompt | |
prompt = prompt if prompt is not None else self.prompt | |
self.return_index = return_index | |
super().__init__(model_id=model_id, model=model, prompt=prompt, device=device) | |
def prepare_texts(self, texts: List[str], **kwargs): | |
""" | |
Prepares prompts for relation extraction to texts. | |
Args: | |
texts (list): List of input texts. | |
Returns: | |
list: List of formatted prompts. | |
""" | |
prompts = [] | |
for id, text in enumerate(texts): | |
prompt = f"{self.prompt} \n {text}" | |
prompts.append(prompt) | |
return prompts | |
def prepare_source_relation(self, ner_predictions: List[dict], relations: List[str]): | |
relation_labels = [] | |
for prediction in ner_predictions: | |
curr_labels = [] | |
unique_entities = {ent['text'] for ent in prediction} | |
for relation in relations: | |
for ent in unique_entities: | |
curr_labels.append(f"{ent} <> {relation}") | |
relation_labels.append(curr_labels) | |
return relation_labels | |
def process_predictions(self, predictions, **kwargs): | |
""" | |
Processes predictions to extract the highest-scoring relation(s). | |
Args: | |
predictions (list): List of predictions with scores. | |
Returns: | |
list: List of predicted labels for each input. | |
""" | |
batch_predicted_relations = [] | |
for prediction in predictions: | |
# Sort predictions by score in descending order | |
curr_relations = [] | |
for target in prediction: | |
target_ent = target['text'] | |
score = target['score'] | |
source, relation = target['label'].split('<>') | |
relation = { | |
"source": source.strip(), | |
"relation": relation.strip(), | |
"target": target_ent.strip(), | |
"score": score | |
} | |
# **pull through** span info if present | |
if self.return_index: | |
relation['start'] = target.get('start', None) | |
relation['end'] = target.get('end', None) | |
curr_relations.append(relation) | |
batch_predicted_relations.append(curr_relations) | |
return batch_predicted_relations | |
def __call__(self, texts: Union[str, List[str]], relations: List[str]=None, | |
entities: List[str] = ['named entity'], | |
relation_labels: Optional[List[List[str]]]=None, | |
ner_threshold: float = 0.5, | |
rel_threshold: float = 0.5, | |
batch_size: int = 8, **kwargs): | |
if isinstance(texts, str): | |
texts = [texts] | |
prompts = self.prepare_texts(texts, **kwargs) | |
if relation_labels is None: | |
# ner | |
ner_predictions = self.model.run(texts, entities, threshold=ner_threshold, batch_size=batch_size) | |
#rex | |
relation_labels = self.prepare_source_relation(ner_predictions, relations) | |
predictions = self.model.run(prompts, relation_labels, threshold=rel_threshold, batch_size=batch_size) | |
results = self.process_predictions(predictions, **kwargs) | |
return results | |
def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None, | |
labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1): | |
""" | |
Evaluates the model on a specified dataset and computes evaluation metrics. | |
Args: | |
dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets). | |
dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored. | |
labels (list, optional): List of target labels to consider for relation extraction. Defaults to None (use all). | |
threshold (float): Confidence threshold for predictions. Defaults to 0.5. | |
max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples). | |
Returns: | |
dict: A dictionary containing evaluation metrics such as F1 scores. | |
Raises: | |
ValueError: If neither `dataset_id` nor `dataset` is provided. | |
""" | |
raise NotImplementedError("Currently `evaluate` method is not implemented.") |