Spaces:
Running
Running
from typing import List, Tuple, Dict, TypedDict, Optional, Any | |
import os | |
import gradio as gr | |
from langchain_openai.chat_models import ChatOpenAI | |
try: | |
from utils import format_chat_ag_response | |
from retrieval.config import ALL_INDICES | |
from chat import run_chat | |
except ImportError: | |
from .utils import format_chat_ag_response | |
from .retrieval.config import ALL_INDICES | |
from .chat import run_chat | |
ROOT = os.path.dirname(os.path.abspath(__file__)) | |
class LoggedComponents(TypedDict): | |
context: List[gr.components.Component] | |
found_helpful: gr.components.Component | |
will_recommend: gr.components.Component | |
comments: gr.components.Component | |
email: gr.components.Component | |
def execute( | |
thread_id: str, | |
user_input: Dict[str, Any], | |
chatbot: List[Dict], | |
max_new_tokens: int, | |
indices: Optional[List[str]] = None, | |
): | |
llm = ChatOpenAI( | |
model_name="gpt-4o", | |
max_tokens=max_new_tokens, | |
api_key=os.getenv("OPENAI_API_KEY"), | |
temperature=0.0, | |
streaming=True | |
) | |
return run_chat( | |
thread_id=thread_id, | |
user_input=user_input, | |
chatbot=chatbot, | |
llm=llm, | |
indices=indices | |
) | |
def build_chat() -> Tuple[LoggedComponents, gr.Blocks]: | |
with gr.Blocks(theme=gr.themes.Soft(), title="Ask Candid") as demo: | |
with gr.Accordion(label="Advanced settings", open=False): | |
es_indices = gr.CheckboxGroup( | |
choices=list(ALL_INDICES), | |
value=list(ALL_INDICES), | |
label="Sources to include", | |
interactive=True | |
) | |
max_new_tokens = gr.Slider( | |
value=256 * 3, minimum=128, maximum=2048, step=128, | |
label="Max new tokens", interactive=True | |
) | |
with gr.Column(): | |
chatbot = gr.Chatbot( | |
label="Candid Assistant", | |
elem_id="chatbot", | |
bubble_full_width=False, | |
avatar_images=( | |
None, | |
os.path.join(ROOT, "static", "candid_logo_yellow.png") | |
), | |
height="45vh", | |
type="messages", | |
show_label=False, | |
show_copy_button=True, | |
show_share_button=True, | |
show_copy_all_button=True | |
) | |
msg = gr.MultimodalTextbox(label="Your message", interactive=True) | |
thread_id = gr.Text(visible=False, value="", label="thread_id") | |
gr.ClearButton(components=[msg, chatbot, thread_id], size="sm") | |
# pylint: disable=no-member | |
chat_msg = msg.submit( | |
fn=execute, | |
inputs=[thread_id, msg, chatbot, max_new_tokens, es_indices], | |
outputs=[msg, chatbot, thread_id] | |
) | |
chat_msg.then(format_chat_ag_response, chatbot, chatbot, api_name="bot_response") | |
logged = LoggedComponents( | |
context=[thread_id, chatbot] | |
) | |
return logged, demo | |
if __name__ == '__main__': | |
_, app = build_chat() | |
app.queue(max_size=5).launch( | |
show_api=False, | |
auth=[ | |
(os.getenv("APP_USERNAME"), os.getenv("APP_PASSWORD")), | |
(os.getenv("APP_PUBLIC_USERNAME"), os.getenv("APP_PUBLIC_PASSWORD")), | |
], | |
auth_message="Login to Candid's AI assistant", | |
ssr_mode=False | |
) | |