daryadm commited on
Commit
c844ac9
·
verified ·
1 Parent(s): a0802f4

Upload elastic.py

Browse files
Files changed (1) hide show
  1. ask_candid/retrieval/elastic.py +500 -500
ask_candid/retrieval/elastic.py CHANGED
@@ -1,500 +1,500 @@
1
- from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
2
- from dataclasses import dataclass
3
- from functools import partial
4
- from itertools import groupby
5
-
6
- from torch.nn import functional as F
7
-
8
- from pydantic import BaseModel, Field
9
- from langchain_core.documents import Document
10
- from langchain_core.tools import Tool
11
-
12
- from elasticsearch import Elasticsearch
13
-
14
- from ask_candid.services.small_lm import CandidSLM
15
- from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
16
- from ask_candid.base.config.data import ElasticIndexMapping, ALL_INDICES
17
-
18
-
19
- @dataclass
20
- class ElasticHitsResult:
21
- """Dataclass for Elasticsearch hits results
22
- """
23
- index: str
24
- id: Any
25
- score: float
26
- source: Dict[str, Any]
27
- inner_hits: Dict[str, Any]
28
-
29
-
30
- class RetrieverInput(BaseModel):
31
- """Input to the Elasticsearch retriever."""
32
- user_input: str = Field(description="query to look up in retriever")
33
-
34
-
35
- def build_sparse_vector_query(
36
- query: str,
37
- fields: Tuple[str],
38
- inference_id: str = ".elser-2-elasticsearch"
39
- ) -> Dict[str, Any]:
40
- """Builds a valid Elasticsearch text expansion query payload
41
-
42
- Parameters
43
- ----------
44
- query : str
45
- Search context string
46
- fields : Tuple[str]
47
- Semantic text field names
48
- inference_id : str, optional
49
- ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
50
-
51
- Returns
52
- -------
53
- Dict[str, Any]
54
- """
55
-
56
- output = []
57
-
58
- for f in fields:
59
- output.append({
60
- "nested": {
61
- "path": f"embeddings.{f}.chunks",
62
- "query": {
63
- "sparse_vector": {
64
- "field": f"embeddings.{f}.chunks.vector",
65
- "inference_id": inference_id,
66
- "query": query,
67
- "boost": 1 / len(fields)
68
- }
69
- },
70
- "inner_hits": {
71
- "_source": False,
72
- "size": 2,
73
- "fields": [f"embeddings.{f}.chunks.chunk"]
74
- }
75
- }
76
- })
77
- return {"query": {"bool": {"should": output}}}
78
-
79
-
80
- def query_builder(query: str, indices: List[str]) -> List[Dict[str, Any]]:
81
- """Builds Elasticsearch multi-search query payload
82
-
83
- Parameters
84
- ----------
85
- query : str
86
- Search context string
87
- indices : List[str]
88
- Semantic index names to search over
89
-
90
- Returns
91
- -------
92
- List[Dict[str, Any]]
93
- """
94
-
95
- queries = []
96
- if indices is None:
97
- indices = list(ALL_INDICES)
98
-
99
- for index in indices:
100
- if index == "issuelab":
101
- q = build_sparse_vector_query(
102
- query=query,
103
- fields=("description", "content", "combined_issuelab_findings", "combined_item_description")
104
- )
105
- q["_source"] = {"excludes": ["embeddings"]}
106
- q["size"] = 1
107
- queries.extend([{"index": ElasticIndexMapping.ISSUELAB_INDEX_ELSER}, q])
108
- elif index == "youtube":
109
- q = build_sparse_vector_query(
110
- query=query,
111
- fields=("captions_cleaned", "description_cleaned", "title")
112
- )
113
- # text_cleaned duplicates captions_cleaned
114
- q["_source"] = {"excludes": ["embeddings", "captions", "description", "text_cleaned"]}
115
- q["size"] = 2
116
- queries.extend([{"index": ElasticIndexMapping.YOUTUBE_INDEX_ELSER}, q])
117
- elif index == "candid_blog":
118
- q = build_sparse_vector_query(
119
- query=query,
120
- fields=("content", "authors_text", "title_summary_tags")
121
- )
122
- q["_source"] = {"excludes": ["embeddings"]}
123
- q["size"] = 2
124
- queries.extend([{"index": ElasticIndexMapping.CANDID_BLOG_INDEX_ELSER}, q])
125
- elif index == "candid_learning":
126
- q = build_sparse_vector_query(
127
- query=query,
128
- fields=("content", "title", "training_topics", "staff_recommendations")
129
- )
130
- q["_source"] = {"excludes": ["embeddings"]}
131
- q["size"] = 2
132
- queries.extend([{"index": ElasticIndexMapping.CANDID_LEARNING_INDEX_ELSER}, q])
133
- elif index == "candid_help":
134
- q = build_sparse_vector_query(
135
- query=query,
136
- fields=("content", "combined_article_description")
137
- )
138
- q["_source"] = {"excludes": ["embeddings"]}
139
- q["size"] = 2
140
- queries.extend([{"index": ElasticIndexMapping.CANDID_HELP_INDEX_ELSER}, q])
141
-
142
- return queries
143
-
144
-
145
- def multi_search(queries: List[Dict[str, Any]]) -> List[ElasticHitsResult]:
146
- """Runs multi-search query
147
-
148
- Parameters
149
- ----------
150
- queries : List[Dict[str, Any]]
151
- Pre-built multi-search query payload
152
-
153
- Returns
154
- -------
155
- List[ElasticHitsResult]
156
- """
157
-
158
- results = []
159
- with Elasticsearch(
160
- cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
161
- api_key=SEMANTIC_ELASTIC_QA.api_key,
162
- verify_certs=False,
163
- request_timeout=60 * 3
164
- ) as es:
165
- for query_group in es.msearch(body=queries).get("responses", []):
166
- for hit in query_group.get("hits", {}).get("hits", []):
167
- hit = ElasticHitsResult(
168
- index=hit["_index"],
169
- id=hit["_id"],
170
- score=hit["_score"],
171
- source=hit["_source"],
172
- inner_hits=hit.get("inner_hits", {})
173
- )
174
- results.append(hit)
175
- return results
176
-
177
-
178
- def get_query_results(search_text: str, indices: Optional[List[str]] = None) -> List[ElasticHitsResult]:
179
- """Builds and executes Elasticsearch data queries from a search string.
180
-
181
- Parameters
182
- ----------
183
- search_text : str
184
- Search context string
185
- indices : Optional[List[str]], optional
186
- Semantic index names to search over, by default None
187
-
188
- Returns
189
- -------
190
- List[ElasticHitsResult]
191
- """
192
-
193
- queries = query_builder(query=search_text, indices=indices)
194
- return multi_search(queries)
195
-
196
-
197
- def retrieved_text(hits: Dict[str, Any]) -> str:
198
- """Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of
199
- re-scoring by a secondary language model.
200
-
201
- Parameters
202
- ----------
203
- hits : Dict[str, Any]
204
-
205
- Returns
206
- -------
207
- str
208
- """
209
-
210
- text = []
211
- for _, v in hits.items():
212
- for h in (v.get("hits", {}).get("hits") or []):
213
- for _, field in h.get("fields", {}).items():
214
- for chunk in field:
215
- if chunk.get("chunk"):
216
- text.extend(chunk["chunk"])
217
- return '\n'.join(text)
218
-
219
-
220
- def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
221
- """Computes cosine scores between retrieved contexts and the original query to re-score results based on overall
222
- relevance to the original query.
223
-
224
- Parameters
225
- ----------
226
- query : str
227
- Search context string
228
- contexts : List[str]
229
- Semantic field sub-texts, order is by document retrieved from the original multi-search query.
230
-
231
- Returns
232
- -------
233
- List[float]
234
- Scores in the same order as the input document contexts
235
- """
236
-
237
- nlp = CandidSLM()
238
- X = nlp.encode([query, *contexts]).vectors
239
- X = F.normalize(X, dim=-1, p=2.)
240
- cosine = X[1:] @ X[:1].T
241
- return cosine.flatten().cpu().numpy().tolist()
242
-
243
-
244
- def reranker(
245
- query_results: Iterable[ElasticHitsResult],
246
- search_text: Optional[str] = None
247
- ) -> Iterator[ElasticHitsResult]:
248
- """Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
249
- This will shuffle results
250
-
251
- Parameters
252
- ----------
253
- query_results : Iterable[ElasticHitsResult]
254
-
255
- Yields
256
- ------
257
- Iterator[ElasticHitsResult]
258
- """
259
-
260
- results: List[ElasticHitsResult] = []
261
- texts: List[str] = []
262
- for _, data in groupby(query_results, key=lambda x: x.index):
263
- data = list(data)
264
- max_score = max(data, key=lambda x: x.score).score
265
- min_score = min(data, key=lambda x: x.score).score
266
-
267
- for d in data:
268
- d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
269
- results.append(d)
270
-
271
- if search_text:
272
- text = retrieved_text(d.inner_hits)
273
- texts.append(text)
274
-
275
- # if search_text and len(texts) == len(results):
276
- # scores = cosine_rescore(search_text, texts)
277
- # for r, s in zip(results, scores):
278
- # r.score = s
279
-
280
- yield from sorted(results, key=lambda x: x.score, reverse=True)
281
-
282
-
283
- def get_results(user_input: str, indices: List[str]) -> Tuple[str, List[Document]]:
284
- """End-to-end search and re-rank function.
285
-
286
- Parameters
287
- ----------
288
- user_input : str
289
- Search context string
290
- indices : List[str]
291
- Semantic index names to search over
292
-
293
- Returns
294
- -------
295
- Tuple[str, List[Document]]
296
- (concatenated text from search results, documents list)
297
- """
298
-
299
- output = ["Search didn't return any Candid sources"]
300
- page_content = []
301
- content = "Search didn't return any Candid sources"
302
- results = get_query_results(search_text=user_input, indices=indices)
303
- if results:
304
- output = get_reranked_results(results, search_text=user_input)
305
- for doc in output:
306
- page_content.append(doc.page_content)
307
- content = "\n\n".join(page_content)
308
-
309
- # for the tool we need to return a tuple for content_and_artifact type
310
- return content, output
311
-
312
-
313
- # TODO make it better!
314
- def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
315
- """Pads the relevant chunk of text with context before and after
316
-
317
- Parameters
318
- ----------
319
- field_name : str
320
- a field with the long text that was chunked into pieces
321
- hit : ElasticHitsResult
322
- context_length : int, optional
323
- length of text to add before and after the chunk, by default 1024
324
-
325
- Returns
326
- -------
327
- str
328
- longer chunks stuffed together
329
- """
330
-
331
- chunks = []
332
- # TODO chunks have tokens, but long text is a normal text, but may contain html that also gets weird after tokenization
333
- long_text = hit.source.get(f"{field_name}", "")
334
- long_text = long_text.lower()
335
- inner_hits_field = f"embeddings.{field_name}.chunks"
336
- found_chunks = hit.inner_hits.get(inner_hits_field, {})
337
- if found_chunks:
338
- hits = found_chunks.get("hits", {}).get("hits", [])
339
- for h in hits:
340
- chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
341
-
342
- # cutting the middle because we may have tokenizing artifacts there
343
- chunk = chunk[3: -3]
344
-
345
- if add_context:
346
- # Find the start and end indices of the chunk in the large text
347
- start_index = long_text.find(chunk[:20])
348
- if start_index != -1: # Chunk is found
349
- end_index = start_index + len(chunk)
350
- pre_start_index = max(0, start_index - context_length)
351
- post_end_index = min(len(long_text), end_index + context_length)
352
- chunks.append(long_text[pre_start_index:post_end_index])
353
- else:
354
- chunks.append(chunk)
355
- return '\n\n'.join(chunks)
356
-
357
-
358
- def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
359
- """Parse Elasticsearch hit results into data structures handled by the RAG pipeline.
360
-
361
- Parameters
362
- ----------
363
- hit : ElasticHitsResult
364
-
365
- Returns
366
- -------
367
- Union[Document, None]
368
- """
369
-
370
- if "issuelab-elser" in hit.index:
371
- combined_item_description = hit.source.get("combined_item_description", "") # title inside
372
- description = hit.source.get("description", "")
373
- combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "")
374
- # we only need to process long texts
375
- chunks_with_context_txt = get_context("content", hit, context_length=12)
376
- doc = Document(
377
- page_content='\n\n'.join([
378
- combined_item_description,
379
- combined_issuelab_findings,
380
- description,
381
- chunks_with_context_txt
382
- ]),
383
- metadata={
384
- "title": hit.source["title"],
385
- "source": "IssueLab",
386
- "source_id": hit.source["resource_id"],
387
- "url": hit.source.get("permalink", "")
388
- }
389
- )
390
- elif "youtube" in hit.index:
391
- title = hit.source.get("title", "")
392
- # we only need to process long texts
393
- description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12)
394
- captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12)
395
- doc = Document(
396
- page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]),
397
- metadata={
398
- "title": title,
399
- "source": "Candid YouTube",
400
- "source_id": hit.source['video_id'],
401
- "url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
402
- }
403
- )
404
- elif "candid-blog" in hit.index:
405
- excerpt = hit.source.get("excerpt", "")
406
- title = hit.source.get("title", "")
407
- # we only need to process long text
408
- content_with_context_txt = get_context("content", hit, context_length=12, add_context=False)
409
- authors = get_context("authors_text", hit, context_length=12, add_context=False)
410
- tags = hit.source.get("title_summary_tags", "")
411
- doc = Document(
412
- page_content='\n\n'.join([title, excerpt, content_with_context_txt, authors, tags]),
413
- metadata={
414
- "title": title,
415
- "source": "Candid Blog",
416
- "source_id": hit.source["id"],
417
- "url": hit.source["link"]
418
- }
419
- )
420
- elif "candid-learning" in hit.index:
421
- title = hit.source.get("title", "")
422
- content_with_context_txt = get_context("content", hit, context_length=12)
423
- training_topics = hit.source.get("training_topics", "")
424
- staff_recommendations = hit.source.get("staff_recommendations", "")
425
-
426
- doc = Document(
427
- page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]),
428
- metadata={
429
- "title": hit.source["title"],
430
- "source": "Candid Learning",
431
- "source_id": hit.source["post_id"],
432
- "url": hit.source.get("url", "")
433
- }
434
- )
435
- elif "candid-help" in hit.index:
436
- title = hit.source.get("title", "")
437
- content_with_context_txt = get_context("content", hit, context_length=12)
438
- combined_article_description = hit.source.get("combined_article_description", "")
439
-
440
- doc = Document(
441
- page_content='\n\n'.join([combined_article_description, content_with_context_txt]),
442
- metadata={
443
- "title": title,
444
- "source": "Candid Help",
445
- "source_id": hit.source["id"],
446
- "url": hit.source.get("link", "")
447
- }
448
- )
449
- else:
450
- doc = None
451
- return doc
452
-
453
-
454
- def get_reranked_results(results: List[ElasticHitsResult], search_text: Optional[str] = None) -> List[Document]:
455
- """Run data re-ranking and document building for tool usage.
456
-
457
- Parameters
458
- ----------
459
- results : List[ElasticHitsResult]
460
- search_text : Optional[str], optional
461
- Search context string, by default None
462
-
463
- Returns
464
- -------
465
- List[Document]
466
- """
467
-
468
- output = []
469
- for r in reranker(results, search_text=search_text):
470
- hit = process_hit(r)
471
- if hit is not None:
472
- output.append(hit)
473
- return output
474
-
475
-
476
- def retriever_tool(indices: List[str]) -> Tool:
477
- """Tool component for use in conditional edge building for RAG execution graph.
478
- Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
479
- https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
480
-
481
- Parameters
482
- ----------
483
- indices : List[str]
484
- Semantic index names to search over
485
-
486
- Returns
487
- -------
488
- Tool
489
- """
490
-
491
- return Tool(
492
- name="retrieve_social_sector_information",
493
- func=partial(get_results, indices=indices),
494
- description=(
495
- "Return additional information about social and philanthropic sector, "
496
- "including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
497
- ),
498
- args_schema=RetrieverInput,
499
- response_format="content_and_artifact"
500
- )
 
