Spaces:
Running
Running
from typing import Literal, Any | |
from collections.abc import Iterator, Iterable | |
from itertools import groupby | |
import logging | |
from langchain_core.documents import Document | |
from ask_candid.base.retrieval.elastic import ( | |
build_sparse_vector_query, | |
build_sparse_vector_and_text_query, | |
news_query_builder, | |
multi_search_base | |
) | |
from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder | |
from ask_candid.base.retrieval.schemas import ElasticHitsResult | |
import ask_candid.base.retrieval.sources as S | |
from ask_candid.services.small_lm import CandidSLM | |
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC | |
SourceNames = Literal[ | |
"Candid Blog", | |
"Candid Help", | |
"Candid Learning", | |
"Candid News", | |
"IssueLab Research Reports", | |
"YouTube Training" | |
] | |
sparse_encoder = SpladeEncoder() | |
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s") | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
# TODO remove | |
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str: | |
"""Pads the relevant chunk of text with context before and after | |
Parameters | |
---------- | |
field_name : str | |
a field with the long text that was chunked into pieces | |
hit : ElasticHitsResult | |
context_length : int, optional | |
length of text to add before and after the chunk, by default 1024 | |
add_context : bool, optional | |
Set to `False` to expand the text context by searching for the Elastic inner hit inside the larger document | |
, by default True | |
Returns | |
------- | |
str | |
longer chunks stuffed together | |
""" | |
chunks = [] | |
# NOTE chunks have tokens, long text is a string, but may contain html which affects tokenization | |
long_text = hit.source.get(field_name) or "" | |
long_text = long_text.lower() | |
inner_hits_field = f"embeddings.{field_name}.chunks" | |
found_chunks = hit.inner_hits.get(inner_hits_field, {}) if hit.inner_hits else None | |
if found_chunks: | |
for h in found_chunks.get("hits", {}).get("hits") or []: | |
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0] | |
# cutting the middle because we may have tokenizing artifacts there | |
chunk = chunk[3: -3] | |
if add_context: | |
# Find the start and end indices of the chunk in the large text | |
start_index = long_text.find(chunk[:20]) | |
# Chunk is found | |
if start_index != -1: | |
end_index = start_index + len(chunk) | |
pre_start_index = max(0, start_index - context_length) | |
post_end_index = min(len(long_text), end_index + context_length) | |
chunks.append(long_text[pre_start_index:post_end_index]) | |
else: | |
chunks.append(chunk) | |
return '\n\n'.join(chunks) | |
def generate_queries( | |
query: str, | |
sources: list[SourceNames], | |
news_days_ago: int = 60 | |
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: | |
"""Builds Elastic queries against indices which do or do not support sparse vector queries. | |
Parameters | |
---------- | |
query : str | |
Text describing a user's question or a description of investigative work which requires support from Candid's | |
knowledge base | |
sources : list[SourceNames] | |
One or more sources of knowledge from different areas at Candid. | |
* Candid Blog: Blog posts from Candid staff and trusted partners intended to help those in the sector or | |
illuminate ongoing work | |
* Candid Help: Candid FAQs to help user's get started with Candid's product platform and learning resources | |
* Candid Learning: Training documents from Candid's subject matter experts | |
* Candid News: News articles and press releases about real-time activity in the philanthropic sector | |
* IssueLab Research Reports: Academic research reports about the social/philanthropic sector | |
* YouTube Training: Transcripts from video-based training seminars from Candid's subject matter experts | |
news_days_ago : int, optional | |
How many days in the past to search for news articles, if a user is asking for recent trends then this value | |
should be set lower >~ 10, by default 60 | |
Returns | |
------- | |
tuple[list[dict[str, Any]], list[dict[str, Any]]] | |
(sparse vector queries, queries for indices which do not support sparse vectors) | |
""" | |
vector_queries = [] | |
quasi_vector_queries = [] | |
for source_name in sources: | |
if source_name == "Candid Blog": | |
q = build_sparse_vector_query(query=query, fields=S.CandidBlogConfig.semantic_fields) | |
q["_source"] = {"excludes": ["embeddings"]} | |
q["size"] = 5 | |
vector_queries.extend([{"index": S.CandidBlogConfig.index_name}, q]) | |
elif source_name == "Candid Help": | |
q = build_sparse_vector_query(query=query, fields=S.CandidHelpConfig.semantic_fields) | |
q["_source"] = {"excludes": ["embeddings"]} | |
q["size"] = 5 | |
vector_queries.extend([{"index": S.CandidHelpConfig.index_name}, q]) | |
elif source_name == "Candid Learning": | |
q = build_sparse_vector_query(query=query, fields=S.CandidLearningConfig.semantic_fields) | |
q["_source"] = {"excludes": ["embeddings"]} | |
q["size"] = 5 | |
vector_queries.extend([{"index": S.CandidLearningConfig.index_name}, q]) | |
elif source_name == "Candid News": | |
q = news_query_builder( | |
query=query, | |
fields=S.CandidNewsConfig.semantic_fields, | |
encoder=sparse_encoder, | |
days_ago=news_days_ago | |
) | |
q["size"] = 5 | |
quasi_vector_queries.extend([{"index": S.CandidNewsConfig.index_name}, q]) | |
elif source_name == "IssueLab Research Reports": | |
q = build_sparse_vector_query(query=query, fields=S.IssueLabConfig.semantic_fields) | |
q["_source"] = {"excludes": ["embeddings"]} | |
q["size"] = 1 | |
vector_queries.extend([{"index": S.IssueLabConfig.index_name}, q]) | |
elif source_name == "YouTube Training": | |
q = build_sparse_vector_and_text_query( | |
query=query, | |
semantic_fields=S.YoutubeConfig.semantic_fields, | |
text_fields=S.YoutubeConfig.text_fields, | |
highlight_fields=S.YoutubeConfig.highlight_fields, | |
excluded_fields=S.YoutubeConfig.excluded_fields | |
) | |
q["size"] = 5 | |
vector_queries.extend([{"index": S.YoutubeConfig.index_name}, q]) | |
return vector_queries, quasi_vector_queries | |
def run_search( | |
vector_searches: list[dict[str, Any]] | None = None, | |
non_vector_searches: list[dict[str, Any]] | None = None, | |
) -> list[ElasticHitsResult]: | |
def _msearch_response_generator(responses: Iterable[dict[str, Any]]) -> Iterator[ElasticHitsResult]: | |
for query_group in responses: | |
for h in query_group.get("hits", {}).get("hits", []): | |
inner_hits = h.get("inner_hits", {}) | |
if not inner_hits and "news" in h.get("_index"): | |
inner_hits = {"text": h.get("_source", {}).get("content")} | |
yield ElasticHitsResult( | |
index=h["_index"], | |
id=h["_id"], | |
score=h["_score"], | |
source=h["_source"], | |
inner_hits=inner_hits, | |
highlight=h.get("highlight", {}) | |
) | |
results = [] | |
if vector_searches is not None and len(vector_searches) > 0: | |
hits = multi_search_base(queries=vector_searches, credentials=SEMANTIC_ELASTIC_QA) | |
for hit in _msearch_response_generator(responses=hits): | |
results.append(hit) | |
if non_vector_searches is not None and len(non_vector_searches) > 0: | |
hits = multi_search_base(queries=non_vector_searches, credentials=NEWS_ELASTIC) | |
for hit in _msearch_response_generator(responses=hits): | |
results.append(hit) | |
return results | |
def retrieved_text(hits: dict[str, Any]) -> str: | |
"""Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of | |
re-scoring by a secondary language model. | |
Parameters | |
---------- | |
hits : Dict[str, Any] | |
Returns | |
------- | |
str | |
""" | |
nlp = CandidSLM() | |
text = [] | |
for _, v in hits.items(): | |
if _ == "text": | |
s = nlp.summarize(v, top_k=3) | |
text.append(s.summary) | |
# text.append(v) | |
continue | |
for h in (v.get("hits", {}).get("hits") or []): | |
for _, field in h.get("fields", {}).items(): | |
for chunk in field: | |
if chunk.get("chunk"): | |
text.extend(chunk["chunk"]) | |
return '\n'.join(text) | |
def reranker( | |
query_results: Iterable[ElasticHitsResult], | |
search_text: str | None = None, | |
max_num_results: int = 5 | |
) -> Iterator[ElasticHitsResult]: | |
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales. | |
This will shuffle results | |
Parameters | |
---------- | |
query_results : Iterable[ElasticHitsResult] | |
Yields | |
------ | |
Iterator[ElasticHitsResult] | |
""" | |
results: list[ElasticHitsResult] = [] | |
texts: list[str] = [] | |
for _, data in groupby(query_results, key=lambda x: x.index): | |
data = list(data) # noqa: PLW2901 | |
max_score = max(data, key=lambda x: x.score).score | |
min_score = min(data, key=lambda x: x.score).score | |
for d in data: | |
d.score = (d.score - min_score) / (max_score - min_score + 1e-9) | |
results.append(d) | |
if search_text: | |
if d.inner_hits: | |
text = retrieved_text(d.inner_hits) | |
if d.highlight: | |
highlight_texts = [] | |
for k,v in d.highlight.items(): | |
v_text = '\n'.join(v) | |
highlight_texts.append(v_text) | |
text = '\n'.join(highlight_texts) | |
texts.append(text) | |
if search_text and len(texts) == len(results) and len(texts) > 1: | |
logger.info("Re-ranking %d retrieval results", len(results)) | |
scores = sparse_encoder.query_reranking(query=search_text, documents=texts) | |
for r, s in zip(results, scores): | |
r.score = s | |
yield from sorted(results, key=lambda x: x.score, reverse=True)[:max_num_results] | |
def process_hit(hit: ElasticHitsResult) -> Document: | |
if "issuelab-elser" in hit.index: | |
doc = Document( | |
page_content='\n\n'.join([ | |
hit.source.get("combined_item_description", ""), | |
hit.source.get("description", ""), | |
hit.source.get("combined_issuelab_findings", ""), | |
get_context("content", hit, context_length=12) | |
]), | |
metadata={ | |
"title": hit.source["title"], | |
"source": "IssueLab", | |
"source_id": hit.source["resource_id"], | |
"url": hit.source.get("permalink", "") | |
} | |
) | |
elif "youtube" in hit.index: | |
highlight = hit.highlight or {} | |
doc = Document( | |
page_content='\n\n'.join([ | |
hit.source.get("title", ""), | |
hit.source.get("semantic_description", ""), | |
' '.join(highlight.get("semantic_cc_text", [])) | |
]), | |
metadata={ | |
"title": hit.source.get("title", ""), | |
"source": "Candid YouTube", | |
"source_id": hit.source['video_id'], | |
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}" | |
} | |
) | |
elif "candid-blog" in hit.index: | |
doc = Document( | |
page_content='\n\n'.join([ | |
hit.source.get("title", ""), | |
hit.source.get("excerpt", ""), | |
get_context("content", hit, context_length=12, add_context=False), | |
get_context("authors_text", hit, context_length=12, add_context=False), | |
hit.source.get("title_summary_tags", "") | |
]), | |
metadata={ | |
"title": hit.source.get("title", ""), | |
"source": "Candid Blog", | |
"source_id": hit.source["id"], | |
"url": hit.source["link"] | |
} | |
) | |
elif "candid-learning" in hit.index: | |
doc = Document( | |
page_content='\n\n'.join([ | |
hit.source.get("title", ""), | |
hit.source.get("staff_recommendations", ""), | |
hit.source.get("training_topics", ""), | |
get_context("content", hit, context_length=12) | |
]), | |
metadata={ | |
"title": hit.source["title"], | |
"source": "Candid Learning", | |
"source_id": hit.source["post_id"], | |
"url": hit.source.get("url", "") | |
} | |
) | |
elif "candid-help" in hit.index: | |
doc = Document( | |
page_content='\n\n'.join([ | |
hit.source.get("combined_article_description", ""), | |
get_context("content", hit, context_length=12) | |
]), | |
metadata={ | |
"title": hit.source.get("title", ""), | |
"source": "Candid Help", | |
"source_id": hit.source["id"], | |
"url": hit.source.get("link", "") | |
} | |
) | |
elif "news" in hit.index: | |
doc = Document( | |
page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]), | |
metadata={ | |
"title": hit.source.get("title", ""), | |
"source": hit.source.get("site_name") or "Candid News", | |
"source_id": hit.source["id"], | |
"url": hit.source.get("link", "") | |
} | |
) | |
else: | |
raise ValueError(f"Unknown source result from index {hit.index}") | |
return doc | |