rafmacalaba commited on
Commit
2463f9e
·
1 Parent(s): ab71a6e

add labels and rels

Browse files
Files changed (1) hide show
  1. app.py +41 -1
app.py CHANGED
@@ -3,7 +3,7 @@ import re
3
  import json
4
  from collections import defaultdict
5
  import gradio as gr
6
-
7
  # Load environment variable for cache dir (useful on Spaces)
8
  _CACHE_DIR = os.environ.get("CACHE_DIR", None)
9
 
@@ -41,6 +41,46 @@ TYPE2RELS = {
41
  "vague dataset": rels,
42
  }
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def prune_acronym_and_self_relations(ner_preds, rel_preds):
45
  # 1) Find acronym targets strictly shorter than their source
46
  acronym_targets = {
 
3
  import json
4
  from collections import defaultdict
5
  import gradio as gr
6
+ from typing import List, Dict, Any, Tuple
7
  # Load environment variable for cache dir (useful on Spaces)
8
  _CACHE_DIR = os.environ.get("CACHE_DIR", None)
9
 
 
41
  "vague dataset": rels,
42
  }
43
 
44
+ def inference_pipeline(
45
+ text: str,
46
+ model,
47
+ labels: List[str],
48
+ relation_extractor: GLiNERRelationExtractor,
49
+ TYPE2RELS: Dict[str, List[str]],
50
+ ner_threshold: float = 0.5,
51
+ re_threshold: float = 0.4,
52
+ re_multi_label: bool = False,
53
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
54
+ ner_preds = model.predict_entities(
55
+ text,
56
+ labels,
57
+ flat_ner=True,
58
+ threshold=ner_threshold
59
+ )
60
+
61
+ re_results: Dict[str, List[Dict[str, Any]]] = {}
62
+ for ner in ner_preds:
63
+ span = ner['text']
64
+ rel_types = TYPE2RELS.get(ner['label'], [])
65
+ if not rel_types:
66
+ continue
67
+
68
+ slot_labels = [f"{span} <> {r}" for r in rel_types]
69
+
70
+ preds = relation_extractor(
71
+ text,
72
+ relations=None,
73
+ entities=None,
74
+ relation_labels=slot_labels,
75
+ threshold=re_threshold,
76
+ multi_label=re_multi_label,
77
+ distance_threshold=100,
78
+ )[0]
79
+
80
+ re_results[span] = preds
81
+
82
+ return ner_preds, re_results
83
+
84
  def prune_acronym_and_self_relations(ner_preds, rel_preds):
85
  # 1) Find acronym targets strictly shorter than their source
86
  acronym_targets = {