brainsqueeze's picture
Batching
2744d22 verified
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
from itertools import groupby
from torch.nn import functional as F
from pydantic import BaseModel, Field
from langchain_core.documents import Document
from elasticsearch import Elasticsearch
from ask_candid.retrieval.sparse_lexical import SpladeEncoder
from ask_candid.retrieval.sources.schema import ElasticHitsResult
from ask_candid.retrieval.sources.issuelab import IssueLabConfig, process_issuelab_hit
from ask_candid.retrieval.sources.youtube import YoutubeConfig, process_youtube_hit
from ask_candid.retrieval.sources.candid_blog import CandidBlogConfig, process_blog_hit
from ask_candid.retrieval.sources.candid_learning import CandidLearningConfig, process_learning_hit
from ask_candid.retrieval.sources.candid_help import CandidHelpConfig, process_help_hit
from ask_candid.retrieval.sources.candid_news import CandidNewsConfig, process_news_hit
from ask_candid.services.small_lm import CandidSLM
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
from ask_candid.base.config.data import DataIndices, ALL_INDICES
encoder = SpladeEncoder()
class RetrieverInput(BaseModel):
"""Input to the Elasticsearch retriever."""
user_input: str = Field(description="query to look up in retriever")
def build_sparse_vector_query(
query: str,
fields: Tuple[str],
inference_id: str = ".elser-2-elasticsearch"
) -> Dict[str, Any]:
"""Builds a valid Elasticsearch text expansion query payload
Parameters
----------
query : str
Search context string
fields : Tuple[str]
Semantic text field names
inference_id : str, optional
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
Returns
-------
Dict[str, Any]
"""
output = []
for f in fields:
output.append({
"nested": {
"path": f"embeddings.{f}.chunks",
"query": {
"sparse_vector": {
"field": f"embeddings.{f}.chunks.vector",
"inference_id": inference_id,
"prune": True,
"query": query,
"boost": 1 / len(fields)
}
},
"inner_hits": {
"_source": False,
"size": 2,
"fields": [f"embeddings.{f}.chunks.chunk"]
}
}
})
return {"query": {"bool": {"should": output}}}
def news_query_builder(query: str) -> Dict[str, Any]:
"""Builds a valid Elasticsearch query against Candid news, simulating a token expansion.
Parameters
----------
query : str
Search context string
Returns
-------
Dict[str, Any]
"""
tokens = encoder.token_expand(query)
query = {
"_source": ["id", "link", "title", "content", "site_name"],
"query": {
"bool": {
"filter": [
{"range": {"event_date": {"gte": "now-60d/d"}}},
{"range": {"insert_date": {"gte": "now-60d/d"}}},
{"range": {"article_trust_worthiness": {"gt": 0.8}}}
],
"should": []
}
}
}
for token, score in tokens.items():
if score > 0.4:
query["query"]["bool"]["should"].append({
"multi_match": {
"query": token,
"fields": CandidNewsConfig.text_fields,
"boost": score
}
})
return query
def query_builder(query: str, indices: List[DataIndices]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Builds Elasticsearch multi-search query payload
Parameters
----------
query : str
Search context string
indices : List[DataIndices]
Semantic index names to search over
Returns
-------
Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]
(semantic index queries, news queries)
"""
queries, news_queries = [], []
if indices is None:
indices = list(ALL_INDICES)
for index in indices:
if index == "issuelab":
q = build_sparse_vector_query(query=query, fields=IssueLabConfig.text_fields)
q["_source"] = {"excludes": ["embeddings"]}
q["size"] = 2
queries.extend([{"index": IssueLabConfig.index_name}, q])
elif index == "youtube":
q = build_sparse_vector_query(query=query, fields=YoutubeConfig.text_fields)
q["_source"] = {"excludes": ["embeddings", *YoutubeConfig.excluded_fields]}
q["size"] = 5
queries.extend([{"index": YoutubeConfig.index_name}, q])
elif index == "candid_blog":
q = build_sparse_vector_query(query=query, fields=CandidBlogConfig.text_fields)
q["_source"] = {"excludes": ["embeddings"]}
q["size"] = 5
queries.extend([{"index": CandidBlogConfig.index_name}, q])
elif index == "candid_learning":
q = build_sparse_vector_query(query=query, fields=CandidLearningConfig.text_fields)
q["_source"] = {"excludes": ["embeddings"]}
q["size"] = 5
queries.extend([{"index": CandidLearningConfig.index_name}, q])
elif index == "candid_help":
q = build_sparse_vector_query(query=query, fields=CandidHelpConfig.text_fields)
q["_source"] = {"excludes": ["embeddings"]}
q["size"] = 5
queries.extend([{"index": CandidHelpConfig.index_name}, q])
elif index == "news":
q = news_query_builder(query=query)
q["size"] = 5
news_queries.extend([{"index": CandidNewsConfig.index_name}, q])
return queries, news_queries
def multi_search(
queries: List[Dict[str, Any]],
news_queries: Optional[List[Dict[str, Any]]] = None
) -> List[ElasticHitsResult]:
"""Runs multi-search query
Parameters
----------
queries : List[Dict[str, Any]]
Pre-built multi-search query payload
Returns
-------
List[ElasticHitsResult]
"""
def _msearch_response_generator(responses: List[Dict[str, Any]]) -> Iterator[ElasticHitsResult]:
for query_group in responses:
for h in query_group.get("hits", {}).get("hits", []):
inner_hits = h.get("inner_hits", {})
if not inner_hits:
if "news" in h.get("_index"):
inner_hits = {"text": h.get("_source", {}).get("content")}
yield ElasticHitsResult(
index=h["_index"],
id=h["_id"],
score=h["_score"],
source=h["_source"],
inner_hits=inner_hits
)
results = []
if len(queries) > 0:
with Elasticsearch(
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
api_key=SEMANTIC_ELASTIC_QA.api_key,
verify_certs=False,
request_timeout=60 * 3
) as es:
for hit in _msearch_response_generator(es.msearch(body=queries).get("responses", [])):
results.append(hit)
if news_queries is not None and len(news_queries):
with Elasticsearch(
NEWS_ELASTIC.url,
http_auth=(NEWS_ELASTIC.username, NEWS_ELASTIC.password),
timeout=60
) as es:
for hit in _msearch_response_generator(es.msearch(body=news_queries).get("responses", [])):
results.append(hit)
return results
def get_query_results(search_text: str, indices: Optional[List[str]] = None) -> List[ElasticHitsResult]:
"""Builds and executes Elasticsearch data queries from a search string.
Parameters
----------
search_text : str
Search context string
indices : Optional[List[str]], optional
Semantic index names to search over, by default None
Returns
-------
List[ElasticHitsResult]
"""
queries, news_q = query_builder(query=search_text, indices=indices)
return multi_search(queries, news_queries=news_q)
def retrieved_text(hits: Dict[str, Any]) -> str:
"""Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of
re-scoring by a secondary language model.
Parameters
----------
hits : Dict[str, Any]
Returns
-------
str
"""
text = []
for _, v in hits.items():
if _ == "text":
text.append(v)
continue
for h in (v.get("hits", {}).get("hits") or []):
for _, field in h.get("fields", {}).items():
for chunk in field:
if chunk.get("chunk"):
text.extend(chunk["chunk"])
return '\n'.join(text)
def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
"""Computes cosine scores between retrieved contexts and the original query to re-score results based on overall
relevance to the original query.
Parameters
----------
query : str
Search context string
contexts : List[str]
Semantic field sub-texts, order is by document retrieved from the original multi-search query.
Returns
-------
List[float]
Scores in the same order as the input document contexts
"""
nlp = CandidSLM()
X = nlp.encode([query, *contexts]).vectors
X = F.normalize(X, dim=-1, p=2.)
cosine = X[1:] @ X[:1].T
return cosine.flatten().cpu().numpy().tolist()
def reranker(
query_results: Iterable[ElasticHitsResult],
search_text: Optional[str] = None,
max_num_results: int = 5
) -> Iterator[ElasticHitsResult]:
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
This will shuffle results
Parameters
----------
query_results : Iterable[ElasticHitsResult]
Yields
------
Iterator[ElasticHitsResult]
"""
results: List[ElasticHitsResult] = []
texts: List[str] = []
for _, data in groupby(query_results, key=lambda x: x.index):
data = list(data)
max_score = max(data, key=lambda x: x.score).score
min_score = min(data, key=lambda x: x.score).score
for d in data:
d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
results.append(d)
if search_text:
text = retrieved_text(d.inner_hits)
texts.append(text)
if search_text and len(texts) == len(results):
# scores = cosine_rescore(search_text, texts)
scores = encoder.query_reranking(query=search_text, documents=texts)
for r, s in zip(results, scores):
r.score = s
yield from sorted(results, key=lambda x: x.score, reverse=True)[:max_num_results]
def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
"""Parse Elasticsearch hit results into data structures handled by the RAG pipeline.
Parameters
----------
hit : ElasticHitsResult
Returns
-------
Union[Document, None]
"""
if "issuelab-elser" in hit.index:
doc = process_issuelab_hit(hit)
elif "youtube" in hit.index:
doc = process_youtube_hit(hit)
elif "candid-blog" in hit.index:
doc = process_blog_hit(hit)
elif "candid-learning" in hit.index:
doc = process_learning_hit(hit)
elif "candid-help" in hit.index:
doc = process_help_hit(hit)
elif "news" in hit.index:
doc = process_news_hit(hit)
else:
doc = None
return doc
def get_reranked_results(results: List[ElasticHitsResult], search_text: Optional[str] = None) -> List[Document]:
"""Run data re-ranking and document building for tool usage.
Parameters
----------
results : List[ElasticHitsResult]
search_text : Optional[str], optional
Search context string, by default None
Returns
-------
List[Document]
"""
output = []
for r in reranker(results, search_text=search_text):
hit = process_hit(r)
if hit is not None:
output.append(hit)
return output