Spaces:
Running
Running
File size: 4,089 Bytes
cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from typing import List, Tuple, Callable, Optional, Any
from functools import partial
import logging
from pydantic import BaseModel, Field
from langchain_core.language_models.llms import LLM
from langchain_core.documents import Document
from langchain_core.tools import Tool
from ask_candid.retrieval.elastic import get_query_results, get_reranked_results
from ask_candid.base.config.data import DataIndices
from ask_candid.agents.schema import AgentState
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class RetrieverInput(BaseModel):
"""Input to the Elasticsearch retriever."""
user_input: str = Field(description="query to look up in retriever")
def get_search_results(
user_input: str,
indices: List[DataIndices],
user_callback: Optional[Callable[[str], Any]] = None
) -> Tuple[str, List[Document]]:
"""End-to-end search and re-rank function.
Parameters
----------
user_input : str
Search context string
indices : List[DataIndices]
Semantic index names to search over
user_callback : Optional[Callable[[str], Any]], optional
Optional UI callback to inform the user of apps states, by default None
Returns
-------
Tuple[str, List[Document]]
(concatenated text from search results, documents list)
"""
if user_callback is not None:
try:
user_callback("Searching for relevant information")
except Exception as ex:
logger.warning("User callback was passed in but failed: %s", ex)
output = ["Search didn't return any Candid sources"]
page_content = []
content = "Search didn't return any Candid sources"
results = get_query_results(search_text=user_input, indices=indices)
if results:
output = get_reranked_results(results, search_text=user_input)
for doc in output:
page_content.append(doc.page_content)
content = "\n\n".join(page_content)
# for the tool we need to return a tuple for content_and_artifact type
return content, output
def retriever_tool(
indices: List[DataIndices],
user_callback: Optional[Callable[[str], Any]] = None
) -> Tool:
"""Tool component for use in conditional edge building for RAG execution graph.
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
Parameters
----------
indices : List[DataIndices]
Semantic index names to search over
user_callback : Optional[Callable[[str], Any]], optional
Optional UI callback to inform the user of apps states, by default None
Returns
-------
Tool
"""
return Tool(
name="retrieve_social_sector_information",
func=partial(get_search_results, indices=indices, user_callback=user_callback),
description=(
"Return additional information about social and philanthropic sector, "
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
),
args_schema=RetrieverInput,
response_format="content_and_artifact"
)
def search_agent(state: AgentState, llm: LLM, tools: List[Tool]) -> AgentState:
"""Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.
Parameters
----------
state : _type_
The current state
llm : LLM
tools : List[Tool]
Returns
-------
AgentState
The updated state with the agent response appended to messages
"""
logger.info("---SEARCH AGENT---")
messages = state["messages"]
question = messages[-1].content
model = llm.bind_tools(tools)
response = model.invoke(messages)
# return a list, because this will get added to the existing list
return {"messages": [response], "user_input": question}
|