ask-candid / chat_v2.py
brainsqueeze's picture
shorten page
aed4c76 verified
from typing import TypedDict, Any
from collections.abc import Iterator, AsyncIterator
import os
import gradio as gr
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import create_react_agent
from langchain_aws import ChatBedrock
import boto3
from ask_candid.tools.org_search import OrganizationIdentifier, find_mentioned_organizations
from ask_candid.tools.search import search_candid_knowledge_base
from ask_candid.tools.general import get_current_day
from ask_candid.utils import html_format_docs_chat
from ask_candid.base.config.constants import START_SYSTEM_PROMPT
from ask_candid.base.config.models import Name2Endpoint
from ask_candid.chat import convert_history_for_graph_agent, format_tool_call, format_tool_response
try:
from feedback import FeedbackApi
ROOT = "."
except ImportError:
from demos.feedback import FeedbackApi
ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..")
BOT_LOGO = os.path.join(ROOT, "static", "candid_logo_yellow.png")
if not os.path.isfile(BOT_LOGO):
BOT_LOGO = os.path.join(ROOT, "..", "..", "static", "candid_logo_yellow.png")
class LoggedComponents(TypedDict):
context: list[gr.Component]
found_helpful: gr.Component
will_recommend: gr.Component
comments: gr.Component
email: gr.Component
def build_execution_graph() -> CompiledStateGraph:
llm = ChatBedrock(
client=boto3.client("bedrock-runtime", region_name="us-east-1"),
model=Name2Endpoint["claude-3.5-haiku"]
)
org_name_recognition = OrganizationIdentifier(llm=llm) # bind the main chat model to the tool
return create_react_agent(
model=llm,
tools=[
get_current_day,
org_name_recognition,
find_mentioned_organizations,
search_candid_knowledge_base
],
)
def generate_postscript_messages(history: list[gr.ChatMessage]) -> Iterator[gr.ChatMessage]:
for record in history:
title = record.metadata.get("tool_name")
if title == search_candid_knowledge_base.name:
yield gr.ChatMessage(
role="assistant",
content=html_format_docs_chat(record.metadata.get("documents")),
metadata={
"title": "Source citations",
}
)
elif title == find_mentioned_organizations.name:
pass
async def execute(
user_input: dict[str, Any],
history: list[gr.ChatMessage]
) -> AsyncIterator[tuple[gr.Component, list[gr.ChatMessage]]]:
if len(history) == 0:
history.append(gr.ChatMessage(role="system", content=START_SYSTEM_PROMPT))
history.append(gr.ChatMessage(role="user", content=user_input["text"]))
for fname in user_input.get("files") or []:
fname: str
if fname.endswith('.txt'):
with open(fname, 'r', encoding='utf8') as f:
history.append(gr.ChatMessage(role="user", content=f.read()))
yield gr.MultimodalTextbox(value=None, interactive=True), history
horizon = len(history)
inputs = {"messages": convert_history_for_graph_agent(history)}
graph = build_execution_graph()
history.append(gr.ChatMessage(role="assistant", content=""))
async for stream_mode, chunk in graph.astream(inputs, stream_mode=["messages", "tasks"]):
if stream_mode == "messages" and chunk[0].content:
for msg in chunk[0].content:
if 'text' in msg:
history[-1].content += msg["text"]
yield gr.MultimodalTextbox(value=None, interactive=True), history
elif stream_mode == "tasks" and chunk.get("name") == "tools" and chunk.get("error") is None:
if "input" in chunk:
for msg in format_tool_call(chunk):
history.append(msg)
yield gr.MultimodalTextbox(value=None, interactive=True), history
elif "result" in chunk:
for msg in format_tool_response(chunk):
history.append(msg)
yield gr.MultimodalTextbox(value=None, interactive=True), history
history.append(gr.ChatMessage(role="assistant", content=""))
for post_msg in generate_postscript_messages(history=history[horizon:]):
history.append(post_msg)
yield gr.MultimodalTextbox(value=None, interactive=True), history
def send_feedback(
chat_context,
found_helpful,
will_recommend,
comments,
email
):
api = FeedbackApi()
total_submissions = 0
try:
response = api(
context=chat_context,
found_helpful=found_helpful,
will_recommend=will_recommend,
comments=comments,
email=email
)
total_submissions = response.get("response", 0)
gr.Info("Thank you for submitting feedback")
except Exception as ex:
raise gr.Error(f"Error submitting feedback: {ex}")
return total_submissions
def build_chat_app():
with gr.Blocks(theme=gr.themes.Soft(), title="Chat") as demo:
gr.Markdown(
"""
<h1>Candid's AI assistant</h1>
<p>
Please read the <a
href='https://info.candid.org/chatbot-reference-guide'
target="_blank"
rel="noopener noreferrer"
>guide</a> to get started.
</p>
<hr>
"""
)
with gr.Column():
chatbot = gr.Chatbot(
label="AskCandid",
elem_id="chatbot",
editable="user",
avatar_images=(
None, # user
BOT_LOGO, # bot
),
height="50vh",
type="messages",
show_label=False,
show_copy_button=True,
autoscroll=True,
layout="panel",
)
msg = gr.MultimodalTextbox(label="Your message", interactive=True)
gr.ClearButton(components=[msg, chatbot], size="sm")
# pylint: disable=no-member
# chatbot.like(fn=like_callback, inputs=chatbot, outputs=None)
msg.submit(
fn=execute,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_api=False
)
logged = LoggedComponents(context=chatbot)
return demo, logged
def build_feedback(components: LoggedComponents) -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Soft(), title="Candid AI demo") as demo:
gr.Markdown("<h1>Help us improve this tool with your valuable feedback</h1>")
with gr.Row():
with gr.Column():
found_helpful = gr.Radio(
[True, False], label="Did you find what you were looking for?"
)
will_recommend = gr.Radio(
[True, False],
label="Will you recommend this Chatbot to others?",
)
comment = gr.Textbox(label="Additional comments (optional)", lines=4)
email = gr.Textbox(label="Your email (optional)", lines=1)
submit = gr.Button("Submit Feedback")
components["found_helpful"] = found_helpful
components["will_recommend"] = will_recommend
components["comments"] = comment
components["email"] = email
# pylint: disable=no-member
submit.click(
fn=send_feedback,
inputs=[
components["context"],
components["found_helpful"],
components["will_recommend"],
components["comments"],
components["email"]
],
outputs=None,
show_api=False,
api_name=False,
preprocess=False,
)
return demo
def build_app():
candid_chat, logger = build_chat_app()
feedback = build_feedback(logger)
with open(os.path.join(ROOT, "static", "chatStyle.css"), "r", encoding="utf8") as f:
css_chat = f.read()
demo = gr.TabbedInterface(
interface_list=[
candid_chat,
feedback
],
tab_names=[
"Candid's AI assistant",
"Feedback"
],
title="Candid's AI assistant",
theme=gr.themes.Soft(),
css=css_chat,
)
return demo
if __name__ == "__main__":
app = build_app()
app.queue(max_size=5).launch(
show_api=False,
mcp_server=False,
auth=[
(os.getenv("APP_USERNAME"), os.getenv("APP_PASSWORD")),
(os.getenv("APP_PUBLIC_USERNAME"), os.getenv("APP_PUBLIC_PASSWORD")),
],
ssr_mode=False,
auth_message="Login to Candid's AI assistant",
)