File size: 10,730 Bytes
63c1d33
 
 
 
 
 
e802c16
63c1d33
7abe8b9
e802c16
63c1d33
b5292e8
 
 
63c1d33
dcd8a1e
f79726a
dcd8a1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f79726a
28d3307
f79726a
dcd8a1e
f79726a
 
 
 
 
 
4d1bbea
f79726a
 
 
 
dcd8a1e
5e0f49f
 
dcd8a1e
 
 
 
 
 
5e0f49f
dcd8a1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06e7d9e
 
dcd8a1e
 
a14601b
dcd8a1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5ca814
dcd8a1e
 
 
 
 
 
 
 
 
 
 
 
 
eb80774
 
dcd8a1e
 
 
 
 
c24585f
 
 
 
 
06e7d9e
c24585f
d11e749
c24585f
 
dcd8a1e
598addc
 
 
 
 
 
dcd8a1e
9f90d1c
 
 
 
dcd8a1e
 
 
97a0bc4
dcd8a1e
7af607c
dcd8a1e
c24585f
 
 
 
 
 
 
dcd8a1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63c1d33
 
 
b5292e8
 
63c1d33
 
 
470c35a
598addc
903b87a
beef5d8
63c1d33
8cdc643
 
 
 
 
e34da9a
8cdc643
de46855
63c1d33
b5292e8
 
 
 
 
 
 
 
63c1d33
6cc5afa
 
 
 
 
 
 
 
 
 
 
 
 
 
63c1d33
b5292e8
63c1d33
7e1e3d2
6cc5afa
b5292e8
63c1d33
6cc5afa
 
 
 
 
 
 
63c1d33
ab8c50d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60f4e03
ab8c50d
 
 
 
 
 
 
 
021337d
ab8c50d
de46855
511cc6d
de46855
ab8c50d
 
b5292e8
63c1d33
b5292e8
beef5d8
b5292e8
 
beef5d8
63c1d33
b5292e8
beef5d8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import streamlit as st
import os
import aiohttp
import asyncio
import discord
import pandas as pd
import requests
from teapotai import TeapotAI, TeapotAISettings
from pydantic import BaseModel, Field

st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide")

DISCORD_TOKEN = os.environ.get("discord_key")


#  ======= API KEYS =======

BRAVE_API_KEY = os.environ.get("brave_api_key")
WEATHER_API_KEY = os.environ.get("weather_api_key")

# ======== TOOLS ===========
import requests
from typing import Optional
from teapotai import TeapotTool
import re
import math
import pandas as pd
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, logging

### SEARCH TOOL
class BraveWebSearch(BaseModel):
    search_query: str = Field(..., description="the search string to answer the question")

def brave_search_context(query, count=3):
    url = "https://api.search.brave.com/res/v1/web/search"
    headers = {"Accept": "application/json", "X-Subscription-Token": BRAVE_API_KEY}
    params = {"q": query, "count": count}
    
    response = requests.get(url, headers=headers, params=params)
    
    if response.status_code == 200:
        results = response.json().get("web", {}).get("results", [])
        return "\n\n".join([res["title"]+"\n"+res["url"]+"\n"+res["description"].replace("<strong>","").replace("</strong>","") for res in results])
    else:
        print(f"Error: {response.status_code}, {response.text}")
        return ""

### CALCULATOR TOOL
import builtins

def evaluate_expression(expr) -> str:
    """
    Evaluate a simple algebraic expression string safely.
    Supports +, -, *, /, **, and parentheses.
    Retries evaluation after stripping non-numeric/non-operator characters if needed.
    """
    allowed_names = {k: getattr(builtins, k) for k in ("abs", "round")}
    allowed_names.update({k: getattr(math, k) for k in ("sqrt", "pow")})

    def safe_eval(expression):
        return eval(expression, {"__builtins__": None}, allowed_names)

    try:
        result = safe_eval(expr)
        return f"{expr} = {result}"
    except Exception as e:
        print(f"Initial evaluation failed: {e}")
        # Strip out any characters that are not numbers, parentheses, or valid operators
        cleaned_expr = re.sub(r"[^0-9\.\+\-\*/\*\*\(\) ]", "", expr)
        try:
            result = safe_eval(cleaned_expr)
            return f"{cleaned_expr} = {result}"
        except Exception as e2:
            print(f"Retry also failed: {e2}")
            return "Sorry, I am unable to calculate that."


class Calculator(BaseModel):
    expression: str = Field(..., description="mathematical expression")


### Weather Tool
def get_weather(weather):
    city_name = weather.city_name
    # OpenWeatherMap API endpoint
    url = f'https://api.openweathermap.org/data/2.5/weather?appid={WEATHER_API_KEY}&units=imperial&q={city_name}'
    
    # Send GET request to the OpenWeatherMap API
    response = requests.get(url)
    
    # Check if the request was successful
    if response.status_code == 200:
        data = response.json()
        
        # Extract relevant weather information
        city = data['name']
        temperature = round(data['main']['temp'])
        weather_description = data['weather'][0]['description']
        
        # Print or return the results
        return f"The weather in {city} is {weather_description} with a temperature of {temperature}°F."
    else:
        print(response.status_code)
        return "City not found or there was an error with the request."

class Weather(BaseModel):
    city_name: str = Field(..., description="The name of the city to pull the weather for")

### Stupid Question Tool
class CountNumberLetter(BaseModel):
    word: str = Field(..., description="the word to count the number of letters in")
    letter: str = Field(..., description="the letter to count the occurences of")

