Spaces:
Running
Running
File size: 5,934 Bytes
ef088c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
from typing import Any
from collections.abc import Iterator
from elasticsearch import Elasticsearch
from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder
from ask_candid.base.config.connections import BaseElasticAPIKeyCredential, BaseElasticSearchConnection
NEWS_TRUST_SCORE_THRESHOLD = 0.8
SPARSE_ENCODING_SCORE_THRESHOLD = 0.4
def build_sparse_vector_query(
query: str,
fields: tuple[str, ...],
inference_id: str = ".elser-2-elasticsearch"
) -> dict[str, Any]:
"""Builds a valid Elasticsearch text expansion query payload
Parameters
----------
query : str
Search context string
fields : Tuple[str, ...]
Semantic text field names
inference_id : str, optional
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
Returns
-------
Dict[str, Any]
"""
output = []
for f in fields:
output.append({
"nested": {
"path": f"embeddings.{f}.chunks",
"query": {
"sparse_vector": {
"field": f"embeddings.{f}.chunks.vector",
"inference_id": inference_id,
"prune": True,
"query": query,
# "boost": 1 / len(fields)
}
},
"inner_hits": {
"_source": False,
"size": 2,
"fields": [f"embeddings.{f}.chunks.chunk"]
}
}
})
return {"query": {"bool": {"should": output}}}
def build_sparse_vector_and_text_query(
query: str,
semantic_fields: tuple[str, ...],
text_fields: tuple[str, ...] | None,
highlight_fields: tuple[str, ...] | None,
excluded_fields: tuple[str, ...] | None,
inference_id: str = ".elser-2-elasticsearch"
) -> dict[str, Any]:
"""Builds Elasticsearch sparse vector and text query payload
Parameters
----------
query : str
Search context string
semantic_fields : Tuple[str]
Semantic text field names
highlight_fields: Tuple[str]
Fields which relevant chunks will be helpful for the agent to read
text_fields : Tuple[str]
Regular text fields
excluded_fields : Tuple[str]
Fields to exclude from the source
inference_id : str, optional
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
Returns
-------
Dict[str, Any]
"""
output = []
final_query = {}
for f in semantic_fields:
output.append({
"sparse_vector": {
"field": f"{f}",
"inference_id": inference_id,
"query": query,
"boost": 1,
"prune": True # doesn't seem it changes anything if we use text queries additionally
}
})
if text_fields:
output.append({
"multi_match": {
"fields": text_fields,
"query": query,
"boost": 3
}
})
final_query = {
"track_total_hits": False,
"query": {
"bool": {"should": output}
}
}
if highlight_fields:
final_query["highlight"] = {
"fields": {
f"{f}": {
"type": "semantic", # ensures that highlighting is applied exclusively to semantic_text fields.
"number_of_fragments": 2, # number of chunks
"order": "none" # can be "score", but we have only two and hope for context
}
for f in highlight_fields
}
}
if excluded_fields:
final_query["_source"] = {"excludes": list(excluded_fields)}
return final_query
def news_query_builder(
query: str,
fields: tuple[str, ...],
encoder: SpladeEncoder,
days_ago: int = 60,
) -> dict[str, Any]:
"""Builds a valid Elasticsearch query against Candid news, simulating a token expansion.
Parameters
----------
query : str
Search context string
Returns
-------
Dict[str, Any]
"""
tokens = encoder.token_expand(query)
elastic_query = {
"_source": ["id", "link", "title", "content", "site_name"],
"query": {
"bool": {
"filter": [
{"range": {"event_date": {"gte": f"now-{days_ago}d/d"}}},
{"range": {"insert_date": {"gte": f"now-{days_ago}d/d"}}},
{"range": {"article_trust_worthiness": {"gt": NEWS_TRUST_SCORE_THRESHOLD}}}
],
"should": []
}
}
}
for token, score in tokens.items():
if score > SPARSE_ENCODING_SCORE_THRESHOLD:
elastic_query["query"]["bool"]["should"].append({
"multi_match": {
"query": token,
"fields": fields,
"boost": score
}
})
return elastic_query
def multi_search_base(
queries: list[dict[str, Any]],
credentials: BaseElasticSearchConnection | BaseElasticAPIKeyCredential,
timeout: int = 180
) -> Iterator[dict[str, Any]]:
if isinstance(credentials, BaseElasticAPIKeyCredential):
es = Elasticsearch(
cloud_id=credentials.cloud_id,
api_key=credentials.api_key,
verify_certs=False,
request_timeout=timeout
)
elif isinstance(credentials, BaseElasticSearchConnection):
es = Elasticsearch(
credentials.url,
http_auth=(credentials.username, credentials.password),
timeout=timeout
)
else:
raise TypeError(f"Invalid credentials of type `{type(credentials)}")
yield from es.msearch(body=queries).get("responses", [])
es.close()
|