Spaces:
Running
Running
Commit
·
2ae65ac
1
Parent(s):
48711e5
add caching of preds
Browse files- app.py +13 -26
- relation_extraction.py +147 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -9,12 +9,12 @@ _CACHE_DIR = os.environ.get("CACHE_DIR", None)
|
|
9 |
|
10 |
# Import GLiNER model and relation extractor
|
11 |
from gliner import GLiNER
|
12 |
-
from
|
13 |
|
14 |
# Cache and initialize model + relation extractor
|
15 |
DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v3"
|
16 |
model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR)
|
17 |
-
relation_extractor =
|
18 |
|
19 |
# Sample text
|
20 |
SAMPLE_TEXT = (
|
@@ -45,11 +45,12 @@ def inference_pipeline(
|
|
45 |
text: str,
|
46 |
model,
|
47 |
labels: List[str],
|
48 |
-
relation_extractor:
|
49 |
TYPE2RELS: Dict[str, List[str]],
|
50 |
-
ner_threshold: float = 0.
|
51 |
-
|
52 |
re_multi_label: bool = False,
|
|
|
53 |
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
|
54 |
ner_preds = model.predict_entities(
|
55 |
text,
|
@@ -72,9 +73,9 @@ def inference_pipeline(
|
|
72 |
relations=None,
|
73 |
entities=None,
|
74 |
relation_labels=slot_labels,
|
75 |
-
threshold=
|
76 |
multi_label=re_multi_label,
|
77 |
-
|
78 |
)[0]
|
79 |
|
80 |
re_results[span] = preds
|
@@ -109,7 +110,7 @@ def prune_acronym_and_self_relations(ner_preds, rel_preds):
|
|
109 |
|
110 |
# Highlighting function
|
111 |
|
112 |
-
def highlight_text(text, ner_threshold,
|
113 |
# Run inference
|
114 |
ner_preds, rel_preds = inference_pipeline(
|
115 |
text,
|
@@ -118,8 +119,9 @@ def highlight_text(text, ner_threshold, re_threshold):
|
|
118 |
relation_extractor=relation_extractor,
|
119 |
TYPE2RELS=TYPE2RELS,
|
120 |
ner_threshold=ner_threshold,
|
121 |
-
|
122 |
-
re_multi_label=False
|
|
|
123 |
)
|
124 |
|
125 |
# Post-process
|
@@ -150,21 +152,6 @@ def highlight_text(text, ner_threshold, re_threshold):
|
|
150 |
return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
|
151 |
|
152 |
# JSON output function
|
153 |
-
|
154 |
-
def get_model_predictions(text, ner_threshold, re_threshold):
|
155 |
-
ner_preds, rel_preds = inference_pipeline(
|
156 |
-
text,
|
157 |
-
model=model,
|
158 |
-
labels=labels,
|
159 |
-
relation_extractor=relation_extractor,
|
160 |
-
TYPE2RELS=TYPE2RELS,
|
161 |
-
ner_threshold=ner_threshold,
|
162 |
-
re_threshold=re_threshold,
|
163 |
-
re_multi_label=False
|
164 |
-
)
|
165 |
-
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
|
166 |
-
return json.dumps({"ner": ner_preds, "relations": rel_preds}, indent=2)
|
167 |
-
|
168 |
def _cached_predictions(state):
|
169 |
if not state:
|
170 |
return "📋 No predictions yet. Click **Submit** first."
|
@@ -216,4 +203,4 @@ with gr.Blocks() as demo:
|
|
216 |
|
217 |
# Launch the app
|
218 |
|
219 |
-
demo.launch(debug=True)
|
|
|
9 |
|
10 |
# Import GLiNER model and relation extractor
|
11 |
from gliner import GLiNER
|
12 |
+
#from relation_extraction import CustomGLiNERRelationExtractor
|
13 |
|
14 |
# Cache and initialize model + relation extractor
|
15 |
DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v3"
|
16 |
model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR)
|
17 |
+
relation_extractor = CustomGLiNERRelationExtractor(model=model, return_index=True)
|
18 |
|
19 |
# Sample text
|
20 |
SAMPLE_TEXT = (
|
|
|
45 |
text: str,
|
46 |
model,
|
47 |
labels: List[str],
|
48 |
+
relation_extractor: CustomGLiNERRelationExtractor,
|
49 |
TYPE2RELS: Dict[str, List[str]],
|
50 |
+
ner_threshold: float = 0.7,
|
51 |
+
rel_threshold: float = 0.5,
|
52 |
re_multi_label: bool = False,
|
53 |
+
return_index: bool = False,
|
54 |
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
|
55 |
ner_preds = model.predict_entities(
|
56 |
text,
|
|
|
73 |
relations=None,
|
74 |
entities=None,
|
75 |
relation_labels=slot_labels,
|
76 |
+
threshold=rel_threshold,
|
77 |
multi_label=re_multi_label,
|
78 |
+
return_index=return_index,
|
79 |
)[0]
|
80 |
|
81 |
re_results[span] = preds
|
|
|
110 |
|
111 |
# Highlighting function
|
112 |
|
113 |
+
def highlight_text(text, ner_threshold, rel_threshold):
|
114 |
# Run inference
|
115 |
ner_preds, rel_preds = inference_pipeline(
|
116 |
text,
|
|
|
119 |
relation_extractor=relation_extractor,
|
120 |
TYPE2RELS=TYPE2RELS,
|
121 |
ner_threshold=ner_threshold,
|
122 |
+
rel_threshold=rel_threshold,
|
123 |
+
re_multi_label=False,
|
124 |
+
return_index=True,
|
125 |
)
|
126 |
|
127 |
# Post-process
|
|
|
152 |
return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
|
153 |
|
154 |
# JSON output function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
def _cached_predictions(state):
|
156 |
if not state:
|
157 |
return "📋 No predictions yet. Click **Submit** first."
|
|
|
203 |
|
204 |
# Launch the app
|
205 |
|
206 |
+
demo.launch(debug=True, inline=True)
|
relation_extraction.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gliner.multitask.base import GLiNERBasePipeline
|
2 |
+
from typing import Optional, List, Union
|
3 |
+
from datasets import load_dataset, Dataset
|
4 |
+
|
5 |
+
class CustomGLiNERRelationExtractor(GLiNERBasePipeline):
|
6 |
+
"""
|
7 |
+
A class to use GLiNER for relation extraction inference and evaluation.
|
8 |
+
|
9 |
+
Attributes:
|
10 |
+
device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'.
|
11 |
+
model (GLiNER): Loaded GLiNER model instance.
|
12 |
+
prompt (str): Template prompt for relation extraction.
|
13 |
+
|
14 |
+
Methods:
|
15 |
+
process_predictions(predictions):
|
16 |
+
Processes model predictions to extract the most likely labels.
|
17 |
+
prepare_texts(texts, labels):
|
18 |
+
Creates relation extraction prompts for each input text.
|
19 |
+
__call__(texts, labels, threshold=0.5):
|
20 |
+
Runs the model on the given texts and returns predicted labels.
|
21 |
+
evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1):
|
22 |
+
Evaluates the model on a dataset and computes F1 scores.
|
23 |
+
"""
|
24 |
+
|
25 |
+
prompt = "Extract relationships between entities from the text: "
|
26 |
+
|
27 |
+
def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', ner_threshold: float = 0.5, rel_threshold: float = 0.5, return_index: bool = False, prompt: Optional[str] = None):
|
28 |
+
"""
|
29 |
+
Initializes the GLiNERRelationExtractor.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
model_id (str, optional): Identifier for the model to be loaded. Defaults to None.
|
33 |
+
model (GLiNER, optional): Preloaded GLiNER model. Defaults to None.
|
34 |
+
device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'.
|
35 |
+
ner_threshold (float, optional): Named Entity Recognition threshold to use. Defaults to 0.5.
|
36 |
+
rel_threshold (float, optional): Relation Extraction threshold to use. Defaults to 0.5.
|
37 |
+
prompt (str, optional): Template prompt for question-answering.
|
38 |
+
"""
|
39 |
+
# Use the provided prompt or default to the class-level prompt
|
40 |
+
prompt = prompt if prompt is not None else self.prompt
|
41 |
+
super().__init__(model_id=model_id, model=model, prompt=prompt, device=device)
|
42 |
+
|
43 |
+
def prepare_texts(self, texts: List[str], **kwargs):
|
44 |
+
"""
|
45 |
+
Prepares prompts for relation extraction to texts.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
texts (list): List of input texts.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
list: List of formatted prompts.
|
52 |
+
"""
|
53 |
+
prompts = []
|
54 |
+
|
55 |
+
for id, text in enumerate(texts):
|
56 |
+
prompt = f"{self.prompt} \n {text}"
|
57 |
+
prompts.append(prompt)
|
58 |
+
return prompts
|
59 |
+
|
60 |
+
def prepare_source_relation(self, ner_predictions: List[dict], relations: List[str]):
|
61 |
+
relation_labels = []
|
62 |
+
for prediction in ner_predictions:
|
63 |
+
curr_labels = []
|
64 |
+
unique_entities = {ent['text'] for ent in prediction}
|
65 |
+
for relation in relations:
|
66 |
+
for ent in unique_entities:
|
67 |
+
curr_labels.append(f"{ent} <> {relation}")
|
68 |
+
relation_labels.append(curr_labels)
|
69 |
+
return relation_labels
|
70 |
+
|
71 |
+
def process_predictions(self, predictions, **kwargs):
|
72 |
+
"""
|
73 |
+
Processes predictions to extract the highest-scoring relation(s).
|
74 |
+
|
75 |
+
Args:
|
76 |
+
predictions (list): List of predictions with scores.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
list: List of predicted labels for each input.
|
80 |
+
"""
|
81 |
+
batch_predicted_relations = []
|
82 |
+
|
83 |
+
for prediction in predictions:
|
84 |
+
# Sort predictions by score in descending order
|
85 |
+
curr_relations = []
|
86 |
+
|
87 |
+
for target in prediction:
|
88 |
+
target_ent = target['text']
|
89 |
+
score = target['score']
|
90 |
+
source, relation = target['label'].split('<>')
|
91 |
+
relation = {
|
92 |
+
"source": source.strip(),
|
93 |
+
"relation": relation.strip(),
|
94 |
+
"target": target_ent.strip(),
|
95 |
+
"score": score
|
96 |
+
}
|
97 |
+
# **pull through** span info if present
|
98 |
+
if self.return_index:
|
99 |
+
relation['start'] = target.get('start', None)
|
100 |
+
relation['end'] = target.get('end', None)
|
101 |
+
curr_relations.append(relation)
|
102 |
+
batch_predicted_relations.append(curr_relations)
|
103 |
+
|
104 |
+
return batch_predicted_relations
|
105 |
+
|
106 |
+
def __call__(self, texts: Union[str, List[str]], relations: List[str]=None,
|
107 |
+
entities: List[str] = ['named entity'],
|
108 |
+
relation_labels: Optional[List[List[str]]]=None,
|
109 |
+
ner_threshold: float = 0.5,
|
110 |
+
rel_threshold: float = 0.5,
|
111 |
+
batch_size: int = 8, **kwargs):
|
112 |
+
if isinstance(texts, str):
|
113 |
+
texts = [texts]
|
114 |
+
|
115 |
+
prompts = self.prepare_texts(texts, **kwargs)
|
116 |
+
|
117 |
+
if relation_labels is None:
|
118 |
+
# ner
|
119 |
+
ner_predictions = self.model.run(texts, entities, threshold=ner_threshold, batch_size=batch_size)
|
120 |
+
#rex
|
121 |
+
relation_labels = self.prepare_source_relation(ner_predictions, relations)
|
122 |
+
|
123 |
+
predictions = self.model.run(prompts, relation_labels, threshold=rel_threshold, batch_size=batch_size)
|
124 |
+
|
125 |
+
results = self.process_predictions(predictions, **kwargs)
|
126 |
+
|
127 |
+
return results
|
128 |
+
|
129 |
+
def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None,
|
130 |
+
labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1):
|
131 |
+
"""
|
132 |
+
Evaluates the model on a specified dataset and computes evaluation metrics.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets).
|
136 |
+
dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored.
|
137 |
+
labels (list, optional): List of target labels to consider for relation extraction. Defaults to None (use all).
|
138 |
+
threshold (float): Confidence threshold for predictions. Defaults to 0.5.
|
139 |
+
max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples).
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
dict: A dictionary containing evaluation metrics such as F1 scores.
|
143 |
+
|
144 |
+
Raises:
|
145 |
+
ValueError: If neither `dataset_id` nor `dataset` is provided.
|
146 |
+
"""
|
147 |
+
raise NotImplementedError("Currently `evaluate` method is not implemented.")
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ gradio
|
|
2 |
gliner
|
3 |
torch
|
4 |
scipy
|
5 |
-
scikit-learn
|
|
|
|
2 |
gliner
|
3 |
torch
|
4 |
scipy
|
5 |
+
scikit-learn
|
6 |
+
datasets
|