Spaces:
Running
Running
from typing import List, Tuple, Callable, Optional, Any | |
from functools import partial | |
import logging | |
from pydantic import BaseModel, Field | |
from langchain_core.language_models.llms import LLM | |
from langchain_core.documents import Document | |
from langchain_core.tools import Tool | |
from ask_candid.retrieval.elastic import get_query_results, get_reranked_results | |
from ask_candid.base.config.data import DataIndices | |
from ask_candid.agents.schema import AgentState | |
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s") | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
class RetrieverInput(BaseModel): | |
"""Input to the Elasticsearch retriever.""" | |
user_input: str = Field(description="query to look up in retriever") | |
def get_search_results( | |
user_input: str, | |
indices: List[DataIndices], | |
user_callback: Optional[Callable[[str], Any]] = None | |
) -> Tuple[str, List[Document]]: | |
"""End-to-end search and re-rank function. | |
Parameters | |
---------- | |
user_input : str | |
Search context string | |
indices : List[DataIndices] | |
Semantic index names to search over | |
user_callback : Optional[Callable[[str], Any]], optional | |
Optional UI callback to inform the user of apps states, by default None | |
Returns | |
------- | |
Tuple[str, List[Document]] | |
(concatenated text from search results, documents list) | |
""" | |
if user_callback is not None: | |
try: | |
user_callback("Searching for relevant information") | |
except Exception as ex: | |
logger.warning("User callback was passed in but failed: %s", ex) | |
output = ["Search didn't return any Candid sources"] | |
page_content = [] | |
content = "Search didn't return any Candid sources" | |
results = get_query_results(search_text=user_input, indices=indices) | |
if results: | |
output = get_reranked_results(results, search_text=user_input) | |
for doc in output: | |
page_content.append(doc.page_content) | |
content = "\n\n".join(page_content) | |
# for the tool we need to return a tuple for content_and_artifact type | |
return content, output | |
def retriever_tool( | |
indices: List[DataIndices], | |
user_callback: Optional[Callable[[str], Any]] = None | |
) -> Tool: | |
"""Tool component for use in conditional edge building for RAG execution graph. | |
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way | |
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution | |
Parameters | |
---------- | |
indices : List[DataIndices] | |
Semantic index names to search over | |
user_callback : Optional[Callable[[str], Any]], optional | |
Optional UI callback to inform the user of apps states, by default None | |
Returns | |
------- | |
Tool | |
""" | |
return Tool( | |
name="retrieve_social_sector_information", | |
func=partial(get_search_results, indices=indices, user_callback=user_callback), | |
description=( | |
"Return additional information about social and philanthropic sector, " | |
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid." | |
), | |
args_schema=RetrieverInput, | |
response_format="content_and_artifact" | |
) | |
def search_agent(state: AgentState, llm: LLM, tools: List[Tool]) -> AgentState: | |
"""Invokes the agent model to generate a response based on the current state. Given | |
the question, it will decide to retrieve using the retriever tool, or simply end. | |
Parameters | |
---------- | |
state : _type_ | |
The current state | |
llm : LLM | |
tools : List[Tool] | |
Returns | |
------- | |
AgentState | |
The updated state with the agent response appended to messages | |
""" | |
logger.info("---SEARCH AGENT---") | |
messages = state["messages"] | |
question = messages[-1].content | |
model = llm.bind_tools(tools) | |
response = model.invoke(messages) | |
# return a list, because this will get added to the existing list | |
return {"messages": [response], "user_input": question} | |