Spaces:
Running
Running
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any | |
from itertools import groupby | |
from torch.nn import functional as F | |
from pydantic import BaseModel, Field | |
from langchain_core.documents import Document | |
from elasticsearch import Elasticsearch | |
from ask_candid.retrieval.sparse_lexical import SpladeEncoder | |
from ask_candid.retrieval.sources.schema import ElasticHitsResult | |
from ask_candid.retrieval.sources.issuelab import IssueLabConfig, process_issuelab_hit | |
from ask_candid.retrieval.sources.youtube import YoutubeConfig, process_youtube_hit | |
from ask_candid.retrieval.sources.candid_blog import CandidBlogConfig, process_blog_hit | |
from ask_candid.retrieval.sources.candid_learning import CandidLearningConfig, process_learning_hit | |
from ask_candid.retrieval.sources.candid_help import CandidHelpConfig, process_help_hit | |
from ask_candid.retrieval.sources.candid_news import CandidNewsConfig, process_news_hit | |
from ask_candid.services.small_lm import CandidSLM | |
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC | |
from ask_candid.base.config.data import DataIndices, ALL_INDICES | |
encoder = SpladeEncoder() | |
class RetrieverInput(BaseModel): | |
"""Input to the Elasticsearch retriever.""" | |
user_input: str = Field(description="query to look up in retriever") | |
def build_sparse_vector_query( | |
query: str, | |
fields: Tuple[str], | |
inference_id: str = ".elser-2-elasticsearch" | |
) -> Dict[str, Any]: | |
"""Builds a valid Elasticsearch text expansion query payload | |
Parameters | |
---------- | |
query : str | |
Search context string | |
fields : Tuple[str] | |
Semantic text field names | |
inference_id : str, optional | |
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch" | |
Returns | |
------- | |
Dict[str, Any] | |
""" | |
output = [] | |
for f in fields: | |
output.append({ | |
"nested": { | |
"path": f"embeddings.{f}.chunks", | |
"query": { | |
"sparse_vector": { | |
"field": f"embeddings.{f}.chunks.vector", | |
"inference_id": inference_id, | |
"prune": True, | |
"query": query, | |
"boost": 1 / len(fields) | |
} | |
}, | |
"inner_hits": { | |
"_source": False, | |
"size": 2, | |
"fields": [f"embeddings.{f}.chunks.chunk"] | |
} | |
} | |
}) | |
return {"query": {"bool": {"should": output}}} | |
def news_query_builder(query: str) -> Dict[str, Any]: | |
"""Builds a valid Elasticsearch query against Candid news, simulating a token expansion. | |
Parameters | |
---------- | |
query : str | |
Search context string | |
Returns | |
------- | |
Dict[str, Any] | |
""" | |
tokens = encoder.token_expand(query) | |
query = { | |
"_source": ["id", "link", "title", "content", "site_name"], | |
"query": { | |
"bool": { | |
"filter": [ | |
{"range": {"event_date": {"gte": "now-60d/d"}}}, | |
{"range": {"insert_date": {"gte": "now-60d/d"}}}, | |
{"range": {"article_trust_worthiness": {"gt": 0.8}}} | |
], | |
"should": [] | |
} | |
} | |
} | |
for token, score in tokens.items(): | |
if score > 0.4: | |
query["query"]["bool"]["should"].append({ | |
"multi_match": { | |
"query": token, | |
"fields": CandidNewsConfig.text_fields, | |
"boost": score | |
} | |
}) | |
return query | |
def query_builder(query: str, indices: List[DataIndices]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: | |
"""Builds Elasticsearch multi-search query payload | |
Parameters | |
---------- | |
query : str | |
Search context string | |
indices : List[DataIndices] | |
Semantic index names to search over | |
Returns | |
------- | |
Tuple[List[Dict[str, Any]], List[Dict[str, Any]]] | |
(semantic index queries, news queries) | |
""" | |
queries, news_queries = [], [] | |
if indices is None: | |
indices = list(ALL_INDICES) | |
for index in indices: | |
if index == "issuelab": | |
q = build_sparse_vector_query(query=query, fields=IssueLabConfig.text_fields) | |
q["_source"] = {"excludes": ["embeddings"]} | |
q["size"] = 2 | |
queries.extend([{"index": IssueLabConfig.index_name}, q]) | |
elif index == "youtube": | |
q = build_sparse_vector_query(query=query, fields=YoutubeConfig.text_fields) | |
q["_source"] = {"excludes": ["embeddings", *YoutubeConfig.excluded_fields]} | |
q["size"] = 5 | |
queries.extend([{"index": YoutubeConfig.index_name}, q]) | |
elif index == "candid_blog": | |
q = build_sparse_vector_query(query=query, fields=CandidBlogConfig.text_fields) | |
q["_source"] = {"excludes": ["embeddings"]} | |
q["size"] = 5 | |
queries.extend([{"index": CandidBlogConfig.index_name}, q]) | |
elif index == "candid_learning": | |
q = build_sparse_vector_query(query=query, fields=CandidLearningConfig.text_fields) | |
q["_source"] = {"excludes": ["embeddings"]} | |
q["size"] = 5 | |
queries.extend([{"index": CandidLearningConfig.index_name}, q]) | |
elif index == "candid_help": | |
q = build_sparse_vector_query(query=query, fields=CandidHelpConfig.text_fields) | |
q["_source"] = {"excludes": ["embeddings"]} | |
q["size"] = 5 | |
queries.extend([{"index": CandidHelpConfig.index_name}, q]) | |
elif index == "news": | |
q = news_query_builder(query=query) | |
q["size"] = 5 | |
news_queries.extend([{"index": CandidNewsConfig.index_name}, q]) | |
return queries, news_queries | |
def multi_search( | |
queries: List[Dict[str, Any]], | |
news_queries: Optional[List[Dict[str, Any]]] = None | |
) -> List[ElasticHitsResult]: | |
"""Runs multi-search query | |
Parameters | |
---------- | |
queries : List[Dict[str, Any]] | |
Pre-built multi-search query payload | |
Returns | |
------- | |
List[ElasticHitsResult] | |
""" | |
def _msearch_response_generator(responses: List[Dict[str, Any]]) -> Iterator[ElasticHitsResult]: | |
for query_group in responses: | |
for h in query_group.get("hits", {}).get("hits", []): | |
inner_hits = h.get("inner_hits", {}) | |
if not inner_hits: | |
if "news" in h.get("_index"): | |
inner_hits = {"text": h.get("_source", {}).get("content")} | |
yield ElasticHitsResult( | |
index=h["_index"], | |
id=h["_id"], | |
score=h["_score"], | |
source=h["_source"], | |
inner_hits=inner_hits | |
) | |
results = [] | |
if len(queries) > 0: | |
with Elasticsearch( | |
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id, | |
api_key=SEMANTIC_ELASTIC_QA.api_key, | |
verify_certs=False, | |
request_timeout=60 * 3 | |
) as es: | |
for hit in _msearch_response_generator(es.msearch(body=queries).get("responses", [])): | |
results.append(hit) | |
if news_queries is not None and len(news_queries): | |
with Elasticsearch( | |
NEWS_ELASTIC.url, | |
http_auth=(NEWS_ELASTIC.username, NEWS_ELASTIC.password), | |
timeout=60 | |
) as es: | |
for hit in _msearch_response_generator(es.msearch(body=news_queries).get("responses", [])): | |
results.append(hit) | |
return results | |
def get_query_results(search_text: str, indices: Optional[List[str]] = None) -> List[ElasticHitsResult]: | |
"""Builds and executes Elasticsearch data queries from a search string. | |
Parameters | |
---------- | |
search_text : str | |
Search context string | |
indices : Optional[List[str]], optional | |
Semantic index names to search over, by default None | |
Returns | |
------- | |
List[ElasticHitsResult] | |
""" | |
queries, news_q = query_builder(query=search_text, indices=indices) | |
return multi_search(queries, news_queries=news_q) | |
def retrieved_text(hits: Dict[str, Any]) -> str: | |
"""Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of | |
re-scoring by a secondary language model. | |
Parameters | |
---------- | |
hits : Dict[str, Any] | |
Returns | |
------- | |
str | |
""" | |
text = [] | |
for _, v in hits.items(): | |
if _ == "text": | |
text.append(v) | |
continue | |
for h in (v.get("hits", {}).get("hits") or []): | |
for _, field in h.get("fields", {}).items(): | |
for chunk in field: | |
if chunk.get("chunk"): | |
text.extend(chunk["chunk"]) | |
return '\n'.join(text) | |
def cosine_rescore(query: str, contexts: List[str]) -> List[float]: | |
"""Computes cosine scores between retrieved contexts and the original query to re-score results based on overall | |
relevance to the original query. | |
Parameters | |
---------- | |
query : str | |
Search context string | |
contexts : List[str] | |
Semantic field sub-texts, order is by document retrieved from the original multi-search query. | |
Returns | |
------- | |
List[float] | |
Scores in the same order as the input document contexts | |
""" | |
nlp = CandidSLM() | |
X = nlp.encode([query, *contexts]).vectors | |
X = F.normalize(X, dim=-1, p=2.) | |
cosine = X[1:] @ X[:1].T | |
return cosine.flatten().cpu().numpy().tolist() | |
def reranker( | |
query_results: Iterable[ElasticHitsResult], | |
search_text: Optional[str] = None, | |
max_num_results: int = 5 | |
) -> Iterator[ElasticHitsResult]: | |
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales. | |
This will shuffle results | |
Parameters | |
---------- | |
query_results : Iterable[ElasticHitsResult] | |
Yields | |
------ | |
Iterator[ElasticHitsResult] | |
""" | |
results: List[ElasticHitsResult] = [] | |
texts: List[str] = [] | |
for _, data in groupby(query_results, key=lambda x: x.index): | |
data = list(data) | |
max_score = max(data, key=lambda x: x.score).score | |
min_score = min(data, key=lambda x: x.score).score | |
for d in data: | |
d.score = (d.score - min_score) / (max_score - min_score + 1e-9) | |
results.append(d) | |
if search_text: | |
text = retrieved_text(d.inner_hits) | |
texts.append(text) | |
if search_text and len(texts) == len(results): | |
# scores = cosine_rescore(search_text, texts) | |
scores = encoder.query_reranking(query=search_text, documents=texts) | |
for r, s in zip(results, scores): | |
r.score = s | |
yield from sorted(results, key=lambda x: x.score, reverse=True)[:max_num_results] | |
def process_hit(hit: ElasticHitsResult) -> Union[Document, None]: | |
"""Parse Elasticsearch hit results into data structures handled by the RAG pipeline. | |
Parameters | |
---------- | |
hit : ElasticHitsResult | |
Returns | |
------- | |
Union[Document, None] | |
""" | |
if "issuelab-elser" in hit.index: | |
doc = process_issuelab_hit(hit) | |
elif "youtube" in hit.index: | |
doc = process_youtube_hit(hit) | |
elif "candid-blog" in hit.index: | |
doc = process_blog_hit(hit) | |
elif "candid-learning" in hit.index: | |
doc = process_learning_hit(hit) | |
elif "candid-help" in hit.index: | |
doc = process_help_hit(hit) | |
elif "news" in hit.index: | |
doc = process_news_hit(hit) | |
else: | |
doc = None | |
return doc | |
def get_reranked_results(results: List[ElasticHitsResult], search_text: Optional[str] = None) -> List[Document]: | |
"""Run data re-ranking and document building for tool usage. | |
Parameters | |
---------- | |
results : List[ElasticHitsResult] | |
search_text : Optional[str], optional | |
Search context string, by default None | |
Returns | |
------- | |
List[Document] | |
""" | |
output = [] | |
for r in reranker(results, search_text=search_text): | |
hit = process_hit(r) | |
if hit is not None: | |
output.append(hit) | |
return output | |