brainsqueeze's picture
UI callbacks and style changes
cc80c3d verified
from typing import List, Tuple, Callable, Optional, Any
from functools import partial
import logging
from pydantic import BaseModel, Field
from langchain_core.language_models.llms import LLM
from langchain_core.documents import Document
from langchain_core.tools import Tool
from ask_candid.retrieval.elastic import get_query_results, get_reranked_results
from ask_candid.base.config.data import DataIndices
from ask_candid.agents.schema import AgentState
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class RetrieverInput(BaseModel):
"""Input to the Elasticsearch retriever."""
user_input: str = Field(description="query to look up in retriever")
def get_search_results(
user_input: str,
indices: List[DataIndices],
user_callback: Optional[Callable[[str], Any]] = None
) -> Tuple[str, List[Document]]:
"""End-to-end search and re-rank function.
Parameters
----------
user_input : str
Search context string
indices : List[DataIndices]
Semantic index names to search over
user_callback : Optional[Callable[[str], Any]], optional
Optional UI callback to inform the user of apps states, by default None
Returns
-------
Tuple[str, List[Document]]
(concatenated text from search results, documents list)
"""
if user_callback is not None:
try:
user_callback("Searching for relevant information")
except Exception as ex:
logger.warning("User callback was passed in but failed: %s", ex)
output = ["Search didn't return any Candid sources"]
page_content = []
content = "Search didn't return any Candid sources"
results = get_query_results(search_text=user_input, indices=indices)
if results:
output = get_reranked_results(results, search_text=user_input)
for doc in output:
page_content.append(doc.page_content)
content = "\n\n".join(page_content)
# for the tool we need to return a tuple for content_and_artifact type
return content, output
def retriever_tool(
indices: List[DataIndices],
user_callback: Optional[Callable[[str], Any]] = None
) -> Tool:
"""Tool component for use in conditional edge building for RAG execution graph.
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
Parameters
----------
indices : List[DataIndices]
Semantic index names to search over
user_callback : Optional[Callable[[str], Any]], optional
Optional UI callback to inform the user of apps states, by default None
Returns
-------
Tool
"""
return Tool(
name="retrieve_social_sector_information",
func=partial(get_search_results, indices=indices, user_callback=user_callback),
description=(
"Return additional information about social and philanthropic sector, "
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
),
args_schema=RetrieverInput,
response_format="content_and_artifact"
)
def search_agent(state: AgentState, llm: LLM, tools: List[Tool]) -> AgentState:
"""Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.
Parameters
----------
state : _type_
The current state
llm : LLM
tools : List[Tool]
Returns
-------
AgentState
The updated state with the agent response appended to messages
"""
logger.info("---SEARCH AGENT---")
messages = state["messages"]
question = messages[-1].content
model = llm.bind_tools(tools)
response = model.invoke(messages)
# return a list, because this will get added to the existing list
return {"messages": [response], "user_input": question}