rafmacalaba commited on
Commit
eb6e673
·
1 Parent(s): 079aa2a
Files changed (1) hide show
  1. app.py +29 -69
app.py CHANGED
@@ -109,84 +109,44 @@ def prune_acronym_and_self_relations(ner_preds, rel_preds):
109
 
110
  # Highlighting function
111
 
112
- # def highlight_text(text, ner_threshold, re_threshold):
113
- # # Run inference
114
- # ner_preds, rel_preds = inference_pipeline(
115
- # text,
116
- # model=model,
117
- # labels=labels,
118
- # relation_extractor=relation_extractor,
119
- # TYPE2RELS=TYPE2RELS,
120
- # ner_threshold=ner_threshold,
121
- # re_threshold=re_threshold,
122
- # re_multi_label=False
123
- # )
124
-
125
- # # Post-process
126
- # ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
127
-
128
- # # Gather all spans
129
- # spans = []
130
- # for ent in ner_preds:
131
- # spans.append((ent["start"], ent["end"], ent["label"]))
132
- # for src, rels in rel_preds.items():
133
- # for r in rels:
134
- # for m in re.finditer(re.escape(r["target"]), text):
135
- # spans.append((m.start(), m.end(), f"{src} <> {r['relation']}"))
136
-
137
- # # Merge labels by span
138
- # merged = defaultdict(list)
139
- # for start, end, lbl in spans:
140
- # merged[(start, end)].append(lbl)
141
-
142
- # # Build Gradio entities
143
- # entities = []
144
- # for (start, end), lbls in sorted(merged.items(), key=lambda x: x[0]):
145
- # entities.append({
146
- # "entity": ", ".join(lbls),
147
- # "start": start,
148
- # "end": end
149
- # })
150
  def highlight_text(text, ner_threshold, re_threshold):
151
- # inference + pruning …
152
- ner_preds, rel_preds = inference_pipeline(…)
 
 
 
 
 
 
 
 
 
 
 
153
  ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
154
 
 
155
  spans = []
156
- # 1) NER spans
157
  for ent in ner_preds:
158
  spans.append((ent["start"], ent["end"], ent["label"]))
159
-
160
- # 2) RE spans, closest‐match logic (no math import needed)
161
  for src, rels in rel_preds.items():
162
- # find the source span center
163
- src_ent = next((e for e in ner_preds if e["text"] == src), None)
164
- src_center = ((src_ent["start"] + src_ent["end"]) / 2) if src_ent else None
165
-
166
  for r in rels:
167
- target = r["target"]
168
- matches = list(re.finditer(re.escape(target), text))
169
- if not matches:
170
- continue
171
- # pick the match whose center is nearest to src_center
172
- if src_center is not None:
173
- best = min(
174
- matches,
175
- key=lambda m: abs(((m.start() + m.end()) / 2) - src_center)
176
- )
177
- else:
178
- best = matches[0]
179
- spans.append((best.start(), best.end(), f"{src} <> {r['relation']}"))
180
-
181
- # 3) merge & return…
182
- merged = defaultdict(list)
183
- for s, e, lbl in spans:
184
- merged[(s, e)].append(lbl)
185
 
186
- entities = [
187
- {"entity": ", ".join(lbls), "start": s, "end": e}
188
- for (s, e), lbls in sorted(merged.items(), key=lambda x: x[0])
189
- ]
 
 
 
 
 
 
 
 
 
190
  return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
191
 
192
  # JSON output function
 
109
 
110
  # Highlighting function
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def highlight_text(text, ner_threshold, re_threshold):
113
+ # Run inference
114
+ ner_preds, rel_preds = inference_pipeline(
115
+ text,
116
+ model=model,
117
+ labels=labels,
118
+ relation_extractor=relation_extractor,
119
+ TYPE2RELS=TYPE2RELS,
120
+ ner_threshold=ner_threshold,
121
+ re_threshold=re_threshold,
122
+ re_multi_label=False
123
+ )
124
+
125
+ # Post-process
126
  ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
127
 
128
+ # Gather all spans
129
  spans = []
 
130
  for ent in ner_preds:
131
  spans.append((ent["start"], ent["end"], ent["label"]))
 
 
132
  for src, rels in rel_preds.items():
 
 
 
 
133
  for r in rels:
134
+ for m in re.finditer(re.escape(r["target"]), text):
135
+ spans.append((m.start(), m.end(), f"{src} <> {r['relation']}"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ # Merge labels by span
138
+ merged = defaultdict(list)
139
+ for start, end, lbl in spans:
140
+ merged[(start, end)].append(lbl)
141
+
142
+ # Build Gradio entities
143
+ entities = []
144
+ for (start, end), lbls in sorted(merged.items(), key=lambda x: x[0]):
145
+ entities.append({
146
+ "entity": ", ".join(lbls),
147
+ "start": start,
148
+ "end": end
149
+ })
150
  return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
151
 
152
  # JSON output function