1
+ from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
2
+ from dataclasses import dataclass
3
+ from functools import partial
4
+ from itertools import groupby
5
+
6
+ from torch.nn import functional as F
7
+
8
+ from pydantic import BaseModel, Field
9
+ from langchain_core.documents import Document
10
+ from langchain_core.tools import Tool
11
+
12
+ from elasticsearch import Elasticsearch
13
+
14
+ from ask_candid.services.small_lm import CandidSLM
15
+ from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
16
+ from ask_candid.base.config.data import ElasticIndexMapping, ALL_INDICES
17
+
18
+
19
+ @dataclass
20
+ class ElasticHitsResult:
21
+ """Dataclass for Elasticsearch hits results
22
+ """
23
+ index: str
24
+ id: Any
25
+ score: float
26
+ source: Dict[str, Any]
27
+ inner_hits: Dict[str, Any]
28
+
29
+
30
+ class RetrieverInput(BaseModel):
31
+ """Input to the Elasticsearch retriever."""
32
+ user_input: str = Field(description="query to look up in retriever")
33
+
34
+
35
+ def build_sparse_vector_query(
36
+ query: str,
37
+ fields: Tuple[str],
38
+ inference_id: str = ".elser-2-elasticsearch"
39
+ ) -> Dict[str, Any]:
40
+ """Builds a valid Elasticsearch text expansion query payload
41
+
42
+ Parameters
43
+ ----------
44
+ query : str
45
+ Search context string
46
+ fields : Tuple[str]
47
+ Semantic text field names
48
+ inference_id : str, optional
49
+ ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
50
+
51
+ Returns
52
+ -------
53
+ Dict[str, Any]
54
+ """
55
+
56
+ output = []
57
+
58
+ for f in fields:
59
+ output.append({
60
+ "nested": {
61
+ "path": f"embeddings.{f}.chunks",
62
+ "query": {
63
+ "sparse_vector": {
64
+ "field": f"embeddings.{f}.chunks.vector",
65
+ "inference_id": inference_id,
66
+ "query": query,
67
+ "boost": 1 / len(fields)
68
+ }
69
+ },
70
+ "inner_hits": {
71
+ "_source": False,
72
+ "size": 2,
73
+ "fields": [f"embeddings.{f}.chunks.chunk"]
74
+ }
75
+ }
76
+ })
77
+ return {"query": {"bool": {"should": output}}}
78
+
79
+
80
+ def query_builder(query: str, indices: List[str]) -> List[Dict[str, Any]]:
81
+ """Builds Elasticsearch multi-search query payload
82
+
83
+ Parameters
84
+ ----------
85
+ query : str
86
+ Search context string
87
+ indices : List[str]
88
+ Semantic index names to search over
89
+
90
+ Returns
91
+ -------
92
+ List[Dict[str, Any]]
93
+ """
94
+
95
+ queries = []
96
+ if indices is None:
97
+ indices = list(ALL_INDICES)
98
+
99
+ for index in indices:
100
+ if index == "issuelab":
101
+ q = build_sparse_vector_query(
102
+ query=query,
103
+ fields=("description", "content", "combined_issuelab_findings", "combined_item_description")
104
+ )
105
+ q["_source"] = {"excludes": ["embeddings"]}
106
+ q["size"] = 1
107
+ queries.extend([{"index": ElasticIndexMapping.ISSUELAB_INDEX_ELSER}, q])
108
+ elif index == "youtube":
109
+ q = build_sparse_vector_query(
110
+ query=query,
111
+ fields=("captions_cleaned", "description_cleaned", "title")
112
+ )
113
+ # text_cleaned duplicates captions_cleaned
114
+ q["_source"] = {"excludes": ["embeddings", "captions", "description", "text_cleaned"]}
115
+ q["size"] = 2
116
+ queries.extend([{"index": ElasticIndexMapping.YOUTUBE_INDEX_ELSER}, q])
117
+ elif index == "candid_blog":
118
+ q = build_sparse_vector_query(
119
+ query=query,
120
+ fields=("content", "authors_text", "title_summary_tags")
121
+ )
122
+ q["_source"] = {"excludes": ["embeddings"]}
123
+ q["size"] = 2
124
+ queries.extend([{"index": ElasticIndexMapping.CANDID_BLOG_INDEX_ELSER}, q])
125
+ elif index == "candid_learning":
126
+ q = build_sparse_vector_query(
127
+ query=query,
128
+ fields=("content", "title", "training_topics", "staff_recommendations")
129
+ )
130
+ q["_source"] = {"excludes": ["embeddings"]}
131
+ q["size"] = 2
132
+ queries.extend([{"index": ElasticIndexMapping.CANDID_LEARNING_INDEX_ELSER}, q])
133
+ elif index == "candid_help":
134
+ q = build_sparse_vector_query(
135
+ query=query,
136
+ fields=("content", "combined_article_description")
137
+ )
138
+ q["_source"] = {"excludes": ["embeddings"]}
139
+ q["size"] = 2
140
+ queries.extend([{"index": ElasticIndexMapping.CANDID_HELP_INDEX_ELSER}, q])
141
+
142
+ return queries
143
+
144
+
145
+ def multi_search(queries: List[Dict[str, Any]]) -> List[ElasticHitsResult]:
146
+ """Runs multi-search query
147
+
148
+ Parameters
149
+ ----------
150
+ queries : List[Dict[str, Any]]
151
+ Pre-built multi-search query payload
152
+
153
+ Returns
154
+ -------
155
+ List[ElasticHitsResult]
156
+ """
157
+
158
+ results = []
159
+ with Elasticsearch(
160
+ cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
161
+ api_key=SEMANTIC_ELASTIC_QA.api_key,
162
+ verify_certs=False,
163
+ request_timeout=60 * 3
164
+ ) as es:
165
+ for query_group in es.msearch(body=queries).get("responses", []):
166
+ for hit in query_group.get("hits", {}).get("hits", []):
167
+ hit = ElasticHitsResult(
168
+ index=hit["_index"],
169
+ id=hit["_id"],
170
+ score=hit["_score"],
171
+ source=hit["_source"],
172
+ inner_hits=hit.get("inner_hits", {})
173
+ )
174
+ results.append(hit)
175
+ return results
176
+
177
+
178
+ def get_query_results(search_text: str, indices: Optional[List[str]] = None) -> List[ElasticHitsResult]:
179
+ """Builds and executes Elasticsearch data queries from a search string.
180
+
181
+ Parameters
182
+ ----------
183
+ search_text : str
184
+ Search context string
185
+ indices : Optional[List[str]], optional
186
+ Semantic index names to search over, by default None
187
+
188
+ Returns
189
+ -------
190
+ List[ElasticHitsResult]
191
+ """
192
+
193
+ queries = query_builder(query=search_text, indices=indices)
194
+ return multi_search(queries)
195
+
196
+
197
+ def retrieved_text(hits: Dict[str, Any]) -> str:
198
+ """Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of
199
+ re-scoring by a secondary language model.
200
+
201
+ Parameters
202
+ ----------
203
+ hits : Dict[str, Any]
204
+
205
+ Returns
206
+ -------
207
+ str
208
+ """
209
+
210
+ text = []
211
+ for _, v in hits.items():
212
+ for h in (v.get("hits", {}).get("hits") or []):
213
+ for _, field in h.get("fields", {}).items():
214
+ for chunk in field:
215
+ if chunk.get("chunk"):
216
+ text.extend(chunk["chunk"])
217
+ return '\n'.join(text)
218
+
219
+
220
+ def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
221
+ """Computes cosine scores between retrieved contexts and the original query to re-score results based on overall
222
+ relevance to the original query.
223
+
224
+ Parameters
225
+ ----------
226
+ query : str
227
+ Search context string
228
+ contexts : List[str]
229
+ Semantic field sub-texts, order is by document retrieved from the original multi-search query.
230
+
231
+ Returns
232
+ -------
233
+ List[float]
234
+ Scores in the same order as the input document contexts
235
+ """
236
+
237
+ nlp = CandidSLM()
238
+ X = nlp.encode([query, *contexts]).vectors
239
+ X = F.normalize(X, dim=-1, p=2.)
240
+ cosine = X[1:] @ X[:1].T
241
+ return cosine.flatten().cpu().numpy().tolist()
242
+
243
+
244
+ def reranker(
245
+ query_results: Iterable[ElasticHitsResult],
246
+ search_text: Optional[str] = None
247
+ ) -> Iterator[ElasticHitsResult]:
248
+ """Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
249
+ This will shuffle results
250
+
251
+ Parameters
252
+ ----------
253
+ query_results : Iterable[ElasticHitsResult]
254
+
255
+ Yields
256
+ ------
257
+ Iterator[ElasticHitsResult]
258
+ """
259
+
260
+ results: List[ElasticHitsResult] = []
261
+ texts: List[str] = []
262
+ for _, data in groupby(query_results, key=lambda x: x.index):
263
+ data = list(data)
264
+ max_score = max(data, key=lambda x: x.score).score
265
+ min_score = min(data, key=lambda x: x.score).score
266
+
267
+ for d in data:
268
+ d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
269
+ results.append(d)
270
+
271
+ if search_text:
272
+ text = retrieved_text(d.inner_hits)
273
+ texts.append(text)
274
+
275
+ # if search_text and len(texts) == len(results):
276
+ # scores = cosine_rescore(search_text, texts)
277
+ # for r, s in zip(results, scores):
278
+ # r.score = s
279
+
280
+ yield from sorted(results, key=lambda x: x.score, reverse=True)
281
+
282
+
283
+ def get_results(user_input: str, indices: List[str]) -> Tuple[str, List[Document]]:
284
+ """End-to-end search and re-rank function.
285
+
286
+ Parameters
287
+ ----------
288
+ user_input : str
289
+ Search context string
290
+ indices : List[str]
291
+ Semantic index names to search over
292
+
293
+ Returns
294
+ -------
295
+ Tuple[str, List[Document]]
296
+ (concatenated text from search results, documents list)
297
+ """
298
+
299
+ output = ["Search didn't return any Candid sources"]
300
+ page_content = []
301
+ content = "Search didn't return any Candid sources"
302
+ results = get_query_results(search_text=user_input, indices=indices)
303
+ if results:
304
+ output = get_reranked_results(results, search_text=user_input)
305
+ for doc in output:
306
+ page_content.append(doc.page_content)
307
+ content = "\n\n".join(page_content)
308
+
309
+ # for the tool we need to return a tuple for content_and_artifact type
310
+ return content, output
311
+
312
+
313
+ # TODO make it better!
314
+ def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
315
+ """Pads the relevant chunk of text with context before and after
316
+
317
+ Parameters
318
+ ----------
319
+ field_name : str
320
+ a field with the long text that was chunked into pieces
321
+ hit : ElasticHitsResult
322
+ context_length : int, optional
323
+ length of text to add before and after the chunk, by default 1024
324
+
325
+ Returns
326
+ -------
327
+ str
328
+ longer chunks stuffed together
329
+ """
330
+
331
+ chunks = []
332
+ # TODO chunks have tokens, but long text is a normal text, but may contain html that also gets weird after tokenization
333
+ long_text = hit.source.get(f"{field_name}", "")
334
+ long_text = long_text.lower()
335
+ inner_hits_field = f"embeddings.{field_name}.chunks"
336
+ found_chunks = hit.inner_hits.get(inner_hits_field, {})
337
+ if found_chunks:
338
+ hits = found_chunks.get("hits", {}).get("hits", [])
339
+ for h in hits:
340
+ chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
341
+
342
+ # cutting the middle because we may have tokenizing artifacts there
343
+ chunk = chunk[3: -3]
344
+
345
+ if add_context:
346
+ # Find the start and end indices of the chunk in the large text
347
+ start_index = long_text.find(chunk[:20])
348
+ if start_index != -1: # Chunk is found
349
+ end_index = start_index + len(chunk)
350
+ pre_start_index = max(0, start_index - context_length)
351
+ post_end_index = min(len(long_text), end_index + context_length)
352
+ chunks.append(long_text[pre_start_index:post_end_index])
353
+ else:
354
+ chunks.append(chunk)
355
+ return '\n\n'.join(chunks)
356
+
357
+
358
+ def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
359
+ """Parse Elasticsearch hit results into data structures handled by the RAG pipeline.
360
+
361
+ Parameters
362
+ ----------
363
+ hit : ElasticHitsResult
364
+
365
+ Returns
366
+ -------
367
+ Union[Document, None]
368
+ """
369
+
370
+ if "issuelab-elser" in hit.index:
371
+ combined_item_description = hit.source.get("combined_item_description", "") # title inside
372
+ description = hit.source.get("description", "")
373
+ combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "")
374
+ # we only need to process long texts
375
+ chunks_with_context_txt = get_context("content", hit, context_length=12)
376
+ doc = Document(
377
+ page_content='\n\n'.join([
378
+ combined_item_description,
379
+ combined_issuelab_findings,
380
+ description,
381
+ chunks_with_context_txt
382
+ ]),
383
+ metadata={
384
+ "title": hit.source["title"],
385
+ "source": "IssueLab",
386
+ "source_id": hit.source["resource_id"],
387
+ "url": hit.source.get("permalink", "")
388
+ }
389
+ )
390
+ elif "youtube" in hit.index:
391
+ title = hit.source.get("title", "")
392
+ # we only need to process long texts
393
+ description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12)
394
+ captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12)
395
+ doc = Document(
396
+ page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]),
397
+ metadata={
398
+ "title": title,
399
+ "source": "Candid YouTube",
400
+ "source_id": hit.source['video_id'],
401
+ "url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
402
+ }
403
+ )
404
+ elif "candid-blog" in hit.index:
405
+ excerpt = hit.source.get("excerpt", "")
406
+ title = hit.source.get("title", "")
407
+ # we only need to process long text
408
+ content_with_context_txt = get_context("content", hit, context_length=12, add_context=False)
409
+ authors = get_context("authors_text", hit, context_length=12, add_context=False)
410
+ tags = hit.source.get("title_summary_tags", "")
411
+ doc = Document(
412
+ page_content='\n\n'.join([title, excerpt, content_with_context_txt, authors, tags]),
413
+ metadata={
414
+ "title": title,
415
+ "source": "Candid Blog",
416
+ "source_id": hit.source["id"],
417
+ "url": hit.source["link"]
418
+ }
419
+ )
420
+ elif "candid-learning" in hit.index:
421
+ title = hit.source.get("title", "")
422
+ content_with_context_txt = get_context("content", hit, context_length=12)
423
+ training_topics = hit.source.get("training_topics", "")
424
+ staff_recommendations = hit.source.get("staff_recommendations", "")
425
+
426
+ doc = Document(
427
+ page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]),
428
+ metadata={
429
+ "title": hit.source["title"],
430
+ "source": "Candid Learning",
431
+ "source_id": hit.source["post_id"],
432
+ "url": hit.source.get("url", "")
433
+ }
434
+ )
435
+ elif "candid-help" in hit.index:
436
+ title = hit.source.get("title", "")
437
+ content_with_context_txt = get_context("content", hit, context_length=12)
438
+ combined_article_description = hit.source.get("combined_article_description", "")
439
+
440
+ doc = Document(
441
+ page_content='\n\n'.join([combined_article_description, content_with_context_txt]),
442
+ metadata={
443
+ "title": title,
444
+ "source": "Candid Help",
445
+ "source_id": hit.source["id"],
446
+ "url": hit.source.get("link", "")
447
+ }
448
+ )
449
+ else:
450
+ doc = None
451
+ return doc
452
+
453
+
454
+ def get_reranked_results(results: List[ElasticHitsResult], search_text: Optional[str] = None) -> List[Document]:
455
+ """Run data re-ranking and document building for tool usage.
456
+
457
+ Parameters
458
+ ----------
459
+ results : List[ElasticHitsResult]
460
+ search_text : Optional[str], optional
461
+ Search context string, by default None
462
+
463
+ Returns
464
+ -------
465
+ List[Document]
466
+ """
467
+
468
+ output = []
469
+ for r in reranker(results, search_text=search_text):
470
+ hit = process_hit(r)
471
+ if hit is not None:
472
+ output.append(hit)
473
+ return output
474
+
475
+
476
+ def retriever_tool(indices: List[str]) -> Tool:
477
+ """Tool component for use in conditional edge building for RAG execution graph.
478
+ Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
479
+ https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
480
+
481
+ Parameters
482
+ ----------
483
+ indices : List[str]
484
+ Semantic index names to search over
485
+
486
+ Returns
487
+ -------
488
+ Tool
489
+ """
490
+
491
+ return Tool(
492
+ name="retrieve_social_sector_information",
493
+ func=partial(get_results, indices=indices),
494
+ description=(
495
+ "Return additional information about social and philanthropic sector, "
496
+ "including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
497
+ ),
498
+ args_schema=RetrieverInput,
499
+ response_format="content_and_artifact"
500
+ )