Spaces:
Sleeping
Sleeping
import json | |
import logging | |
import os | |
from threading import Thread | |
from flask import Flask, Response, abort, jsonify, request | |
import google.generativeai as genai | |
import gradio as gr | |
import requests | |
import config | |
# get app logger | |
app_logger = logging.getLogger(__name__) | |
# config the model | |
genai.configure(api_key=os.environ["GEMINI_API_KEY"]) | |
generation_config = { | |
"max_output_tokens": 512, | |
"temperature": 0.9 | |
} | |
model = genai.GenerativeModel(model_name="gemini-1.5-flash", | |
generation_config=generation_config) | |
# handle the chatbot's messages | |
def chatbot_response(user_input: str, history: list[dict[str, str]]) -> str: | |
messages = [{"role": "system", "content": config.system_prompt}] + history + [{"role": "user", "content": user_input}] | |
while True: | |
chat_completion = model.generate_content(str(messages)) | |
chat_completion = chat_completion.text.replace('```json', '').replace('```', '').strip() | |
app_logger.info(chat_completion) | |
messages.append(chat_completion) | |
try: | |
response = json.loads(chat_completion) | |
except json.decoder.JSONDecodeError as e: | |
app_logger.error(e) | |
message = '{"error": "Your previous message was invalid JSON and caused an error during parsing. (Hint: you may have hit the token limit, try separating your messages into multiple messages)}' | |
app_logger.info(message) | |
messages.append({"role": "system", "content": message}) | |
continue | |
if response["type"] == "tool": | |
handle_tools(response["content"], messages) | |
elif response["type"] == "answer": | |
return response["content"] | |
# make the tool calls | |
def handle_tools(tools: list[dict[str, str | dict]], messages: list[dict[str]]) -> None: | |
for tool in tools: | |
try: | |
if tool["name"] not in {"datetime", "calculator", "websearch"}: | |
response = f'{{"error": "Tool {tool["name"]} is an invalid tool"}}' | |
else: | |
url = f"{config.tools_external_url}/{tool['name']}" | |
data = tool.get("data") | |
app_logger.info(f"Making request to {url}{f' with {data}' if data else ''}") | |
r = requests.post(url, json=data) if data else requests.get(url) | |
r.raise_for_status() | |
response = r.json() | |
except Exception as e: | |
app_logger.error(e) | |
response = f'{{"error": "An error with making the API call to the tool {tool["name"]} has occurred, please inform the user of this"}}' | |
app_logger.info(f"Tool response: {response}") | |
messages.append({"role": "system", "content": response}) | |
# start the api server in a thread (with its own logger) | |
api = Flask(__name__) | |
formatter = logging.Formatter(f'%(asctime)s - API - %(levelname)s - %(message)s') | |
api_logger = api.logger | |
handler = api_logger.handlers[0].setFormatter(formatter) | |
def query() -> Response: | |
try: | |
data = json.loads(request.data) | |
return jsonify({"result": chatbot_response(data["user_input"], data["history"])}) | |
except KeyError as e: | |
abort(400, description="Missing value in request body: " + str(e)) | |
except Exception as e: | |
abort(400, description="Error: " + str(e)) | |
f = Thread( | |
target=api.run, | |
kwargs={ | |
"host": config.api_host, | |
"port": config.api_port, | |
} | |
) | |
f.start() | |
# start the gradio interface | |
demo = gr.ChatInterface(fn=chatbot_response, | |
title="✨ Gemini Tooluse Prototype 🔨", | |
description="A prototype of a Gemini 1.5 Flash chatbot with Tooluse using a demo API implemented in Flask. It can use a calculator to perform basic arithmetic, get the current date and time, and search the web.") | |
demo.launch(server_name=config.gradio_host, server_port=config.gradio_port, show_api=False) |