xarical's picture
Initial commit
f2a243f
import json
import logging
import os
from threading import Thread
from flask import Flask, Response, abort, jsonify, request
import google.generativeai as genai
import gradio as gr
import requests
import config
# get app logger
app_logger = logging.getLogger(__name__)
# config the model
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
generation_config = {
"max_output_tokens": 512,
"temperature": 0.9
}
model = genai.GenerativeModel(model_name="gemini-1.5-flash",
generation_config=generation_config)
# handle the chatbot's messages
def chatbot_response(user_input: str, history: list[dict[str, str]]) -> str:
messages = [{"role": "system", "content": config.system_prompt}] + history + [{"role": "user", "content": user_input}]
while True:
chat_completion = model.generate_content(str(messages))
chat_completion = chat_completion.text.replace('```json', '').replace('```', '').strip()
app_logger.info(chat_completion)
messages.append(chat_completion)
try:
response = json.loads(chat_completion)
except json.decoder.JSONDecodeError as e:
app_logger.error(e)
message = '{"error": "Your previous message was invalid JSON and caused an error during parsing. (Hint: you may have hit the token limit, try separating your messages into multiple messages)}'
app_logger.info(message)
messages.append({"role": "system", "content": message})
continue
if response["type"] == "tool":
handle_tools(response["content"], messages)
elif response["type"] == "answer":
return response["content"]
# make the tool calls
def handle_tools(tools: list[dict[str, str | dict]], messages: list[dict[str]]) -> None:
for tool in tools:
try:
if tool["name"] not in {"datetime", "calculator", "websearch"}:
response = f'{{"error": "Tool {tool["name"]} is an invalid tool"}}'
else:
url = f"{config.tools_external_url}/{tool['name']}"
data = tool.get("data")
app_logger.info(f"Making request to {url}{f' with {data}' if data else ''}")
r = requests.post(url, json=data) if data else requests.get(url)
r.raise_for_status()
response = r.json()
except Exception as e:
app_logger.error(e)
response = f'{{"error": "An error with making the API call to the tool {tool["name"]} has occurred, please inform the user of this"}}'
app_logger.info(f"Tool response: {response}")
messages.append({"role": "system", "content": response})
# start the api server in a thread (with its own logger)
api = Flask(__name__)
formatter = logging.Formatter(f'%(asctime)s - API - %(levelname)s - %(message)s')
api_logger = api.logger
handler = api_logger.handlers[0].setFormatter(formatter)
@api.route("/", methods=["POST"])
def query() -> Response:
try:
data = json.loads(request.data)
return jsonify({"result": chatbot_response(data["user_input"], data["history"])})
except KeyError as e:
abort(400, description="Missing value in request body: " + str(e))
except Exception as e:
abort(400, description="Error: " + str(e))
f = Thread(
target=api.run,
kwargs={
"host": config.api_host,
"port": config.api_port,
}
)
f.start()
# start the gradio interface
demo = gr.ChatInterface(fn=chatbot_response,
title="✨ Gemini Tooluse Prototype 🔨",
description="A prototype of a Gemini 1.5 Flash chatbot with Tooluse using a demo API implemented in Flask. It can use a calculator to perform basic arithmetic, get the current date and time, and search the web.")
demo.launch(server_name=config.gradio_host, server_port=config.gradio_port, show_api=False)