Spaces:
Running
Running
File size: 6,588 Bytes
2ae65ac d1de289 2ae65ac 593f17e 2ae65ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from gliner.multitask.base import GLiNERBasePipeline
from typing import Optional, List, Union
from datasets import load_dataset, Dataset
from gliner import GLiNER
class CustomGLiNERRelationExtractor(GLiNERBasePipeline):
"""
A class to use GLiNER for relation extraction inference and evaluation.
Attributes:
device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'.
model (GLiNER): Loaded GLiNER model instance.
prompt (str): Template prompt for relation extraction.
Methods:
process_predictions(predictions):
Processes model predictions to extract the most likely labels.
prepare_texts(texts, labels):
Creates relation extraction prompts for each input text.
__call__(texts, labels, threshold=0.5):
Runs the model on the given texts and returns predicted labels.
evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1):
Evaluates the model on a dataset and computes F1 scores.
"""
prompt = "Extract relationships between entities from the text: "
def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', ner_threshold: float = 0.5, rel_threshold: float = 0.5, return_index: bool = False, prompt: Optional[str] = None):
"""
Initializes the GLiNERRelationExtractor.
Args:
model_id (str, optional): Identifier for the model to be loaded. Defaults to None.
model (GLiNER, optional): Preloaded GLiNER model. Defaults to None.
device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'.
ner_threshold (float, optional): Named Entity Recognition threshold to use. Defaults to 0.5.
rel_threshold (float, optional): Relation Extraction threshold to use. Defaults to 0.5.
prompt (str, optional): Template prompt for question-answering.
"""
# Use the provided prompt or default to the class-level prompt
prompt = prompt if prompt is not None else self.prompt
self.return_index = return_index
super().__init__(model_id=model_id, model=model, prompt=prompt, device=device)
def prepare_texts(self, texts: List[str], **kwargs):
"""
Prepares prompts for relation extraction to texts.
Args:
texts (list): List of input texts.
Returns:
list: List of formatted prompts.
"""
prompts = []
for id, text in enumerate(texts):
prompt = f"{self.prompt} \n {text}"
prompts.append(prompt)
return prompts
def prepare_source_relation(self, ner_predictions: List[dict], relations: List[str]):
relation_labels = []
for prediction in ner_predictions:
curr_labels = []
unique_entities = {ent['text'] for ent in prediction}
for relation in relations:
for ent in unique_entities:
curr_labels.append(f"{ent} <> {relation}")
relation_labels.append(curr_labels)
return relation_labels
def process_predictions(self, predictions, **kwargs):
"""
Processes predictions to extract the highest-scoring relation(s).
Args:
predictions (list): List of predictions with scores.
Returns:
list: List of predicted labels for each input.
"""
batch_predicted_relations = []
for prediction in predictions:
# Sort predictions by score in descending order
curr_relations = []
for target in prediction:
target_ent = target['text']
score = target['score']
source, relation = target['label'].split('<>')
relation = {
"source": source.strip(),
"relation": relation.strip(),
"target": target_ent.strip(),
"score": score
}
# **pull through** span info if present
if self.return_index:
relation['start'] = target.get('start', None)
relation['end'] = target.get('end', None)
curr_relations.append(relation)
batch_predicted_relations.append(curr_relations)
return batch_predicted_relations
def __call__(self, texts: Union[str, List[str]], relations: List[str]=None,
entities: List[str] = ['named entity'],
relation_labels: Optional[List[List[str]]]=None,
ner_threshold: float = 0.5,
rel_threshold: float = 0.5,
batch_size: int = 8, **kwargs):
if isinstance(texts, str):
texts = [texts]
prompts = self.prepare_texts(texts, **kwargs)
if relation_labels is None:
# ner
ner_predictions = self.model.run(texts, entities, threshold=ner_threshold, batch_size=batch_size)
#rex
relation_labels = self.prepare_source_relation(ner_predictions, relations)
predictions = self.model.run(prompts, relation_labels, threshold=rel_threshold, batch_size=batch_size)
results = self.process_predictions(predictions, **kwargs)
return results
def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None,
labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1):
"""
Evaluates the model on a specified dataset and computes evaluation metrics.
Args:
dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets).
dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored.
labels (list, optional): List of target labels to consider for relation extraction. Defaults to None (use all).
threshold (float): Confidence threshold for predictions. Defaults to 0.5.
max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples).
Returns:
dict: A dictionary containing evaluation metrics such as F1 scores.
Raises:
ValueError: If neither `dataset_id` nor `dataset` is provided.
"""
raise NotImplementedError("Currently `evaluate` method is not implemented.") |