rafmacalaba commited on
Commit
28e7655
·
1 Parent(s): d1de289

fix highlighting

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -112,7 +112,7 @@ def prune_acronym_and_self_relations(ner_preds, rel_preds):
112
  # Highlighting function
113
 
114
  def highlight_text(text, ner_threshold, rel_threshold):
115
- # Run inference
116
  ner_preds, rel_preds = inference_pipeline(
117
  text,
118
  model=model,
@@ -124,32 +124,43 @@ def highlight_text(text, ner_threshold, rel_threshold):
124
  re_multi_label=False,
125
  return_index=True,
126
  )
127
-
128
- # Post-process
129
  ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
130
 
131
- # Gather all spans
 
 
 
 
 
132
  spans = []
133
  for ent in ner_preds:
134
  spans.append((ent["start"], ent["end"], ent["label"]))
 
 
135
  for src, rels in rel_preds.items():
136
  for r in rels:
137
- for m in re.finditer(re.escape(r["target"]), text):
138
- spans.append((m.start(), m.end(), f"{src} <> {r['relation']}"))
139
-
140
- # Merge labels by span
 
 
 
 
 
 
141
  merged = defaultdict(list)
142
- for start, end, lbl in spans:
143
- merged[(start, end)].append(lbl)
144
 
145
- # Build Gradio entities
146
  entities = []
147
- for (start, end), lbls in sorted(merged.items(), key=lambda x: x[0]):
148
  entities.append({
149
  "entity": ", ".join(lbls),
150
- "start": start,
151
- "end": end
152
  })
 
153
  return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
154
 
155
  # JSON output function
 
112
  # Highlighting function
113
 
114
  def highlight_text(text, ner_threshold, rel_threshold):
115
+ # 1) Inference
116
  ner_preds, rel_preds = inference_pipeline(
117
  text,
118
  model=model,
 
124
  re_multi_label=False,
125
  return_index=True,
126
  )
 
 
127
  ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
128
 
129
+ # 2) Compute how long the RE prompt prefix is
130
+ # This must match exactly what your extractor prepends:
131
+ prefix = f"{relation_extractor.prompt} \n "
132
+ prefix_len = len(prefix)
133
+
134
+ # 3) Gather spans
135
  spans = []
136
  for ent in ner_preds:
137
  spans.append((ent["start"], ent["end"], ent["label"]))
138
+
139
+ # Use the extractor‐returned start/end, minus prefix_len
140
  for src, rels in rel_preds.items():
141
  for r in rels:
142
+ # adjust the indices back onto the raw text
143
+ s = r["start"] - prefix_len
144
+ e = r["end"] - prefix_len
145
+ # skip anything that fell outside
146
+ if s < 0 or e < 0:
147
+ continue
148
+ label = f"{r['source']} <> {r['relation']}"
149
+ spans.append((s, e, label))
150
+
151
+ # 4) Merge & build entities (same as before)
152
  merged = defaultdict(list)
153
+ for s, e, lbl in spans:
154
+ merged[(s, e)].append(lbl)
155
 
 
156
  entities = []
157
+ for (s, e), lbls in sorted(merged.items(), key=lambda x: x[0]):
158
  entities.append({
159
  "entity": ", ".join(lbls),
160
+ "start": s,
161
+ "end": e
162
  })
163
+
164
  return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
165
 
166
  # JSON output function