rafmacalaba commited on
Commit
2ae65ac
·
1 Parent(s): 48711e5

add caching of preds

Browse files
Files changed (3) hide show
  1. app.py +13 -26
  2. relation_extraction.py +147 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -9,12 +9,12 @@ _CACHE_DIR = os.environ.get("CACHE_DIR", None)
9
 
10
  # Import GLiNER model and relation extractor
11
  from gliner import GLiNER
12
- from gliner.multitask import GLiNERRelationExtractor
13
 
14
  # Cache and initialize model + relation extractor
15
  DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v3"
16
  model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR)
17
- relation_extractor = GLiNERRelationExtractor(model=model)
18
 
19
  # Sample text
20
  SAMPLE_TEXT = (
@@ -45,11 +45,12 @@ def inference_pipeline(
45
  text: str,
46
  model,
47
  labels: List[str],
48
- relation_extractor: GLiNERRelationExtractor,
49
  TYPE2RELS: Dict[str, List[str]],
50
- ner_threshold: float = 0.5,
51
- re_threshold: float = 0.4,
52
  re_multi_label: bool = False,
 
53
  ) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
54
  ner_preds = model.predict_entities(
55
  text,
@@ -72,9 +73,9 @@ def inference_pipeline(
72
  relations=None,
73
  entities=None,
74
  relation_labels=slot_labels,
75
- threshold=re_threshold,
76
  multi_label=re_multi_label,
77
- distance_threshold=100,
78
  )[0]
79
 
80
  re_results[span] = preds
@@ -109,7 +110,7 @@ def prune_acronym_and_self_relations(ner_preds, rel_preds):
109
 
110
  # Highlighting function
111
 
112
- def highlight_text(text, ner_threshold, re_threshold):
113
  # Run inference
114
  ner_preds, rel_preds = inference_pipeline(
115
  text,
@@ -118,8 +119,9 @@ def highlight_text(text, ner_threshold, re_threshold):
118
  relation_extractor=relation_extractor,
119
  TYPE2RELS=TYPE2RELS,
120
  ner_threshold=ner_threshold,
121
- re_threshold=re_threshold,
122
- re_multi_label=False
 
123
  )
124
 
125
  # Post-process
@@ -150,21 +152,6 @@ def highlight_text(text, ner_threshold, re_threshold):
150
  return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
151
 
152
  # JSON output function
153
-
154
- def get_model_predictions(text, ner_threshold, re_threshold):
155
- ner_preds, rel_preds = inference_pipeline(
156
- text,
157
- model=model,
158
- labels=labels,
159
- relation_extractor=relation_extractor,
160
- TYPE2RELS=TYPE2RELS,
161
- ner_threshold=ner_threshold,
162
- re_threshold=re_threshold,
163
- re_multi_label=False
164
- )
165
- ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
166
- return json.dumps({"ner": ner_preds, "relations": rel_preds}, indent=2)
167
-
168
  def _cached_predictions(state):
169
  if not state:
170
  return "📋 No predictions yet. Click **Submit** first."
@@ -216,4 +203,4 @@ with gr.Blocks() as demo:
216
 
217
  # Launch the app
218
 
219
- demo.launch(debug=True)
 
9
 
10
  # Import GLiNER model and relation extractor
11
  from gliner import GLiNER
12
+ #from relation_extraction import CustomGLiNERRelationExtractor
13
 
14
  # Cache and initialize model + relation extractor
15
  DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v3"
16
  model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR)
17
+ relation_extractor = CustomGLiNERRelationExtractor(model=model, return_index=True)
18
 
19
  # Sample text
20
  SAMPLE_TEXT = (
 
45
  text: str,
46
  model,
47
  labels: List[str],
48
+ relation_extractor: CustomGLiNERRelationExtractor,
49
  TYPE2RELS: Dict[str, List[str]],
50
+ ner_threshold: float = 0.7,
51
+ rel_threshold: float = 0.5,
52
  re_multi_label: bool = False,
53
+ return_index: bool = False,
54
  ) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
55
  ner_preds = model.predict_entities(
56
  text,
 
73
  relations=None,
74
  entities=None,
75
  relation_labels=slot_labels,
76
+ threshold=rel_threshold,
77
  multi_label=re_multi_label,
78
+ return_index=return_index,
79
  )[0]
80
 
81
  re_results[span] = preds
 
110
 
111
  # Highlighting function
112
 
113
+ def highlight_text(text, ner_threshold, rel_threshold):
114
  # Run inference
115
  ner_preds, rel_preds = inference_pipeline(
116
  text,
 
119
  relation_extractor=relation_extractor,
120
  TYPE2RELS=TYPE2RELS,
121
  ner_threshold=ner_threshold,
122
+ rel_threshold=rel_threshold,
123
+ re_multi_label=False,
124
+ return_index=True,
125
  )
126
 
127
  # Post-process
 
152
  return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
153
 
154
  # JSON output function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def _cached_predictions(state):
156
  if not state:
157
  return "📋 No predictions yet. Click **Submit** first."
 
203
 
204
  # Launch the app
205
 
206
+ demo.launch(debug=True, inline=True)
relation_extraction.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gliner.multitask.base import GLiNERBasePipeline
2
+ from typing import Optional, List, Union
3
+ from datasets import load_dataset, Dataset
4
+
5
+ class CustomGLiNERRelationExtractor(GLiNERBasePipeline):
6
+ """
7
+ A class to use GLiNER for relation extraction inference and evaluation.
8
+
9
+ Attributes:
10
+ device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'.
11
+ model (GLiNER): Loaded GLiNER model instance.
12
+ prompt (str): Template prompt for relation extraction.
13
+
14
+ Methods:
15
+ process_predictions(predictions):
16
+ Processes model predictions to extract the most likely labels.
17
+ prepare_texts(texts, labels):
18
+ Creates relation extraction prompts for each input text.
19
+ __call__(texts, labels, threshold=0.5):
20
+ Runs the model on the given texts and returns predicted labels.
21
+ evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1):
22
+ Evaluates the model on a dataset and computes F1 scores.
23
+ """
24
+
25
+ prompt = "Extract relationships between entities from the text: "
26
+
27
+ def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', ner_threshold: float = 0.5, rel_threshold: float = 0.5, return_index: bool = False, prompt: Optional[str] = None):
28
+ """
29
+ Initializes the GLiNERRelationExtractor.
30
+
31
+ Args:
32
+ model_id (str, optional): Identifier for the model to be loaded. Defaults to None.
33
+ model (GLiNER, optional): Preloaded GLiNER model. Defaults to None.
34
+ device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'.
35
+ ner_threshold (float, optional): Named Entity Recognition threshold to use. Defaults to 0.5.
36
+ rel_threshold (float, optional): Relation Extraction threshold to use. Defaults to 0.5.
37
+ prompt (str, optional): Template prompt for question-answering.
38
+ """
39
+ # Use the provided prompt or default to the class-level prompt
40
+ prompt = prompt if prompt is not None else self.prompt
41
+ super().__init__(model_id=model_id, model=model, prompt=prompt, device=device)
42
+
43
+ def prepare_texts(self, texts: List[str], **kwargs):
44
+ """
45
+ Prepares prompts for relation extraction to texts.
46
+
47
+ Args:
48
+ texts (list): List of input texts.
49
+
50
+ Returns:
51
+ list: List of formatted prompts.
52
+ """
53
+ prompts = []
54
+
55
+ for id, text in enumerate(texts):
56
+ prompt = f"{self.prompt} \n {text}"
57
+ prompts.append(prompt)
58
+ return prompts
59
+
60
+ def prepare_source_relation(self, ner_predictions: List[dict], relations: List[str]):
61
+ relation_labels = []
62
+ for prediction in ner_predictions:
63
+ curr_labels = []
64
+ unique_entities = {ent['text'] for ent in prediction}
65
+ for relation in relations:
66
+ for ent in unique_entities:
67
+ curr_labels.append(f"{ent} <> {relation}")
68
+ relation_labels.append(curr_labels)
69
+ return relation_labels
70
+
71
+ def process_predictions(self, predictions, **kwargs):
72
+ """
73
+ Processes predictions to extract the highest-scoring relation(s).
74
+
75
+ Args:
76
+ predictions (list): List of predictions with scores.
77
+
78
+ Returns:
79
+ list: List of predicted labels for each input.
80
+ """
81
+ batch_predicted_relations = []
82
+
83
+ for prediction in predictions:
84
+ # Sort predictions by score in descending order
85
+ curr_relations = []
86
+
87
+ for target in prediction:
88
+ target_ent = target['text']
89
+ score = target['score']
90
+ source, relation = target['label'].split('<>')
91
+ relation = {
92
+ "source": source.strip(),
93
+ "relation": relation.strip(),
94
+ "target": target_ent.strip(),
95
+ "score": score
96
+ }
97
+ # **pull through** span info if present
98
+ if self.return_index:
99
+ relation['start'] = target.get('start', None)
100
+ relation['end'] = target.get('end', None)
101
+ curr_relations.append(relation)
102
+ batch_predicted_relations.append(curr_relations)
103
+
104
+ return batch_predicted_relations
105
+
106
+ def __call__(self, texts: Union[str, List[str]], relations: List[str]=None,
107
+ entities: List[str] = ['named entity'],
108
+ relation_labels: Optional[List[List[str]]]=None,
109
+ ner_threshold: float = 0.5,
110
+ rel_threshold: float = 0.5,
111
+ batch_size: int = 8, **kwargs):
112
+ if isinstance(texts, str):
113
+ texts = [texts]
114
+
115
+ prompts = self.prepare_texts(texts, **kwargs)
116
+
117
+ if relation_labels is None:
118
+ # ner
119
+ ner_predictions = self.model.run(texts, entities, threshold=ner_threshold, batch_size=batch_size)
120
+ #rex
121
+ relation_labels = self.prepare_source_relation(ner_predictions, relations)
122
+
123
+ predictions = self.model.run(prompts, relation_labels, threshold=rel_threshold, batch_size=batch_size)
124
+
125
+ results = self.process_predictions(predictions, **kwargs)
126
+
127
+ return results
128
+
129
+ def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None,
130
+ labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1):
131
+ """
132
+ Evaluates the model on a specified dataset and computes evaluation metrics.
133
+
134
+ Args:
135
+ dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets).
136
+ dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored.
137
+ labels (list, optional): List of target labels to consider for relation extraction. Defaults to None (use all).
138
+ threshold (float): Confidence threshold for predictions. Defaults to 0.5.
139
+ max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples).
140
+
141
+ Returns:
142
+ dict: A dictionary containing evaluation metrics such as F1 scores.
143
+
144
+ Raises:
145
+ ValueError: If neither `dataset_id` nor `dataset` is provided.
146
+ """
147
+ raise NotImplementedError("Currently `evaluate` method is not implemented.")
requirements.txt CHANGED
@@ -2,4 +2,5 @@ gradio
2
  gliner
3
  torch
4
  scipy
5
- scikit-learn
 
 
2
  gliner
3
  torch
4
  scipy
5
+ scikit-learn
6
+ datasets