import os import re import json from collections import defaultdict import gradio as gr from typing import List, Dict, Any, Tuple # Load environment variable for cache dir (useful on Spaces) _CACHE_DIR = os.environ.get("CACHE_DIR", None) # Import GLiNER model and relation extractor from gliner import GLiNER #from relation_extraction import CustomGLiNERRelationExtractor # Cache and initialize model + relation extractor DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v3" model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR) from relation_extraction import CustomGLiNERRelationExtractor relation_extractor = CustomGLiNERRelationExtractor(model=model, return_index=True) # Sample text SAMPLE_TEXT = ( "In 2019, the third round of the Demographic and Health Survey (DHS) was conducted in 2020 by the World Bank, serving as the principal data source for a nationally representative cross‐sectional survey that covered Nigeria, Kenya, and Ghana and providing detailed demographic, health, and nutrition data, including household composition, fertility and mortality rates, maternal and child health indicators, and access to water and sanitation." ) # Post-processing: prune acronyms and self-relations labels = ['named dataset', 'unnamed dataset', 'vague dataset'] rels = ['acronym', 'author', 'data description',\ 'data geography', 'data source', 'data type',\ 'publication year', 'publisher', 'reference year', 'version'] TYPE2RELS = { "named dataset": rels, "unnamed dataset": rels, "vague dataset": rels, } def inference_pipeline( text: str, model, labels: List[str], relation_extractor: CustomGLiNERRelationExtractor, TYPE2RELS: Dict[str, List[str]], ner_threshold: float = 0.7, rel_threshold: float = 0.5, re_multi_label: bool = False, return_index: bool = False, ) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]: ner_preds = model.predict_entities( text, labels, flat_ner=True, threshold=ner_threshold ) re_results: Dict[str, List[Dict[str, Any]]] = {} for ner in ner_preds: span = ner['text'] rel_types = TYPE2RELS.get(ner['label'], []) if not rel_types: continue slot_labels = [f"{span} <> {r}" for r in rel_types] preds = relation_extractor( text, relations=None, entities=None, relation_labels=slot_labels, threshold=rel_threshold, multi_label=re_multi_label, return_index=return_index, )[0] re_results[span] = preds return ner_preds, re_results def prune_acronym_and_self_relations(ner_preds, rel_preds): # 1) Find acronym targets strictly shorter than their source acronym_targets = { r["target"] for src, rels in rel_preds.items() for r in rels if r["relation"] == "acronym" and len(r["target"]) < len(src) } # 2) Filter NER: drop any named-dataset whose text is in that set filtered_ner = [ ent for ent in ner_preds if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets) ] # 3) Filter RE: drop blocks for acronym sources, and self-relations filtered_re = {} for src, rels in rel_preds.items(): if src in acronym_targets: continue kept = [r for r in rels if r["target"] != src] if kept: filtered_re[src] = kept return filtered_ner, filtered_re # Highlighting function def highlight_text(text, ner_threshold, rel_threshold): # 1) Inference ner_preds, rel_preds = inference_pipeline( text, model=model, labels=labels, relation_extractor=relation_extractor, TYPE2RELS=TYPE2RELS, ner_threshold=ner_threshold, rel_threshold=rel_threshold, re_multi_label=False, return_index=True, ) ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds) # 2) Compute how long the RE prompt prefix is # This must match exactly what your extractor prepends: prefix = f"{relation_extractor.prompt} \n " prefix_len = len(prefix) # 3) Gather spans spans = [] for ent in ner_preds: spans.append((ent["start"], ent["end"], ent["label"])) # Use the extractor‐returned start/end, minus prefix_len for src, rels in rel_preds.items(): for r in rels: # adjust the indices back onto the raw text s = r["start"] - prefix_len e = r["end"] - prefix_len # skip anything that fell outside if s < 0 or e < 0: continue label = f"{r['source']} <> {r['relation']}" spans.append((s, e, label)) # 4) Merge & build entities (same as before) merged = defaultdict(list) for s, e, lbl in spans: merged[(s, e)].append(lbl) entities = [] for (s, e), lbls in sorted(merged.items(), key=lambda x: x[0]): entities.append({ "entity": ", ".join(lbls), "start": s, "end": e }) return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds} # JSON output function def _cached_predictions(state): if not state: return "📋 No predictions yet. Click **Submit** first." return json.dumps(state, indent=2) with gr.Blocks() as demo: gr.Markdown("""# Data Use Detector This Space demonstrates our fine-tuned GLiNER model’s ability to spot **dataset mentions** and **relations** in any input text. It identifies dataset names via NER, then extracts relations such as **publisher**, **acronym**, **publication year**, **data geography**, and more. **How it works** 1. **NER**: Recognizes dataset names in your text. 2. **RE**: Links each dataset to its attributes (e.g., publisher, year, acronym). 3. **Visualization**: Highlights entities and relation spans inline. **Instructions** 1. Paste or edit your text in the box below. 2. Tweak the **NER** & **RE** confidence sliders. 3. Click **Submit** to see highlights. 4. Click **Get Model Predictions** to view the raw JSON output. **Resources** - **Model:** [rafmacalaba/gliner_re_finetuned-v3](https://huggingface.co/rafmacalaba/gliner_re_finetuned-v3) - **Paper:** _Large Language Models and Synthetic Data for Monitoring Dataset Mentions in Research Papers_ – ArXiv: [2502.10263](https://arxiv.org/pdf/2502.10263) - [GLiNER GitHub Repo](https://github.com/urchade/GLiNER) - [Project Docs](https://worldbank.github.io/ai4data-use/docs/introduction.html) """) txt_in = gr.Textbox( label="Input Text", lines=4, value=SAMPLE_TEXT ) ner_slider = gr.Slider( 0, 1, value=0.7, step=0.01, label="NER Threshold", info="Minimum confidence for named-entity spans." ) re_slider = gr.Slider( 0, 1, value=0.5, step=0.01, label="RE Threshold", info="Minimum confidence for relation extractions." ) highlight_btn = gr.Button("Submit") txt_out = gr.HighlightedText(label="Annotated Entities") get_pred_btn = gr.Button("Get Model Predictions") json_out = gr.Textbox(label="Model Predictions (JSON)", lines=15) state = gr.State() # Wire up interactions highlight_btn.click( fn=highlight_text, inputs=[txt_in, ner_slider, re_slider], outputs=[txt_out, state] ) get_pred_btn.click( fn=_cached_predictions, inputs=[state], outputs=[json_out] ) # Enable queue for concurrency demo.queue(default_concurrency_limit=5) # Launch the app demo.launch(debug=True, inline=True)