Spaces:
Running
Running
File size: 7,157 Bytes
fc4b12d fd0fe48 d8c3809 fc4b12d 3b9fb2c 2463f9e fc4b12d ab71a6e 2463f9e fc4b12d cd683ff fc4b12d cd683ff c35975c fc4b12d e70665f fc4b12d 9c95361 fc4b12d fd0fe48 fc4b12d cd683ff fc4b12d cd683ff fc4b12d fd0fe48 3d53082 fc4b12d 215cbc3 fc4b12d c38ba9f fc4b12d 13e7831 d799589 fc4b12d 215cbc3 fc4b12d d8c3809 13e7831 fc4b12d 13e7831 fc4b12d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import os
import re
import json
from collections import defaultdict
import gradio as gr
from typing import List, Dict, Any, Tuple
# Load environment variable for cache dir (useful on Spaces)
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
# Import GLiNER model and relation extractor
from gliner import GLiNER
from gliner.multitask import GLiNERRelationExtractor
# Cache and initialize model + relation extractor
DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v3"
model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR)
relation_extractor = GLiNERRelationExtractor(model=model)
# Sample text
SAMPLE_TEXT = (
"In early 2012, the World Bank published the full report of the 2011 Demographic and Health Survey (DHS) "
"for the Republic of Mali. Conducted between June and December 2011 under the technical oversight of Mali’s "
"National Institute of Statistics and paired with on-the-ground data-collection teams, this nationally representative survey "
"gathered detailed information on household composition, education levels, employment and income, fertility and family planning, "
"maternal and child health, nutrition, mortality, and access to basic services. By combining traditional census modules with "
"specialized questionnaires on women’s and children’s health, the DHS offers policymakers, development partners, and researchers "
"a rich dataset of socioeconomic characteristics—ranging from literacy and school attendance to water and sanitation infrastructure—"
"that can be used to monitor progress on poverty reduction, inform targeted social programs, and guide longer-term economic planning."
)
# Post-processing: prune acronyms and self-relations
labels = ['named dataset', 'unnamed dataset', 'vague dataset']
rels = ['acronym', 'author', 'data description',\
'data geography', 'data source', 'data type',\
'publication year', 'publisher', 'reference year', 'version']
TYPE2RELS = {
"named dataset": rels,
"unnamed dataset": rels,
"vague dataset": rels,
}
def inference_pipeline(
text: str,
model,
labels: List[str],
relation_extractor: GLiNERRelationExtractor,
TYPE2RELS: Dict[str, List[str]],
ner_threshold: float = 0.5,
re_threshold: float = 0.4,
re_multi_label: bool = False,
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
ner_preds = model.predict_entities(
text,
labels,
flat_ner=True,
threshold=ner_threshold
)
re_results: Dict[str, List[Dict[str, Any]]] = {}
for ner in ner_preds:
span = ner['text']
rel_types = TYPE2RELS.get(ner['label'], [])
if not rel_types:
continue
slot_labels = [f"{span} <> {r}" for r in rel_types]
preds = relation_extractor(
text,
relations=None,
entities=None,
relation_labels=slot_labels,
threshold=re_threshold,
multi_label=re_multi_label,
distance_threshold=100,
)[0]
re_results[span] = preds
return ner_preds, re_results
def prune_acronym_and_self_relations(ner_preds, rel_preds):
# 1) Find acronym targets strictly shorter than their source
acronym_targets = {
r["target"]
for src, rels in rel_preds.items()
for r in rels
if r["relation"] == "acronym" and len(r["target"]) < len(src)
}
# 2) Filter NER: drop any named-dataset whose text is in that set
filtered_ner = [
ent for ent in ner_preds
if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets)
]
# 3) Filter RE: drop blocks for acronym sources, and self-relations
filtered_re = {}
for src, rels in rel_preds.items():
if src in acronym_targets:
continue
kept = [r for r in rels if r["target"] != src]
if kept:
filtered_re[src] = kept
return filtered_ner, filtered_re
# Highlighting function
def highlight_text(text, ner_threshold, re_threshold):
# Run inference
ner_preds, rel_preds = inference_pipeline(
text,
model=model,
labels=labels,
relation_extractor=relation_extractor,
TYPE2RELS=TYPE2RELS,
ner_threshold=ner_threshold,
re_threshold=re_threshold,
re_multi_label=False
)
# Post-process
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
# Gather all spans
spans = []
for ent in ner_preds:
spans.append((ent["start"], ent["end"], ent["label"]))
for src, rels in rel_preds.items():
for r in rels:
for m in re.finditer(re.escape(r["target"]), text):
spans.append((m.start(), m.end(), f"{src} <> {r['relation']}"))
# Merge labels by span
merged = defaultdict(list)
for start, end, lbl in spans:
merged[(start, end)].append(lbl)
# Build Gradio entities
entities = []
for (start, end), lbls in sorted(merged.items(), key=lambda x: x[0]):
entities.append({
"entity": ", ".join(lbls),
"start": start,
"end": end
})
return {"text": text, "entities": entities}
# JSON output function
def get_model_predictions(text, ner_threshold, re_threshold):
ner_preds, rel_preds = inference_pipeline(
text,
model=model,
labels=labels,
relation_extractor=relation_extractor,
TYPE2RELS=TYPE2RELS,
ner_threshold=ner_threshold,
re_threshold=re_threshold,
re_multi_label=False
)
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
return json.dumps({"ner": ner_preds, "relations": rel_preds}, indent=2)
# Build Gradio UI
demo = gr.Blocks()
with demo:
gr.Markdown("## Data Use Detector\n"
"Adjust the sliders below to set thresholds, then:\n"
"- **Submit** to highlight entities.\n"
"- **Get Model Predictions** to see the raw JSON output.")
txt_in = gr.Textbox(
label="Input Text",
lines=4,
value=SAMPLE_TEXT
)
ner_slider = gr.Slider(
0, 1, value=0.7, step=0.01,
label="NER Threshold",
info="Minimum confidence for named-entity spans."
)
re_slider = gr.Slider(
0, 1, value=0.5, step=0.01,
label="RE Threshold",
info="Minimum confidence for relation extractions."
)
highlight_btn = gr.Button("Submit")
txt_out = gr.HighlightedText(label="Annotated Entities")
get_pred_btn = gr.Button("Get Model Predictions")
ner_rel_box = gr.Textbox(label="Model Predictions (JSON)", lines=15)
# Wire up interactions
highlight_btn.click(
fn=highlight_text,
inputs=[txt_in, ner_slider, re_slider],
outputs=txt_out
)
get_pred_btn.click(
fn=get_model_predictions,
inputs=[txt_in, ner_slider, re_slider],
outputs=ner_rel_box
)
# Enable queue for concurrency
demo.queue(default_concurrency_limit=5)
# Launch the app
demo.launch(debug=True) |