from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any from itertools import groupby from torch.nn import functional as F from pydantic import BaseModel, Field from langchain_core.documents import Document from elasticsearch import Elasticsearch from ask_candid.retrieval.sparse_lexical import SpladeEncoder from ask_candid.retrieval.sources.schema import ElasticHitsResult from ask_candid.retrieval.sources.issuelab import IssueLabConfig, process_issuelab_hit from ask_candid.retrieval.sources.youtube import YoutubeConfig, process_youtube_hit from ask_candid.retrieval.sources.candid_blog import CandidBlogConfig, process_blog_hit from ask_candid.retrieval.sources.candid_learning import CandidLearningConfig, process_learning_hit from ask_candid.retrieval.sources.candid_help import CandidHelpConfig, process_help_hit from ask_candid.retrieval.sources.candid_news import CandidNewsConfig, process_news_hit from ask_candid.services.small_lm import CandidSLM from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC from ask_candid.base.config.data import DataIndices, ALL_INDICES encoder = SpladeEncoder() class RetrieverInput(BaseModel): """Input to the Elasticsearch retriever.""" user_input: str = Field(description="query to look up in retriever") def build_sparse_vector_query( query: str, fields: Tuple[str], inference_id: str = ".elser-2-elasticsearch" ) -> Dict[str, Any]: """Builds a valid Elasticsearch text expansion query payload Parameters ---------- query : str Search context string fields : Tuple[str] Semantic text field names inference_id : str, optional ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch" Returns ------- Dict[str, Any] """ output = [] for f in fields: output.append({ "nested": { "path": f"embeddings.{f}.chunks", "query": { "sparse_vector": { "field": f"embeddings.{f}.chunks.vector", "inference_id": inference_id, "prune": True, "query": query, "boost": 1 / len(fields) } }, "inner_hits": { "_source": False, "size": 2, "fields": [f"embeddings.{f}.chunks.chunk"] } } }) return {"query": {"bool": {"should": output}}} def news_query_builder(query: str) -> Dict[str, Any]: """Builds a valid Elasticsearch query against Candid news, simulating a token expansion. Parameters ---------- query : str Search context string Returns ------- Dict[str, Any] """ tokens = encoder.token_expand(query) query = { "_source": ["id", "link", "title", "content", "site_name"], "query": { "bool": { "filter": [ {"range": {"event_date": {"gte": "now-60d/d"}}}, {"range": {"insert_date": {"gte": "now-60d/d"}}}, {"range": {"article_trust_worthiness": {"gt": 0.8}}} ], "should": [] } } } for token, score in tokens.items(): if score > 0.4: query["query"]["bool"]["should"].append({ "multi_match": { "query": token, "fields": CandidNewsConfig.text_fields, "boost": score } }) return query def query_builder(query: str, indices: List[DataIndices]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: """Builds Elasticsearch multi-search query payload Parameters ---------- query : str Search context string indices : List[DataIndices] Semantic index names to search over Returns ------- Tuple[List[Dict[str, Any]], List[Dict[str, Any]]] (semantic index queries, news queries) """ queries, news_queries = [], [] if indices is None: indices = list(ALL_INDICES) for index in indices: if index == "issuelab": q = build_sparse_vector_query(query=query, fields=IssueLabConfig.text_fields) q["_source"] = {"excludes": ["embeddings"]} q["size"] = 2 queries.extend([{"index": IssueLabConfig.index_name}, q]) elif index == "youtube": q = build_sparse_vector_query(query=query, fields=YoutubeConfig.text_fields) q["_source"] = {"excludes": ["embeddings", *YoutubeConfig.excluded_fields]} q["size"] = 5 queries.extend([{"index": YoutubeConfig.index_name}, q]) elif index == "candid_blog": q = build_sparse_vector_query(query=query, fields=CandidBlogConfig.text_fields) q["_source"] = {"excludes": ["embeddings"]} q["size"] = 5 queries.extend([{"index": CandidBlogConfig.index_name}, q]) elif index == "candid_learning": q = build_sparse_vector_query(query=query, fields=CandidLearningConfig.text_fields) q["_source"] = {"excludes": ["embeddings"]} q["size"] = 5 queries.extend([{"index": CandidLearningConfig.index_name}, q]) elif index == "candid_help": q = build_sparse_vector_query(query=query, fields=CandidHelpConfig.text_fields) q["_source"] = {"excludes": ["embeddings"]} q["size"] = 5 queries.extend([{"index": CandidHelpConfig.index_name}, q]) elif index == "news": q = news_query_builder(query=query) q["size"] = 5 news_queries.extend([{"index": CandidNewsConfig.index_name}, q]) return queries, news_queries def multi_search( queries: List[Dict[str, Any]], news_queries: Optional[List[Dict[str, Any]]] = None ) -> List[ElasticHitsResult]: """Runs multi-search query Parameters ---------- queries : List[Dict[str, Any]] Pre-built multi-search query payload Returns ------- List[ElasticHitsResult] """ def _msearch_response_generator(responses: List[Dict[str, Any]]) -> Iterator[ElasticHitsResult]: for query_group in responses: for h in query_group.get("hits", {}).get("hits", []): inner_hits = h.get("inner_hits", {}) if not inner_hits: if "news" in h.get("_index"): inner_hits = {"text": h.get("_source", {}).get("content")} yield ElasticHitsResult( index=h["_index"], id=h["_id"], score=h["_score"], source=h["_source"], inner_hits=inner_hits ) results = [] if len(queries) > 0: with Elasticsearch( cloud_id=SEMANTIC_ELASTIC_QA.cloud_id, api_key=SEMANTIC_ELASTIC_QA.api_key, verify_certs=False, request_timeout=60 * 3 ) as es: for hit in _msearch_response_generator(es.msearch(body=queries).get("responses", [])): results.append(hit) if news_queries is not None and len(news_queries): with Elasticsearch( NEWS_ELASTIC.url, http_auth=(NEWS_ELASTIC.username, NEWS_ELASTIC.password), timeout=60 ) as es: for hit in _msearch_response_generator(es.msearch(body=news_queries).get("responses", [])): results.append(hit) return results def get_query_results(search_text: str, indices: Optional[List[str]] = None) -> List[ElasticHitsResult]: """Builds and executes Elasticsearch data queries from a search string. Parameters ---------- search_text : str Search context string indices : Optional[List[str]], optional Semantic index names to search over, by default None Returns ------- List[ElasticHitsResult] """ queries, news_q = query_builder(query=search_text, indices=indices) return multi_search(queries, news_queries=news_q) def retrieved_text(hits: Dict[str, Any]) -> str: """Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of re-scoring by a secondary language model. Parameters ---------- hits : Dict[str, Any] Returns ------- str """ text = [] for _, v in hits.items(): if _ == "text": text.append(v) continue for h in (v.get("hits", {}).get("hits") or []): for _, field in h.get("fields", {}).items(): for chunk in field: if chunk.get("chunk"): text.extend(chunk["chunk"]) return '\n'.join(text) def cosine_rescore(query: str, contexts: List[str]) -> List[float]: """Computes cosine scores between retrieved contexts and the original query to re-score results based on overall relevance to the original query. Parameters ---------- query : str Search context string contexts : List[str] Semantic field sub-texts, order is by document retrieved from the original multi-search query. Returns ------- List[float] Scores in the same order as the input document contexts """ nlp = CandidSLM() X = nlp.encode([query, *contexts]).vectors X = F.normalize(X, dim=-1, p=2.) cosine = X[1:] @ X[:1].T return cosine.flatten().cpu().numpy().tolist() def reranker( query_results: Iterable[ElasticHitsResult], search_text: Optional[str] = None, max_num_results: int = 5 ) -> Iterator[ElasticHitsResult]: """Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales. This will shuffle results Parameters ---------- query_results : Iterable[ElasticHitsResult] Yields ------ Iterator[ElasticHitsResult] """ results: List[ElasticHitsResult] = [] texts: List[str] = [] for _, data in groupby(query_results, key=lambda x: x.index): data = list(data) max_score = max(data, key=lambda x: x.score).score min_score = min(data, key=lambda x: x.score).score for d in data: d.score = (d.score - min_score) / (max_score - min_score + 1e-9) results.append(d) if search_text: text = retrieved_text(d.inner_hits) texts.append(text) if search_text and len(texts) == len(results): # scores = cosine_rescore(search_text, texts) scores = encoder.query_reranking(query=search_text, documents=texts) for r, s in zip(results, scores): r.score = s yield from sorted(results, key=lambda x: x.score, reverse=True)[:max_num_results] def process_hit(hit: ElasticHitsResult) -> Union[Document, None]: """Parse Elasticsearch hit results into data structures handled by the RAG pipeline. Parameters ---------- hit : ElasticHitsResult Returns ------- Union[Document, None] """ if "issuelab-elser" in hit.index: doc = process_issuelab_hit(hit) elif "youtube" in hit.index: doc = process_youtube_hit(hit) elif "candid-blog" in hit.index: doc = process_blog_hit(hit) elif "candid-learning" in hit.index: doc = process_learning_hit(hit) elif "candid-help" in hit.index: doc = process_help_hit(hit) elif "news" in hit.index: doc = process_news_hit(hit) else: doc = None return doc def get_reranked_results(results: List[ElasticHitsResult], search_text: Optional[str] = None) -> List[Document]: """Run data re-ranking and document building for tool usage. Parameters ---------- results : List[ElasticHitsResult] search_text : Optional[str], optional Search context string, by default None Returns ------- List[Document] """ output = [] for r in reranker(results, search_text=search_text): hit = process_hit(r) if hit is not None: output.append(hit) return output