Spaces:
Running
Running
from typing import List, Optional, Dict, Any, TypedDict, Annotated, Sequence | |
from functools import partial | |
import os | |
import gradio as gr | |
from langchain_core.messages import AIMessage, BaseMessage | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.language_models.llms import LLM | |
from langgraph.prebuilt import tools_condition, ToolNode | |
from langgraph.checkpoint.memory import MemorySaver | |
from langgraph.graph.state import StateGraph | |
from langgraph.graph.message import add_messages | |
from langgraph.constants import START, END | |
try: | |
from utils import html_format_docs_chat, get_session_id | |
from tools.question_reformulation import reformulate_question_using_history | |
from tools.org_seach import ( | |
extract_org_links_from_chatbot, | |
embed_org_links_in_text, | |
generate_org_link_dict, | |
) | |
from retrieval.elastic import retriever_tool | |
except ImportError: | |
from .utils import html_format_docs_chat, get_session_id | |
from .tools.question_reformulation import reformulate_question_using_history | |
from .tools.org_seach import ( | |
extract_org_links_from_chatbot, | |
embed_org_links_in_text, | |
generate_org_link_dict, | |
) | |
from .retrieval.elastic import retriever_tool | |
ROOT = os.path.dirname(os.path.abspath(__file__)) | |
# TODO https://www.metadocs.co/2024/08/29/simple-domain-specific-corrective-rag-with-langchain-and-langgraph/ | |
class AgentState(TypedDict): | |
# The add_messages function defines how an update should be processed | |
# Default is to replace. add_messages says "append" | |
messages: Annotated[Sequence[BaseMessage], add_messages] | |
user_input: str | |
org_dict: Dict | |
def search_agent(state, llm: LLM, tools) -> AgentState: | |
"""Invokes the agent model to generate a response based on the current state. Given | |
the question, it will decide to retrieve using the retriever tool, or simply end. | |
Parameters | |
---------- | |
state : _type_ | |
The current state | |
llm : LLM | |
tools : _type_ | |
_description_ | |
Returns | |
------- | |
AgentState | |
The updated state with the agent response appended to messages | |
""" | |
print("---SEARCH AGENT---") | |
messages = state["messages"] | |
question = messages[-1].content | |
model = llm.bind_tools(tools) | |
response = model.invoke(messages) | |
# return a list, because this will get added to the existing list | |
return {"messages": [response], "user_input": question} | |
def generate_with_context(state, llm: LLM) -> AgentState: | |
"""Generate answer. | |
Parameters | |
---------- | |
state : _type_ | |
The current state | |
llm : LLM | |
tools : _type_ | |
_description_ | |
Returns | |
------- | |
AgentState | |
The updated state with the agent response appended to messages | |
""" | |
print("---GENERATE ANSWER---") | |
messages = state["messages"] | |
question = state["user_input"] | |
last_message = messages[-1] | |
sources_str = last_message.content | |
sources_list = last_message.artifact # cannot use directly as list of Documents | |
# converting to html string | |
sources_html = html_format_docs_chat(sources_list) | |
if sources_list: | |
print("---ADD SOURCES---") | |
state["messages"].append(BaseMessage(content=sources_html, type="HTML")) | |
# Prompt | |
qa_system_prompt = """ | |
You are an assistant for question-answering tasks in the social and philanthropic sector. \n | |
Use the following pieces of retrieved context to answer the question at the end. \n | |
If you don't know the answer, just say that you don't know. \n | |
Keep the response professional, friendly, and as concise as possible. \n | |
Question: {question} | |
Context: {context} | |
Answer: | |
""" | |
qa_prompt = ChatPromptTemplate( | |
[ | |
("system", qa_system_prompt), | |
("human", question), | |
] | |
) | |
rag_chain = qa_prompt | llm | StrOutputParser() | |
response = rag_chain.invoke({"context": sources_str, "question": question}) | |
# couldn't figure out why returning usual "response" was seen as HumanMessage | |
return {"messages": [AIMessage(content=response)], "user_input": question} | |
def has_org_name(state: AgentState) -> AgentState: | |
""" | |
Processes the latest message to extract organization links and determine the next step. | |
Args: | |
state (AgentState): The current state of the agent, including a list of messages. | |
Returns: | |
dict: A dictionary with the next agent action and, if available, a dictionary of organization links. | |
""" | |
print("---HAS ORG NAMES?---") | |
messages = state["messages"] | |
last_message = messages[-1].content | |
output_list = extract_org_links_from_chatbot(last_message) | |
link_dict = generate_org_link_dict(output_list) if output_list else {} | |
if link_dict: | |
print("---FOUND ORG NAMES---") | |
return {"next": "insert_org_link", "org_dict": link_dict} | |
print("---NO ORG NAMES FOUND---") | |
return {"next": END, "messages": messages} | |
def insert_org_link(state: AgentState) -> AgentState: | |
""" | |
Embeds organization links in the latest message content and returns it as an AI message. | |
Args: | |
state (dict): The current state, including the organization links and latest message. | |
Returns: | |
dict: A dictionary with the updated message content as an AIMessage. | |
""" | |
print("---INSERT ORG LINKS---") | |
messages = state["messages"] | |
last_message = messages[-1].content | |
messages.pop(-1) # Deleting the original message because we will append the same one but with links | |
link_dict = state["org_dict"] | |
last_message = embed_org_links_in_text(last_message, link_dict) | |
return {"messages": [AIMessage(content=last_message)]} | |
def build_compute_graph(llm: LLM, indices: List[str]) -> StateGraph: | |
candid_retriever_tool = retriever_tool(indices=indices) | |
retrieve = ToolNode([candid_retriever_tool]) | |
tools = [candid_retriever_tool] | |
G = StateGraph(AgentState) | |
# Add nodes | |
G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm)) | |
G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools)) | |
G.add_node("retrieve", retrieve) | |
G.add_node("generate_with_context", partial(generate_with_context, llm=llm)) | |
G.add_node("has_org_name", has_org_name) | |
G.add_node("insert_org_link", insert_org_link) | |
# Add edges | |
G.add_edge(START, "reformulate") | |
G.add_edge("reformulate", "search_agent") | |
# Conditional edges from search_agent | |
G.add_conditional_edges( | |
source="search_agent", | |
path=tools_condition, # TODO just a conditional edge here? | |
path_map={ | |
"tools": "retrieve", | |
"__end__": "has_org_name", | |
}, | |
) | |
G.add_edge("retrieve", "generate_with_context") | |
# Add edges | |
G.add_edge("generate_with_context", "has_org_name") | |
# Use add_conditional_edges for has_org_name | |
G.add_conditional_edges( | |
"has_org_name", | |
lambda x: x["next"], # Now we're accessing the 'next' key from the dict | |
{"insert_org_link": "insert_org_link", END: END}, | |
) | |
G.add_edge("insert_org_link", END) | |
return G | |
def run_chat( | |
thread_id: str, | |
user_input: Dict[str, Any], | |
chatbot: List[Dict], | |
llm: LLM, | |
indices: Optional[List[str]] = None, | |
): | |
# https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag/#graph | |
chatbot.append({"role": "user", "content": user_input["text"]}) | |
inputs = {"messages": chatbot} | |
# thread_id can be an email https://github.com/yurisasc/memory-enhanced-ai-assistant/blob/main/assistant.py | |
thread_id = get_session_id(thread_id) | |
config = {"configurable": {"thread_id": thread_id}} | |
workflow = build_compute_graph(llm=llm, indices=indices) | |
memory = MemorySaver() # TODO: don't use for Prod | |
graph = workflow.compile(checkpointer=memory) | |
response = graph.invoke(inputs, config=config) | |
messages = response["messages"] | |
last_message = messages[-1] | |
ai_answer = last_message.content | |
sources_html = "" | |
for message in messages[-2:]: | |
if message.type == "HTML": | |
sources_html = message.content | |
chatbot.append({"role": "assistant", "content": ai_answer}) | |
if sources_html: | |
chatbot.append( | |
{ | |
"role": "assistant", | |
"content": sources_html, | |
"metadata": {"title": "Sources HTML"}, | |
} | |
) | |
return gr.MultimodalTextbox(value=None, interactive=True), chatbot, thread_id | |