from typing import List, Optional, Dict, Any, TypedDict, Annotated, Sequence from functools import partial import logging import os import gradio as gr from langchain_core.messages import AIMessage, BaseMessage from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.language_models.llms import LLM from langgraph.prebuilt import tools_condition, ToolNode from langgraph.checkpoint.memory import MemorySaver from langgraph.graph.state import StateGraph from langgraph.graph.message import add_messages from langgraph.constants import START, END from ask_candid.tools.org_seach import extract_org_links_from_chatbot, embed_org_links_in_text, generate_org_link_dict from ask_candid.tools.question_reformulation import reformulate_question_using_history from ask_candid.utils import html_format_docs_chat, get_session_id from ask_candid.retrieval.elastic import retriever_tool ROOT = os.path.dirname(os.path.abspath(__file__)) logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s") logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # TODO https://www.metadocs.co/2024/08/29/simple-domain-specific-corrective-rag-with-langchain-and-langgraph/ class AgentState(TypedDict): # The add_messages function defines how an update should be processed # Default is to replace. add_messages says "append" messages: Annotated[Sequence[BaseMessage], add_messages] user_input: str org_dict: Dict def search_agent(state, llm: LLM, tools) -> AgentState: """Invokes the agent model to generate a response based on the current state. Given the question, it will decide to retrieve using the retriever tool, or simply end. Parameters ---------- state : _type_ The current state llm : LLM tools : _type_ _description_ Returns ------- AgentState The updated state with the agent response appended to messages """ logger.info("---SEARCH AGENT---") messages = state["messages"] question = messages[-1].content model = llm.bind_tools(tools) response = model.invoke(messages) # return a list, because this will get added to the existing list return {"messages": [response], "user_input": question} def generate_with_context(state, llm: LLM) -> AgentState: """Generate answer. Parameters ---------- state : _type_ The current state llm : LLM tools : _type_ _description_ Returns ------- AgentState The updated state with the agent response appended to messages """ logger.info("---GENERATE ANSWER---") messages = state["messages"] question = state["user_input"] last_message = messages[-1] sources_str = last_message.content sources_list = last_message.artifact # cannot use directly as list of Documents # converting to html string sources_html = html_format_docs_chat(sources_list) if sources_list: logger.info("---ADD SOURCES---") state["messages"].append(BaseMessage(content=sources_html, type="HTML")) # Prompt qa_system_prompt = """ You are an assistant for question-answering tasks in the social and philanthropic sector. \n Use the following pieces of retrieved context to answer the question at the end. \n If you don't know the answer, just say that you don't know. \n Keep the response professional, friendly, and as concise as possible. \n Question: {question} Context: {context} Answer: """ qa_prompt = ChatPromptTemplate( [ ("system", qa_system_prompt), ("human", question), ] ) rag_chain = qa_prompt | llm | StrOutputParser() response = rag_chain.invoke({"context": sources_str, "question": question}) # couldn't figure out why returning usual "response" was seen as HumanMessage return {"messages": [AIMessage(content=response)], "user_input": question} def has_org_name(state: AgentState) -> AgentState: """ Processes the latest message to extract organization links and determine the next step. Args: state (AgentState): The current state of the agent, including a list of messages. Returns: dict: A dictionary with the next agent action and, if available, a dictionary of organization links. """ logger.info("---HAS ORG NAMES?---") messages = state["messages"] last_message = messages[-1].content output_list = extract_org_links_from_chatbot(last_message) link_dict = generate_org_link_dict(output_list) if output_list else {} if link_dict: logger.info("---FOUND ORG NAMES---") return {"next": "insert_org_link", "org_dict": link_dict} logger.info("---NO ORG NAMES FOUND---") return {"next": END, "messages": messages} def insert_org_link(state: AgentState) -> AgentState: """ Embeds organization links in the latest message content and returns it as an AI message. Args: state (dict): The current state, including the organization links and latest message. Returns: dict: A dictionary with the updated message content as an AIMessage. """ logger.info("---INSERT ORG LINKS---") messages = state["messages"] last_message = messages[-1].content messages.pop(-1) # Deleting the original message because we will append the same one but with links link_dict = state["org_dict"] last_message = embed_org_links_in_text(last_message, link_dict) return {"messages": [AIMessage(content=last_message)]} def build_compute_graph(llm: LLM, indices: List[str]) -> StateGraph: candid_retriever_tool = retriever_tool(indices=indices) retrieve = ToolNode([candid_retriever_tool]) tools = [candid_retriever_tool] G = StateGraph(AgentState) # Add nodes G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm)) G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools)) G.add_node("retrieve", retrieve) G.add_node("generate_with_context", partial(generate_with_context, llm=llm)) G.add_node("has_org_name", has_org_name) G.add_node("insert_org_link", insert_org_link) # Add edges G.add_edge(START, "reformulate") G.add_edge("reformulate", "search_agent") # Conditional edges from search_agent G.add_conditional_edges( source="search_agent", path=tools_condition, path_map={ "tools": "retrieve", END: "has_org_name", }, ) G.add_edge("retrieve", "generate_with_context") # Add edges G.add_edge("generate_with_context", "has_org_name") # Use add_conditional_edges for has_org_name G.add_conditional_edges( "has_org_name", lambda x: x["next"], # Now we're accessing the 'next' key from the dict {"insert_org_link": "insert_org_link", END: END}, ) G.add_edge("insert_org_link", END) return G def run_chat( thread_id: str, user_input: Dict[str, Any], history: List[Dict], llm: LLM, indices: Optional[List[str]] = None, ): # https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag/#graph if len(history) == 0: history.append({ "role": "system", "content": ( "You are a Candid subject matter expert on the social sector and philanthropy. " "You should address the user's queries and stay on topic." ) }) history.append({"role": "user", "content": user_input["text"]}) inputs = {"messages": history} # thread_id can be an email https://github.com/yurisasc/memory-enhanced-ai-assistant/blob/main/assistant.py thread_id = get_session_id(thread_id) config = {"configurable": {"thread_id": thread_id}} workflow = build_compute_graph(llm=llm, indices=indices) memory = MemorySaver() # TODO: don't use for Prod graph = workflow.compile(checkpointer=memory) response = graph.invoke(inputs, config=config) messages = response["messages"] last_message = messages[-1] ai_answer = last_message.content sources_html = "" for message in messages[-2:]: if message.type == "HTML": sources_html = message.content history.append({"role": "assistant", "content": ai_answer}) if sources_html: history.append( { "role": "assistant", "content": sources_html, "metadata": {"title": "Sources HTML"}, } ) return gr.MultimodalTextbox(value=None, interactive=True), history, thread_id