File size: 6,588 Bytes
2ae65ac
 
 
d1de289
2ae65ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593f17e
2ae65ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from gliner.multitask.base import GLiNERBasePipeline
from typing import Optional, List, Union
from datasets import load_dataset, Dataset
from gliner import GLiNER

class CustomGLiNERRelationExtractor(GLiNERBasePipeline):
    """
    A class to use GLiNER for relation extraction inference and evaluation.

    Attributes:
        device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'.
        model (GLiNER): Loaded GLiNER model instance.
        prompt (str): Template prompt for relation extraction.

    Methods:
        process_predictions(predictions):
            Processes model predictions to extract the most likely labels.
        prepare_texts(texts, labels):
            Creates relation extraction prompts for each input text.
        __call__(texts, labels, threshold=0.5):
            Runs the model on the given texts and returns predicted labels.
        evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1):
            Evaluates the model on a dataset and computes F1 scores.
    """

    prompt = "Extract relationships between entities from the text: "

    def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', ner_threshold: float = 0.5, rel_threshold: float = 0.5, return_index: bool = False, prompt: Optional[str] = None):
        """
        Initializes the GLiNERRelationExtractor.

        Args:
            model_id (str, optional): Identifier for the model to be loaded. Defaults to None.
            model (GLiNER, optional): Preloaded GLiNER model. Defaults to None.
            device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'.
            ner_threshold (float, optional): Named Entity Recognition threshold to use. Defaults to 0.5.
            rel_threshold (float, optional): Relation Extraction threshold to use. Defaults to 0.5.
            prompt (str, optional): Template prompt for question-answering.
        """
        # Use the provided prompt or default to the class-level prompt
        prompt = prompt if prompt is not None else self.prompt
        self.return_index = return_index
        super().__init__(model_id=model_id, model=model, prompt=prompt, device=device)

    def prepare_texts(self, texts: List[str], **kwargs):
        """
        Prepares prompts for relation extraction to texts.

        Args:
            texts (list): List of input texts.

        Returns:
            list: List of formatted prompts.
        """
        prompts = []

        for id, text in enumerate(texts):
            prompt = f"{self.prompt} \n {text}"
            prompts.append(prompt)
        return prompts

    def prepare_source_relation(self, ner_predictions: List[dict], relations: List[str]):
        relation_labels = []
        for prediction in ner_predictions:
            curr_labels = []
            unique_entities = {ent['text'] for ent in prediction}
            for relation in relations:
                for ent in unique_entities:
                    curr_labels.append(f"{ent} <> {relation}")
            relation_labels.append(curr_labels)
        return relation_labels

    def process_predictions(self, predictions, **kwargs):
        """
        Processes predictions to extract the highest-scoring relation(s).

        Args:
            predictions (list): List of predictions with scores.

        Returns:
            list: List of predicted labels for each input.
        """
        batch_predicted_relations = []

        for prediction in predictions:
            # Sort predictions by score in descending order
            curr_relations = []

            for target in prediction:
                target_ent = target['text']
                score = target['score']
                source, relation = target['label'].split('<>')
                relation = {
                    "source": source.strip(),
                    "relation": relation.strip(),
                    "target": target_ent.strip(),
                    "score": score
                }
                # **pull through** span info if present
                if self.return_index:
                    relation['start'] = target.get('start', None)
                    relation['end']   = target.get('end', None)
                curr_relations.append(relation)
            batch_predicted_relations.append(curr_relations)

        return batch_predicted_relations
    
    def __call__(self, texts: Union[str, List[str]], relations: List[str]=None, 
                                entities: List[str] = ['named entity'], 
                                relation_labels: Optional[List[List[str]]]=None, 
                                ner_threshold: float = 0.5,
                                rel_threshold: float = 0.5, 
                                batch_size: int = 8, **kwargs):
        if isinstance(texts, str):
            texts = [texts]
        
        prompts = self.prepare_texts(texts, **kwargs)

        if relation_labels is None:
            # ner
            ner_predictions = self.model.run(texts, entities, threshold=ner_threshold, batch_size=batch_size)
            #rex
            relation_labels = self.prepare_source_relation(ner_predictions, relations)

        predictions = self.model.run(prompts, relation_labels, threshold=rel_threshold, batch_size=batch_size)

        results = self.process_predictions(predictions, **kwargs)

        return results

    def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None, 
                    labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1):
        """
        Evaluates the model on a specified dataset and computes evaluation metrics.

        Args:
            dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets).
            dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored.
            labels (list, optional): List of target labels to consider for relation extraction. Defaults to None (use all).
            threshold (float): Confidence threshold for predictions. Defaults to 0.5.
            max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples).

        Returns:
            dict: A dictionary containing evaluation metrics such as F1 scores.
        
        Raises:
            ValueError: If neither `dataset_id` nor `dataset` is provided.
        """
        raise NotImplementedError("Currently `evaluate` method is not implemented.")