rafmacalaba commited on
Commit
3d53082
·
1 Parent(s): c35975c
Files changed (1) hide show
  1. app.py +96 -33
app.py CHANGED
@@ -1,56 +1,119 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- # Step 1: Textbox input
4
- # Define relation types and sample text
5
- rels = [
6
- 'acronym',
7
- 'author',
8
- 'data description',
9
- 'data geography',
10
- 'data source',
11
- 'data type',
12
- 'publication year',
13
- 'publisher',
14
- 'reference year',
15
- 'version'
16
  ]
17
- sample_text = (
18
- "Recent studies on ocean currents from the Global Ocean Temperature Dataset (GOTD) "
19
- "indicate significant shifts in marine biodiversity."
20
- )
21
 
22
- # Dummy inference echoes input + relations
23
- def dummy_inference(query: str) -> str:
24
- # TODO: replace with actual NER+RE model inference
25
- return f"Model received: '{query}' with relations: {rels}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- with gr.Blocks(title="Step 1: Input Box Demo") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  gr.Markdown(
29
  """
30
- ## Step 1: Implement a Text Input
31
- Enter any text below (prepopulated with a sample).
32
- This is where your NER + relation-extraction model will later consume the query.
33
  """
34
  )
35
  query_input = gr.Textbox(
36
  lines=4,
37
- value=sample_text,
38
- label="Input Text",
39
  placeholder="Type your text here...",
 
40
  )
41
  submit_btn = gr.Button("Submit")
42
  output_box = gr.Textbox(
43
- lines=3,
44
- label="Echoed Output",
45
  )
46
  submit_btn.click(
47
- fn=dummy_inference,
48
  inputs=[query_input],
49
  outputs=[output_box],
50
  )
51
 
52
- # Launch the demo
53
  if __name__ == "__main__":
54
- demo.queue(default_concurrency_limit=5)
55
  demo.launch(debug=True)
56
-
 
1
+ import os
2
+ import json
3
  import gradio as gr
4
+ import torch
5
+ import spaces
6
+ from gliner import GLiNER
7
+ from gliner.multitask import GLiNERRelationExtractor
8
+ from typing import List, Dict, Any, Tuple
9
+ from tqdm.auto import tqdm
10
 
11
+ # Configuration
12
+ data_model_id = "rafmacalaba/gliner_re_finetuned-v3"
13
+ CACHE_DIR = os.environ.get("CACHE_DIR", None)
14
+
15
+ # Relation types
16
+ trels = [
17
+ 'acronym', 'author', 'data description',
18
+ 'data geography', 'data source', 'data type',
19
+ 'publication year', 'publisher', 'reference year', 'version'
 
 
 
 
20
  ]
 
 
 
 
21
 
22
+ # Map NER labels to relation types
23
+ TYPE2RELS = {
24
+ "named dataset": trels,
25
+ "unnamed dataset": trels,
26
+ "vague dataset": trels,
27
+ }
28
+
29
+ # Load models
30
+ print("Loading NER+RE model...")
31
+ model = GLiNER.from_pretrained(data_model_id, cache_dir=CACHE_DIR)
32
+ relation_extractor = GLiNERRelationExtractor(model=model)
33
+ if torch.cuda.is_available():
34
+ model.to("cuda")
35
+ relation_extractor.model.to("cuda")
36
+ print("Models loaded.")
37
+
38
+ # Inference pipeline
39
+ def inference_pipeline(
40
+ text: str,
41
+ model,
42
+ labels: List[str],
43
+ relation_extractor: GLiNERRelationExtractor,
44
+ TYPE2RELS: Dict[str, List[str]],
45
+ ner_threshold: float = 0.5,
46
+ re_threshold: float = 0.4,
47
+ re_multi_label: bool = False,
48
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
49
+ # NER predictions
50
+ ner_preds = model.predict_entities(
51
+ text,
52
+ labels,
53
+ flat_ner=True,
54
+ threshold=ner_threshold
55
+ )
56
+
57
+ # Relation extraction per entity span
58
+ re_results: Dict[str, List[Dict[str, Any]]] = {}
59
+ for ner in ner_preds:
60
+ span = ner['text']
61
+ rel_types = TYPE2RELS.get(ner['label'], [])
62
+ if not rel_types:
63
+ continue
64
+ slot_labels = [f"{span} <> {r}" for r in rel_types]
65
+ preds = relation_extractor(
66
+ text,
67
+ relations=None,
68
+ entities=None,
69
+ relation_labels=slot_labels,
70
+ threshold=re_threshold,
71
+ multi_label=re_multi_label,
72
+ distance_threshold=100,
73
+ )[0]
74
+ re_results[span] = preds
75
+
76
+ return ner_preds, re_results
77
 
78
+ # Gradio UI - Step 2: Model Inference
79
+ @spaces.GPU(enable_queue=True, duration=120)
80
+ def model_inference(query: str) -> str:
81
+ labels = ["named dataset", "unnamed dataset", "vague dataset"]
82
+ ner_preds, re_results = inference_pipeline(
83
+ query,
84
+ model,
85
+ labels,
86
+ relation_extractor,
87
+ TYPE2RELS
88
+ )
89
+ output = {
90
+ "entities": ner_preds,
91
+ "relations": re_results,
92
+ }
93
+ return json.dumps(output, indent=2)
94
+
95
+ with gr.Blocks(title="Step 2: NER + Relation Inference") as demo:
96
  gr.Markdown(
97
  """
98
+ ## Step 2: Integrate Model Inference
99
+ Enter text and click submit to run your GLiNER-based NER + RE pipeline.
 
100
  """
101
  )
102
  query_input = gr.Textbox(
103
  lines=4,
 
 
104
  placeholder="Type your text here...",
105
+ label="Input Text",
106
  )
107
  submit_btn = gr.Button("Submit")
108
  output_box = gr.Textbox(
109
+ lines=15,
110
+ label="Model Output (JSON)",
111
  )
112
  submit_btn.click(
113
+ fn=model_inference,
114
  inputs=[query_input],
115
  outputs=[output_box],
116
  )
117
 
 
118
  if __name__ == "__main__":
 
119
  demo.launch(debug=True)