rafmacalaba commited on
Commit
079aa2a
·
1 Parent(s): b99c40c

just highlight closest

Browse files
Files changed (1) hide show
  1. app.py +68 -29
app.py CHANGED
@@ -109,45 +109,84 @@ def prune_acronym_and_self_relations(ner_preds, rel_preds):
109
 
110
  # Highlighting function
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def highlight_text(text, ner_threshold, re_threshold):
113
- # Run inference
114
- ner_preds, rel_preds = inference_pipeline(
115
- text,
116
- model=model,
117
- labels=labels,
118
- relation_extractor=relation_extractor,
119
- TYPE2RELS=TYPE2RELS,
120
- ner_threshold=ner_threshold,
121
- re_threshold=re_threshold,
122
- re_multi_label=False
123
- )
124
-
125
- # Post-process
126
  ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
127
 
128
- # Gather all spans
129
  spans = []
 
130
  for ent in ner_preds:
131
  spans.append((ent["start"], ent["end"], ent["label"]))
 
 
132
  for src, rels in rel_preds.items():
133
- for r in rels:
134
- for m in re.finditer(re.escape(r["target"]), text):
135
- spans.append((m.start(), m.end(), f"{src} <> {r['relation']}"))
136
 
137
- # Merge labels by span
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  merged = defaultdict(list)
139
- for start, end, lbl in spans:
140
- merged[(start, end)].append(lbl)
141
-
142
- # Build Gradio entities
143
- entities = []
144
- for (start, end), lbls in sorted(merged.items(), key=lambda x: x[0]):
145
- entities.append({
146
- "entity": ", ".join(lbls),
147
- "start": start,
148
- "end": end
149
- })
150
 
 
 
 
 
151
  return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
152
 
153
  # JSON output function
 
109
 
110
  # Highlighting function
111
 
112
+ # def highlight_text(text, ner_threshold, re_threshold):
113
+ # # Run inference
114
+ # ner_preds, rel_preds = inference_pipeline(
115
+ # text,
116
+ # model=model,
117
+ # labels=labels,
118
+ # relation_extractor=relation_extractor,
119
+ # TYPE2RELS=TYPE2RELS,
120
+ # ner_threshold=ner_threshold,
121
+ # re_threshold=re_threshold,
122
+ # re_multi_label=False
123
+ # )
124
+
125
+ # # Post-process
126
+ # ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
127
+
128
+ # # Gather all spans
129
+ # spans = []
130
+ # for ent in ner_preds:
131
+ # spans.append((ent["start"], ent["end"], ent["label"]))
132
+ # for src, rels in rel_preds.items():
133
+ # for r in rels:
134
+ # for m in re.finditer(re.escape(r["target"]), text):
135
+ # spans.append((m.start(), m.end(), f"{src} <> {r['relation']}"))
136
+
137
+ # # Merge labels by span
138
+ # merged = defaultdict(list)
139
+ # for start, end, lbl in spans:
140
+ # merged[(start, end)].append(lbl)
141
+
142
+ # # Build Gradio entities
143
+ # entities = []
144
+ # for (start, end), lbls in sorted(merged.items(), key=lambda x: x[0]):
145
+ # entities.append({
146
+ # "entity": ", ".join(lbls),
147
+ # "start": start,
148
+ # "end": end
149
+ # })
150
  def highlight_text(text, ner_threshold, re_threshold):
151
+ # inference + pruning …
152
+ ner_preds, rel_preds = inference_pipeline(…)
 
 
 
 
 
 
 
 
 
 
 
153
  ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
154
 
 
155
  spans = []
156
+ # 1) NER spans
157
  for ent in ner_preds:
158
  spans.append((ent["start"], ent["end"], ent["label"]))
159
+
160
+ # 2) RE spans, closest‐match logic (no math import needed)
161
  for src, rels in rel_preds.items():
162
+ # find the source span center
163
+ src_ent = next((e for e in ner_preds if e["text"] == src), None)
164
+ src_center = ((src_ent["start"] + src_ent["end"]) / 2) if src_ent else None
165
 
166
+ for r in rels:
167
+ target = r["target"]
168
+ matches = list(re.finditer(re.escape(target), text))
169
+ if not matches:
170
+ continue
171
+ # pick the match whose center is nearest to src_center
172
+ if src_center is not None:
173
+ best = min(
174
+ matches,
175
+ key=lambda m: abs(((m.start() + m.end()) / 2) - src_center)
176
+ )
177
+ else:
178
+ best = matches[0]
179
+ spans.append((best.start(), best.end(), f"{src} <> {r['relation']}"))
180
+
181
+ # 3) merge & return…
182
  merged = defaultdict(list)
183
+ for s, e, lbl in spans:
184
+ merged[(s, e)].append(lbl)
 
 
 
 
 
 
 
 
 
185
 
186
+ entities = [
187
+ {"entity": ", ".join(lbls), "start": s, "end": e}
188
+ for (s, e), lbls in sorted(merged.items(), key=lambda x: x[0])
189
+ ]
190
  return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
191
 
192
  # JSON output function