from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Any from dataclasses import dataclass from functools import partial from itertools import groupby from pydantic import BaseModel, Field from langchain_core.documents import Document from langchain_core.tools import Tool from elasticsearch import Elasticsearch try: # from news import build_knn_query as news_query # from up_orgs import build_organizations_knn_model_query as org_query # from cds import build_transactions_knn_query as transactions_query from config import ONECANDID_QA, ALL_INDICES, Indices except ImportError: # from .news import build_knn_query as news_query # from .up_orgs import build_organizations_knn_model_query as org_query # from .cds import build_transactions_knn_query as transactions_query from .config import ONECANDID_QA, ALL_INDICES, Indices @dataclass class ElasticHitsResult: """Dataclass for Elasticsearch hits results """ index: str id: Any score: float source: Dict[str, Any] inner_hits: Dict[str, Any] class RetrieverInput(BaseModel): """Input to the Elasticsearch retriever.""" user_input: str = Field(description="query to look up in retriever") def build_text_expansion_query( query: str, fields: Tuple[str], model_id: str = ".elser_model_2_linux-x86_64" ) -> Dict[str, Any]: output = [] for f in fields: output.append({ "nested": { "path": f"embeddings.{f}.chunks", "query": { "text_expansion": { f"embeddings.{f}.chunks.vector": { "model_id": model_id, "model_text": query, "boost": 1 / len(fields) } } }, "inner_hits": { "_source": False, "size": 2, "fields": [f"embeddings.{f}.chunks.chunk"] } } }) return {"query": {"bool": {"should": output}}} def query_builder(query: str, indices: List[str], **kwargs): queries = [] if indices is None: indices = list(ALL_INDICES) for index in indices: if index == "news": # q = news_query(query) # q["_source"] = {"excludes": ["embeddings"]} # q["size"] = 5 # queries.extend([{"index": Indices.NEWS_INDEX}, q]) pass elif index == "organizations": # q = org_query(query) # q["_source"] = {"excludes": ["embeddings"]} # q["size"] = 10 # queries.extend([{"index": Indices.ORGANIZATION_INDEX}, q]) pass elif index == "grants": # q = transactions_query(query) # q["_source"] = {"excludes": ["embeddings"]} # q["size"] = 10 # queries.extend([{"index": Indices.TRANSACTION_INDEX}, q]) pass elif index == "issuelab": q = build_text_expansion_query( query=query, fields=("description", "content", "combined_issuelab_findings", "combined_item_description") ) q["_source"] = {"excludes": ["embeddings"]} q["size"] = 1 queries.extend([{"index": Indices.ISSUELAB_INDEX_ELSER}, q]) elif index == "youtube": q = build_text_expansion_query( query=query, fields=("captions_cleaned", "description_cleaned", "title") ) # text_cleaned duplicates captions_cleaned q["_source"] = {"excludes": ["embeddings", "captions", "description", "text_cleaned"]} q["size"] = 2 queries.extend([{"index": Indices.YOUTUBE_INDEX_ELSER}, q]) elif index == "candid_blog": q = build_text_expansion_query( query=query, fields=("content", "title") ) q["_source"] = {"excludes": ["embeddings"]} q["size"] = 2 queries.extend([{"index": Indices.CANDID_BLOG_INDEX_ELSER}, q]) elif index == "candid_learning": q = build_text_expansion_query( query=query, fields=("content", "title", "training_topics", "staff_recommendations") ) q["_source"] = {"excludes": ["embeddings"]} q["size"] = 2 queries.extend([{"index": Indices.CANDID_LEARNING_INDEX_ELSER}, q]) elif index == "candid_help": q = build_text_expansion_query( query=query, fields=("content", "combined_article_description") ) q["_source"] = {"excludes": ["embeddings"]} q["size"] = 2 queries.extend([{"index": Indices.CANDID_HELP_INDEX_ELSER}, q]) return queries def multi_search(queries: List[ElasticHitsResult]): results = [] with Elasticsearch( cloud_id=ONECANDID_QA["ES_CLOUD_ID"], api_key=ONECANDID_QA["ES_API_KEY"], verify_certs=False, request_timeout=60 * 3 ) as es: for query_group in es.msearch(body=queries).get("responses", []): for hit in query_group.get("hits", {}).get("hits", []): hit = ElasticHitsResult( index=hit["_index"], id=hit["_id"], score=hit["_score"], source=hit["_source"], inner_hits=hit.get("inner_hits", {}) ) results.append(hit) return results def get_query_results(search_text: str, indices: Optional[List[str]] = None): queries = query_builder(query=search_text, indices=indices) return multi_search(queries) def reranker(query_results: Iterable[ElasticHitsResult]) -> Iterator[ElasticHitsResult]: """Reranks Elasticsearch hits coming from multiple indicies/queries which may have scores on different scales. This will shuffle results Parameters ---------- query_results : Iterable[ElasticHitsResult] Yields ------ Iterator[ElasticHitsResult] """ results: List[ElasticHitsResult] = [] for _, data in groupby(query_results, key=lambda x: x.index): data = list(data) max_score = max(data, key=lambda x: x.score).score min_score = min(data, key=lambda x: x.score).score for d in data: d.score = (d.score - min_score) / (max_score - min_score + 1e-9) results.append(d) yield from sorted(results, key=lambda x: x.score, reverse=True) def get_results(user_input: str, indices: List[str]) -> List[ElasticHitsResult]: output = ["Search didn't return any Candid sources"] page_content=[] content = "Search didn't return any Candid sources" results = get_query_results(search_text=user_input, indices=indices) if results: output = get_reranked_results(results) for doc in output: page_content.append(doc.page_content) content = "/n/n".join(page_content) # for the tool we need to return a tuple for content_and_artifact type return content, output def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024) -> str: """Pads the relevant chunk of text with context before and after Parameters ---------- field_name : str a field with the long text that was chunked into pieces hit : ElasticHitsResult context_length : int, optional length of text to add before and after the chunk, by default 1024 Returns ------- str longer chunks stuffed together """ chunks_with_context = [] long_text = hit.source.get(f"{field_name}", "") inner_hits_field = f"embeddings.{field_name}.chunks" inner_hits = hit.inner_hits found_chunks = inner_hits.get(inner_hits_field, {}) if found_chunks: hits = found_chunks.get("hits", {}).get("hits", []) for h in hits: chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0] chunk = chunk[3:-3] # cutting the middle because we may have tokenizing artefacts there # Find the start and end indices of the chunk in the large text start_index = long_text.find(chunk) if start_index != -1: # Chunk is found end_index = start_index + len(chunk) pre_start_index = max(0, start_index - context_length) post_end_index = min(len(long_text), end_index + context_length) context = long_text[pre_start_index:post_end_index] chunks_with_context.append(context) chunks_with_context_txt = '\n\n'.join(chunks_with_context) return chunks_with_context_txt def process_hit(hit: ElasticHitsResult) -> Document | None: if "issuelab-elser" in hit.index: combined_item_description = hit.source.get("combined_item_description", "") # title inside description = hit.source.get("description", "") combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "") # we only need to process long texts chunks_with_context_txt = get_context("content", hit, context_length=12) doc = Document( page_content='\n\n'.join([ combined_item_description, combined_issuelab_findings, description, chunks_with_context_txt ]), metadata={ "source": "IssueLab", "source_id": hit.source["resource_id"], "url": hit.source.get("permalink", "") } ) elif "youtube" in hit.index: title = hit.source.get("title", "") # we only need to process long texts description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12) captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12) doc = Document( page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]), metadata={ "source": "Candid Youtube", "source_id": hit.source['video_id'], "url": f"https://www.youtube.com/watch?v={hit.source['video_id']}" } ) elif "candid-blog" in hit.index: excerpt = hit.source.get("excerpt", "") title = hit.source.get("title", "") # we only need to process long texts content_with_context_txt = get_context("content", hit, context_length=12) doc = Document( page_content='\n\n'.join([title, excerpt, content_with_context_txt]), metadata={ "source": "Candid Blog", "source_id": hit.source["id"], "url": hit.source["link"] } ) elif "candid-learning" in hit.index: title = hit.source.get("title", "") content_with_context_txt = get_context("content", hit, context_length=12) training_topics = hit.source.get("training_topics", "") staff_recommendations = hit.source.get("staff_recommendations", "") doc = Document( page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]), metadata={ "source": "Candid Learning", "source_id": hit.source["post_id"], "url": hit.source.get("url", "") } ) elif "candid-help" in hit.index: title = hit.source.get("title", "") content_with_context_txt = get_context("content", hit, context_length=12) combined_article_description = hit.source.get("combined_article_description", "") doc = Document( page_content='\n\n'.join([combined_article_description, content_with_context_txt]), metadata={ "source": "Candid Help", "source_id": hit.source["id"], "url": hit.source.get("link", "") } ) else: doc = None return doc def get_reranked_results(results: List[ElasticHitsResult]) -> List[Document]: output = [] for r in reranker(results): hit = process_hit(r) output.append(hit) return output def retriever_tool(indices: List[str]) -> Tool: # cannot use create_retriever_tool because it only provides content losing all metadata on the way # https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution return Tool( name="retrieve_social_sector_information", func=partial(get_results, indices=indices), description=( "Return additional information about social and philanthropic sector, " "including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid." ), args_schema=RetrieverInput, response_format="content_and_artifact" )