Spaces:
Running
Running
Upload elastic.py
Browse files- ask_candid/retrieval/elastic.py +500 -500
ask_candid/retrieval/elastic.py
CHANGED
@@ -1,500 +1,500 @@
|
|
1 |
-
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
|
2 |
-
from dataclasses import dataclass
|
3 |
-
from functools import partial
|
4 |
-
from itertools import groupby
|
5 |
-
|
6 |
-
from torch.nn import functional as F
|
7 |
-
|
8 |
-
from pydantic import BaseModel, Field
|
9 |
-
from langchain_core.documents import Document
|
10 |
-
from langchain_core.tools import Tool
|
11 |
-
|
12 |
-
from elasticsearch import Elasticsearch
|
13 |
-
|
14 |
-
from ask_candid.services.small_lm import CandidSLM
|
15 |
-
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
|
16 |
-
from ask_candid.base.config.data import ElasticIndexMapping, ALL_INDICES
|
17 |
-
|
18 |
-
|
19 |
-
@dataclass
|
20 |
-
class ElasticHitsResult:
|
21 |
-
"""Dataclass for Elasticsearch hits results
|
22 |
-
"""
|
23 |
-
index: str
|
24 |
-
id: Any
|
25 |
-
score: float
|
26 |
-
source: Dict[str, Any]
|
27 |
-
inner_hits: Dict[str, Any]
|
28 |
-
|
29 |
-
|
30 |
-
class RetrieverInput(BaseModel):
|
31 |
-
"""Input to the Elasticsearch retriever."""
|
32 |
-
user_input: str = Field(description="query to look up in retriever")
|
33 |
-
|
34 |
-
|
35 |
-
def build_sparse_vector_query(
|
36 |
-
query: str,
|
37 |
-
fields: Tuple[str],
|
38 |
-
inference_id: str = ".elser-2-elasticsearch"
|
39 |
-
) -> Dict[str, Any]:
|
40 |
-
"""Builds a valid Elasticsearch text expansion query payload
|
41 |
-
|
42 |
-
Parameters
|
43 |
-
----------
|
44 |
-
query : str
|
45 |
-
Search context string
|
46 |
-
fields : Tuple[str]
|
47 |
-
Semantic text field names
|
48 |
-
inference_id : str, optional
|
49 |
-
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
|
50 |
-
|
51 |
-
Returns
|
52 |
-
-------
|
53 |
-
Dict[str, Any]
|
54 |
-
"""
|
55 |
-
|
56 |
-
output = []
|
57 |
-
|
58 |
-
for f in fields:
|
59 |
-
output.append({
|
60 |
-
"nested": {
|
61 |
-
"path": f"embeddings.{f}.chunks",
|
62 |
-
"query": {
|
63 |
-
"sparse_vector": {
|
64 |
-
"field": f"embeddings.{f}.chunks.vector",
|
65 |
-
"inference_id": inference_id,
|
66 |
-
"query": query,
|
67 |
-
"boost": 1 / len(fields)
|
68 |
-
}
|
69 |
-
},
|
70 |
-
"inner_hits": {
|
71 |
-
"_source": False,
|
72 |
-
"size": 2,
|
73 |
-
"fields": [f"embeddings.{f}.chunks.chunk"]
|
74 |
-
}
|
75 |
-
}
|
76 |
-
})
|
77 |
-
return {"query": {"bool": {"should": output}}}
|
78 |
-
|
79 |
-
|
80 |
-
def query_builder(query: str, indices: List[str]) -> List[Dict[str, Any]]:
|
81 |
-
"""Builds Elasticsearch multi-search query payload
|
82 |
-
|
83 |
-
Parameters
|
84 |
-
----------
|
85 |
-
query : str
|
86 |
-
Search context string
|
87 |
-
indices : List[str]
|
88 |
-
Semantic index names to search over
|
89 |
-
|
90 |
-
Returns
|
91 |
-
-------
|
92 |
-
List[Dict[str, Any]]
|
93 |
-
"""
|
94 |
-
|
95 |
-
queries = []
|
96 |
-
if indices is None:
|
97 |
-
indices = list(ALL_INDICES)
|
98 |
-
|
99 |
-
for index in indices:
|
100 |
-
if index == "issuelab":
|
101 |
-
q = build_sparse_vector_query(
|
102 |
-
query=query,
|
103 |
-
fields=("description", "content", "combined_issuelab_findings", "combined_item_description")
|
104 |
-
)
|
105 |
-
q["_source"] = {"excludes": ["embeddings"]}
|
106 |
-
q["size"] = 1
|
107 |
-
queries.extend([{"index": ElasticIndexMapping.ISSUELAB_INDEX_ELSER}, q])
|
108 |
-
elif index == "youtube":
|
109 |
-
q = build_sparse_vector_query(
|
110 |
-
query=query,
|
111 |
-
fields=("captions_cleaned", "description_cleaned", "title")
|
112 |
-
)
|
113 |
-
# text_cleaned duplicates captions_cleaned
|
114 |
-
q["_source"] = {"excludes": ["embeddings", "captions", "description", "text_cleaned"]}
|
115 |
-
q["size"] = 2
|
116 |
-
queries.extend([{"index": ElasticIndexMapping.YOUTUBE_INDEX_ELSER}, q])
|
117 |
-
elif index == "candid_blog":
|
118 |
-
q = build_sparse_vector_query(
|
119 |
-
query=query,
|
120 |
-
fields=("content", "authors_text", "title_summary_tags")
|
121 |
-
)
|
122 |
-
q["_source"] = {"excludes": ["embeddings"]}
|
123 |
-
q["size"] = 2
|
124 |
-
queries.extend([{"index": ElasticIndexMapping.CANDID_BLOG_INDEX_ELSER}, q])
|
125 |
-
elif index == "candid_learning":
|
126 |
-
q = build_sparse_vector_query(
|
127 |
-
query=query,
|
128 |
-
fields=("content", "title", "training_topics", "staff_recommendations")
|
129 |
-
)
|
130 |
-
q["_source"] = {"excludes": ["embeddings"]}
|
131 |
-
q["size"] = 2
|
132 |
-
queries.extend([{"index": ElasticIndexMapping.CANDID_LEARNING_INDEX_ELSER}, q])
|
133 |
-
elif index == "candid_help":
|
134 |
-
q = build_sparse_vector_query(
|
135 |
-
query=query,
|
136 |
-
fields=("content", "combined_article_description")
|
137 |
-
)
|
138 |
-
q["_source"] = {"excludes": ["embeddings"]}
|
139 |
-
q["size"] = 2
|
140 |
-
queries.extend([{"index": ElasticIndexMapping.CANDID_HELP_INDEX_ELSER}, q])
|
141 |
-
|
142 |
-
return queries
|
143 |
-
|
144 |
-
|
145 |
-
def multi_search(queries: List[Dict[str, Any]]) -> List[ElasticHitsResult]:
|
146 |
-
"""Runs multi-search query
|
147 |
-
|
148 |
-
Parameters
|
149 |
-
----------
|
150 |
-
queries : List[Dict[str, Any]]
|
151 |
-
Pre-built multi-search query payload
|
152 |
-
|
153 |
-
Returns
|
154 |
-
-------
|
155 |
-
List[ElasticHitsResult]
|
156 |
-
"""
|
157 |
-
|
158 |
-
results = []
|
159 |
-
with Elasticsearch(
|
160 |
-
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
|
161 |
-
api_key=SEMANTIC_ELASTIC_QA.api_key,
|
162 |
-
verify_certs=False,
|
163 |
-
request_timeout=60 * 3
|
164 |
-
) as es:
|
165 |
-
for query_group in es.msearch(body=queries).get("responses", []):
|
166 |
-
for hit in query_group.get("hits", {}).get("hits", []):
|
167 |
-
hit = ElasticHitsResult(
|
168 |
-
index=hit["_index"],
|
169 |
-
id=hit["_id"],
|
170 |
-
score=hit["_score"],
|
171 |
-
source=hit["_source"],
|
172 |
-
inner_hits=hit.get("inner_hits", {})
|
173 |
-
)
|
174 |
-
results.append(hit)
|
175 |
-
return results
|
176 |
-
|
177 |
-
|
178 |
-
def get_query_results(search_text: str, indices: Optional[List[str]] = None) -> List[ElasticHitsResult]:
|
179 |
-
"""Builds and executes Elasticsearch data queries from a search string.
|
180 |
-
|
181 |
-
Parameters
|
182 |
-
----------
|
183 |
-
search_text : str
|
184 |
-
Search context string
|
185 |
-
indices : Optional[List[str]], optional
|
186 |
-
Semantic index names to search over, by default None
|
187 |
-
|
188 |
-
Returns
|
189 |
-
-------
|
190 |
-
List[ElasticHitsResult]
|
191 |
-
"""
|
192 |
-
|
193 |
-
queries = query_builder(query=search_text, indices=indices)
|
194 |
-
return multi_search(queries)
|
195 |
-
|
196 |
-
|
197 |
-
def retrieved_text(hits: Dict[str, Any]) -> str:
|
198 |
-
"""Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of
|
199 |
-
re-scoring by a secondary language model.
|
200 |
-
|
201 |
-
Parameters
|
202 |
-
----------
|
203 |
-
hits : Dict[str, Any]
|
204 |
-
|
205 |
-
Returns
|
206 |
-
-------
|
207 |
-
str
|
208 |
-
"""
|
209 |
-
|
210 |
-
text = []
|
211 |
-
for _, v in hits.items():
|
212 |
-
for h in (v.get("hits", {}).get("hits") or []):
|
213 |
-
for _, field in h.get("fields", {}).items():
|
214 |
-
for chunk in field:
|
215 |
-
if chunk.get("chunk"):
|
216 |
-
text.extend(chunk["chunk"])
|
217 |
-
return '\n'.join(text)
|
218 |
-
|
219 |
-
|
220 |
-
def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
|
221 |
-
"""Computes cosine scores between retrieved contexts and the original query to re-score results based on overall
|
222 |
-
relevance to the original query.
|
223 |
-
|
224 |
-
Parameters
|
225 |
-
----------
|
226 |
-
query : str
|
227 |
-
Search context string
|
228 |
-
contexts : List[str]
|
229 |
-
Semantic field sub-texts, order is by document retrieved from the original multi-search query.
|
230 |
-
|
231 |
-
Returns
|
232 |
-
-------
|
233 |
-
List[float]
|
234 |
-
Scores in the same order as the input document contexts
|
235 |
-
"""
|
236 |
-
|
237 |
-
nlp = CandidSLM()
|
238 |
-
X = nlp.encode([query, *contexts]).vectors
|
239 |
-
X = F.normalize(X, dim=-1, p=2.)
|
240 |
-
cosine = X[1:] @ X[:1].T
|
241 |
-
return cosine.flatten().cpu().numpy().tolist()
|
242 |
-
|
243 |
-
|
244 |
-
def reranker(
|
245 |
-
query_results: Iterable[ElasticHitsResult],
|
246 |
-
search_text: Optional[str] = None
|
247 |
-
) -> Iterator[ElasticHitsResult]:
|
248 |
-
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
|
249 |
-
This will shuffle results
|
250 |
-
|
251 |
-
Parameters
|
252 |
-
----------
|
253 |
-
query_results : Iterable[ElasticHitsResult]
|
254 |
-
|
255 |
-
Yields
|
256 |
-
------
|
257 |
-
Iterator[ElasticHitsResult]
|
258 |
-
"""
|
259 |
-
|
260 |
-
results: List[ElasticHitsResult] = []
|
261 |
-
texts: List[str] = []
|
262 |
-
for _, data in groupby(query_results, key=lambda x: x.index):
|
263 |
-
data = list(data)
|
264 |
-
max_score = max(data, key=lambda x: x.score).score
|
265 |
-
min_score = min(data, key=lambda x: x.score).score
|
266 |
-
|
267 |
-
for d in data:
|
268 |
-
d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
|
269 |
-
results.append(d)
|
270 |
-
|
271 |
-
if search_text:
|
272 |
-
text = retrieved_text(d.inner_hits)
|
273 |
-
texts.append(text)
|
274 |
-
|
275 |
-
# if search_text and len(texts) == len(results):
|
276 |
-
# scores = cosine_rescore(search_text, texts)
|
277 |
-
# for r, s in zip(results, scores):
|
278 |
-
# r.score = s
|
279 |
-
|
280 |
-
yield from sorted(results, key=lambda x: x.score, reverse=True)
|
281 |
-
|
282 |
-
|
283 |
-
def get_results(user_input: str, indices: List[str]) -> Tuple[str, List[Document]]:
|
284 |
-
"""End-to-end search and re-rank function.
|
285 |
-
|
286 |
-
Parameters
|
287 |
-
----------
|
288 |
-
user_input : str
|
289 |
-
Search context string
|
290 |
-
indices : List[str]
|
291 |
-
Semantic index names to search over
|
292 |
-
|
293 |
-
Returns
|
294 |
-
-------
|
295 |
-
Tuple[str, List[Document]]
|
296 |
-
(concatenated text from search results, documents list)
|
297 |
-
"""
|
298 |
-
|
299 |
-
output = ["Search didn't return any Candid sources"]
|
300 |
-
page_content = []
|
301 |
-
content = "Search didn't return any Candid sources"
|
302 |
-
results = get_query_results(search_text=user_input, indices=indices)
|
303 |
-
if results:
|
304 |
-
output = get_reranked_results(results, search_text=user_input)
|
305 |
-
for doc in output:
|
306 |
-
page_content.append(doc.page_content)
|
307 |
-
content = "\n\n".join(page_content)
|
308 |
-
|
309 |
-
# for the tool we need to return a tuple for content_and_artifact type
|
310 |
-
return content, output
|
311 |
-
|
312 |
-
|
313 |
-
# TODO make it better!
|
314 |
-
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
|
315 |
-
"""Pads the relevant chunk of text with context before and after
|
316 |
-
|
317 |
-
Parameters
|
318 |
-
----------
|
319 |
-
field_name : str
|
320 |
-
a field with the long text that was chunked into pieces
|
321 |
-
hit : ElasticHitsResult
|
322 |
-
context_length : int, optional
|
323 |
-
length of text to add before and after the chunk, by default 1024
|
324 |
-
|
325 |
-
Returns
|
326 |
-
-------
|
327 |
-
str
|
328 |
-
longer chunks stuffed together
|
329 |
-
"""
|
330 |
-
|
331 |
-
chunks = []
|
332 |
-
# TODO chunks have tokens, but long text is a normal text, but may contain html that also gets weird after tokenization
|
333 |
-
long_text = hit.source.get(f"{field_name}", "")
|
334 |
-
long_text = long_text.lower()
|
335 |
-
inner_hits_field = f"embeddings.{field_name}.chunks"
|
336 |
-
found_chunks = hit.inner_hits.get(inner_hits_field, {})
|
337 |
-
if found_chunks:
|
338 |
-
hits = found_chunks.get("hits", {}).get("hits", [])
|
339 |
-
for h in hits:
|
340 |
-
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
|
341 |
-
|
342 |
-
# cutting the middle because we may have tokenizing artifacts there
|
343 |
-
chunk = chunk[3: -3]
|
344 |
-
|
345 |
-
if add_context:
|
346 |
-
# Find the start and end indices of the chunk in the large text
|
347 |
-
start_index = long_text.find(chunk[:20])
|
348 |
-
if start_index != -1: # Chunk is found
|
349 |
-
end_index = start_index + len(chunk)
|
350 |
-
pre_start_index = max(0, start_index - context_length)
|
351 |
-
post_end_index = min(len(long_text), end_index + context_length)
|
352 |
-
chunks.append(long_text[pre_start_index:post_end_index])
|
353 |
-
else:
|
354 |
-
chunks.append(chunk)
|
355 |
-
return '\n\n'.join(chunks)
|
356 |
-
|
357 |
-
|
358 |
-
def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
|
359 |
-
"""Parse Elasticsearch hit results into data structures handled by the RAG pipeline.
|
360 |
-
|
361 |
-
Parameters
|
362 |
-
----------
|
363 |
-
hit : ElasticHitsResult
|
364 |
-
|
365 |
-
Returns
|
366 |
-
-------
|
367 |
-
Union[Document, None]
|
368 |
-
"""
|
369 |
-
|
370 |
-
if "issuelab-elser" in hit.index:
|
371 |
-
combined_item_description = hit.source.get("combined_item_description", "") # title inside
|
372 |
-
description = hit.source.get("description", "")
|
373 |
-
combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "")
|
374 |
-
# we only need to process long texts
|
375 |
-
chunks_with_context_txt = get_context("content", hit, context_length=12)
|
376 |
-
doc = Document(
|
377 |
-
page_content='\n\n'.join([
|
378 |
-
combined_item_description,
|
379 |
-
combined_issuelab_findings,
|
380 |
-
description,
|
381 |
-
chunks_with_context_txt
|
382 |
-
]),
|
383 |
-
metadata={
|
384 |
-
"title": hit.source["title"],
|
385 |
-
"source": "IssueLab",
|
386 |
-
"source_id": hit.source["resource_id"],
|
387 |
-
"url": hit.source.get("permalink", "")
|
388 |
-
}
|
389 |
-
)
|
390 |
-
elif "youtube" in hit.index:
|
391 |
-
title = hit.source.get("title", "")
|
392 |
-
# we only need to process long texts
|
393 |
-
description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12)
|
394 |
-
captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12)
|
395 |
-
doc = Document(
|
396 |
-
page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]),
|
397 |
-
metadata={
|
398 |
-
"title": title,
|
399 |
-
"source": "Candid YouTube",
|
400 |
-
"source_id": hit.source['video_id'],
|
401 |
-
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
|
402 |
-
}
|
403 |
-
)
|
404 |
-
elif "candid-blog" in hit.index:
|
405 |
-
excerpt = hit.source.get("excerpt", "")
|
406 |
-
title = hit.source.get("title", "")
|
407 |
-
# we only need to process long text
|
408 |
-
content_with_context_txt = get_context("content", hit, context_length=12, add_context=False)
|
409 |
-
authors = get_context("authors_text", hit, context_length=12, add_context=False)
|
410 |
-
tags = hit.source.get("title_summary_tags", "")
|
411 |
-
doc = Document(
|
412 |
-
page_content='\n\n'.join([title, excerpt, content_with_context_txt, authors, tags]),
|
413 |
-
metadata={
|
414 |
-
"title": title,
|
415 |
-
"source": "Candid Blog",
|
416 |
-
"source_id": hit.source["id"],
|
417 |
-
"url": hit.source["link"]
|
418 |
-
}
|
419 |
-
)
|
420 |
-
elif "candid-learning" in hit.index:
|
421 |
-
title = hit.source.get("title", "")
|
422 |
-
content_with_context_txt = get_context("content", hit, context_length=12)
|
423 |
-
training_topics = hit.source.get("training_topics", "")
|
424 |
-
staff_recommendations = hit.source.get("staff_recommendations", "")
|
425 |
-
|
426 |
-
doc = Document(
|
427 |
-
page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]),
|
428 |
-
metadata={
|
429 |
-
"title": hit.source["title"],
|
430 |
-
"source": "Candid Learning",
|
431 |
-
"source_id": hit.source["post_id"],
|
432 |
-
"url": hit.source.get("url", "")
|
433 |
-
}
|
434 |
-
)
|
435 |
-
elif "candid-help" in hit.index:
|
436 |
-
title = hit.source.get("title", "")
|
437 |
-
content_with_context_txt = get_context("content", hit, context_length=12)
|
438 |
-
combined_article_description = hit.source.get("combined_article_description", "")
|
439 |
-
|
440 |
-
doc = Document(
|
441 |
-
page_content='\n\n'.join([combined_article_description, content_with_context_txt]),
|
442 |
-
metadata={
|
443 |
-
"title": title,
|
444 |
-
"source": "Candid Help",
|
445 |
-
"source_id": hit.source["id"],
|
446 |
-
"url": hit.source.get("link", "")
|
447 |
-
}
|
448 |
-
)
|
449 |
-
else:
|
450 |
-
doc = None
|
451 |
-
return doc
|
452 |
-
|
453 |
-
|
454 |
-
def get_reranked_results(results: List[ElasticHitsResult], search_text: Optional[str] = None) -> List[Document]:
|
455 |
-
"""Run data re-ranking and document building for tool usage.
|
456 |
-
|
457 |
-
Parameters
|
458 |
-
----------
|
459 |
-
results : List[ElasticHitsResult]
|
460 |
-
search_text : Optional[str], optional
|
461 |
-
Search context string, by default None
|
462 |
-
|
463 |
-
Returns
|
464 |
-
-------
|
465 |
-
List[Document]
|
466 |
-
"""
|
467 |
-
|
468 |
-
output = []
|
469 |
-
for r in reranker(results, search_text=search_text):
|
470 |
-
hit = process_hit(r)
|
471 |
-
if hit is not None:
|
472 |
-
output.append(hit)
|
473 |
-
return output
|
474 |
-
|
475 |
-
|
476 |
-
def retriever_tool(indices: List[str]) -> Tool:
|
477 |
-
"""Tool component for use in conditional edge building for RAG execution graph.
|
478 |
-
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
|
479 |
-
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
|
480 |
-
|
481 |
-
Parameters
|
482 |
-
----------
|
483 |
-
indices : List[str]
|
484 |
-
Semantic index names to search over
|
485 |
-
|
486 |
-
Returns
|
487 |
-
-------
|
488 |
-
Tool
|
489 |
-
"""
|
490 |
-
|
491 |
-
return Tool(
|
492 |
-
name="retrieve_social_sector_information",
|
493 |
-
func=partial(get_results, indices=indices),
|
494 |
-
description=(
|
495 |
-
"Return additional information about social and philanthropic sector, "
|
496 |
-
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
|
497 |
-
),
|
498 |
-
args_schema=RetrieverInput,
|
499 |
-
response_format="content_and_artifact"
|
500 |
-
)
|
|
|
1 |
+
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from functools import partial
|
4 |
+
from itertools import groupby
|
5 |
+
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
from langchain_core.documents import Document
|
10 |
+
from langchain_core.tools import Tool
|
11 |
+
|
12 |
+
from elasticsearch import Elasticsearch
|
13 |
+
|
14 |
+
from ask_candid.services.small_lm import CandidSLM
|
15 |
+
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
|
16 |
+
from ask_candid.base.config.data import ElasticIndexMapping, ALL_INDICES
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class ElasticHitsResult:
|
21 |
+
"""Dataclass for Elasticsearch hits results
|
22 |
+
"""
|
23 |
+
index: str
|
24 |
+
id: Any
|
25 |
+
score: float
|
26 |
+
source: Dict[str, Any]
|
27 |
+
inner_hits: Dict[str, Any]
|
28 |
+
|
29 |
+
|
30 |
+
class RetrieverInput(BaseModel):
|
31 |
+
"""Input to the Elasticsearch retriever."""
|
32 |
+
user_input: str = Field(description="query to look up in retriever")
|
33 |
+
|
34 |
+
|
35 |
+
def build_sparse_vector_query(
|
36 |
+
query: str,
|
37 |
+
fields: Tuple[str],
|
38 |
+
inference_id: str = ".elser-2-elasticsearch"
|
39 |
+
) -> Dict[str, Any]:
|
40 |
+
"""Builds a valid Elasticsearch text expansion query payload
|
41 |
+
|
42 |
+
Parameters
|
43 |
+
----------
|
44 |
+
query : str
|
45 |
+
Search context string
|
46 |
+
fields : Tuple[str]
|
47 |
+
Semantic text field names
|
48 |
+
inference_id : str, optional
|
49 |
+
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
|
50 |
+
|
51 |
+
Returns
|
52 |
+
-------
|
53 |
+
Dict[str, Any]
|
54 |
+
"""
|
55 |
+
|
56 |
+
output = []
|
57 |
+
|
58 |
+
for f in fields:
|
59 |
+
output.append({
|
60 |
+
"nested": {
|
61 |
+
"path": f"embeddings.{f}.chunks",
|
62 |
+
"query": {
|
63 |
+
"sparse_vector": {
|
64 |
+
"field": f"embeddings.{f}.chunks.vector",
|
65 |
+
"inference_id": inference_id,
|
66 |
+
"query": query,
|
67 |
+
"boost": 1 / len(fields)
|
68 |
+
}
|
69 |
+
},
|
70 |
+
"inner_hits": {
|
71 |
+
"_source": False,
|
72 |
+
"size": 2,
|
73 |
+
"fields": [f"embeddings.{f}.chunks.chunk"]
|
74 |
+
}
|
75 |
+
}
|
76 |
+
})
|
77 |
+
return {"query": {"bool": {"should": output}}}
|
78 |
+
|
79 |
+
|
80 |
+
def query_builder(query: str, indices: List[str]) -> List[Dict[str, Any]]:
|
81 |
+
"""Builds Elasticsearch multi-search query payload
|
82 |
+
|
83 |
+
Parameters
|
84 |
+
----------
|
85 |
+
query : str
|
86 |
+
Search context string
|
87 |
+
indices : List[str]
|
88 |
+
Semantic index names to search over
|
89 |
+
|
90 |
+
Returns
|
91 |
+
-------
|
92 |
+
List[Dict[str, Any]]
|
93 |
+
"""
|
94 |
+
|
95 |
+
queries = []
|
96 |
+
if indices is None:
|
97 |
+
indices = list(ALL_INDICES)
|
98 |
+
|
99 |
+
for index in indices:
|
100 |
+
if index == "issuelab":
|
101 |
+
q = build_sparse_vector_query(
|
102 |
+
query=query,
|
103 |
+
fields=("description", "content", "combined_issuelab_findings", "combined_item_description")
|
104 |
+
)
|
105 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
106 |
+
q["size"] = 1
|
107 |
+
queries.extend([{"index": ElasticIndexMapping.ISSUELAB_INDEX_ELSER}, q])
|
108 |
+
elif index == "youtube":
|
109 |
+
q = build_sparse_vector_query(
|
110 |
+
query=query,
|
111 |
+
fields=("captions_cleaned", "description_cleaned", "title")
|
112 |
+
)
|
113 |
+
# text_cleaned duplicates captions_cleaned
|
114 |
+
q["_source"] = {"excludes": ["embeddings", "captions", "description", "text_cleaned"]}
|
115 |
+
q["size"] = 2
|
116 |
+
queries.extend([{"index": ElasticIndexMapping.YOUTUBE_INDEX_ELSER}, q])
|
117 |
+
elif index == "candid_blog":
|
118 |
+
q = build_sparse_vector_query(
|
119 |
+
query=query,
|
120 |
+
fields=("content", "authors_text", "title_summary_tags")
|
121 |
+
)
|
122 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
123 |
+
q["size"] = 2
|
124 |
+
queries.extend([{"index": ElasticIndexMapping.CANDID_BLOG_INDEX_ELSER}, q])
|
125 |
+
elif index == "candid_learning":
|
126 |
+
q = build_sparse_vector_query(
|
127 |
+
query=query,
|
128 |
+
fields=("content", "title", "training_topics", "staff_recommendations")
|
129 |
+
)
|
130 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
131 |
+
q["size"] = 2
|
132 |
+
queries.extend([{"index": ElasticIndexMapping.CANDID_LEARNING_INDEX_ELSER}, q])
|
133 |
+
elif index == "candid_help":
|
134 |
+
q = build_sparse_vector_query(
|
135 |
+
query=query,
|
136 |
+
fields=("content", "combined_article_description")
|
137 |
+
)
|
138 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
139 |
+
q["size"] = 2
|
140 |
+
queries.extend([{"index": ElasticIndexMapping.CANDID_HELP_INDEX_ELSER}, q])
|
141 |
+
|
142 |
+
return queries
|
143 |
+
|
144 |
+
|
145 |
+
def multi_search(queries: List[Dict[str, Any]]) -> List[ElasticHitsResult]:
|
146 |
+
"""Runs multi-search query
|
147 |
+
|
148 |
+
Parameters
|
149 |
+
----------
|
150 |
+
queries : List[Dict[str, Any]]
|
151 |
+
Pre-built multi-search query payload
|
152 |
+
|
153 |
+
Returns
|
154 |
+
-------
|
155 |
+
List[ElasticHitsResult]
|
156 |
+
"""
|
157 |
+
|
158 |
+
results = []
|
159 |
+
with Elasticsearch(
|
160 |
+
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
|
161 |
+
api_key=SEMANTIC_ELASTIC_QA.api_key,
|
162 |
+
verify_certs=False,
|
163 |
+
request_timeout=60 * 3
|
164 |
+
) as es:
|
165 |
+
for query_group in es.msearch(body=queries).get("responses", []):
|
166 |
+
for hit in query_group.get("hits", {}).get("hits", []):
|
167 |
+
hit = ElasticHitsResult(
|
168 |
+
index=hit["_index"],
|
169 |
+
id=hit["_id"],
|
170 |
+
score=hit["_score"],
|
171 |
+
source=hit["_source"],
|
172 |
+
inner_hits=hit.get("inner_hits", {})
|
173 |
+
)
|
174 |
+
results.append(hit)
|
175 |
+
return results
|
176 |
+
|
177 |
+
|
178 |
+
def get_query_results(search_text: str, indices: Optional[List[str]] = None) -> List[ElasticHitsResult]:
|
179 |
+
"""Builds and executes Elasticsearch data queries from a search string.
|
180 |
+
|
181 |
+
Parameters
|
182 |
+
----------
|
183 |
+
search_text : str
|
184 |
+
Search context string
|
185 |
+
indices : Optional[List[str]], optional
|
186 |
+
Semantic index names to search over, by default None
|
187 |
+
|
188 |
+
Returns
|
189 |
+
-------
|
190 |
+
List[ElasticHitsResult]
|
191 |
+
"""
|
192 |
+
|
193 |
+
queries = query_builder(query=search_text, indices=indices)
|
194 |
+
return multi_search(queries)
|
195 |
+
|
196 |
+
|
197 |
+
def retrieved_text(hits: Dict[str, Any]) -> str:
|
198 |
+
"""Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of
|
199 |
+
re-scoring by a secondary language model.
|
200 |
+
|
201 |
+
Parameters
|
202 |
+
----------
|
203 |
+
hits : Dict[str, Any]
|
204 |
+
|
205 |
+
Returns
|
206 |
+
-------
|
207 |
+
str
|
208 |
+
"""
|
209 |
+
|
210 |
+
text = []
|
211 |
+
for _, v in hits.items():
|
212 |
+
for h in (v.get("hits", {}).get("hits") or []):
|
213 |
+
for _, field in h.get("fields", {}).items():
|
214 |
+
for chunk in field:
|
215 |
+
if chunk.get("chunk"):
|
216 |
+
text.extend(chunk["chunk"])
|
217 |
+
return '\n'.join(text)
|
218 |
+
|
219 |
+
|
220 |
+
def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
|
221 |
+
"""Computes cosine scores between retrieved contexts and the original query to re-score results based on overall
|
222 |
+
relevance to the original query.
|
223 |
+
|
224 |
+
Parameters
|
225 |
+
----------
|
226 |
+
query : str
|
227 |
+
Search context string
|
228 |
+
contexts : List[str]
|
229 |
+
Semantic field sub-texts, order is by document retrieved from the original multi-search query.
|
230 |
+
|
231 |
+
Returns
|
232 |
+
-------
|
233 |
+
List[float]
|
234 |
+
Scores in the same order as the input document contexts
|
235 |
+
"""
|
236 |
+
|
237 |
+
nlp = CandidSLM()
|
238 |
+
X = nlp.encode([query, *contexts]).vectors
|
239 |
+
X = F.normalize(X, dim=-1, p=2.)
|
240 |
+
cosine = X[1:] @ X[:1].T
|
241 |
+
return cosine.flatten().cpu().numpy().tolist()
|
242 |
+
|
243 |
+
|
244 |
+
def reranker(
|
245 |
+
query_results: Iterable[ElasticHitsResult],
|
246 |
+
search_text: Optional[str] = None
|
247 |
+
) -> Iterator[ElasticHitsResult]:
|
248 |
+
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
|
249 |
+
This will shuffle results
|
250 |
+
|
251 |
+
Parameters
|
252 |
+
----------
|
253 |
+
query_results : Iterable[ElasticHitsResult]
|
254 |
+
|
255 |
+
Yields
|
256 |
+
------
|
257 |
+
Iterator[ElasticHitsResult]
|
258 |
+
"""
|
259 |
+
|
260 |
+
results: List[ElasticHitsResult] = []
|
261 |
+
texts: List[str] = []
|
262 |
+
for _, data in groupby(query_results, key=lambda x: x.index):
|
263 |
+
data = list(data)
|
264 |
+
max_score = max(data, key=lambda x: x.score).score
|
265 |
+
min_score = min(data, key=lambda x: x.score).score
|
266 |
+
|
267 |
+
for d in data:
|
268 |
+
d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
|
269 |
+
results.append(d)
|
270 |
+
|
271 |
+
if search_text:
|
272 |
+
text = retrieved_text(d.inner_hits)
|
273 |
+
texts.append(text)
|
274 |
+
|
275 |
+
# if search_text and len(texts) == len(results):
|
276 |
+
# scores = cosine_rescore(search_text, texts)
|
277 |
+
# for r, s in zip(results, scores):
|
278 |
+
# r.score = s
|
279 |
+
|
280 |
+
yield from sorted(results, key=lambda x: x.score, reverse=True)
|
281 |
+
|
282 |
+
|
283 |
+
def get_results(user_input: str, indices: List[str]) -> Tuple[str, List[Document]]:
|
284 |
+
"""End-to-end search and re-rank function.
|
285 |
+
|
286 |
+
Parameters
|
287 |
+
----------
|
288 |
+
user_input : str
|
289 |
+
Search context string
|
290 |
+
indices : List[str]
|
291 |
+
Semantic index names to search over
|
292 |
+
|
293 |
+
Returns
|
294 |
+
-------
|
295 |
+
Tuple[str, List[Document]]
|
296 |
+
(concatenated text from search results, documents list)
|
297 |
+
"""
|
298 |
+
|
299 |
+
output = ["Search didn't return any Candid sources"]
|
300 |
+
page_content = []
|
301 |
+
content = "Search didn't return any Candid sources"
|
302 |
+
results = get_query_results(search_text=user_input, indices=indices)
|
303 |
+
if results:
|
304 |
+
output = get_reranked_results(results, search_text=user_input)
|
305 |
+
for doc in output:
|
306 |
+
page_content.append(doc.page_content)
|
307 |
+
content = "\n\n".join(page_content)
|
308 |
+
|
309 |
+
# for the tool we need to return a tuple for content_and_artifact type
|
310 |
+
return content, output
|
311 |
+
|
312 |
+
|
313 |
+
# TODO make it better!
|
314 |
+
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
|
315 |
+
"""Pads the relevant chunk of text with context before and after
|
316 |
+
|
317 |
+
Parameters
|
318 |
+
----------
|
319 |
+
field_name : str
|
320 |
+
a field with the long text that was chunked into pieces
|
321 |
+
hit : ElasticHitsResult
|
322 |
+
context_length : int, optional
|
323 |
+
length of text to add before and after the chunk, by default 1024
|
324 |
+
|
325 |
+
Returns
|
326 |
+
-------
|
327 |
+
str
|
328 |
+
longer chunks stuffed together
|
329 |
+
"""
|
330 |
+
|
331 |
+
chunks = []
|
332 |
+
# TODO chunks have tokens, but long text is a normal text, but may contain html that also gets weird after tokenization
|
333 |
+
long_text = hit.source.get(f"{field_name}", "")
|
334 |
+
long_text = long_text.lower()
|
335 |
+
inner_hits_field = f"embeddings.{field_name}.chunks"
|
336 |
+
found_chunks = hit.inner_hits.get(inner_hits_field, {})
|
337 |
+
if found_chunks:
|
338 |
+
hits = found_chunks.get("hits", {}).get("hits", [])
|
339 |
+
for h in hits:
|
340 |
+
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
|
341 |
+
|
342 |
+
# cutting the middle because we may have tokenizing artifacts there
|
343 |
+
chunk = chunk[3: -3]
|
344 |
+
|
345 |
+
if add_context:
|
346 |
+
# Find the start and end indices of the chunk in the large text
|
347 |
+
start_index = long_text.find(chunk[:20])
|
348 |
+
if start_index != -1: # Chunk is found
|
349 |
+
end_index = start_index + len(chunk)
|
350 |
+
pre_start_index = max(0, start_index - context_length)
|
351 |
+
post_end_index = min(len(long_text), end_index + context_length)
|
352 |
+
chunks.append(long_text[pre_start_index:post_end_index])
|
353 |
+
else:
|
354 |
+
chunks.append(chunk)
|
355 |
+
return '\n\n'.join(chunks)
|
356 |
+
|
357 |
+
|
358 |
+
def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
|
359 |
+
"""Parse Elasticsearch hit results into data structures handled by the RAG pipeline.
|
360 |
+
|
361 |
+
Parameters
|
362 |
+
----------
|
363 |
+
hit : ElasticHitsResult
|
364 |
+
|
365 |
+
Returns
|
366 |
+
-------
|
367 |
+
Union[Document, None]
|
368 |
+
"""
|
369 |
+
|
370 |
+
if "issuelab-elser" in hit.index:
|
371 |
+
combined_item_description = hit.source.get("combined_item_description", "") # title inside
|
372 |
+
description = hit.source.get("description", "")
|
373 |
+
combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "")
|
374 |
+
# we only need to process long texts
|
375 |
+
chunks_with_context_txt = get_context("content", hit, context_length=12)
|
376 |
+
doc = Document(
|
377 |
+
page_content='\n\n'.join([
|
378 |
+
combined_item_description,
|
379 |
+
combined_issuelab_findings,
|
380 |
+
description,
|
381 |
+
chunks_with_context_txt
|
382 |
+
]),
|
383 |
+
metadata={
|
384 |
+
"title": hit.source["title"],
|
385 |
+
"source": "IssueLab",
|
386 |
+
"source_id": hit.source["resource_id"],
|
387 |
+
"url": hit.source.get("permalink", "")
|
388 |
+
}
|
389 |
+
)
|
390 |
+
elif "youtube" in hit.index:
|
391 |
+
title = hit.source.get("title", "")
|
392 |
+
# we only need to process long texts
|
393 |
+
description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12)
|
394 |
+
captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12)
|
395 |
+
doc = Document(
|
396 |
+
page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]),
|
397 |
+
metadata={
|
398 |
+
"title": title,
|
399 |
+
"source": "Candid YouTube",
|
400 |
+
"source_id": hit.source['video_id'],
|
401 |
+
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
|
402 |
+
}
|
403 |
+
)
|
404 |
+
elif "candid-blog" in hit.index:
|
405 |
+
excerpt = hit.source.get("excerpt", "")
|
406 |
+
title = hit.source.get("title", "")
|
407 |
+
# we only need to process long text
|
408 |
+
content_with_context_txt = get_context("content", hit, context_length=12, add_context=False)
|
409 |
+
authors = get_context("authors_text", hit, context_length=12, add_context=False)
|
410 |
+
tags = hit.source.get("title_summary_tags", "")
|
411 |
+
doc = Document(
|
412 |
+
page_content='\n\n'.join([title, excerpt, content_with_context_txt, authors, tags]),
|
413 |
+
metadata={
|
414 |
+
"title": title,
|
415 |
+
"source": "Candid Blog",
|
416 |
+
"source_id": hit.source["id"],
|
417 |
+
"url": hit.source["link"]
|
418 |
+
}
|
419 |
+
)
|
420 |
+
elif "candid-learning" in hit.index:
|
421 |
+
title = hit.source.get("title", "")
|
422 |
+
content_with_context_txt = get_context("content", hit, context_length=12)
|
423 |
+
training_topics = hit.source.get("training_topics", "")
|
424 |
+
staff_recommendations = hit.source.get("staff_recommendations", "")
|
425 |
+
|
426 |
+
doc = Document(
|
427 |
+
page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]),
|
428 |
+
metadata={
|
429 |
+
"title": hit.source["title"],
|
430 |
+
"source": "Candid Learning",
|
431 |
+
"source_id": hit.source["post_id"],
|
432 |
+
"url": hit.source.get("url", "")
|
433 |
+
}
|
434 |
+
)
|
435 |
+
elif "candid-help" in hit.index:
|
436 |
+
title = hit.source.get("title", "")
|
437 |
+
content_with_context_txt = get_context("content", hit, context_length=12)
|
438 |
+
combined_article_description = hit.source.get("combined_article_description", "")
|
439 |
+
|
440 |
+
doc = Document(
|
441 |
+
page_content='\n\n'.join([combined_article_description, content_with_context_txt]),
|
442 |
+
metadata={
|
443 |
+
"title": title,
|
444 |
+
"source": "Candid Help",
|
445 |
+
"source_id": hit.source["id"],
|
446 |
+
"url": hit.source.get("link", "")
|
447 |
+
}
|
448 |
+
)
|
449 |
+
else:
|
450 |
+
doc = None
|
451 |
+
return doc
|
452 |
+
|
453 |
+
|
454 |
+
def get_reranked_results(results: List[ElasticHitsResult], search_text: Optional[str] = None) -> List[Document]:
|
455 |
+
"""Run data re-ranking and document building for tool usage.
|
456 |
+
|
457 |
+
Parameters
|
458 |
+
----------
|
459 |
+
results : List[ElasticHitsResult]
|
460 |
+
search_text : Optional[str], optional
|
461 |
+
Search context string, by default None
|
462 |
+
|
463 |
+
Returns
|
464 |
+
-------
|
465 |
+
List[Document]
|
466 |
+
"""
|
467 |
+
|
468 |
+
output = []
|
469 |
+
for r in reranker(results, search_text=search_text):
|
470 |
+
hit = process_hit(r)
|
471 |
+
if hit is not None:
|
472 |
+
output.append(hit)
|
473 |
+
return output
|
474 |
+
|
475 |
+
|
476 |
+
def retriever_tool(indices: List[str]) -> Tool:
|
477 |
+
"""Tool component for use in conditional edge building for RAG execution graph.
|
478 |
+
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
|
479 |
+
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
|
480 |
+
|
481 |
+
Parameters
|
482 |
+
----------
|
483 |
+
indices : List[str]
|
484 |
+
Semantic index names to search over
|
485 |
+
|
486 |
+
Returns
|
487 |
+
-------
|
488 |
+
Tool
|
489 |
+
"""
|
490 |
+
|
491 |
+
return Tool(
|
492 |
+
name="retrieve_social_sector_information",
|
493 |
+
func=partial(get_results, indices=indices),
|
494 |
+
description=(
|
495 |
+
"Return additional information about social and philanthropic sector, "
|
496 |
+
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
|
497 |
+
),
|
498 |
+
args_schema=RetrieverInput,
|
499 |
+
response_format="content_and_artifact"
|
500 |
+
)
|