xarical commited on
Commit
f2a243f
·
0 Parent(s):

Initial commit

Browse files
.github/README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Gemini-Tooluse-Prototype
2
+ A prototype of a Gemini 1.5 Flash chatbot with Tooluse using a demo API implemented in Flask. It can use a calculator to perform basic arithmetic, get the current date and time, and search the web (using Selenium). Built in Python using Gradio and Flask
3
+
4
+ https://huggingface.co/spaces/xarical/Gemini-Tooluse-Prototype
.github/workflows/checkfilesizes.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Check file sizes
2
+ on: # or directly `on: [push]` to run the action on every push on any branch
3
+ pull_request:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ check-file-size:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - name: Check file sizes
14
+ uses: ActionsDesk/lfs-warning@v2.0
15
+ with:
16
+ filesizelimit: 10485760 # this is 10MB so we can sync to HF Spaces
.github/workflows/deploytospace.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Deploy to HF Space
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ deploy-to-space:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v3
14
+ with:
15
+ fetch-depth: 0
16
+ lfs: true
17
+ - name: Push to HF space
18
+ env:
19
+ HF_API_KEY: ${{ secrets.HF_API_KEY }}
20
+ SPACE_ID: ${{ secrets.SPACE_ID }}
21
+ run: git push https://xarical:$HF_API_KEY@huggingface.co/spaces/$SPACE_ID main
Dockerfile ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+ FROM python:3
4
+
5
+ # Add new user 'user' (non-root)
6
+ RUN useradd -m -u 1000 user
7
+
8
+ # Set working dir to app
9
+ ENV HOME=/home/user \
10
+ PATH=/home/user/.local/bin:$PATH
11
+ RUN mkdir $HOME/app
12
+ WORKDIR $HOME/app
13
+
14
+ # Switch to root
15
+ USER root
16
+
17
+ # Install nginx and packages.txt
18
+ COPY --chown=root packages.txt packages.txt
19
+ RUN apt-get -y update && apt-get -y install nginx && xargs apt-get -y install < packages.txt
20
+
21
+ # Give app permissions to 'user' (non-root)
22
+ RUN chown user:user .
23
+
24
+ # Give nginx permissions to 'user' (non-root)
25
+ # See https://www.rockyourcode.com/run-docker-nginx-as-non-root-user/
26
+ RUN mkdir -p /var/cache/nginx \
27
+ /var/log/nginx \
28
+ /var/lib/nginx
29
+ RUN touch /var/run/nginx.pid
30
+ RUN chown -R user:user /var/cache/nginx \
31
+ /var/log/nginx \
32
+ /var/lib/nginx \
33
+ /var/run/nginx.pid
34
+
35
+ # Switch to 'user' (non-root)
36
+ USER user
37
+
38
+ # Install requirements.txt
39
+ COPY --chown=user requirements.txt requirements.txt
40
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
41
+
42
+ # Copy nginx configuration
43
+ COPY --chown=user nginx.conf /etc/nginx/sites-available/default
44
+
45
+ # Copy app
46
+ COPY --chown=user . .
47
+
48
+ # Run
49
+ CMD ["bash", "run.sh"]
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Gemini Tool Use
3
+ emoji: 🛠️
4
+ colorFrom: green
5
+ colorTo: indigo
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: true
9
+ short_description: A Gemini 1.5 Flash chatbot with Tool Use
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from threading import Thread
5
+
6
+ from flask import Flask, Response, abort, jsonify, request
7
+ import google.generativeai as genai
8
+ import gradio as gr
9
+ import requests
10
+
11
+ import config
12
+
13
+ # get app logger
14
+ app_logger = logging.getLogger(__name__)
15
+
16
+ # config the model
17
+ genai.configure(api_key=os.environ["GEMINI_API_KEY"])
18
+ generation_config = {
19
+ "max_output_tokens": 512,
20
+ "temperature": 0.9
21
+ }
22
+ model = genai.GenerativeModel(model_name="gemini-1.5-flash",
23
+ generation_config=generation_config)
24
+
25
+ # handle the chatbot's messages
26
+ def chatbot_response(user_input: str, history: list[dict[str, str]]) -> str:
27
+ messages = [{"role": "system", "content": config.system_prompt}] + history + [{"role": "user", "content": user_input}]
28
+ while True:
29
+ chat_completion = model.generate_content(str(messages))
30
+ chat_completion = chat_completion.text.replace('```json', '').replace('```', '').strip()
31
+ app_logger.info(chat_completion)
32
+ messages.append(chat_completion)
33
+ try:
34
+ response = json.loads(chat_completion)
35
+ except json.decoder.JSONDecodeError as e:
36
+ app_logger.error(e)
37
+ message = '{"error": "Your previous message was invalid JSON and caused an error during parsing. (Hint: you may have hit the token limit, try separating your messages into multiple messages)}'
38
+ app_logger.info(message)
39
+ messages.append({"role": "system", "content": message})
40
+ continue
41
+ if response["type"] == "tool":
42
+ handle_tools(response["content"], messages)
43
+ elif response["type"] == "answer":
44
+ return response["content"]
45
+
46
+ # make the tool calls
47
+ def handle_tools(tools: list[dict[str, str | dict]], messages: list[dict[str]]) -> None:
48
+ for tool in tools:
49
+ try:
50
+ if tool["name"] not in {"datetime", "calculator", "websearch"}:
51
+ response = f'{{"error": "Tool {tool["name"]} is an invalid tool"}}'
52
+ else:
53
+ url = f"{config.tools_external_url}/{tool['name']}"
54
+ data = tool.get("data")
55
+ app_logger.info(f"Making request to {url}{f' with {data}' if data else ''}")
56
+ r = requests.post(url, json=data) if data else requests.get(url)
57
+ r.raise_for_status()
58
+ response = r.json()
59
+ except Exception as e:
60
+ app_logger.error(e)
61
+ response = f'{{"error": "An error with making the API call to the tool {tool["name"]} has occurred, please inform the user of this"}}'
62
+ app_logger.info(f"Tool response: {response}")
63
+ messages.append({"role": "system", "content": response})
64
+
65
+ # start the api server in a thread (with its own logger)
66
+ api = Flask(__name__)
67
+ formatter = logging.Formatter(f'%(asctime)s - API - %(levelname)s - %(message)s')
68
+ api_logger = api.logger
69
+ handler = api_logger.handlers[0].setFormatter(formatter)
70
+
71
+ @api.route("/", methods=["POST"])
72
+ def query() -> Response:
73
+ try:
74
+ data = json.loads(request.data)
75
+ return jsonify({"result": chatbot_response(data["user_input"], data["history"])})
76
+ except KeyError as e:
77
+ abort(400, description="Missing value in request body: " + str(e))
78
+ except Exception as e:
79
+ abort(400, description="Error: " + str(e))
80
+
81
+ f = Thread(
82
+ target=api.run,
83
+ kwargs={
84
+ "host": config.api_host,
85
+ "port": config.api_port,
86
+ }
87
+ )
88
+ f.start()
89
+
90
+ # start the gradio interface
91
+ demo = gr.ChatInterface(fn=chatbot_response,
92
+ title="✨ Gemini Tooluse Prototype 🔨",
93
+ description="A prototype of a Gemini 1.5 Flash chatbot with Tooluse using a demo API implemented in Flask. It can use a calculator to perform basic arithmetic, get the current date and time, and search the web.")
94
+
95
+ demo.launch(server_name=config.gradio_host, server_port=config.gradio_port, show_api=False)
config.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio_host = "localhost"
2
+ gradio_port = 8080
3
+
4
+ api_host = "localhost"
5
+ api_port = 4000
6
+
7
+ tools_host = "localhost"
8
+ tools_port = 3000
9
+
10
+ import os
11
+ tools_external_url = f"https://{os.environ['SPACE_HOST']}/tools" # This should be changed based on where you're hosting
12
+
13
+ system_prompt = """**Prompt:**
14
+ You are a helpful assistant chatbot. You are talking to the user through a chat interface.
15
+ Use the tools detailed in the Tool Specifications section below to augment your built-in ability to perform tasks. Use different tools together to answer questions. If the tool call returns unhelpful results or an error or if an applicable tool is not listed, answer the user's question to the best of your knowledge. Separate lists of more than 2 tools into separate messages and send those messages one at a time.
16
+ Your response must be a valid JSON object similar to the 'assistant' messages in the following example conversation:
17
+ **Example:**
18
+ {
19
+ "role": "user",
20
+ "content": "Hi!"
21
+ }
22
+ {
23
+ "role": "assistant",
24
+ "type": "answer"
25
+ "content": "Hello! 👋 How can I assist you today?"
26
+ }
27
+ {
28
+ "role": "user",
29
+ "content": "What is the time in India?"
30
+ }
31
+ {
32
+ "role": "assistant",
33
+ "type": "tool",
34
+ "content": [
35
+ {"name": "datetime", "data": {}},
36
+ {"name": "websearch", "data": {"query": "Time difference between India and UTC"}}
37
+ ]
38
+ }
39
+ {
40
+ "role": "system",
41
+ "content": {"date": "2024-08-13", "time": "9:47:13"}
42
+ }
43
+ {
44
+ "role": "system",
45
+ "content": "Time Difference between UTC and IST Indian. Indian Standard Time is 5 hours 30 minutes ahead from the UTC universal time. UTC to IST Indian Time Conversion"
46
+ }
47
+ {
48
+ "role": "assistant",
49
+ "type": "tool",
50
+ "content": [
51
+ {"name": "calculator", "data": {"operator": "add", "num1": 9, "num2": 5}},
52
+ {"name": "calculator", "data": {"operator": "add", "num1": 47, "num2": 30}}
53
+ ]
54
+ }
55
+ {
56
+ "role": "system",
57
+ "content": "14"
58
+ }
59
+ {
60
+ "role": "system",
61
+ "content": "77"
62
+ }
63
+ {
64
+ "role": "assistant",
65
+ "type": "answer",
66
+ "content": "It's currently 15:17 in IST (Indian Standard Time). This is equivalent to 3:17 PM in IST. Is there anything else I can assist you with? 😊"
67
+ }
68
+ **Tool Specifications:**
69
+ {
70
+ "name": "datetime"
71
+ "description": "A tool to get the current date and time in form YYYY-MM-DD and HH:MM:SS, in UTC with a 24-hour clock. It accepts nothing as the data."
72
+ "data": {}
73
+ }
74
+ {
75
+ "name": "calculator"
76
+ "description": "A simple calculator tool that can do the four arithmetic operations. It accepts an operator and two integers."
77
+ "data": {"operator": "string, add/subtract/multiply/divide", "num1": any int, "num2": any int}
78
+ }
79
+ {
80
+ "name": "websearch"
81
+ "description": "A search tool that can do a limited search of the web for answers. Results are not guaranteed to be accurate. IT DOES NOT WORK TO SEARCH ANYTHING TIME-SENSITIVE. It accepts a search query."
82
+ "data": {"query": "any string"}
83
+ }
84
+ """
nginx.conf ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ server {
2
+ listen 7860 default_server;
3
+ listen [::]:7860 default_server;
4
+
5
+ server_name _;
6
+
7
+ location / {
8
+ # Serve Gradio from port 8080
9
+ proxy_pass http://localhost:8080;
10
+ proxy_http_version 1.1;
11
+ proxy_set_header Upgrade $http_upgrade;
12
+ proxy_set_header Connection 'upgrade';
13
+ proxy_set_header Host $host;
14
+ proxy_set_header X-Real-IP $remote_addr;
15
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
16
+ proxy_cache_bypass $http_upgrade;
17
+ proxy_read_timeout 86400;
18
+ proxy_redirect off;
19
+ }
20
+
21
+ location /query {
22
+ # Serve Flask app API server from port 4000
23
+ rewrite ^/query/?(.*)$ /$1 break; # strip the /query/
24
+ proxy_pass http://localhost:4000;
25
+ proxy_http_version 1.1;
26
+ proxy_set_header Upgrade $http_upgrade;
27
+ proxy_set_header Connection 'upgrade';
28
+ proxy_set_header Host $host;
29
+ proxy_set_header X-Real-IP $remote_addr;
30
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
31
+ proxy_cache_bypass $http_upgrade;
32
+ proxy_read_timeout 86400;
33
+ proxy_redirect off;
34
+ }
35
+
36
+ location /tools {
37
+ # Serve Flask tools API server from port 3000
38
+ rewrite ^/tools/?(.*)$ /$1 break; # strip the /tools/
39
+ proxy_pass http://localhost:3000;
40
+ proxy_http_version 1.1;
41
+ proxy_set_header Upgrade $http_upgrade;
42
+ proxy_set_header Connection 'upgrade';
43
+ proxy_set_header Host $host;
44
+ proxy_set_header X-Real-IP $remote_addr;
45
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
46
+ proxy_cache_bypass $http_upgrade;
47
+ proxy_read_timeout 86400;
48
+ proxy_redirect off;
49
+ }
50
+ }
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ chromium-driver
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ flask
2
+ google-generativeai
3
+ gradio
4
+ requests
5
+ selenium
run.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # start nginx
4
+ service nginx start
5
+
6
+ # start the processes
7
+ python tools.py > /dev/stdout 2>&1 & echo $! > tools.pid
8
+ python app.py > /dev/stdout 2>&1 # blocking
9
+
10
+ # when unblocked, kill other processes and clean up
11
+ pkill -F tools.pid
12
+ rm tools.pid
tools.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from datetime import datetime, timezone
5
+
6
+ from flask import Flask, Response, abort, json, jsonify, request
7
+ from selenium import webdriver
8
+ from selenium.webdriver.chrome.options import Options
9
+ from selenium.webdriver.common.by import By
10
+ from selenium.webdriver.common.keys import Keys
11
+
12
+ import config
13
+
14
+ # selenium options
15
+ chrome_options = Options()
16
+ chrome_options.add_argument('--no-sandbox')
17
+ chrome_options.add_argument('--disable-dev-shm-usage')
18
+ chrome_options.add_argument('--disable-gpu')
19
+ chrome_options.add_argument('--headless')
20
+
21
+ tools = Flask(__name__)
22
+ formatter = logging.Formatter(f'%(asctime)s - TOOLS - %(levelname)s - %(message)s')
23
+ tools_logger = tools.logger
24
+ handler = tools_logger.handlers[0].setFormatter(formatter)
25
+
26
+ @tools.route("/datetime", methods=["GET"])
27
+ def get_datetime() -> Response:
28
+ curr_date = datetime.now(timezone.utc).date().strftime("%Y-%m-%d")
29
+ curr_time = datetime.now(timezone.utc).strftime("%H:%M:%S")
30
+ return jsonify({"date": curr_date, "time": curr_time})
31
+
32
+ @tools.route("/calculator", methods=["POST"])
33
+ def calculate() -> Response:
34
+ try:
35
+ data = json.loads(request.data)
36
+ if data["operator"] == "add":
37
+ return jsonify(result=data["num1"] + data["num2"])
38
+ elif data["operator"] == "subtract":
39
+ return jsonify(result=data["num1"] - data["num2"])
40
+ elif data["operator"] == "multiply":
41
+ return jsonify(result=data["num1"] * data["num2"])
42
+ elif data["operator"] == "divide":
43
+ return jsonify(result=data["num1"] / data["num2"])
44
+ else:
45
+ abort(400, description="Invalid operator: " + data["operator"])
46
+ except KeyError as e:
47
+ abort(400, description="Missing value in request body: " + str(e))
48
+ except Exception as e:
49
+ abort(400, description="Error: " + str(e))
50
+
51
+ @tools.route("/websearch", methods=["POST"])
52
+ def google_search() -> Response:
53
+ try:
54
+ data = json.loads(request.data)
55
+ global driver
56
+ driver = webdriver.Chrome(options=chrome_options)
57
+ driver.get("https://www.google.com/")
58
+ search_bar = driver.find_element(By.NAME, "q")
59
+ search_bar.send_keys(data["query"])
60
+ search_bar.send_keys(Keys.RETURN)
61
+ time.sleep(1)
62
+
63
+ search_results = driver.find_elements(By.CSS_SELECTOR, "div.kno-rdesc span span")
64
+ if len(search_results) > 0: # check if google quick answer box exists
65
+ return jsonify(result=search_results[0].text)
66
+ else: # otherwise, find list of search results
67
+ search_results = driver.find_elements(By.CSS_SELECTOR, "div.g")
68
+ for result in search_results:
69
+ try: # first sentence of last span element in result
70
+ first_result = result.find_elements(By.CSS_SELECTOR, "div.VwiC3b span")
71
+ return jsonify(result=first_result[-1].text.split("...")[0])
72
+ except IndexError: # if no span element in result, go to next result
73
+ pass
74
+ return jsonify(result="No search results found")
75
+ except KeyError as e:
76
+ abort(400, description="Missing value in request body: " + str(e))
77
+ except Exception as e:
78
+ abort(400, description="Error: " + str(e))
79
+ finally:
80
+ try:
81
+ # quit selenium and kill the vnc
82
+ driver.quit()
83
+ os.system("pkill -1 Xvnc")
84
+ except Exception:
85
+ pass
86
+
87
+ @tools.route("/", methods=["GET"])
88
+ def home() -> str:
89
+ return "hello"
90
+
91
+ tools.run(host=config.tools_host, port=config.tools_port)