brainsqueeze's picture
v2 of public chat
ef088c2 verified
raw
history blame
5.93 kB
from typing import Any
from collections.abc import Iterator
from elasticsearch import Elasticsearch
from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder
from ask_candid.base.config.connections import BaseElasticAPIKeyCredential, BaseElasticSearchConnection
NEWS_TRUST_SCORE_THRESHOLD = 0.8
SPARSE_ENCODING_SCORE_THRESHOLD = 0.4
def build_sparse_vector_query(
query: str,
fields: tuple[str, ...],
inference_id: str = ".elser-2-elasticsearch"
) -> dict[str, Any]:
"""Builds a valid Elasticsearch text expansion query payload
Parameters
----------
query : str
Search context string
fields : Tuple[str, ...]
Semantic text field names
inference_id : str, optional
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
Returns
-------
Dict[str, Any]
"""
output = []
for f in fields:
output.append({
"nested": {
"path": f"embeddings.{f}.chunks",
"query": {
"sparse_vector": {
"field": f"embeddings.{f}.chunks.vector",
"inference_id": inference_id,
"prune": True,
"query": query,
# "boost": 1 / len(fields)
}
},
"inner_hits": {
"_source": False,
"size": 2,
"fields": [f"embeddings.{f}.chunks.chunk"]
}
}
})
return {"query": {"bool": {"should": output}}}
def build_sparse_vector_and_text_query(
query: str,
semantic_fields: tuple[str, ...],
text_fields: tuple[str, ...] | None,
highlight_fields: tuple[str, ...] | None,
excluded_fields: tuple[str, ...] | None,
inference_id: str = ".elser-2-elasticsearch"
) -> dict[str, Any]:
"""Builds Elasticsearch sparse vector and text query payload
Parameters
----------
query : str
Search context string
semantic_fields : Tuple[str]
Semantic text field names
highlight_fields: Tuple[str]
Fields which relevant chunks will be helpful for the agent to read
text_fields : Tuple[str]
Regular text fields
excluded_fields : Tuple[str]
Fields to exclude from the source
inference_id : str, optional
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
Returns
-------
Dict[str, Any]
"""
output = []
final_query = {}
for f in semantic_fields:
output.append({
"sparse_vector": {
"field": f"{f}",
"inference_id": inference_id,
"query": query,
"boost": 1,
"prune": True # doesn't seem it changes anything if we use text queries additionally
}
})
if text_fields:
output.append({
"multi_match": {
"fields": text_fields,
"query": query,
"boost": 3
}
})
final_query = {
"track_total_hits": False,
"query": {
"bool": {"should": output}
}
}
if highlight_fields:
final_query["highlight"] = {
"fields": {
f"{f}": {
"type": "semantic", # ensures that highlighting is applied exclusively to semantic_text fields.
"number_of_fragments": 2, # number of chunks
"order": "none" # can be "score", but we have only two and hope for context
}
for f in highlight_fields
}
}
if excluded_fields:
final_query["_source"] = {"excludes": list(excluded_fields)}
return final_query
def news_query_builder(
query: str,
fields: tuple[str, ...],
encoder: SpladeEncoder,
days_ago: int = 60,
) -> dict[str, Any]:
"""Builds a valid Elasticsearch query against Candid news, simulating a token expansion.
Parameters
----------
query : str
Search context string
Returns
-------
Dict[str, Any]
"""
tokens = encoder.token_expand(query)
elastic_query = {
"_source": ["id", "link", "title", "content", "site_name"],
"query": {
"bool": {
"filter": [
{"range": {"event_date": {"gte": f"now-{days_ago}d/d"}}},
{"range": {"insert_date": {"gte": f"now-{days_ago}d/d"}}},
{"range": {"article_trust_worthiness": {"gt": NEWS_TRUST_SCORE_THRESHOLD}}}
],
"should": []
}
}
}
for token, score in tokens.items():
if score > SPARSE_ENCODING_SCORE_THRESHOLD:
elastic_query["query"]["bool"]["should"].append({
"multi_match": {
"query": token,
"fields": fields,
"boost": score
}
})
return elastic_query
def multi_search_base(
queries: list[dict[str, Any]],
credentials: BaseElasticSearchConnection | BaseElasticAPIKeyCredential,
timeout: int = 180
) -> Iterator[dict[str, Any]]:
if isinstance(credentials, BaseElasticAPIKeyCredential):
es = Elasticsearch(
cloud_id=credentials.cloud_id,
api_key=credentials.api_key,
verify_certs=False,
request_timeout=timeout
)
elif isinstance(credentials, BaseElasticSearchConnection):
es = Elasticsearch(
credentials.url,
http_auth=(credentials.username, credentials.password),
timeout=timeout
)
else:
raise TypeError(f"Invalid credentials of type `{type(credentials)}")
yield from es.msearch(body=queries).get("responses", [])
es.close()