from typing import Literal, Any from collections.abc import Iterator, Iterable from itertools import groupby import logging from langchain_core.documents import Document from ask_candid.base.retrieval.elastic import ( build_sparse_vector_query, build_sparse_vector_and_text_query, news_query_builder, multi_search_base ) from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder from ask_candid.base.retrieval.schemas import ElasticHitsResult import ask_candid.base.retrieval.sources as S from ask_candid.services.small_lm import CandidSLM from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC SourceNames = Literal[ "Candid Blog", "Candid Help", "Candid Learning", "Candid News", "IssueLab Research Reports", "YouTube Training" ] sparse_encoder = SpladeEncoder() logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s") logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # TODO remove def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str: """Pads the relevant chunk of text with context before and after Parameters ---------- field_name : str a field with the long text that was chunked into pieces hit : ElasticHitsResult context_length : int, optional length of text to add before and after the chunk, by default 1024 add_context : bool, optional Set to `False` to expand the text context by searching for the Elastic inner hit inside the larger document , by default True Returns ------- str longer chunks stuffed together """ chunks = [] # NOTE chunks have tokens, long text is a string, but may contain html which affects tokenization long_text = hit.source.get(field_name) or "" long_text = long_text.lower() inner_hits_field = f"embeddings.{field_name}.chunks" found_chunks = hit.inner_hits.get(inner_hits_field, {}) if hit.inner_hits else None if found_chunks: for h in found_chunks.get("hits", {}).get("hits") or []: chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0] # cutting the middle because we may have tokenizing artifacts there chunk = chunk[3: -3] if add_context: # Find the start and end indices of the chunk in the large text start_index = long_text.find(chunk[:20]) # Chunk is found if start_index != -1: end_index = start_index + len(chunk) pre_start_index = max(0, start_index - context_length) post_end_index = min(len(long_text), end_index + context_length) chunks.append(long_text[pre_start_index:post_end_index]) else: chunks.append(chunk) return '\n\n'.join(chunks) def generate_queries( query: str, sources: list[SourceNames], news_days_ago: int = 60 ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """Builds Elastic queries against indices which do or do not support sparse vector queries. Parameters ---------- query : str Text describing a user's question or a description of investigative work which requires support from Candid's knowledge base sources : list[SourceNames] One or more sources of knowledge from different areas at Candid. * Candid Blog: Blog posts from Candid staff and trusted partners intended to help those in the sector or illuminate ongoing work * Candid Help: Candid FAQs to help user's get started with Candid's product platform and learning resources * Candid Learning: Training documents from Candid's subject matter experts * Candid News: News articles and press releases about real-time activity in the philanthropic sector * IssueLab Research Reports: Academic research reports about the social/philanthropic sector * YouTube Training: Transcripts from video-based training seminars from Candid's subject matter experts news_days_ago : int, optional How many days in the past to search for news articles, if a user is asking for recent trends then this value should be set lower >~ 10, by default 60 Returns ------- tuple[list[dict[str, Any]], list[dict[str, Any]]] (sparse vector queries, queries for indices which do not support sparse vectors) """ vector_queries = [] quasi_vector_queries = [] for source_name in sources: if source_name == "Candid Blog": q = build_sparse_vector_query(query=query, fields=S.CandidBlogConfig.semantic_fields) q["_source"] = {"excludes": ["embeddings"]} q["size"] = 5 vector_queries.extend([{"index": S.CandidBlogConfig.index_name}, q]) elif source_name == "Candid Help": q = build_sparse_vector_query(query=query, fields=S.CandidHelpConfig.semantic_fields) q["_source"] = {"excludes": ["embeddings"]} q["size"] = 5 vector_queries.extend([{"index": S.CandidHelpConfig.index_name}, q]) elif source_name == "Candid Learning": q = build_sparse_vector_query(query=query, fields=S.CandidLearningConfig.semantic_fields) q["_source"] = {"excludes": ["embeddings"]} q["size"] = 5 vector_queries.extend([{"index": S.CandidLearningConfig.index_name}, q]) elif source_name == "Candid News": q = news_query_builder( query=query, fields=S.CandidNewsConfig.semantic_fields, encoder=sparse_encoder, days_ago=news_days_ago ) q["size"] = 5 quasi_vector_queries.extend([{"index": S.CandidNewsConfig.index_name}, q]) elif source_name == "IssueLab Research Reports": q = build_sparse_vector_query(query=query, fields=S.IssueLabConfig.semantic_fields) q["_source"] = {"excludes": ["embeddings"]} q["size"] = 1 vector_queries.extend([{"index": S.IssueLabConfig.index_name}, q]) elif source_name == "YouTube Training": q = build_sparse_vector_and_text_query( query=query, semantic_fields=S.YoutubeConfig.semantic_fields, text_fields=S.YoutubeConfig.text_fields, highlight_fields=S.YoutubeConfig.highlight_fields, excluded_fields=S.YoutubeConfig.excluded_fields ) q["size"] = 5 vector_queries.extend([{"index": S.YoutubeConfig.index_name}, q]) return vector_queries, quasi_vector_queries def run_search( vector_searches: list[dict[str, Any]] | None = None, non_vector_searches: list[dict[str, Any]] | None = None, ) -> list[ElasticHitsResult]: def _msearch_response_generator(responses: Iterable[dict[str, Any]]) -> Iterator[ElasticHitsResult]: for query_group in responses: for h in query_group.get("hits", {}).get("hits", []): inner_hits = h.get("inner_hits", {}) if not inner_hits and "news" in h.get("_index"): inner_hits = {"text": h.get("_source", {}).get("content")} yield ElasticHitsResult( index=h["_index"], id=h["_id"], score=h["_score"], source=h["_source"], inner_hits=inner_hits, highlight=h.get("highlight", {}) ) results = [] if vector_searches is not None and len(vector_searches) > 0: hits = multi_search_base(queries=vector_searches, credentials=SEMANTIC_ELASTIC_QA) for hit in _msearch_response_generator(responses=hits): results.append(hit) if non_vector_searches is not None and len(non_vector_searches) > 0: hits = multi_search_base(queries=non_vector_searches, credentials=NEWS_ELASTIC) for hit in _msearch_response_generator(responses=hits): results.append(hit) return results def retrieved_text(hits: dict[str, Any]) -> str: """Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of re-scoring by a secondary language model. Parameters ---------- hits : Dict[str, Any] Returns ------- str """ nlp = CandidSLM() text = [] for _, v in hits.items(): if _ == "text": s = nlp.summarize(v, top_k=3) text.append(s.summary) # text.append(v) continue for h in (v.get("hits", {}).get("hits") or []): for _, field in h.get("fields", {}).items(): for chunk in field: if chunk.get("chunk"): text.extend(chunk["chunk"]) return '\n'.join(text) def reranker( query_results: Iterable[ElasticHitsResult], search_text: str | None = None, max_num_results: int = 5 ) -> Iterator[ElasticHitsResult]: """Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales. This will shuffle results Parameters ---------- query_results : Iterable[ElasticHitsResult] Yields ------ Iterator[ElasticHitsResult] """ results: list[ElasticHitsResult] = [] texts: list[str] = [] for _, data in groupby(query_results, key=lambda x: x.index): data = list(data) # noqa: PLW2901 max_score = max(data, key=lambda x: x.score).score min_score = min(data, key=lambda x: x.score).score for d in data: d.score = (d.score - min_score) / (max_score - min_score + 1e-9) results.append(d) if search_text: if d.inner_hits: text = retrieved_text(d.inner_hits) if d.highlight: highlight_texts = [] for k,v in d.highlight.items(): v_text = '\n'.join(v) highlight_texts.append(v_text) text = '\n'.join(highlight_texts) texts.append(text) if search_text and len(texts) == len(results) and len(texts) > 1: logger.info("Re-ranking %d retrieval results", len(results)) scores = sparse_encoder.query_reranking(query=search_text, documents=texts) for r, s in zip(results, scores): r.score = s yield from sorted(results, key=lambda x: x.score, reverse=True)[:max_num_results] def process_hit(hit: ElasticHitsResult) -> Document: if "issuelab-elser" in hit.index: doc = Document( page_content='\n\n'.join([ hit.source.get("combined_item_description", ""), hit.source.get("description", ""), hit.source.get("combined_issuelab_findings", ""), get_context("content", hit, context_length=12) ]), metadata={ "title": hit.source["title"], "source": "IssueLab", "source_id": hit.source["resource_id"], "url": hit.source.get("permalink", "") } ) elif "youtube" in hit.index: highlight = hit.highlight or {} doc = Document( page_content='\n\n'.join([ hit.source.get("title", ""), hit.source.get("semantic_description", ""), ' '.join(highlight.get("semantic_cc_text", [])) ]), metadata={ "title": hit.source.get("title", ""), "source": "Candid YouTube", "source_id": hit.source['video_id'], "url": f"https://www.youtube.com/watch?v={hit.source['video_id']}" } ) elif "candid-blog" in hit.index: doc = Document( page_content='\n\n'.join([ hit.source.get("title", ""), hit.source.get("excerpt", ""), get_context("content", hit, context_length=12, add_context=False), get_context("authors_text", hit, context_length=12, add_context=False), hit.source.get("title_summary_tags", "") ]), metadata={ "title": hit.source.get("title", ""), "source": "Candid Blog", "source_id": hit.source["id"], "url": hit.source["link"] } ) elif "candid-learning" in hit.index: doc = Document( page_content='\n\n'.join([ hit.source.get("title", ""), hit.source.get("staff_recommendations", ""), hit.source.get("training_topics", ""), get_context("content", hit, context_length=12) ]), metadata={ "title": hit.source["title"], "source": "Candid Learning", "source_id": hit.source["post_id"], "url": hit.source.get("url", "") } ) elif "candid-help" in hit.index: doc = Document( page_content='\n\n'.join([ hit.source.get("combined_article_description", ""), get_context("content", hit, context_length=12) ]), metadata={ "title": hit.source.get("title", ""), "source": "Candid Help", "source_id": hit.source["id"], "url": hit.source.get("link", "") } ) elif "news" in hit.index: doc = Document( page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]), metadata={ "title": hit.source.get("title", ""), "source": hit.source.get("site_name") or "Candid News", "source_id": hit.source["id"], "url": hit.source.get("link", "") } ) else: raise ValueError(f"Unknown source result from index {hit.index}") return doc