Spaces:
Running
Running
from typing import Any | |
from collections.abc import Iterator | |
from elasticsearch import Elasticsearch | |
from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder | |
from ask_candid.base.config.connections import BaseElasticAPIKeyCredential, BaseElasticSearchConnection | |
NEWS_TRUST_SCORE_THRESHOLD = 0.8 | |
SPARSE_ENCODING_SCORE_THRESHOLD = 0.4 | |
def build_sparse_vector_query( | |
query: str, | |
fields: tuple[str, ...], | |
inference_id: str = ".elser-2-elasticsearch" | |
) -> dict[str, Any]: | |
"""Builds a valid Elasticsearch text expansion query payload | |
Parameters | |
---------- | |
query : str | |
Search context string | |
fields : Tuple[str, ...] | |
Semantic text field names | |
inference_id : str, optional | |
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch" | |
Returns | |
------- | |
Dict[str, Any] | |
""" | |
output = [] | |
for f in fields: | |
output.append({ | |
"nested": { | |
"path": f"embeddings.{f}.chunks", | |
"query": { | |
"sparse_vector": { | |
"field": f"embeddings.{f}.chunks.vector", | |
"inference_id": inference_id, | |
"prune": True, | |
"query": query, | |
# "boost": 1 / len(fields) | |
} | |
}, | |
"inner_hits": { | |
"_source": False, | |
"size": 2, | |
"fields": [f"embeddings.{f}.chunks.chunk"] | |
} | |
} | |
}) | |
return {"query": {"bool": {"should": output}}} | |
def build_sparse_vector_and_text_query( | |
query: str, | |
semantic_fields: tuple[str, ...], | |
text_fields: tuple[str, ...] | None, | |
highlight_fields: tuple[str, ...] | None, | |
excluded_fields: tuple[str, ...] | None, | |
inference_id: str = ".elser-2-elasticsearch" | |
) -> dict[str, Any]: | |
"""Builds Elasticsearch sparse vector and text query payload | |
Parameters | |
---------- | |
query : str | |
Search context string | |
semantic_fields : Tuple[str] | |
Semantic text field names | |
highlight_fields: Tuple[str] | |
Fields which relevant chunks will be helpful for the agent to read | |
text_fields : Tuple[str] | |
Regular text fields | |
excluded_fields : Tuple[str] | |
Fields to exclude from the source | |
inference_id : str, optional | |
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch" | |
Returns | |
------- | |
Dict[str, Any] | |
""" | |
output = [] | |
final_query = {} | |
for f in semantic_fields: | |
output.append({ | |
"sparse_vector": { | |
"field": f"{f}", | |
"inference_id": inference_id, | |
"query": query, | |
"boost": 1, | |
"prune": True # doesn't seem it changes anything if we use text queries additionally | |
} | |
}) | |
if text_fields: | |
output.append({ | |
"multi_match": { | |
"fields": text_fields, | |
"query": query, | |
"boost": 3 | |
} | |
}) | |
final_query = { | |
"track_total_hits": False, | |
"query": { | |
"bool": {"should": output} | |
} | |
} | |
if highlight_fields: | |
final_query["highlight"] = { | |
"fields": { | |
f"{f}": { | |
"type": "semantic", # ensures that highlighting is applied exclusively to semantic_text fields. | |
"number_of_fragments": 2, # number of chunks | |
"order": "none" # can be "score", but we have only two and hope for context | |
} | |
for f in highlight_fields | |
} | |
} | |
if excluded_fields: | |
final_query["_source"] = {"excludes": list(excluded_fields)} | |
return final_query | |
def news_query_builder( | |
query: str, | |
fields: tuple[str, ...], | |
encoder: SpladeEncoder, | |
days_ago: int = 60, | |
) -> dict[str, Any]: | |
"""Builds a valid Elasticsearch query against Candid news, simulating a token expansion. | |
Parameters | |
---------- | |
query : str | |
Search context string | |
Returns | |
------- | |
Dict[str, Any] | |
""" | |
tokens = encoder.token_expand(query) | |
elastic_query = { | |
"_source": ["id", "link", "title", "content", "site_name"], | |
"query": { | |
"bool": { | |
"filter": [ | |
{"range": {"event_date": {"gte": f"now-{days_ago}d/d"}}}, | |
{"range": {"insert_date": {"gte": f"now-{days_ago}d/d"}}}, | |
{"range": {"article_trust_worthiness": {"gt": NEWS_TRUST_SCORE_THRESHOLD}}} | |
], | |
"should": [] | |
} | |
} | |
} | |
for token, score in tokens.items(): | |
if score > SPARSE_ENCODING_SCORE_THRESHOLD: | |
elastic_query["query"]["bool"]["should"].append({ | |
"multi_match": { | |
"query": token, | |
"fields": fields, | |
"boost": score | |
} | |
}) | |
return elastic_query | |
def multi_search_base( | |
queries: list[dict[str, Any]], | |
credentials: BaseElasticSearchConnection | BaseElasticAPIKeyCredential, | |
timeout: int = 180 | |
) -> Iterator[dict[str, Any]]: | |
if isinstance(credentials, BaseElasticAPIKeyCredential): | |
es = Elasticsearch( | |
cloud_id=credentials.cloud_id, | |
api_key=credentials.api_key, | |
verify_certs=False, | |
request_timeout=timeout | |
) | |
elif isinstance(credentials, BaseElasticSearchConnection): | |
es = Elasticsearch( | |
credentials.url, | |
http_auth=(credentials.username, credentials.password), | |
timeout=timeout | |
) | |
else: | |
raise TypeError(f"Invalid credentials of type `{type(credentials)}") | |
yield from es.msearch(body=queries).get("responses", []) | |
es.close() | |