Spaces:
Sleeping
Sleeping
import gradio as gr | |
from gradio import ChatMessage | |
import json | |
from openai import OpenAI | |
from tools import tools, oitools | |
from dotenv import load_dotenv | |
import os | |
import re | |
load_dotenv(".env") | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
BASE_URL = os.environ.get("BASE_URL") | |
SYSTEM_PROMPT_TEMPLATE = """You are an AI assistant designed to assist users with a hotel booking and information system. Your role is to provide detailed and accurate information about the hotel, including available accommodations, facilities, dining options, and reservation services. You can assist with bookings, modify or cancel reservations, and answer general inquiries about the hotel, etc. | |
Maintain clarity, conciseness, and relevance in your responses, ensuring a seamless user experience. | |
Always respond in the same **language as the user’s query** to preserve their preferred language.""" | |
# print(json.dumps(oitools, indent=2)) | |
client = OpenAI( | |
base_url=f"{BASE_URL}/v1", | |
api_key=HF_TOKEN | |
) | |
def clean_json_string(json_str): | |
return re.sub(r'[ ,}\s]+$', '', json_str) + '}' | |
def completion(history, model, system_prompt, tools=None): | |
messages = [{"role": "system", "content": system_prompt}] | |
for msg in history: | |
if isinstance(msg, dict): | |
msg = ChatMessage(**msg) | |
if msg.role == "assistant" and hasattr(msg, "metadata") and msg.metadata: | |
tools_calls = json.loads(msg.metadata.get("title", "[]")) | |
# for tool_calls in tools_calls: | |
# tool_calls["function"]["arguments"] = json.loads(tool_calls["function"]["arguments"]) | |
messages.append({"role": "assistant", "tool_calls": tools_calls, "content": ""}) | |
messages.append({"role": "tool", "content": msg.content}) | |
else: | |
messages.append({"role": msg.role, "content": msg.content}) | |
for msg in messages: | |
print(msg) | |
print("") | |
print("") | |
request_params = { | |
"model": model, | |
"messages": messages, | |
"stream": False, | |
"max_tokens": 1000, | |
"temperature": 0.2, | |
"frequency_penalty": 1, | |
"extra_body": {"repetition_penalty": 1.1}, | |
} | |
if tools: | |
request_params.update({"tool_choice": "auto", "tools": tools}) | |
return client.chat.completions.create(**request_params) | |
def llm_in_loop(history, system_prompt, recursive): | |
try: | |
models = client.models.list() | |
model = models.data[0].id if models.data else "gpt-3.5-turbo" | |
except Exception as err: | |
gr.Warning("The model is initializing. Please wait; this may take 5 to 10 minutes ⏳.", duration=20) | |
raise err | |
arguments = "" | |
name = "" | |
chat_completion = completion(history=history, tools=oitools, model=model, system_prompt=system_prompt) | |
appended = False | |
if chat_completion.choices and chat_completion.choices[0].message.tool_calls: | |
call = chat_completion.choices[0].message.tool_calls[0] | |
if hasattr(call.function, "name") and call.function.name: | |
name = call.function.name | |
if hasattr(call.function, "arguments") and call.function.arguments: | |
arguments += call.function.arguments | |
elif chat_completion.choices[0].message.content: | |
if not appended: | |
history.append(ChatMessage(role="assistant", content="")) | |
appended = True | |
history[-1].content += chat_completion.choices[0].message.content | |
yield history[recursive:] | |
# for chunk in chat_completion: | |
# if chunk.choices and chunk.choices[0].delta.tool_calls: | |
# call = chunk.choices[0].delta.tool_calls[0] | |
# if hasattr(call.function, "name") and call.function.name: | |
# name = call.function.name | |
# if hasattr(call.function, "arguments") and call.function.arguments: | |
# arguments += call.function.arguments | |
# elif chunk.choices[0].delta.content: | |
# if not appended: | |
# history.append(ChatMessage(role="assistant", content="")) | |
# appended = True | |
# history[-1].content += chunk.choices[0].delta.content | |
# yield history[recursive:] | |
arguments = clean_json_string(arguments) if arguments else "{}" | |
print(name, arguments) | |
arguments = json.loads(arguments) | |
if appended: | |
recursive -= 1 | |
if name: | |
result = f"💥 Error using tool {name}, tool doesn't exist" if name not in tools else str(tools[name].invoke(input=arguments)) | |
result = json.dumps({name: result}, ensure_ascii=False) | |
# msg = ChatMessage( | |
# role="assistant", | |
# content="", | |
# metadata= {"title": f"🛠️ Using tool '{name}', arguments: {json.dumps(json_arguments, ensure_ascii=False)}"}, | |
# options=[{"label":"tool_calls", "value": json.dumps([{"id": "call_FthC9qRpsL5kBpwwyw6c7j4k","function": {"arguments": arguments,"name": name},"type": "function"}])}] | |
# ) | |
history.append(ChatMessage(role="assistant", content=result, metadata={"title": json.dumps([{"id": "call_id", "function": {"arguments": json.dumps(arguments), "name": name}, "type": "function"}])})) | |
yield history[recursive:] | |
yield from llm_in_loop(history, system_prompt, recursive - 1) | |
def respond(message, history, additional_inputs): | |
history.append(ChatMessage(role="user", content=message)) | |
yield from llm_in_loop(history, additional_inputs, -1) | |
if __name__ == "__main__": | |
system_prompt = gr.Textbox(label="System prompt", value=SYSTEM_PROMPT_TEMPLATE, lines=3) | |
demo = gr.ChatInterface(respond, type="messages", additional_inputs=[system_prompt]) | |
demo.launch() | |