ask-candid / chat.py
brainsqueeze's picture
Initial commit
92feab2 verified
raw
history blame
8.62 kB
from typing import List, Optional, Dict, Any, TypedDict, Annotated, Sequence
from functools import partial
import os
import gradio as gr
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models.llms import LLM
from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.state import StateGraph
from langgraph.graph.message import add_messages
from langgraph.constants import START, END
try:
from utils import html_format_docs_chat, get_session_id
from tools.question_reformulation import reformulate_question_using_history
from tools.org_seach import (
extract_org_links_from_chatbot,
embed_org_links_in_text,
generate_org_link_dict,
)
from retrieval.elastic import retriever_tool
except ImportError:
from .utils import html_format_docs_chat, get_session_id
from .tools.question_reformulation import reformulate_question_using_history
from .tools.org_seach import (
extract_org_links_from_chatbot,
embed_org_links_in_text,
generate_org_link_dict,
)
from .retrieval.elastic import retriever_tool
ROOT = os.path.dirname(os.path.abspath(__file__))
# TODO https://www.metadocs.co/2024/08/29/simple-domain-specific-corrective-rag-with-langchain-and-langgraph/
class AgentState(TypedDict):
# The add_messages function defines how an update should be processed
# Default is to replace. add_messages says "append"
messages: Annotated[Sequence[BaseMessage], add_messages]
user_input: str
org_dict: Dict
def search_agent(state, llm: LLM, tools) -> AgentState:
"""Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.
Parameters
----------
state : _type_
The current state
llm : LLM
tools : _type_
_description_
Returns
-------
AgentState
The updated state with the agent response appended to messages
"""
print("---SEARCH AGENT---")
messages = state["messages"]
question = messages[-1].content
model = llm.bind_tools(tools)
response = model.invoke(messages)
# return a list, because this will get added to the existing list
return {"messages": [response], "user_input": question}
def generate_with_context(state, llm: LLM) -> AgentState:
"""Generate answer.
Parameters
----------
state : _type_
The current state
llm : LLM
tools : _type_
_description_
Returns
-------
AgentState
The updated state with the agent response appended to messages
"""
print("---GENERATE ANSWER---")
messages = state["messages"]
question = state["user_input"]
last_message = messages[-1]
sources_str = last_message.content
sources_list = last_message.artifact # cannot use directly as list of Documents
# converting to html string
sources_html = html_format_docs_chat(sources_list)
if sources_list:
print("---ADD SOURCES---")
state["messages"].append(BaseMessage(content=sources_html, type="HTML"))
# Prompt
qa_system_prompt = """
You are an assistant for question-answering tasks in the social and philanthropic sector. \n
Use the following pieces of retrieved context to answer the question at the end. \n
If you don't know the answer, just say that you don't know. \n
Keep the response professional, friendly, and as concise as possible. \n
Question: {question}
Context: {context}
Answer:
"""
qa_prompt = ChatPromptTemplate(
[
("system", qa_system_prompt),
("human", question),
]
)
rag_chain = qa_prompt | llm | StrOutputParser()
response = rag_chain.invoke({"context": sources_str, "question": question})
# couldn't figure out why returning usual "response" was seen as HumanMessage
return {"messages": [AIMessage(content=response)], "user_input": question}
def has_org_name(state: AgentState) -> AgentState:
"""
Processes the latest message to extract organization links and determine the next step.
Args:
state (AgentState): The current state of the agent, including a list of messages.
Returns:
dict: A dictionary with the next agent action and, if available, a dictionary of organization links.
"""
print("---HAS ORG NAMES?---")
messages = state["messages"]
last_message = messages[-1].content
output_list = extract_org_links_from_chatbot(last_message)
link_dict = generate_org_link_dict(output_list) if output_list else {}
if link_dict:
print("---FOUND ORG NAMES---")
return {"next": "insert_org_link", "org_dict": link_dict}
print("---NO ORG NAMES FOUND---")
return {"next": END, "messages": messages}
def insert_org_link(state: AgentState) -> AgentState:
"""
Embeds organization links in the latest message content and returns it as an AI message.
Args:
state (dict): The current state, including the organization links and latest message.
Returns:
dict: A dictionary with the updated message content as an AIMessage.
"""
print("---INSERT ORG LINKS---")
messages = state["messages"]
last_message = messages[-1].content
messages.pop(-1) # Deleting the original message because we will append the same one but with links
link_dict = state["org_dict"]
last_message = embed_org_links_in_text(last_message, link_dict)
return {"messages": [AIMessage(content=last_message)]}
def build_compute_graph(llm: LLM, indices: List[str]) -> StateGraph:
candid_retriever_tool = retriever_tool(indices=indices)
retrieve = ToolNode([candid_retriever_tool])
tools = [candid_retriever_tool]
G = StateGraph(AgentState)
# Add nodes
G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm))
G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools))
G.add_node("retrieve", retrieve)
G.add_node("generate_with_context", partial(generate_with_context, llm=llm))
G.add_node("has_org_name", has_org_name)
G.add_node("insert_org_link", insert_org_link)
# Add edges
G.add_edge(START, "reformulate")
G.add_edge("reformulate", "search_agent")
# Conditional edges from search_agent
G.add_conditional_edges(
source="search_agent",
path=tools_condition, # TODO just a conditional edge here?
path_map={
"tools": "retrieve",
"__end__": "has_org_name",
},
)
G.add_edge("retrieve", "generate_with_context")
# Add edges
G.add_edge("generate_with_context", "has_org_name")
# Use add_conditional_edges for has_org_name
G.add_conditional_edges(
"has_org_name",
lambda x: x["next"], # Now we're accessing the 'next' key from the dict
{"insert_org_link": "insert_org_link", END: END},
)
G.add_edge("insert_org_link", END)
return G
def run_chat(
thread_id: str,
user_input: Dict[str, Any],
chatbot: List[Dict],
llm: LLM,
indices: Optional[List[str]] = None,
):
# https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag/#graph
chatbot.append({"role": "user", "content": user_input["text"]})
inputs = {"messages": chatbot}
# thread_id can be an email https://github.com/yurisasc/memory-enhanced-ai-assistant/blob/main/assistant.py
thread_id = get_session_id(thread_id)
config = {"configurable": {"thread_id": thread_id}}
workflow = build_compute_graph(llm=llm, indices=indices)
memory = MemorySaver() # TODO: don't use for Prod
graph = workflow.compile(checkpointer=memory)
response = graph.invoke(inputs, config=config)
messages = response["messages"]
last_message = messages[-1]
ai_answer = last_message.content
sources_html = ""
for message in messages[-2:]:
if message.type == "HTML":
sources_html = message.content
chatbot.append({"role": "assistant", "content": ai_answer})
if sources_html:
chatbot.append(
{
"role": "assistant",
"content": sources_html,
"metadata": {"title": "Sources HTML"},
}
)
return gr.MultimodalTextbox(value=None, interactive=True), chatbot, thread_id