import os import re import json from collections import defaultdict import gradio as gr from typing import List, Dict, Any, Tuple # Load environment variable for cache dir (useful on Spaces) _CACHE_DIR = os.environ.get("CACHE_DIR", None) # Import GLiNER model and relation extractor from gliner import GLiNER #from relation_extraction import CustomGLiNERRelationExtractor # Cache and initialize model + relation extractor DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v7-pos" model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR) from relation_extraction import CustomGLiNERRelationExtractor relation_extractor = CustomGLiNERRelationExtractor(model=model, return_index=True) # Sample text SAMPLE_TEXT = ( "Encuesta Nacional de Hogares (ENAHO) is the Peruvian version of the Living Standards Measurement Survey, e.g. a nationally representative household survey collected monthly on a continuous basis. For our analysis, we use data from January 2007 to December 2020. The survey covers a wide variety of topics, including basic demographics, educational background, labor market conditions, crime victimization, and a module on respondent’s perceptions about the main problems in the country and trust in different local and national‐level institutions." ) # Post-processing: prune acronyms and self-relations labels = ['named dataset', 'unnamed dataset', 'vague dataset'] rels = ['acronym', 'author', 'data description', 'data geography', 'data source', 'data type', 'publication year', 'publisher', 'reference population', 'reference year', 'version'] TYPE2RELS = { "named dataset": rels, "unnamed dataset": rels, "vague dataset": rels, } def inference_pipeline( text: str, model, labels: List[str], relation_extractor: CustomGLiNERRelationExtractor, TYPE2RELS: Dict[str, List[str]], ner_threshold: float = 0.7, rel_threshold: float = 0.5, re_multi_label: bool = False, return_index: bool = False, ) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]: ner_preds = model.predict_entities( text, labels, flat_ner=True, threshold=ner_threshold ) re_results: Dict[str, List[Dict[str, Any]]] = {} for ner in ner_preds: span = ner['text'] rel_types = TYPE2RELS.get(ner['label'], []) if not rel_types: continue slot_labels = [f"{span} <> {r}" for r in rel_types] preds = relation_extractor( text, relations=None, entities=None, relation_labels=slot_labels, threshold=rel_threshold, multi_label=re_multi_label, return_index=return_index, )[0] re_results[span] = preds return ner_preds, re_results def prune_acronym_and_self_relations(ner_preds, rel_preds): # 1) Find acronym targets strictly shorter than their source acronym_targets = { r["target"] for src, rels in rel_preds.items() for r in rels if r["relation"] == "acronym" and len(r["target"]) < len(src) } # 2) Filter NER: drop any named-dataset whose text is in that set filtered_ner = [ ent for ent in ner_preds if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets) ] # 3) Filter RE: drop blocks for acronym sources, and self-relations filtered_re = {} for src, rels in rel_preds.items(): if src in acronym_targets: continue kept = [r for r in rels if r["target"] != src] if kept: filtered_re[src] = kept return filtered_ner, filtered_re # Highlighting function def highlight_text(text, ner_threshold, rel_threshold): # 1) Inference ner_preds, rel_preds = inference_pipeline( text, model=model, labels=labels, relation_extractor=relation_extractor, TYPE2RELS=TYPE2RELS, ner_threshold=ner_threshold, rel_threshold=rel_threshold, re_multi_label=False, return_index=True, ) ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds) # 2) Compute how long the RE prompt prefix is # This must match exactly what your extractor prepends: prefix = f"{relation_extractor.prompt} \n " prefix_len = len(prefix) # 3) Gather spans spans = [] for ent in ner_preds: spans.append((ent["start"], ent["end"], ent["label"])) # Use the extractor‐returned start/end, minus prefix_len for src, rels in rel_preds.items(): for r in rels: # adjust the indices back onto the raw text s = r["start"] - prefix_len e = r["end"] - prefix_len # skip anything that fell outside if s < 0 or e < 0: continue label = f"{r['source']} <> {r['relation']}" spans.append((s, e, label)) # 4) Merge & build entities (same as before) merged = defaultdict(list) for s, e, lbl in spans: merged[(s, e)].append(lbl) entities = [] for (s, e), lbls in sorted(merged.items(), key=lambda x: x[0]): entities.append({ "entity": ", ".join(lbls), "start": s, "end": e }) return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds} # JSON output function def _cached_predictions(state): if not state: return "📋 No predictions yet. Click **Submit** first." return json.dumps(state, indent=2) with gr.Blocks() as demo: gr.Markdown(f"""# Data Use Detector This Space demonstrates our fine-tuned GLiNER model’s ability to spot **dataset mentions** and **relations** in any input text. It identifies dataset names via NER, then extracts relations such as **publisher**, **acronym**, **publication year**, **data geography**, and more. **How it works** 1. **NER**: Recognizes dataset names in your text. 2. **RE**: Links each dataset to its attributes (e.g., publisher, year, acronym). 3. **Visualization**: Highlights entities and relation spans inline. **Instructions** 1. Paste or edit your text in the box below. 2. Tweak the **NER** & **RE** confidence sliders. 3. Click **Submit** to see highlights. 4. Click **Get Model Predictions** to view the raw JSON output. **Resources** - **Model:** [{DATA_MODEL_ID}](https://huggingface.co/{DATA_MODEL_ID}) - **Paper:** _Large Language Models and Synthetic Data for Monitoring Dataset Mentions in Research Papers_ – ArXiv: [2502.10263](https://arxiv.org/pdf/2502.10263) - [GLiNER GitHub Repo](https://github.com/urchade/GLiNER) - [Project Docs](https://worldbank.github.io/ai4data-use/docs/introduction.html) """) txt_in = gr.Textbox( label="Input Text", lines=4, value=SAMPLE_TEXT ) ner_slider = gr.Slider( 0, 1, value=0.7, step=0.01, label="NER Threshold", info="Minimum confidence for named-entity spans." ) re_slider = gr.Slider( 0, 1, value=0.5, step=0.01, label="RE Threshold", info="Minimum confidence for relation extractions." ) highlight_btn = gr.Button("Submit") txt_out = gr.HighlightedText(label="Annotated Entities") get_pred_btn = gr.Button("Get Model Predictions") json_out = gr.Textbox(label="Model Predictions (JSON)", lines=15) state = gr.State() # Wire up interactions highlight_btn.click( fn=highlight_text, inputs=[txt_in, ner_slider, re_slider], outputs=[txt_out, state] ) get_pred_btn.click( fn=_cached_predictions, inputs=[state], outputs=[json_out] ) # Enable queue for concurrency demo.queue(default_concurrency_limit=5) # Launch the app demo.launch(debug=True, inline=True)