from typing import Any from collections.abc import Iterator from elasticsearch import Elasticsearch from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder from ask_candid.base.config.connections import BaseElasticAPIKeyCredential, BaseElasticSearchConnection NEWS_TRUST_SCORE_THRESHOLD = 0.8 SPARSE_ENCODING_SCORE_THRESHOLD = 0.4 def build_sparse_vector_query( query: str, fields: tuple[str, ...], inference_id: str = ".elser-2-elasticsearch" ) -> dict[str, Any]: """Builds a valid Elasticsearch text expansion query payload Parameters ---------- query : str Search context string fields : Tuple[str, ...] Semantic text field names inference_id : str, optional ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch" Returns ------- Dict[str, Any] """ output = [] for f in fields: output.append({ "nested": { "path": f"embeddings.{f}.chunks", "query": { "sparse_vector": { "field": f"embeddings.{f}.chunks.vector", "inference_id": inference_id, "prune": True, "query": query, # "boost": 1 / len(fields) } }, "inner_hits": { "_source": False, "size": 2, "fields": [f"embeddings.{f}.chunks.chunk"] } } }) return {"query": {"bool": {"should": output}}} def build_sparse_vector_and_text_query( query: str, semantic_fields: tuple[str, ...], text_fields: tuple[str, ...] | None, highlight_fields: tuple[str, ...] | None, excluded_fields: tuple[str, ...] | None, inference_id: str = ".elser-2-elasticsearch" ) -> dict[str, Any]: """Builds Elasticsearch sparse vector and text query payload Parameters ---------- query : str Search context string semantic_fields : Tuple[str] Semantic text field names highlight_fields: Tuple[str] Fields which relevant chunks will be helpful for the agent to read text_fields : Tuple[str] Regular text fields excluded_fields : Tuple[str] Fields to exclude from the source inference_id : str, optional ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch" Returns ------- Dict[str, Any] """ output = [] final_query = {} for f in semantic_fields: output.append({ "sparse_vector": { "field": f"{f}", "inference_id": inference_id, "query": query, "boost": 1, "prune": True # doesn't seem it changes anything if we use text queries additionally } }) if text_fields: output.append({ "multi_match": { "fields": text_fields, "query": query, "boost": 3 } }) final_query = { "track_total_hits": False, "query": { "bool": {"should": output} } } if highlight_fields: final_query["highlight"] = { "fields": { f"{f}": { "type": "semantic", # ensures that highlighting is applied exclusively to semantic_text fields. "number_of_fragments": 2, # number of chunks "order": "none" # can be "score", but we have only two and hope for context } for f in highlight_fields } } if excluded_fields: final_query["_source"] = {"excludes": list(excluded_fields)} return final_query def news_query_builder( query: str, fields: tuple[str, ...], encoder: SpladeEncoder, days_ago: int = 60, ) -> dict[str, Any]: """Builds a valid Elasticsearch query against Candid news, simulating a token expansion. Parameters ---------- query : str Search context string Returns ------- Dict[str, Any] """ tokens = encoder.token_expand(query) elastic_query = { "_source": ["id", "link", "title", "content", "site_name"], "query": { "bool": { "filter": [ {"range": {"event_date": {"gte": f"now-{days_ago}d/d"}}}, {"range": {"insert_date": {"gte": f"now-{days_ago}d/d"}}}, {"range": {"article_trust_worthiness": {"gt": NEWS_TRUST_SCORE_THRESHOLD}}} ], "should": [] } } } for token, score in tokens.items(): if score > SPARSE_ENCODING_SCORE_THRESHOLD: elastic_query["query"]["bool"]["should"].append({ "multi_match": { "query": token, "fields": fields, "boost": score } }) return elastic_query def multi_search_base( queries: list[dict[str, Any]], credentials: BaseElasticSearchConnection | BaseElasticAPIKeyCredential, timeout: int = 180 ) -> Iterator[dict[str, Any]]: if isinstance(credentials, BaseElasticAPIKeyCredential): es = Elasticsearch( cloud_id=credentials.cloud_id, api_key=credentials.api_key, verify_certs=False, request_timeout=timeout ) elif isinstance(credentials, BaseElasticSearchConnection): es = Elasticsearch( credentials.url, http_auth=(credentials.username, credentials.password), timeout=timeout ) else: raise TypeError(f"Invalid credentials of type `{type(credentials)}") yield from es.msearch(body=queries).get("responses", []) es.close()