rafmacalaba's picture
just
eb6e673
raw
history blame
7.35 kB
import os
import re
import json
from collections import defaultdict
import gradio as gr
from typing import List, Dict, Any, Tuple
# Load environment variable for cache dir (useful on Spaces)
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
# Import GLiNER model and relation extractor
from gliner import GLiNER
from gliner.multitask import GLiNERRelationExtractor
# Cache and initialize model + relation extractor
DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v3"
model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR)
relation_extractor = GLiNERRelationExtractor(model=model)
# Sample text
SAMPLE_TEXT = (
"In early 2012, the World Bank published the full report of the 2011 Demographic and Health Survey (DHS) "
"for the Republic of Mali. Conducted between June and December 2011 under the technical oversight of Mali’s "
"National Institute of Statistics and paired with on-the-ground data-collection teams, this nationally representative survey "
"gathered detailed information on household composition, education levels, employment and income, fertility and family planning, "
"maternal and child health, nutrition, mortality, and access to basic services. By combining traditional census modules with "
"specialized questionnaires on women’s and children’s health, the DHS offers policymakers, development partners, and researchers "
"a rich dataset of socioeconomic characteristics—ranging from literacy and school attendance to water and sanitation infrastructure—"
"that can be used to monitor progress on poverty reduction, inform targeted social programs, and guide longer-term economic planning."
)
# Post-processing: prune acronyms and self-relations
labels = ['named dataset', 'unnamed dataset', 'vague dataset']
rels = ['acronym', 'author', 'data description',\
'data geography', 'data source', 'data type',\
'publication year', 'publisher', 'reference year', 'version']
TYPE2RELS = {
"named dataset": rels,
"unnamed dataset": rels,
"vague dataset": rels,
}
def inference_pipeline(
text: str,
model,
labels: List[str],
relation_extractor: GLiNERRelationExtractor,
TYPE2RELS: Dict[str, List[str]],
ner_threshold: float = 0.5,
re_threshold: float = 0.4,
re_multi_label: bool = False,
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
ner_preds = model.predict_entities(
text,
labels,
flat_ner=True,
threshold=ner_threshold
)
re_results: Dict[str, List[Dict[str, Any]]] = {}
for ner in ner_preds:
span = ner['text']
rel_types = TYPE2RELS.get(ner['label'], [])
if not rel_types:
continue
slot_labels = [f"{span} <> {r}" for r in rel_types]
preds = relation_extractor(
text,
relations=None,
entities=None,
relation_labels=slot_labels,
threshold=re_threshold,
multi_label=re_multi_label,
distance_threshold=100,
)[0]
re_results[span] = preds
return ner_preds, re_results
def prune_acronym_and_self_relations(ner_preds, rel_preds):
# 1) Find acronym targets strictly shorter than their source
acronym_targets = {
r["target"]
for src, rels in rel_preds.items()
for r in rels
if r["relation"] == "acronym" and len(r["target"]) < len(src)
}
# 2) Filter NER: drop any named-dataset whose text is in that set
filtered_ner = [
ent for ent in ner_preds
if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets)
]
# 3) Filter RE: drop blocks for acronym sources, and self-relations
filtered_re = {}
for src, rels in rel_preds.items():
if src in acronym_targets:
continue
kept = [r for r in rels if r["target"] != src]
if kept:
filtered_re[src] = kept
return filtered_ner, filtered_re
# Highlighting function
def highlight_text(text, ner_threshold, re_threshold):
# Run inference
ner_preds, rel_preds = inference_pipeline(
text,
model=model,
labels=labels,
relation_extractor=relation_extractor,
TYPE2RELS=TYPE2RELS,
ner_threshold=ner_threshold,
re_threshold=re_threshold,
re_multi_label=False
)
# Post-process
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
# Gather all spans
spans = []
for ent in ner_preds:
spans.append((ent["start"], ent["end"], ent["label"]))
for src, rels in rel_preds.items():
for r in rels:
for m in re.finditer(re.escape(r["target"]), text):
spans.append((m.start(), m.end(), f"{src} <> {r['relation']}"))
# Merge labels by span
merged = defaultdict(list)
for start, end, lbl in spans:
merged[(start, end)].append(lbl)
# Build Gradio entities
entities = []
for (start, end), lbls in sorted(merged.items(), key=lambda x: x[0]):
entities.append({
"entity": ", ".join(lbls),
"start": start,
"end": end
})
return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
# JSON output function
def get_model_predictions(text, ner_threshold, re_threshold):
ner_preds, rel_preds = inference_pipeline(
text,
model=model,
labels=labels,
relation_extractor=relation_extractor,
TYPE2RELS=TYPE2RELS,
ner_threshold=ner_threshold,
re_threshold=re_threshold,
re_multi_label=False
)
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
return json.dumps({"ner": ner_preds, "relations": rel_preds}, indent=2)
def _cached_predictions(state):
if not state:
return "📋 No predictions yet. Click **Submit** first."
return json.dumps(state, indent=2)
with gr.Blocks() as demo:
gr.Markdown("## Data Use Detector\n"
"Adjust the sliders below to set thresholds, then:\n"
"- **Submit** to highlight entities.\n"
"- **Get Model Predictions** to see the raw JSON output.")
txt_in = gr.Textbox(
label="Input Text",
lines=4,
value=SAMPLE_TEXT
)
ner_slider = gr.Slider(
0, 1, value=0.7, step=0.01,
label="NER Threshold",
info="Minimum confidence for named-entity spans."
)
re_slider = gr.Slider(
0, 1, value=0.5, step=0.01,
label="RE Threshold",
info="Minimum confidence for relation extractions."
)
highlight_btn = gr.Button("Submit")
txt_out = gr.HighlightedText(label="Annotated Entities")
get_pred_btn = gr.Button("Get Model Predictions")
json_out = gr.Textbox(label="Model Predictions (JSON)", lines=15)
state = gr.State()
# Wire up interactions
highlight_btn.click(
fn=highlight_text,
inputs=[txt_in, ner_slider, re_slider],
outputs=[txt_out, state]
)
get_pred_btn.click(
fn=_cached_predictions,
inputs=[state],
outputs=[json_out]
)
# Enable queue for concurrency
demo.queue(default_concurrency_limit=5)
# Launch the app
demo.launch(debug=True)