ask-candid / ask_candid /base /retrieval /knowledge_base.py
brainsqueeze's picture
v2 of public chat
ef088c2 verified
raw
history blame
14.2 kB
from typing import Literal, Any
from collections.abc import Iterator, Iterable
from itertools import groupby
import logging
from langchain_core.documents import Document
from ask_candid.base.retrieval.elastic import (
build_sparse_vector_query,
build_sparse_vector_and_text_query,
news_query_builder,
multi_search_base
)
from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder
from ask_candid.base.retrieval.schemas import ElasticHitsResult
import ask_candid.base.retrieval.sources as S
from ask_candid.services.small_lm import CandidSLM
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
SourceNames = Literal[
"Candid Blog",
"Candid Help",
"Candid Learning",
"Candid News",
"IssueLab Research Reports",
"YouTube Training"
]
sparse_encoder = SpladeEncoder()
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# TODO remove
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
"""Pads the relevant chunk of text with context before and after
Parameters
----------
field_name : str
a field with the long text that was chunked into pieces
hit : ElasticHitsResult
context_length : int, optional
length of text to add before and after the chunk, by default 1024
add_context : bool, optional
Set to `False` to expand the text context by searching for the Elastic inner hit inside the larger document
, by default True
Returns
-------
str
longer chunks stuffed together
"""
chunks = []
# NOTE chunks have tokens, long text is a string, but may contain html which affects tokenization
long_text = hit.source.get(field_name) or ""
long_text = long_text.lower()
inner_hits_field = f"embeddings.{field_name}.chunks"
found_chunks = hit.inner_hits.get(inner_hits_field, {}) if hit.inner_hits else None
if found_chunks:
for h in found_chunks.get("hits", {}).get("hits") or []:
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
# cutting the middle because we may have tokenizing artifacts there
chunk = chunk[3: -3]
if add_context:
# Find the start and end indices of the chunk in the large text
start_index = long_text.find(chunk[:20])
# Chunk is found
if start_index != -1:
end_index = start_index + len(chunk)
pre_start_index = max(0, start_index - context_length)
post_end_index = min(len(long_text), end_index + context_length)
chunks.append(long_text[pre_start_index:post_end_index])
else:
chunks.append(chunk)
return '\n\n'.join(chunks)
def generate_queries(
query: str,
sources: list[SourceNames],
news_days_ago: int = 60
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""Builds Elastic queries against indices which do or do not support sparse vector queries.
Parameters
----------
query : str
Text describing a user's question or a description of investigative work which requires support from Candid's
knowledge base
sources : list[SourceNames]
One or more sources of knowledge from different areas at Candid.
* Candid Blog: Blog posts from Candid staff and trusted partners intended to help those in the sector or
illuminate ongoing work
* Candid Help: Candid FAQs to help user's get started with Candid's product platform and learning resources
* Candid Learning: Training documents from Candid's subject matter experts
* Candid News: News articles and press releases about real-time activity in the philanthropic sector
* IssueLab Research Reports: Academic research reports about the social/philanthropic sector
* YouTube Training: Transcripts from video-based training seminars from Candid's subject matter experts
news_days_ago : int, optional
How many days in the past to search for news articles, if a user is asking for recent trends then this value
should be set lower >~ 10, by default 60
Returns
-------
tuple[list[dict[str, Any]], list[dict[str, Any]]]
(sparse vector queries, queries for indices which do not support sparse vectors)
"""
vector_queries = []
quasi_vector_queries = []
for source_name in sources:
if source_name == "Candid Blog":
q = build_sparse_vector_query(query=query, fields=S.CandidBlogConfig.semantic_fields)
q["_source"] = {"excludes": ["embeddings"]}
q["size"] = 5
vector_queries.extend([{"index": S.CandidBlogConfig.index_name}, q])
elif source_name == "Candid Help":
q = build_sparse_vector_query(query=query, fields=S.CandidHelpConfig.semantic_fields)
q["_source"] = {"excludes": ["embeddings"]}
q["size"] = 5
vector_queries.extend([{"index": S.CandidHelpConfig.index_name}, q])
elif source_name == "Candid Learning":
q = build_sparse_vector_query(query=query, fields=S.CandidLearningConfig.semantic_fields)
q["_source"] = {"excludes": ["embeddings"]}
q["size"] = 5
vector_queries.extend([{"index": S.CandidLearningConfig.index_name}, q])
elif source_name == "Candid News":
q = news_query_builder(
query=query,
fields=S.CandidNewsConfig.semantic_fields,
encoder=sparse_encoder,
days_ago=news_days_ago
)
q["size"] = 5
quasi_vector_queries.extend([{"index": S.CandidNewsConfig.index_name}, q])
elif source_name == "IssueLab Research Reports":
q = build_sparse_vector_query(query=query, fields=S.IssueLabConfig.semantic_fields)
q["_source"] = {"excludes": ["embeddings"]}
q["size"] = 1
vector_queries.extend([{"index": S.IssueLabConfig.index_name}, q])
elif source_name == "YouTube Training":
q = build_sparse_vector_and_text_query(
query=query,
semantic_fields=S.YoutubeConfig.semantic_fields,
text_fields=S.YoutubeConfig.text_fields,
highlight_fields=S.YoutubeConfig.highlight_fields,
excluded_fields=S.YoutubeConfig.excluded_fields
)
q["size"] = 5
vector_queries.extend([{"index": S.YoutubeConfig.index_name}, q])
return vector_queries, quasi_vector_queries
def run_search(
vector_searches: list[dict[str, Any]] | None = None,
non_vector_searches: list[dict[str, Any]] | None = None,
) -> list[ElasticHitsResult]:
def _msearch_response_generator(responses: Iterable[dict[str, Any]]) -> Iterator[ElasticHitsResult]:
for query_group in responses:
for h in query_group.get("hits", {}).get("hits", []):
inner_hits = h.get("inner_hits", {})
if not inner_hits and "news" in h.get("_index"):
inner_hits = {"text": h.get("_source", {}).get("content")}
yield ElasticHitsResult(
index=h["_index"],
id=h["_id"],
score=h["_score"],
source=h["_source"],
inner_hits=inner_hits,
highlight=h.get("highlight", {})
)
results = []
if vector_searches is not None and len(vector_searches) > 0:
hits = multi_search_base(queries=vector_searches, credentials=SEMANTIC_ELASTIC_QA)
for hit in _msearch_response_generator(responses=hits):
results.append(hit)
if non_vector_searches is not None and len(non_vector_searches) > 0:
hits = multi_search_base(queries=non_vector_searches, credentials=NEWS_ELASTIC)
for hit in _msearch_response_generator(responses=hits):
results.append(hit)
return results
def retrieved_text(hits: dict[str, Any]) -> str:
"""Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of
re-scoring by a secondary language model.
Parameters
----------
hits : Dict[str, Any]
Returns
-------
str
"""
nlp = CandidSLM()
text = []
for _, v in hits.items():
if _ == "text":
s = nlp.summarize(v, top_k=3)
text.append(s.summary)
# text.append(v)
continue
for h in (v.get("hits", {}).get("hits") or []):
for _, field in h.get("fields", {}).items():
for chunk in field:
if chunk.get("chunk"):
text.extend(chunk["chunk"])
return '\n'.join(text)
def reranker(
query_results: Iterable[ElasticHitsResult],
search_text: str | None = None,
max_num_results: int = 5
) -> Iterator[ElasticHitsResult]:
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
This will shuffle results
Parameters
----------
query_results : Iterable[ElasticHitsResult]
Yields
------
Iterator[ElasticHitsResult]
"""
results: list[ElasticHitsResult] = []
texts: list[str] = []
for _, data in groupby(query_results, key=lambda x: x.index):
data = list(data) # noqa: PLW2901
max_score = max(data, key=lambda x: x.score).score
min_score = min(data, key=lambda x: x.score).score
for d in data:
d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
results.append(d)
if search_text:
if d.inner_hits:
text = retrieved_text(d.inner_hits)
if d.highlight:
highlight_texts = []
for k,v in d.highlight.items():
v_text = '\n'.join(v)
highlight_texts.append(v_text)
text = '\n'.join(highlight_texts)
texts.append(text)
if search_text and len(texts) == len(results) and len(texts) > 1:
logger.info("Re-ranking %d retrieval results", len(results))
scores = sparse_encoder.query_reranking(query=search_text, documents=texts)
for r, s in zip(results, scores):
r.score = s
yield from sorted(results, key=lambda x: x.score, reverse=True)[:max_num_results]
def process_hit(hit: ElasticHitsResult) -> Document:
if "issuelab-elser" in hit.index:
doc = Document(
page_content='\n\n'.join([
hit.source.get("combined_item_description", ""),
hit.source.get("description", ""),
hit.source.get("combined_issuelab_findings", ""),
get_context("content", hit, context_length=12)
]),
metadata={
"title": hit.source["title"],
"source": "IssueLab",
"source_id": hit.source["resource_id"],
"url": hit.source.get("permalink", "")
}
)
elif "youtube" in hit.index:
highlight = hit.highlight or {}
doc = Document(
page_content='\n\n'.join([
hit.source.get("title", ""),
hit.source.get("semantic_description", ""),
' '.join(highlight.get("semantic_cc_text", []))
]),
metadata={
"title": hit.source.get("title", ""),
"source": "Candid YouTube",
"source_id": hit.source['video_id'],
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
}
)
elif "candid-blog" in hit.index:
doc = Document(
page_content='\n\n'.join([
hit.source.get("title", ""),
hit.source.get("excerpt", ""),
get_context("content", hit, context_length=12, add_context=False),
get_context("authors_text", hit, context_length=12, add_context=False),
hit.source.get("title_summary_tags", "")
]),
metadata={
"title": hit.source.get("title", ""),
"source": "Candid Blog",
"source_id": hit.source["id"],
"url": hit.source["link"]
}
)
elif "candid-learning" in hit.index:
doc = Document(
page_content='\n\n'.join([
hit.source.get("title", ""),
hit.source.get("staff_recommendations", ""),
hit.source.get("training_topics", ""),
get_context("content", hit, context_length=12)
]),
metadata={
"title": hit.source["title"],
"source": "Candid Learning",
"source_id": hit.source["post_id"],
"url": hit.source.get("url", "")
}
)
elif "candid-help" in hit.index:
doc = Document(
page_content='\n\n'.join([
hit.source.get("combined_article_description", ""),
get_context("content", hit, context_length=12)
]),
metadata={
"title": hit.source.get("title", ""),
"source": "Candid Help",
"source_id": hit.source["id"],
"url": hit.source.get("link", "")
}
)
elif "news" in hit.index:
doc = Document(
page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]),
metadata={
"title": hit.source.get("title", ""),
"source": hit.source.get("site_name") or "Candid News",
"source_id": hit.source["id"],
"url": hit.source.get("link", "")
}
)
else:
raise ValueError(f"Unknown source result from index {hit.index}")
return doc