Spaces:
Running
Running
Commit
·
eb6e673
1
Parent(s):
079aa2a
just
Browse files
app.py
CHANGED
@@ -109,84 +109,44 @@ def prune_acronym_and_self_relations(ner_preds, rel_preds):
|
|
109 |
|
110 |
# Highlighting function
|
111 |
|
112 |
-
# def highlight_text(text, ner_threshold, re_threshold):
|
113 |
-
# # Run inference
|
114 |
-
# ner_preds, rel_preds = inference_pipeline(
|
115 |
-
# text,
|
116 |
-
# model=model,
|
117 |
-
# labels=labels,
|
118 |
-
# relation_extractor=relation_extractor,
|
119 |
-
# TYPE2RELS=TYPE2RELS,
|
120 |
-
# ner_threshold=ner_threshold,
|
121 |
-
# re_threshold=re_threshold,
|
122 |
-
# re_multi_label=False
|
123 |
-
# )
|
124 |
-
|
125 |
-
# # Post-process
|
126 |
-
# ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
|
127 |
-
|
128 |
-
# # Gather all spans
|
129 |
-
# spans = []
|
130 |
-
# for ent in ner_preds:
|
131 |
-
# spans.append((ent["start"], ent["end"], ent["label"]))
|
132 |
-
# for src, rels in rel_preds.items():
|
133 |
-
# for r in rels:
|
134 |
-
# for m in re.finditer(re.escape(r["target"]), text):
|
135 |
-
# spans.append((m.start(), m.end(), f"{src} <> {r['relation']}"))
|
136 |
-
|
137 |
-
# # Merge labels by span
|
138 |
-
# merged = defaultdict(list)
|
139 |
-
# for start, end, lbl in spans:
|
140 |
-
# merged[(start, end)].append(lbl)
|
141 |
-
|
142 |
-
# # Build Gradio entities
|
143 |
-
# entities = []
|
144 |
-
# for (start, end), lbls in sorted(merged.items(), key=lambda x: x[0]):
|
145 |
-
# entities.append({
|
146 |
-
# "entity": ", ".join(lbls),
|
147 |
-
# "start": start,
|
148 |
-
# "end": end
|
149 |
-
# })
|
150 |
def highlight_text(text, ner_threshold, re_threshold):
|
151 |
-
#
|
152 |
-
ner_preds, rel_preds = inference_pipeline(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
|
154 |
|
|
|
155 |
spans = []
|
156 |
-
# 1) NER spans
|
157 |
for ent in ner_preds:
|
158 |
spans.append((ent["start"], ent["end"], ent["label"]))
|
159 |
-
|
160 |
-
# 2) RE spans, closest‐match logic (no math import needed)
|
161 |
for src, rels in rel_preds.items():
|
162 |
-
# find the source span center
|
163 |
-
src_ent = next((e for e in ner_preds if e["text"] == src), None)
|
164 |
-
src_center = ((src_ent["start"] + src_ent["end"]) / 2) if src_ent else None
|
165 |
-
|
166 |
for r in rels:
|
167 |
-
|
168 |
-
|
169 |
-
if not matches:
|
170 |
-
continue
|
171 |
-
# pick the match whose center is nearest to src_center
|
172 |
-
if src_center is not None:
|
173 |
-
best = min(
|
174 |
-
matches,
|
175 |
-
key=lambda m: abs(((m.start() + m.end()) / 2) - src_center)
|
176 |
-
)
|
177 |
-
else:
|
178 |
-
best = matches[0]
|
179 |
-
spans.append((best.start(), best.end(), f"{src} <> {r['relation']}"))
|
180 |
-
|
181 |
-
# 3) merge & return…
|
182 |
-
merged = defaultdict(list)
|
183 |
-
for s, e, lbl in spans:
|
184 |
-
merged[(s, e)].append(lbl)
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
|
191 |
|
192 |
# JSON output function
|
|
|
109 |
|
110 |
# Highlighting function
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
def highlight_text(text, ner_threshold, re_threshold):
|
113 |
+
# Run inference
|
114 |
+
ner_preds, rel_preds = inference_pipeline(
|
115 |
+
text,
|
116 |
+
model=model,
|
117 |
+
labels=labels,
|
118 |
+
relation_extractor=relation_extractor,
|
119 |
+
TYPE2RELS=TYPE2RELS,
|
120 |
+
ner_threshold=ner_threshold,
|
121 |
+
re_threshold=re_threshold,
|
122 |
+
re_multi_label=False
|
123 |
+
)
|
124 |
+
|
125 |
+
# Post-process
|
126 |
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
|
127 |
|
128 |
+
# Gather all spans
|
129 |
spans = []
|
|
|
130 |
for ent in ner_preds:
|
131 |
spans.append((ent["start"], ent["end"], ent["label"]))
|
|
|
|
|
132 |
for src, rels in rel_preds.items():
|
|
|
|
|
|
|
|
|
133 |
for r in rels:
|
134 |
+
for m in re.finditer(re.escape(r["target"]), text):
|
135 |
+
spans.append((m.start(), m.end(), f"{src} <> {r['relation']}"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
+
# Merge labels by span
|
138 |
+
merged = defaultdict(list)
|
139 |
+
for start, end, lbl in spans:
|
140 |
+
merged[(start, end)].append(lbl)
|
141 |
+
|
142 |
+
# Build Gradio entities
|
143 |
+
entities = []
|
144 |
+
for (start, end), lbls in sorted(merged.items(), key=lambda x: x[0]):
|
145 |
+
entities.append({
|
146 |
+
"entity": ", ".join(lbls),
|
147 |
+
"start": start,
|
148 |
+
"end": end
|
149 |
+
})
|
150 |
return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
|
151 |
|
152 |
# JSON output function
|