File size: 10,730 Bytes
63c1d33 e802c16 63c1d33 7abe8b9 e802c16 63c1d33 b5292e8 63c1d33 dcd8a1e f79726a dcd8a1e f79726a 28d3307 f79726a dcd8a1e f79726a 4d1bbea f79726a dcd8a1e 5e0f49f dcd8a1e 5e0f49f dcd8a1e 06e7d9e dcd8a1e a14601b dcd8a1e b5ca814 dcd8a1e eb80774 dcd8a1e c24585f 06e7d9e c24585f d11e749 c24585f dcd8a1e 598addc dcd8a1e 9f90d1c dcd8a1e 97a0bc4 dcd8a1e 7af607c dcd8a1e c24585f dcd8a1e 63c1d33 b5292e8 63c1d33 470c35a 598addc 903b87a beef5d8 63c1d33 8cdc643 e34da9a 8cdc643 de46855 63c1d33 b5292e8 63c1d33 6cc5afa 63c1d33 b5292e8 63c1d33 7e1e3d2 6cc5afa b5292e8 63c1d33 6cc5afa 63c1d33 ab8c50d 60f4e03 ab8c50d 021337d ab8c50d de46855 511cc6d de46855 ab8c50d b5292e8 63c1d33 b5292e8 beef5d8 b5292e8 beef5d8 63c1d33 b5292e8 beef5d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
import streamlit as st
import os
import aiohttp
import asyncio
import discord
import pandas as pd
import requests
from teapotai import TeapotAI, TeapotAISettings
from pydantic import BaseModel, Field
st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide")
DISCORD_TOKEN = os.environ.get("discord_key")
# ======= API KEYS =======
BRAVE_API_KEY = os.environ.get("brave_api_key")
WEATHER_API_KEY = os.environ.get("weather_api_key")
# ======== TOOLS ===========
import requests
from typing import Optional
from teapotai import TeapotTool
import re
import math
import pandas as pd
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, logging
### SEARCH TOOL
class BraveWebSearch(BaseModel):
search_query: str = Field(..., description="the search string to answer the question")
def brave_search_context(query, count=3):
url = "https://api.search.brave.com/res/v1/web/search"
headers = {"Accept": "application/json", "X-Subscription-Token": BRAVE_API_KEY}
params = {"q": query, "count": count}
response = requests.get(url, headers=headers, params=params)
if response.status_code == 200:
results = response.json().get("web", {}).get("results", [])
return "\n\n".join([res["title"]+"\n"+res["url"]+"\n"+res["description"].replace("<strong>","").replace("</strong>","") for res in results])
else:
print(f"Error: {response.status_code}, {response.text}")
return ""
### CALCULATOR TOOL
import builtins
def evaluate_expression(expr) -> str:
"""
Evaluate a simple algebraic expression string safely.
Supports +, -, *, /, **, and parentheses.
Retries evaluation after stripping non-numeric/non-operator characters if needed.
"""
allowed_names = {k: getattr(builtins, k) for k in ("abs", "round")}
allowed_names.update({k: getattr(math, k) for k in ("sqrt", "pow")})
def safe_eval(expression):
return eval(expression, {"__builtins__": None}, allowed_names)
try:
result = safe_eval(expr)
return f"{expr} = {result}"
except Exception as e:
print(f"Initial evaluation failed: {e}")
# Strip out any characters that are not numbers, parentheses, or valid operators
cleaned_expr = re.sub(r"[^0-9\.\+\-\*/\*\*\(\) ]", "", expr)
try:
result = safe_eval(cleaned_expr)
return f"{cleaned_expr} = {result}"
except Exception as e2:
print(f"Retry also failed: {e2}")
return "Sorry, I am unable to calculate that."
class Calculator(BaseModel):
expression: str = Field(..., description="mathematical expression")
### Weather Tool
def get_weather(weather):
city_name = weather.city_name
# OpenWeatherMap API endpoint
url = f'https://api.openweathermap.org/data/2.5/weather?appid={WEATHER_API_KEY}&units=imperial&q={city_name}'
# Send GET request to the OpenWeatherMap API
response = requests.get(url)
# Check if the request was successful
if response.status_code == 200:
data = response.json()
# Extract relevant weather information
city = data['name']
temperature = round(data['main']['temp'])
weather_description = data['weather'][0]['description']
# Print or return the results
return f"The weather in {city} is {weather_description} with a temperature of {temperature}°F."
else:
print(response.status_code)
return "City not found or there was an error with the request."
class Weather(BaseModel):
city_name: str = Field(..., description="The name of the city to pull the weather for")
### Stupid Question Tool
class CountNumberLetter(BaseModel):
word: str = Field(..., description="the word to count the number of letters in")
letter: str = Field(..., description="the letter to count the occurences of")
def count_number_letters(obj):
letter = obj.letter.lower()
expression = obj.word.lower()
if letter == "None":
return f"There are {len(obj.word)} letters in '{expression}'"
count = len([l for l in expression if l == letter])
if count == 1:
return f"There is 1 '{letter}' in '{expression}'"
return f"There are {count} '{letter}'s in '{expression}'"
### Image Gen Tool
class ImageGen(BaseModel):
prompt: str = Field(..., description="The prompt to use to generate the image")
def generate_image(prompt):
if "teapot" in prompt.prompt.lower():
return "I generated an image of a teapot for you: https://teapotai.com/assets/teapotsmile.png"
return "Ok I can't generate images, but you could easily hook up an image gen model to this tool call. Check out this image I did generate for you: https://teapotai.com/assets/teapotsmile.png"
### Tool Creation
DEFAULT_TOOLS = [
TeapotTool(
name="websearch",
description="Execute web searches with pagination and filtering",
schema=BraveWebSearch,
fn=brave_search_context
),
TeapotTool(
name="letter_counter",
description="Can count how many times a letter occurs in a word.",
schema=CountNumberLetter,
fn=count_number_letters
),
TeapotTool(
name="calculator",
description="Can perform calculations on numbers using addition, subtraction, multiplication, and division.",
schema=Calculator,
fn=lambda expression: evaluate_expression(expression.expression),
),
TeapotTool(
name="generate_image",
description="Can generate an image for a user based on a prompt",
schema=ImageGen,
fn=generate_image,
directly_return_result=True
),
TeapotTool(
name="weather",
description="Can pull today's weather information for any city.",
schema=Weather,
fn=get_weather
)
]
# ========= CONFIG =========
CONFIG = {
# "OneTrainer": TeapotAI(
# documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=361556791&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(),
# settings=TeapotAISettings(rag_num_results=7)
# ),
"Teapot AI": TeapotAI(
model = AutoModelForSeq2SeqLM.from_pretrained(
"teapotai/teapotllm",
revision="5aa6f84b5bd59da85552d55cc00efb702869cbf8",
),
documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=1617599323&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(),
settings=TeapotAISettings(rag_num_results=3, log_level="debug"),
tools=DEFAULT_TOOLS
),
}
# ========= DISCORD CLIENT =========
intents = discord.Intents.default()
intents.messages = True
client = discord.Client(intents=intents)
async def handle_teapot_inference(server_name, user_input):
teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"])
print(f"Using Teapot instance for server: {server_name}")
# Running query in a separate thread to avoid blocking the event loop
# response = await asyncio.to_thread(teapot_instance.query, query=user_input, context=brave_search_context(user_input))
response = await asyncio.to_thread(teapot_instance.query, query=user_input, system_prompt="""You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization. You can use tools such as a web search, a calculator and an image generator to assist users.""")
return response
async def debug_teapot_inference(server_name, user_input):
teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"])
print(f"Using Teapot instance for server: {server_name}")
# Running query in a separate thread to avoid blocking the event loop
search_result = brave_search_context(user_input)
rag_results = teapot_instance.rag(query=user_input)
return "\n\n".join(rag_results), search_result
@client.event
async def on_ready():
print(f'Logged in as {client.user}')
@client.event
async def on_message(message):
if message.author == client.user:
return
# Check if the message mentions the bot
mentioned = f'<@{client.user.id}>' in message.content
# Check if the message is a reply to the bot
replied_to_bot = False
previous_message = ""
if message.reference:
replied_message = await message.channel.fetch_message(message.reference.message_id)
if replied_message.author == client.user:
replied_to_bot = True
previous_message = "agent: "+replied_message.content+"\n"
# If not mentioned and not replying to the bot, ignore
if not (mentioned or replied_to_bot):
return
server_name = message.guild.name if message.guild else "Teapot AI"
print(server_name, message.author, message.content)
async with message.channel.typing():
cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip()
full_context = previous_message + cleaned_message
response = await handle_teapot_inference(server_name, full_context)
await message.reply(response)
@client.event
async def on_reaction_add(reaction, user):
if user == client.user:
return
if str(reaction.emoji) not in ["❓", "❔"]:
return
message = reaction.message
# Make sure it's a bot message that was a reply
if message.author != client.user or not message.reference:
return
# Fetch the original message that this bot message replied to
cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip()
original_message = await message.channel.fetch_message(message.reference.message_id)
user_input = original_message.content.strip()
server_name = message.guild.name if message.guild else "Teapot AI"
# Create a thread or use existing one
thread = message.thread
if thread is None:
thread = await message.create_thread(name=f"Debug Thread: '{cleaned_message[0:30]}...'", auto_archive_duration=60)
rag_result, search_result = await debug_teapot_inference(server_name, user_input)
debug_response = "## RAG:\n```"+discord.utils.escape_markdown(rag_result)[-900:]+"```\n\n## Search:\n```"+discord.utils.escape_markdown(search_result)[-900:]+"```"
await thread.send(debug_response)
# ========= STREAMLIT =========
@st.cache_resource
def discord_loop():
st.session_state["initialized"] = True
client.run(DISCORD_TOKEN)
st.write("418 I'm a teapot")
return
discord_loop()
|