Spaces:
Running
Running
Commit
·
a508c57
1
Parent(s):
0fd2cd1
simulated
Browse files- app.py +50 -102
- requirements.py +2 -4
app.py
CHANGED
@@ -1,118 +1,66 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
import gradio as gr
|
4 |
-
|
5 |
-
import spaces
|
6 |
-
from gliner import GLiNER
|
7 |
-
from gliner.multitask import GLiNERRelationExtractor
|
8 |
-
from typing import List, Dict, Any, Tuple
|
9 |
|
10 |
-
#
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
16 |
'acronym', 'author', 'data description',
|
17 |
'data geography', 'data source', 'data type',
|
18 |
'publication year', 'publisher', 'reference year', 'version'
|
19 |
]
|
|
|
20 |
|
21 |
-
#
|
22 |
-
TYPE2RELS = {
|
23 |
-
"named dataset": trels,
|
24 |
-
"unnamed dataset": trels,
|
25 |
-
"vague dataset": trels,
|
26 |
-
}
|
27 |
-
|
28 |
-
# Load models
|
29 |
-
print("Loading NER+RE model...")
|
30 |
-
model = GLiNER.from_pretrained(data_model_id, cache_dir=CACHE_DIR)
|
31 |
-
relation_extractor = GLiNERRelationExtractor(model=model)
|
32 |
-
# if torch.cuda.is_available():
|
33 |
-
# model.to("cuda")
|
34 |
-
# relation_extractor.model.to("cuda")
|
35 |
-
# print("Models loaded.")
|
36 |
-
|
37 |
-
# Inference pipeline
|
38 |
-
def inference_pipeline(
|
39 |
-
text: str,
|
40 |
-
model,
|
41 |
-
labels: List[str],
|
42 |
-
relation_extractor: GLiNERRelationExtractor,
|
43 |
-
TYPE2RELS: Dict[str, List[str]],
|
44 |
-
ner_threshold: float = 0.5,
|
45 |
-
re_threshold: float = 0.4,
|
46 |
-
re_multi_label: bool = False,
|
47 |
-
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
|
48 |
-
# NER predictions
|
49 |
-
ner_preds = model.predict_entities(
|
50 |
-
text,
|
51 |
-
labels,
|
52 |
-
flat_ner=True,
|
53 |
-
threshold=ner_threshold
|
54 |
-
)
|
55 |
-
|
56 |
-
# Relation extraction per entity span
|
57 |
-
re_results: Dict[str, List[Dict[str, Any]]] = {}
|
58 |
-
for ner in ner_preds:
|
59 |
-
span = ner['text']
|
60 |
-
rel_types = TYPE2RELS.get(ner['label'], [])
|
61 |
-
if not rel_types:
|
62 |
-
continue
|
63 |
-
slot_labels = [f"{span} <> {r}" for r in rel_types]
|
64 |
-
preds = relation_extractor(
|
65 |
-
text,
|
66 |
-
relations=None,
|
67 |
-
entities=None,
|
68 |
-
relation_labels=slot_labels,
|
69 |
-
threshold=re_threshold,
|
70 |
-
multi_label=re_multi_label,
|
71 |
-
distance_threshold=100,
|
72 |
-
)[0]
|
73 |
-
re_results[span] = preds
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
"relations": re_results,
|
91 |
}
|
92 |
-
return json.dumps(output, indent=2)
|
93 |
|
94 |
-
|
|
|
|
|
95 |
gr.Markdown(
|
96 |
"""
|
97 |
-
## Step
|
98 |
-
Enter text and click
|
99 |
"""
|
100 |
)
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
)
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
115 |
)
|
116 |
|
117 |
-
|
118 |
-
demo.launch(debug=True)
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from typing import Union, Dict, Any, List
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
# Sample pre-calculated entities
|
5 |
+
sample_text = (
|
6 |
+
"Recent studies on ocean currents from the Global Ocean Temperature Dataset "
|
7 |
+
"(GOTD) indicate significant shifts in marine biodiversity."
|
8 |
+
)
|
9 |
+
sample_entities = [
|
10 |
+
{"label": "named dataset", "text": "Global Ocean Temperature Dataset", "start": 29, "end": 62, "score": 0.99},
|
11 |
+
{"label": "acronym", "text": "GOTD", "start": 64, "end": 68, "score": 0.98},
|
12 |
+
]
|
13 |
+
rels = [
|
14 |
'acronym', 'author', 'data description',
|
15 |
'data geography', 'data source', 'data type',
|
16 |
'publication year', 'publisher', 'reference year', 'version'
|
17 |
]
|
18 |
+
MODELS = ["demo-model-1", "demo-model-2"]
|
19 |
|
20 |
+
# Annotate_query simulation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
def annotate_query(
|
23 |
+
query: str,
|
24 |
+
labels: Union[str, List[str]],
|
25 |
+
threshold: float = 0.3,
|
26 |
+
nested_ner: bool = False,
|
27 |
+
model_name: str = None
|
28 |
+
) -> Dict[str, Any]:
|
29 |
+
# In a real app, you'd call parse_query/inference_pipeline here.
|
30 |
+
# For simulation, reuse sample_entities.
|
31 |
+
return {
|
32 |
+
"text": query,
|
33 |
+
"entities": [
|
34 |
+
{"start": ent["start"], "end": ent["end"], "label": ent["label"]}
|
35 |
+
for ent in sample_entities
|
36 |
+
]
|
|
|
37 |
}
|
|
|
38 |
|
39 |
+
# Build Gradio UI
|
40 |
+
demo = gr.Blocks()
|
41 |
+
with demo:
|
42 |
gr.Markdown(
|
43 |
"""
|
44 |
+
## Step: Annotate Query Simulation
|
45 |
+
Enter text (prepopulated) and click **Annotate** to see how entities are highlighted.
|
46 |
"""
|
47 |
)
|
48 |
+
# Inputs
|
49 |
+
query = gr.Textbox(lines=3, value=sample_text, label="Input Text")
|
50 |
+
entities = gr.Textbox(value=", ".join(rels), label="Relations (unused in simulation)")
|
51 |
+
threshold = gr.Slider(0, 1, value=0.3, step=0.01, label="Threshold")
|
52 |
+
nested = gr.Checkbox(value=False, label="Nested NER")
|
53 |
+
model = gr.Radio(choices=MODELS, value=MODELS[0], label="Model")
|
54 |
+
|
55 |
+
# Outputs
|
56 |
+
output_hl = gr.HighlightedText(label="Annotated Entities")
|
57 |
+
|
58 |
+
# Button
|
59 |
+
annotate_btn = gr.Button("Annotate")
|
60 |
+
annotate_btn.click(
|
61 |
+
fn=annotate_query,
|
62 |
+
inputs=[query, entities, threshold, nested, model],
|
63 |
+
outputs=[output_hl]
|
64 |
)
|
65 |
|
66 |
+
demo.launch(debug=True)
|
|
requirements.py
CHANGED
@@ -1,4 +1,2 @@
|
|
1 |
-
gradio
|
2 |
-
|
3 |
-
torch # PyTorch (CPU build by default)
|
4 |
-
huggingface-hub
|
|
|
1 |
+
gradio
|
2 |
+
spaces
|
|
|
|