rafmacalaba commited on
Commit
a508c57
·
1 Parent(s): 0fd2cd1
Files changed (2) hide show
  1. app.py +50 -102
  2. requirements.py +2 -4
app.py CHANGED
@@ -1,118 +1,66 @@
1
- import os
2
- import json
3
  import gradio as gr
4
- # import torch
5
- import spaces
6
- from gliner import GLiNER
7
- from gliner.multitask import GLiNERRelationExtractor
8
- from typing import List, Dict, Any, Tuple
9
 
10
- # Configuration
11
- data_model_id = "rafmacalaba/gliner_re_finetuned-v3"
12
- CACHE_DIR = os.environ.get("CACHE_DIR", None)
13
-
14
- # Relation types
15
- trels = [
 
 
 
 
16
  'acronym', 'author', 'data description',
17
  'data geography', 'data source', 'data type',
18
  'publication year', 'publisher', 'reference year', 'version'
19
  ]
 
20
 
21
- # Map NER labels to relation types
22
- TYPE2RELS = {
23
- "named dataset": trels,
24
- "unnamed dataset": trels,
25
- "vague dataset": trels,
26
- }
27
-
28
- # Load models
29
- print("Loading NER+RE model...")
30
- model = GLiNER.from_pretrained(data_model_id, cache_dir=CACHE_DIR)
31
- relation_extractor = GLiNERRelationExtractor(model=model)
32
- # if torch.cuda.is_available():
33
- # model.to("cuda")
34
- # relation_extractor.model.to("cuda")
35
- # print("Models loaded.")
36
-
37
- # Inference pipeline
38
- def inference_pipeline(
39
- text: str,
40
- model,
41
- labels: List[str],
42
- relation_extractor: GLiNERRelationExtractor,
43
- TYPE2RELS: Dict[str, List[str]],
44
- ner_threshold: float = 0.5,
45
- re_threshold: float = 0.4,
46
- re_multi_label: bool = False,
47
- ) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
48
- # NER predictions
49
- ner_preds = model.predict_entities(
50
- text,
51
- labels,
52
- flat_ner=True,
53
- threshold=ner_threshold
54
- )
55
-
56
- # Relation extraction per entity span
57
- re_results: Dict[str, List[Dict[str, Any]]] = {}
58
- for ner in ner_preds:
59
- span = ner['text']
60
- rel_types = TYPE2RELS.get(ner['label'], [])
61
- if not rel_types:
62
- continue
63
- slot_labels = [f"{span} <> {r}" for r in rel_types]
64
- preds = relation_extractor(
65
- text,
66
- relations=None,
67
- entities=None,
68
- relation_labels=slot_labels,
69
- threshold=re_threshold,
70
- multi_label=re_multi_label,
71
- distance_threshold=100,
72
- )[0]
73
- re_results[span] = preds
74
 
75
- return ner_preds, re_results
76
-
77
- # Gradio UI - Step 2: Model Inference
78
- @spaces.GPU(enable_queue=True, duration=120)
79
- def model_inference(query: str) -> str:
80
- labels = ["named dataset", "unnamed dataset", "vague dataset"]
81
- ner_preds, re_results = inference_pipeline(
82
- query,
83
- model,
84
- labels,
85
- relation_extractor,
86
- TYPE2RELS
87
- )
88
- output = {
89
- "entities": ner_preds,
90
- "relations": re_results,
91
  }
92
- return json.dumps(output, indent=2)
93
 
94
- with gr.Blocks(title="Step 2: NER + Relation Inference") as demo:
 
 
95
  gr.Markdown(
96
  """
97
- ## Step 2: Integrate Model Inference
98
- Enter text and click submit to run your GLiNER-based NER + RE pipeline.
99
  """
100
  )
101
- query_input = gr.Textbox(
102
- lines=4,
103
- placeholder="Type your text here...",
104
- label="Input Text",
105
- )
106
- submit_btn = gr.Button("Submit")
107
- output_box = gr.Textbox(
108
- lines=15,
109
- label="Model Output (JSON)",
110
- )
111
- submit_btn.click(
112
- fn=model_inference,
113
- inputs=[query_input],
114
- outputs=[output_box],
 
 
115
  )
116
 
117
- if __name__ == "__main__":
118
- demo.launch(debug=True)
 
 
 
1
  import gradio as gr
2
+ from typing import Union, Dict, Any, List
 
 
 
 
3
 
4
+ # Sample pre-calculated entities
5
+ sample_text = (
6
+ "Recent studies on ocean currents from the Global Ocean Temperature Dataset "
7
+ "(GOTD) indicate significant shifts in marine biodiversity."
8
+ )
9
+ sample_entities = [
10
+ {"label": "named dataset", "text": "Global Ocean Temperature Dataset", "start": 29, "end": 62, "score": 0.99},
11
+ {"label": "acronym", "text": "GOTD", "start": 64, "end": 68, "score": 0.98},
12
+ ]
13
+ rels = [
14
  'acronym', 'author', 'data description',
15
  'data geography', 'data source', 'data type',
16
  'publication year', 'publisher', 'reference year', 'version'
17
  ]
18
+ MODELS = ["demo-model-1", "demo-model-2"]
19
 
20
+ # Annotate_query simulation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def annotate_query(
23
+ query: str,
24
+ labels: Union[str, List[str]],
25
+ threshold: float = 0.3,
26
+ nested_ner: bool = False,
27
+ model_name: str = None
28
+ ) -> Dict[str, Any]:
29
+ # In a real app, you'd call parse_query/inference_pipeline here.
30
+ # For simulation, reuse sample_entities.
31
+ return {
32
+ "text": query,
33
+ "entities": [
34
+ {"start": ent["start"], "end": ent["end"], "label": ent["label"]}
35
+ for ent in sample_entities
36
+ ]
 
37
  }
 
38
 
39
+ # Build Gradio UI
40
+ demo = gr.Blocks()
41
+ with demo:
42
  gr.Markdown(
43
  """
44
+ ## Step: Annotate Query Simulation
45
+ Enter text (prepopulated) and click **Annotate** to see how entities are highlighted.
46
  """
47
  )
48
+ # Inputs
49
+ query = gr.Textbox(lines=3, value=sample_text, label="Input Text")
50
+ entities = gr.Textbox(value=", ".join(rels), label="Relations (unused in simulation)")
51
+ threshold = gr.Slider(0, 1, value=0.3, step=0.01, label="Threshold")
52
+ nested = gr.Checkbox(value=False, label="Nested NER")
53
+ model = gr.Radio(choices=MODELS, value=MODELS[0], label="Model")
54
+
55
+ # Outputs
56
+ output_hl = gr.HighlightedText(label="Annotated Entities")
57
+
58
+ # Button
59
+ annotate_btn = gr.Button("Annotate")
60
+ annotate_btn.click(
61
+ fn=annotate_query,
62
+ inputs=[query, entities, threshold, nested, model],
63
+ outputs=[output_hl]
64
  )
65
 
66
+ demo.launch(debug=True)
 
requirements.py CHANGED
@@ -1,4 +1,2 @@
1
- gradio>=3.0
2
- gliner # your GLiNER package
3
- torch # PyTorch (CPU build by default)
4
- huggingface-hub
 
1
+ gradio
2
+ spaces