Spaces:
Running
Running
import os | |
import re | |
import json | |
from collections import defaultdict | |
import gradio as gr | |
from typing import List, Dict, Any, Tuple | |
# Load environment variable for cache dir (useful on Spaces) | |
_CACHE_DIR = os.environ.get("CACHE_DIR", None) | |
# Import GLiNER model and relation extractor | |
from gliner import GLiNER | |
from gliner.multitask import GLiNERRelationExtractor | |
# Cache and initialize model + relation extractor | |
DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v3" | |
model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR) | |
relation_extractor = GLiNERRelationExtractor(model=model) | |
# Sample text | |
SAMPLE_TEXT = ( | |
"In early 2012, the World Bank published the full report of the 2011 Demographic and Health Survey (DHS) " | |
"for the Republic of Mali. Conducted between June and December 2011 under the technical oversight of Mali’s " | |
"National Institute of Statistics and paired with on-the-ground data-collection teams, this nationally representative survey " | |
"gathered detailed information on household composition, education levels, employment and income, fertility and family planning, " | |
"maternal and child health, nutrition, mortality, and access to basic services. By combining traditional census modules with " | |
"specialized questionnaires on women’s and children’s health, the DHS offers policymakers, development partners, and researchers " | |
"a rich dataset of socioeconomic characteristics—ranging from literacy and school attendance to water and sanitation infrastructure—" | |
"that can be used to monitor progress on poverty reduction, inform targeted social programs, and guide longer-term economic planning." | |
) | |
# Post-processing: prune acronyms and self-relations | |
labels = ['named dataset', 'unnamed dataset', 'vague dataset'] | |
rels = ['acronym', 'author', 'data description',\ | |
'data geography', 'data source', 'data type',\ | |
'publication year', 'publisher', 'reference year', 'version'] | |
TYPE2RELS = { | |
"named dataset": rels, | |
"unnamed dataset": rels, | |
"vague dataset": rels, | |
} | |
def inference_pipeline( | |
text: str, | |
model, | |
labels: List[str], | |
relation_extractor: GLiNERRelationExtractor, | |
TYPE2RELS: Dict[str, List[str]], | |
ner_threshold: float = 0.5, | |
re_threshold: float = 0.4, | |
re_multi_label: bool = False, | |
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]: | |
ner_preds = model.predict_entities( | |
text, | |
labels, | |
flat_ner=True, | |
threshold=ner_threshold | |
) | |
re_results: Dict[str, List[Dict[str, Any]]] = {} | |
for ner in ner_preds: | |
span = ner['text'] | |
rel_types = TYPE2RELS.get(ner['label'], []) | |
if not rel_types: | |
continue | |
slot_labels = [f"{span} <> {r}" for r in rel_types] | |
preds = relation_extractor( | |
text, | |
relations=None, | |
entities=None, | |
relation_labels=slot_labels, | |
threshold=re_threshold, | |
multi_label=re_multi_label, | |
distance_threshold=100, | |
)[0] | |
re_results[span] = preds | |
return ner_preds, re_results | |
def prune_acronym_and_self_relations(ner_preds, rel_preds): | |
# 1) Find acronym targets strictly shorter than their source | |
acronym_targets = { | |
r["target"] | |
for src, rels in rel_preds.items() | |
for r in rels | |
if r["relation"] == "acronym" and len(r["target"]) < len(src) | |
} | |
# 2) Filter NER: drop any named-dataset whose text is in that set | |
filtered_ner = [ | |
ent for ent in ner_preds | |
if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets) | |
] | |
# 3) Filter RE: drop blocks for acronym sources, and self-relations | |
filtered_re = {} | |
for src, rels in rel_preds.items(): | |
if src in acronym_targets: | |
continue | |
kept = [r for r in rels if r["target"] != src] | |
if kept: | |
filtered_re[src] = kept | |
return filtered_ner, filtered_re | |
# Highlighting function | |
def highlight_text(text, ner_threshold, re_threshold): | |
# Run inference | |
ner_preds, rel_preds = inference_pipeline( | |
text, | |
model=model, | |
labels=labels, | |
relation_extractor=relation_extractor, | |
TYPE2RELS=TYPE2RELS, | |
ner_threshold=ner_threshold, | |
re_threshold=re_threshold, | |
re_multi_label=False | |
) | |
# Post-process | |
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds) | |
# Gather all spans | |
spans = [] | |
for ent in ner_preds: | |
spans.append((ent["start"], ent["end"], ent["label"])) | |
for src, rels in rel_preds.items(): | |
for r in rels: | |
for m in re.finditer(re.escape(r["target"]), text): | |
spans.append((m.start(), m.end(), f"{src} <> {r['relation']}")) | |
# Merge labels by span | |
merged = defaultdict(list) | |
for start, end, lbl in spans: | |
merged[(start, end)].append(lbl) | |
# Build Gradio entities | |
entities = [] | |
for (start, end), lbls in sorted(merged.items(), key=lambda x: x[0]): | |
entities.append({ | |
"entity": ", ".join(lbls), | |
"start": start, | |
"end": end | |
}) | |
return {"text": text, "entities": entities} | |
# JSON output function | |
def get_model_predictions(text, ner_threshold, re_threshold): | |
ner_preds, rel_preds = inference_pipeline( | |
text, | |
model=model, | |
labels=labels, | |
relation_extractor=relation_extractor, | |
TYPE2RELS=TYPE2RELS, | |
ner_threshold=ner_threshold, | |
re_threshold=re_threshold, | |
re_multi_label=False | |
) | |
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds) | |
return json.dumps({"ner": ner_preds, "relations": rel_preds}, indent=2) | |
# Build Gradio UI | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("## Data Use Detector\n" | |
"Adjust the sliders below to set thresholds, then:\n" | |
"- **Submit** to highlight entities.\n" | |
"- **Get Model Predictions** to see the raw JSON output.") | |
txt_in = gr.Textbox( | |
label="Input Text", | |
lines=4, | |
value=SAMPLE_TEXT | |
) | |
ner_slider = gr.Slider( | |
0, 1, value=0.7, step=0.01, | |
label="NER Threshold", | |
info="Minimum confidence for named-entity spans." | |
) | |
re_slider = gr.Slider( | |
0, 1, value=0.5, step=0.01, | |
label="RE Threshold", | |
info="Minimum confidence for relation extractions." | |
) | |
highlight_btn = gr.Button("Submit") | |
txt_out = gr.HighlightedText(label="Annotated Entities") | |
get_pred_btn = gr.Button("Get Model Predictions") | |
ner_rel_box = gr.Textbox(label="Model Predictions (JSON)", lines=15) | |
# Wire up interactions | |
highlight_btn.click( | |
fn=highlight_text, | |
inputs=[txt_in, ner_slider, re_slider], | |
outputs=txt_out | |
) | |
get_pred_btn.click( | |
fn=get_model_predictions, | |
inputs=[txt_in, ner_slider, re_slider], | |
outputs=ner_rel_box | |
) | |
# Enable queue for concurrency | |
demo.queue(default_concurrency_limit=5) | |
# Launch the app | |
demo.launch(debug=True) |