Spaces:
Paused
Paused
import os | |
import re | |
import json | |
from collections import defaultdict | |
import gradio as gr | |
from typing import List, Dict, Any, Tuple | |
# Load environment variable for cache dir (useful on Spaces) | |
_CACHE_DIR = os.environ.get("CACHE_DIR", None) | |
# Import GLiNER model and relation extractor | |
from gliner import GLiNER | |
#from relation_extraction import CustomGLiNERRelationExtractor | |
# Cache and initialize model + relation extractor | |
# DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-vfin" | |
DATA_MODEL_ID = "rafmacalaba/datause-extraction-v0" | |
model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR) | |
from relation_extraction import CustomGLiNERRelationExtractor | |
relation_extractor = CustomGLiNERRelationExtractor(model=model, return_index=True) | |
# Sample text | |
SAMPLE_TEXT = ( | |
"The 2020 Demographic and Health Survey (DHS), conducted by the Ministry of Health in Kenya, provides data on maternal health and child nutrition for rural households." | |
) | |
# Post-processing: prune acronyms and self-relations | |
labels = ['named dataset', 'unnamed dataset', 'vague dataset'] | |
rels = ['acronym', 'author', 'data description', 'data geography', 'data source', 'data type', 'publication year', 'publisher', 'reference population', 'reference year', 'version'] | |
TYPE2RELS = { | |
"named dataset": rels, | |
"unnamed dataset": rels, | |
"vague dataset": rels, | |
} | |
def inference_pipeline( | |
text: str, | |
model, | |
labels: List[str], | |
relation_extractor: CustomGLiNERRelationExtractor, | |
TYPE2RELS: Dict[str, List[str]], | |
ner_threshold: float = 0.5, | |
rel_threshold: float = 0.5, | |
re_multi_label: bool = False, | |
return_index: bool = False, | |
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]: | |
ner_preds = model.predict_entities( | |
text, | |
labels, | |
flat_ner=True, | |
threshold=ner_threshold | |
) | |
re_results: Dict[str, List[Dict[str, Any]]] = {} | |
for ner in ner_preds: | |
span = ner['text'] | |
rel_types = TYPE2RELS.get(ner['label'], []) | |
if not rel_types: | |
continue | |
slot_labels = [f"{span} <> {r}" for r in rel_types] | |
preds = relation_extractor( | |
text, | |
relations=None, | |
entities=None, | |
relation_labels=slot_labels, | |
threshold=rel_threshold, | |
multi_label=re_multi_label, | |
return_index=return_index, | |
)[0] | |
re_results[span] = preds | |
return ner_preds, re_results | |
def prune_acronym_and_self_relations(ner_preds, rel_preds): | |
# 1) Find acronym targets strictly shorter than their source | |
acronym_targets = { | |
r["target"] | |
for src, rels in rel_preds.items() | |
for r in rels | |
if r["relation"] == "acronym" and len(r["target"]) < len(src) | |
} | |
# 2) Filter NER: drop any named-dataset whose text is in that set | |
filtered_ner = [ | |
ent for ent in ner_preds | |
if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets) | |
] | |
# 3) Filter RE: drop blocks for acronym sources, and self-relations | |
filtered_re = {} | |
for src, rels in rel_preds.items(): | |
if src in acronym_targets: | |
continue | |
kept = [r for r in rels if r["target"] != src] | |
if kept: | |
filtered_re[src] = kept | |
return filtered_ner, filtered_re | |
# Highlighting function | |
def highlight_text(text, ner_threshold, rel_threshold): | |
# 1) Inference | |
ner_preds, rel_preds = inference_pipeline( | |
text, | |
model=model, | |
labels=labels, | |
relation_extractor=relation_extractor, | |
TYPE2RELS=TYPE2RELS, | |
ner_threshold=ner_threshold, | |
rel_threshold=rel_threshold, | |
re_multi_label=False, | |
return_index=True, | |
) | |
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds) | |
# 2) Compute how long the RE prompt prefix is | |
# This must match exactly what your extractor prepends: | |
prefix = f"{relation_extractor.prompt} \n " | |
prefix_len = len(prefix) | |
# 3) Gather spans | |
spans = [] | |
for ent in ner_preds: | |
spans.append((ent["start"], ent["end"], ent["label"])) | |
# Use the extractor‐returned start/end, minus prefix_len | |
for src, rels in rel_preds.items(): | |
for r in rels: | |
# adjust the indices back onto the raw text | |
s = r["start"] - prefix_len | |
e = r["end"] - prefix_len | |
# skip anything that fell outside | |
if s < 0 or e < 0: | |
continue | |
label = f"{r['source']} <> {r['relation']}" | |
spans.append((s, e, label)) | |
# 4) Merge & build entities (same as before) | |
merged = defaultdict(list) | |
for s, e, lbl in spans: | |
merged[(s, e)].append(lbl) | |
entities = [] | |
for (s, e), lbls in sorted(merged.items(), key=lambda x: x[0]): | |
entities.append({ | |
"entity": ", ".join(lbls), | |
"start": s, | |
"end": e | |
}) | |
return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds} | |
# JSON output function | |
def _cached_predictions(state): | |
if not state: | |
return "📋 No predictions yet. Click **Submit** first." | |
return json.dumps(state, indent=2) | |
with gr.Blocks() as demo: | |
gr.Markdown(f"""# Data Use Detector | |
This Space demonstrates our fine-tuned GLiNER model’s ability to spot **dataset mentions** and **relations** in any input text. It identifies dataset names via NER, then extracts relations such as **publisher**, **acronym**, **publication year**, **data geography**, and more. | |
**How it works** | |
1. **NER**: Recognizes dataset names in your text. | |
2. **RE**: Links each dataset to its attributes (e.g., publisher, year, acronym). | |
3. **Visualization**: Highlights entities and relation spans inline. | |
**Instructions** | |
1. Paste or edit your text in the box below. | |
2. Tweak the **NER** & **RE** confidence sliders. | |
3. Click **Submit** to see highlights. | |
4. Click **Get Model Predictions** to view the raw JSON output. | |
**Resources** | |
- **Model:** [{DATA_MODEL_ID}](https://huggingface.co/{DATA_MODEL_ID}) | |
- **Paper:** _Large Language Models and Synthetic Data for Monitoring Dataset Mentions in Research Papers_ – ArXiv: [2502.10263](https://arxiv.org/pdf/2502.10263) | |
- [GLiNER GitHub Repo](https://github.com/urchade/GLiNER) | |
- [Project Docs](https://worldbank.github.io/ai4data-use/docs/introduction.html) | |
""") | |
txt_in = gr.Textbox( | |
label="Input Text", | |
lines=4, | |
value=SAMPLE_TEXT | |
) | |
EXAMPLE_TEXTS = [ | |
"Using the Demographic and Health Survey (DHS) 2015–2016 dataset, a longitudinal panel dataset of household expenditures compiled by the National Bureau of Statistics from 2010 to 2020, and anonymized administrative records, we examine trends in maternal health outcomes across rural districts.", | |
"Leveraging the NOAA Global Historical Climatology Network (GHCN) daily temperature records from 2010 to 2020, we examine long-term warming trends across North American cities.", | |
"We analyze anonymized administrative hospital discharge records from the UK National Health Service (NHS) for the years 2015 through 2019 to identify seasonal patterns in respiratory illnesses.", | |
"Our analysis uses a panel of student performance scores from the OECD’s Programme for International Student Assessment (PISA) 2015 cycle to explore educational outcomes across member states.", | |
"We employ a longitudinal survey of household income and expenditure collected by India’s National Sample Survey Office (NSSO) from 2005 to examine poverty dynamics." | |
] | |
examples = gr.Examples( | |
examples=[[t] for t in EXAMPLE_TEXTS], | |
inputs=[txt_in], | |
label="Try one of these example texts" | |
) | |
ner_slider = gr.Slider( | |
0, 1, value=0.5, step=0.01, | |
label="NER Threshold", | |
info="Minimum confidence for named-entity spans." | |
) | |
re_slider = gr.Slider( | |
0, 1, value=0.5, step=0.01, | |
label="RE Threshold", | |
info="Minimum confidence for relation extractions." | |
) | |
highlight_btn = gr.Button("Submit") | |
txt_out = gr.HighlightedText(label="Annotated Entities") | |
get_pred_btn = gr.Button("Get Model Predictions") | |
json_out = gr.Textbox(label="Model Predictions (JSON)", lines=15) | |
state = gr.State() | |
# Wire up interactions | |
highlight_btn.click( | |
fn=highlight_text, | |
inputs=[txt_in, ner_slider, re_slider], | |
outputs=[txt_out, state] | |
) | |
get_pred_btn.click( | |
fn=_cached_predictions, | |
inputs=[state], | |
outputs=[json_out] | |
) | |
# Enable queue for concurrency | |
demo.queue(default_concurrency_limit=5) | |
# Launch the app | |
demo.launch(debug=True, inline=True) |