ask-candid / app.py
brainsqueeze's picture
Upload 36 files
a0e37e2 verified
raw
history blame
4.93 kB
from typing import List, Tuple, Dict, TypedDict, Optional, Any
import os
import gradio as gr
from langchain_core.language_models.llms import LLM
from langchain_openai.chat_models import ChatOpenAI
from langchain_aws import ChatBedrock
import boto3
from ask_candid.base.config.rest import OPENAI
from ask_candid.base.config.models import Name2Endpoint
from ask_candid.base.config.data import ALL_INDICES
from ask_candid.utils import format_chat_ag_response
from ask_candid.chat import run_chat
ROOT = os.path.dirname(os.path.abspath(__file__))
BUCKET = "candid-data-science-reporting"
PREFIX = "Assistant"
class LoggedComponents(TypedDict):
context: List[gr.components.Component]
found_helpful: gr.components.Component
will_recommend: gr.components.Component
comments: gr.components.Component
email: gr.components.Component
def select_foundation_model(model_name: str, max_new_tokens: int) -> LLM:
if model_name == "gpt-4o":
llm = ChatOpenAI(
model_name=Name2Endpoint[model_name],
max_tokens=max_new_tokens,
api_key=OPENAI["key"],
temperature=0.0,
streaming=True,
)
elif model_name in {"claude-3.5-haiku", "llama-3.1-70b-instruct", "mistral-large", "mixtral-8x7B"}:
llm = ChatBedrock(
client=boto3.client("bedrock-runtime"),
model=Name2Endpoint[model_name],
max_tokens=max_new_tokens,
temperature=0.0
)
else:
raise gr.Error(f"Base model `{model_name}` is not supported")
return llm
def execute(
thread_id: str,
user_input: Dict[str, Any],
history: List[Dict],
model_name: str,
max_new_tokens: int,
indices: Optional[List[str]] = None,
):
return run_chat(
thread_id=thread_id,
user_input=user_input,
history=history,
llm=select_foundation_model(model_name=model_name, max_new_tokens=max_new_tokens),
indices=indices
)
def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]:
with gr.Blocks(theme=gr.themes.Soft(), title="Chat") as demo:
gr.Markdown(
"""
<h1>Ask Candid</h1>
<p>
Please read the <a
href='https://info.candid.org/chatbot-reference-guide'
target="_blank"
rel="noopener noreferrer"
>guide</a> to get started.
</p>
<hr>
"""
)
with gr.Accordion(label="Advanced settings", open=False):
es_indices = gr.CheckboxGroup(
choices=list(ALL_INDICES),
value=list(ALL_INDICES),
label="Sources to include",
interactive=True,
)
llmname = gr.Radio(
label="Language model",
value="gpt-4o",
choices=list(Name2Endpoint.keys()),
interactive=True,
)
max_new_tokens = gr.Slider(
value=256 * 3,
minimum=128,
maximum=2048,
step=128,
label="Max new tokens",
interactive=True,
)
with gr.Column():
chatbot = gr.Chatbot(
label="AskCandid",
elem_id="chatbot",
bubble_full_width=True,
avatar_images=(
None,
os.path.join(ROOT, "static", "candid_logo_yellow.png"),
),
height="45vh",
type="messages",
show_label=False,
show_copy_button=True,
show_share_button=True,
show_copy_all_button=True,
)
msg = gr.MultimodalTextbox(label="Your message", interactive=True)
thread_id = gr.Text(visible=False, value="", label="thread_id")
gr.ClearButton(components=[msg, chatbot, thread_id], size="sm")
# pylint: disable=no-member
chat_msg = msg.submit(
fn=execute,
inputs=[thread_id, msg, chatbot, llmname, max_new_tokens, es_indices],
outputs=[msg, chatbot, thread_id],
)
chat_msg.then(format_chat_ag_response, chatbot, chatbot, api_name="bot_response")
logged = LoggedComponents(context=[thread_id, chatbot])
return logged, demo
def build_app():
_, candid_chat = build_rag_chat()
with open(os.path.join(ROOT, "static", "chatStyle.css"), "r", encoding="utf8") as f:
css_chat = f.read()
demo = gr.TabbedInterface(
interface_list=[
candid_chat,
],
tab_names=[
"AskCandid",
],
theme=gr.themes.Soft(),
css=css_chat,
)
return demo
if __name__ == "__main__":
app = build_app()
app.queue(max_size=5).launch(show_api=False)