Spaces:
Running
Running
Commit
·
28e7655
1
Parent(s):
d1de289
fix highlighting
Browse files
app.py
CHANGED
@@ -112,7 +112,7 @@ def prune_acronym_and_self_relations(ner_preds, rel_preds):
|
|
112 |
# Highlighting function
|
113 |
|
114 |
def highlight_text(text, ner_threshold, rel_threshold):
|
115 |
-
#
|
116 |
ner_preds, rel_preds = inference_pipeline(
|
117 |
text,
|
118 |
model=model,
|
@@ -124,32 +124,43 @@ def highlight_text(text, ner_threshold, rel_threshold):
|
|
124 |
re_multi_label=False,
|
125 |
return_index=True,
|
126 |
)
|
127 |
-
|
128 |
-
# Post-process
|
129 |
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
|
130 |
|
131 |
-
#
|
|
|
|
|
|
|
|
|
|
|
132 |
spans = []
|
133 |
for ent in ner_preds:
|
134 |
spans.append((ent["start"], ent["end"], ent["label"]))
|
|
|
|
|
135 |
for src, rels in rel_preds.items():
|
136 |
for r in rels:
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
merged = defaultdict(list)
|
142 |
-
for
|
143 |
-
merged[(
|
144 |
|
145 |
-
# Build Gradio entities
|
146 |
entities = []
|
147 |
-
for (
|
148 |
entities.append({
|
149 |
"entity": ", ".join(lbls),
|
150 |
-
"start":
|
151 |
-
"end":
|
152 |
})
|
|
|
153 |
return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
|
154 |
|
155 |
# JSON output function
|
|
|
112 |
# Highlighting function
|
113 |
|
114 |
def highlight_text(text, ner_threshold, rel_threshold):
|
115 |
+
# 1) Inference
|
116 |
ner_preds, rel_preds = inference_pipeline(
|
117 |
text,
|
118 |
model=model,
|
|
|
124 |
re_multi_label=False,
|
125 |
return_index=True,
|
126 |
)
|
|
|
|
|
127 |
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
|
128 |
|
129 |
+
# 2) Compute how long the RE prompt prefix is
|
130 |
+
# This must match exactly what your extractor prepends:
|
131 |
+
prefix = f"{relation_extractor.prompt} \n "
|
132 |
+
prefix_len = len(prefix)
|
133 |
+
|
134 |
+
# 3) Gather spans
|
135 |
spans = []
|
136 |
for ent in ner_preds:
|
137 |
spans.append((ent["start"], ent["end"], ent["label"]))
|
138 |
+
|
139 |
+
# Use the extractor‐returned start/end, minus prefix_len
|
140 |
for src, rels in rel_preds.items():
|
141 |
for r in rels:
|
142 |
+
# adjust the indices back onto the raw text
|
143 |
+
s = r["start"] - prefix_len
|
144 |
+
e = r["end"] - prefix_len
|
145 |
+
# skip anything that fell outside
|
146 |
+
if s < 0 or e < 0:
|
147 |
+
continue
|
148 |
+
label = f"{r['source']} <> {r['relation']}"
|
149 |
+
spans.append((s, e, label))
|
150 |
+
|
151 |
+
# 4) Merge & build entities (same as before)
|
152 |
merged = defaultdict(list)
|
153 |
+
for s, e, lbl in spans:
|
154 |
+
merged[(s, e)].append(lbl)
|
155 |
|
|
|
156 |
entities = []
|
157 |
+
for (s, e), lbls in sorted(merged.items(), key=lambda x: x[0]):
|
158 |
entities.append({
|
159 |
"entity": ", ".join(lbls),
|
160 |
+
"start": s,
|
161 |
+
"end": e
|
162 |
})
|
163 |
+
|
164 |
return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
|
165 |
|
166 |
# JSON output function
|