Spaces:
Running
Running
File size: 7,168 Bytes
cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 bea5044 c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 cc80c3d c751e97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
from typing import List, Optional, Callable, Any
from functools import partial
import logging
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models.llms import LLM
from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.graph.state import StateGraph
from langgraph.constants import START, END
from ask_candid.tools.recommendation import (
detect_intent_with_llm,
determine_context,
make_recommendation
)
from ask_candid.tools.question_reformulation import reformulate_question_using_history
from ask_candid.tools.org_seach import has_org_name, insert_org_link
from ask_candid.tools.search import search_agent, retriever_tool
from ask_candid.agents.schema import AgentState
from ask_candid.base.config.data import DataIndices
from ask_candid.utils import html_format_docs_chat
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def generate_with_context(
state: AgentState,
llm: LLM,
user_callback: Optional[Callable[[str], Any]] = None
) -> AgentState:
"""Generate answer.
Parameters
----------
state : AgentState
The current state
llm : LLM
user_callback : Optional[Callable[[str], Any]], optional
Optional UI callback to inform the user of apps states, by default None
Returns
-------
AgentState
The updated state with the agent response appended to messages
"""
logger.info("---GENERATE ANSWER---")
if user_callback is not None:
try:
user_callback("Writing a response...")
except Exception as ex:
logger.warning("User callback was passed in but failed: %s", ex)
messages = state["messages"]
question = state["user_input"]
last_message = messages[-1]
sources_str = last_message.content
sources_list = last_message.artifact
sources_html = html_format_docs_chat(sources_list)
if sources_list:
logger.info("---ADD SOURCES---")
state["messages"].append(BaseMessage(content=sources_html, type="HTML"))
# Prompt
qa_system_prompt = """
You are an assistant for question-answering tasks in the social and philanthropic sector. \n
Use the following pieces of retrieved context to answer the question at the end. \n
If you don't know the answer, just say that you don't know. \n
Keep the response professional, friendly, and as concise as possible. \n
Question: {question}
Context: {context}
Answer:
"""
qa_prompt = ChatPromptTemplate([
("system", qa_system_prompt),
("human", question),
])
rag_chain = qa_prompt | llm | StrOutputParser()
response = rag_chain.invoke({"context": sources_str, "question": question})
return {"messages": [AIMessage(content=response)], "user_input": question}
def add_recommendations_pipeline_(
G: StateGraph,
llm: LLM,
reformulation_node_name: str = "reformulate",
search_node_name: str = "search_agent"
) -> None:
"""Adds execution sub-graph for recommendation engine flow. Graph changes are in-place.
Parameters
----------
G : StateGraph
Execution graph
reformulation_node_name : str, optional
Name of the node which reforumates input queries, by default "reformulate"
search_node_name : str, optional
Name of the node which executes document search + retrieval, by default "search_agent"
"""
# Nodes for recommendation functionalities
G.add_node(node="detect_intent_with_llm", action=partial(detect_intent_with_llm, llm=llm))
G.add_node(node="determine_context", action=determine_context)
G.add_node(node="make_recommendation", action=make_recommendation)
# Check for recommendation query first
# Execute until reaching END if user asks for recommendation
G.add_edge(start_key=reformulation_node_name, end_key="detect_intent_with_llm")
G.add_conditional_edges(
source="detect_intent_with_llm",
path=lambda state: "determine_context" if state["intent"] in ["rfp", "funder"] else search_node_name,
path_map={
"determine_context": "determine_context",
search_node_name: search_node_name
},
)
G.add_edge(start_key="determine_context", end_key="make_recommendation")
G.add_edge(start_key="make_recommendation", end_key=END)
def build_compute_graph(
llm: LLM,
indices: List[DataIndices],
enable_recommendations: bool = False,
user_callback: Optional[Callable[[str], Any]] = None
) -> StateGraph:
"""Execution graph builder, the output is the execution flow for an interaction with the assistant.
Parameters
----------
llm : LLM
indices : List[DataIndices]
Semantic index names to search over
enable_recommendations : bool, optional
Set to `True` to allow the flow to generate recommendations based on context, by default False
user_callback : Optional[Callable[[str], Any]], optional
Optional UI callback to inform the user of apps states, by default None
Returns
-------
StateGraph
Execution graph
"""
candid_retriever_tool = retriever_tool(indices=indices, user_callback=user_callback)
retrieve = ToolNode([candid_retriever_tool])
tools = [candid_retriever_tool]
G = StateGraph(AgentState)
G.add_node(
node="reformulate",
action=partial(reformulate_question_using_history, llm=llm, focus_on_recommendations=enable_recommendations)
)
G.add_node(node="search_agent", action=partial(search_agent, llm=llm, tools=tools))
G.add_node(node="retrieve", action=retrieve)
G.add_node(
node="generate_with_context",
action=partial(generate_with_context, llm=llm, user_callback=user_callback)
)
G.add_node(node="has_org_name", action=partial(has_org_name, llm=llm, user_callback=user_callback))
G.add_node(node="insert_org_link", action=insert_org_link)
if enable_recommendations:
add_recommendations_pipeline_(
G, llm=llm,
reformulation_node_name="reformulate",
search_node_name="search_agent"
)
else:
G.add_edge(start_key="reformulate", end_key="search_agent")
G.add_edge(start_key=START, end_key="reformulate")
G.add_conditional_edges(
source="search_agent",
path=tools_condition,
path_map={
"tools": "retrieve",
END: "has_org_name",
},
)
G.add_edge(start_key="retrieve", end_key="generate_with_context")
G.add_edge(start_key="generate_with_context", end_key="has_org_name")
G.add_conditional_edges(
source="has_org_name",
path=lambda x: x["next"], # Now we're accessing the 'next' key from the dict
path_map={
"insert_org_link": "insert_org_link",
END: END
},
)
G.add_edge(start_key="insert_org_link", end_key=END)
return G
|