File size: 4,089 Bytes
cc80c3d
 
c751e97
 
cc80c3d
c751e97
cc80c3d
c751e97
 
cc80c3d
 
c751e97
 
 
 
 
 
 
cc80c3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c751e97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import List, Tuple, Callable, Optional, Any
from functools import partial
import logging

from pydantic import BaseModel, Field
from langchain_core.language_models.llms import LLM
from langchain_core.documents import Document
from langchain_core.tools import Tool

from ask_candid.retrieval.elastic import get_query_results, get_reranked_results
from ask_candid.base.config.data import DataIndices
from ask_candid.agents.schema import AgentState

logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class RetrieverInput(BaseModel):
    """Input to the Elasticsearch retriever."""
    user_input: str = Field(description="query to look up in retriever")


def get_search_results(
    user_input: str,
    indices: List[DataIndices],
    user_callback: Optional[Callable[[str], Any]] = None
) -> Tuple[str, List[Document]]:
    """End-to-end search and re-rank function.

    Parameters
    ----------
    user_input : str
        Search context string
    indices : List[DataIndices]
        Semantic index names to search over
    user_callback : Optional[Callable[[str], Any]], optional
        Optional UI callback to inform the user of apps states, by default None

    Returns
    -------
    Tuple[str, List[Document]]
        (concatenated text from search results, documents list)
    """

    if user_callback is not None:
        try:
            user_callback("Searching for relevant information")
        except Exception as ex:
            logger.warning("User callback was passed in but failed: %s", ex)

    output = ["Search didn't return any Candid sources"]
    page_content = []
    content = "Search didn't return any Candid sources"
    results = get_query_results(search_text=user_input, indices=indices)
    if results:
        output = get_reranked_results(results, search_text=user_input)
        for doc in output:
            page_content.append(doc.page_content)
        content = "\n\n".join(page_content)

    # for the tool we need to return a tuple for content_and_artifact type
    return content, output


def retriever_tool(
    indices: List[DataIndices],
    user_callback: Optional[Callable[[str], Any]] = None
) -> Tool:
    """Tool component for use in conditional edge building for RAG execution graph.
    Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
    https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution

    Parameters
    ----------
    indices : List[DataIndices]
        Semantic index names to search over
    user_callback : Optional[Callable[[str], Any]], optional
        Optional UI callback to inform the user of apps states, by default None

    Returns
    -------
    Tool
    """

    return Tool(
        name="retrieve_social_sector_information",
        func=partial(get_search_results, indices=indices, user_callback=user_callback),
        description=(
            "Return additional information about social and philanthropic sector, "
            "including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
        ),
        args_schema=RetrieverInput,
        response_format="content_and_artifact"
    )


def search_agent(state: AgentState, llm: LLM, tools: List[Tool]) -> AgentState:
    """Invokes the agent model to generate a response based on the current state. Given
    the question, it will decide to retrieve using the retriever tool, or simply end.

    Parameters
    ----------
    state : _type_
        The current state
    llm : LLM
    tools : List[Tool]

    Returns
    -------
    AgentState
        The updated state with the agent response appended to messages
    """

    logger.info("---SEARCH AGENT---")
    messages = state["messages"]
    question = messages[-1].content

    model = llm.bind_tools(tools)
    response = model.invoke(messages)
    # return a list, because this will get added to the existing list
    return {"messages": [response], "user_input": question}