def count_number_letters(obj):
  letter = obj.letter.lower()
  expression = obj.word.lower()
  if letter == "None":
    return f"There are {len(obj.word)} letters in '{expression}'"
  count = len([l for l in expression if l == letter])
  if count == 1:
    return f"There is 1 '{letter}' in '{expression}'"
  return f"There are {count} '{letter}'s in '{expression}'"

### Image Gen Tool
class ImageGen(BaseModel):
    prompt: str = Field(..., description="The prompt to use to generate the image")

def generate_image(prompt):
    if "teapot" in prompt.prompt.lower():
        return "I generated an image of a teapot for you: https://teapotai.com/assets/teapotsmile.png"
    return "Ok I can't generate images, but you could easily hook up an image gen model to this tool call. Check out this image I did generate for you: https://teapotai.com/assets/teapotsmile.png"
    
### Tool Creation
DEFAULT_TOOLS = [
    TeapotTool(
        name="websearch",
        description="Execute web searches with pagination and filtering",
        schema=BraveWebSearch,
        fn=brave_search_context
    ),
    TeapotTool(
        name="letter_counter",
        description="Can count how many times a letter occurs in a word.",
        schema=CountNumberLetter,
        fn=count_number_letters
    ),
    TeapotTool(
        name="calculator",
        description="Can perform calculations on numbers using addition, subtraction, multiplication, and division.",
        schema=Calculator,
        fn=lambda expression: evaluate_expression(expression.expression),
    ),
    TeapotTool(
        name="generate_image",
        description="Can generate an image for a user based on a prompt",
        schema=ImageGen,
        fn=generate_image,
        directly_return_result=True
    ),
    TeapotTool(
        name="weather",
        description="Can pull today's weather information for any city.",
        schema=Weather,
        fn=get_weather
    )
]

# ========= CONFIG =========
CONFIG = {
    # "OneTrainer": TeapotAI(
    #     documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=361556791&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(), 
    #     settings=TeapotAISettings(rag_num_results=7)
    # ),
    "Teapot AI": TeapotAI(
        model = AutoModelForSeq2SeqLM.from_pretrained(
          "teapotai/teapotllm",
          revision="5aa6f84b5bd59da85552d55cc00efb702869cbf8",
        ),
        documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=1617599323&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(), 
        settings=TeapotAISettings(rag_num_results=3, log_level="debug"),
        tools=DEFAULT_TOOLS
    ),
}



# ========= DISCORD CLIENT =========
intents = discord.Intents.default()
intents.messages = True
client = discord.Client(intents=intents)

async def handle_teapot_inference(server_name, user_input):
    teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"])
    print(f"Using Teapot instance for server: {server_name}")
    # Running query in a separate thread to avoid blocking the event loop
    # response = await asyncio.to_thread(teapot_instance.query, query=user_input, context=brave_search_context(user_input))
    response = await asyncio.to_thread(teapot_instance.query, query=user_input, system_prompt="""You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization. You can use tools such as a web search, a calculator and an image generator to assist users.""")
    return response

async def debug_teapot_inference(server_name, user_input):
    teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"])
    print(f"Using Teapot instance for server: {server_name}")
    # Running query in a separate thread to avoid blocking the event loop
    search_result = brave_search_context(user_input)
    rag_results = teapot_instance.rag(query=user_input)
    
    return "\n\n".join(rag_results), search_result

@client.event
async def on_ready():
    print(f'Logged in as {client.user}')

@client.event
async def on_message(message):
    if message.author == client.user:
        return

    # Check if the message mentions the bot
    mentioned = f'<@{client.user.id}>' in message.content

    # Check if the message is a reply to the bot
    replied_to_bot = False
    previous_message = ""
    if message.reference:
        replied_message = await message.channel.fetch_message(message.reference.message_id)
        if replied_message.author == client.user:
            replied_to_bot = True
            previous_message = "agent: "+replied_message.content+"\n"

    # If not mentioned and not replying to the bot, ignore
    if not (mentioned or replied_to_bot):
        return

    server_name = message.guild.name if message.guild else "Teapot AI"
    print(server_name, message.author, message.content)

    async with message.channel.typing():
        cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip()

        full_context = previous_message + cleaned_message

        response = await handle_teapot_inference(server_name, full_context)
        await message.reply(response)



@client.event
async def on_reaction_add(reaction, user):
    if user == client.user:
        return

    if str(reaction.emoji) not in ["❓", "❔"]:
        return

    message = reaction.message

    # Make sure it's a bot message that was a reply
    if message.author != client.user or not message.reference:
        return

    # Fetch the original message that this bot message replied to
    cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip()
    original_message = await message.channel.fetch_message(message.reference.message_id)
    user_input = original_message.content.strip()

    server_name = message.guild.name if message.guild else "Teapot AI"

    # Create a thread or use existing one
    thread = message.thread
    if thread is None:
        thread = await message.create_thread(name=f"Debug Thread: '{cleaned_message[0:30]}...'", auto_archive_duration=60)

    rag_result, search_result = await debug_teapot_inference(server_name, user_input)
    debug_response = "## RAG:\n```"+discord.utils.escape_markdown(rag_result)[-900:]+"```\n\n## Search:\n```"+discord.utils.escape_markdown(search_result)[-900:]+"```"
    await thread.send(debug_response)



# ========= STREAMLIT =========
@st.cache_resource
def discord_loop():
    st.session_state["initialized"] = True
    client.run(DISCORD_TOKEN)
    st.write("418 I'm a teapot")
    return


discord_loop()