zakerytclarke's picture
Update app.py
a14601b verified
import streamlit as st
import os
import aiohttp
import asyncio
import discord
import pandas as pd
import requests
from teapotai import TeapotAI, TeapotAISettings
from pydantic import BaseModel, Field
st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide")
DISCORD_TOKEN = os.environ.get("discord_key")
# ======= API KEYS =======
BRAVE_API_KEY = os.environ.get("brave_api_key")
WEATHER_API_KEY = os.environ.get("weather_api_key")
# ======== TOOLS ===========
import requests
from typing import Optional
from teapotai import TeapotTool
import re
import math
import pandas as pd
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, logging
### SEARCH TOOL
class BraveWebSearch(BaseModel):
search_query: str = Field(..., description="the search string to answer the question")
def brave_search_context(query, count=3):
url = "https://api.search.brave.com/res/v1/web/search"
headers = {"Accept": "application/json", "X-Subscription-Token": BRAVE_API_KEY}
params = {"q": query, "count": count}
response = requests.get(url, headers=headers, params=params)
if response.status_code == 200:
results = response.json().get("web", {}).get("results", [])
return "\n\n".join([res["title"]+"\n"+res["url"]+"\n"+res["description"].replace("<strong>","").replace("</strong>","") for res in results])
else:
print(f"Error: {response.status_code}, {response.text}")
return ""
### CALCULATOR TOOL
import builtins
def evaluate_expression(expr) -> str:
"""
Evaluate a simple algebraic expression string safely.
Supports +, -, *, /, **, and parentheses.
Retries evaluation after stripping non-numeric/non-operator characters if needed.
"""
allowed_names = {k: getattr(builtins, k) for k in ("abs", "round")}
allowed_names.update({k: getattr(math, k) for k in ("sqrt", "pow")})
def safe_eval(expression):
return eval(expression, {"__builtins__": None}, allowed_names)
try:
result = safe_eval(expr)
return f"{expr} = {result}"
except Exception as e:
print(f"Initial evaluation failed: {e}")
# Strip out any characters that are not numbers, parentheses, or valid operators
cleaned_expr = re.sub(r"[^0-9\.\+\-\*/\*\*\(\) ]", "", expr)
try:
result = safe_eval(cleaned_expr)
return f"{cleaned_expr} = {result}"
except Exception as e2:
print(f"Retry also failed: {e2}")
return "Sorry, I am unable to calculate that."
class Calculator(BaseModel):
expression: str = Field(..., description="mathematical expression")
### Weather Tool
def get_weather(weather):
city_name = weather.city_name
# OpenWeatherMap API endpoint
url = f'https://api.openweathermap.org/data/2.5/weather?appid={WEATHER_API_KEY}&units=imperial&q={city_name}'
# Send GET request to the OpenWeatherMap API
response = requests.get(url)
# Check if the request was successful
if response.status_code == 200:
data = response.json()
# Extract relevant weather information
city = data['name']
temperature = round(data['main']['temp'])
weather_description = data['weather'][0]['description']
# Print or return the results
return f"The weather in {city} is {weather_description} with a temperature of {temperature}°F."
else:
print(response.status_code)
return "City not found or there was an error with the request."
class Weather(BaseModel):
city_name: str = Field(..., description="The name of the city to pull the weather for")
### Stupid Question Tool
class CountNumberLetter(BaseModel):
word: str = Field(..., description="the word to count the number of letters in")
letter: str = Field(..., description="the letter to count the occurences of")
def count_number_letters(obj):
letter = obj.letter.lower()
expression = obj.word.lower()
if letter == "None":
return f"There are {len(obj.word)} letters in '{expression}'"
count = len([l for l in expression if l == letter])
if count == 1:
return f"There is 1 '{letter}' in '{expression}'"
return f"There are {count} '{letter}'s in '{expression}'"
### Image Gen Tool
class ImageGen(BaseModel):
prompt: str = Field(..., description="The prompt to use to generate the image")
def generate_image(prompt):
if "teapot" in prompt.prompt.lower():
return "I generated an image of a teapot for you: https://teapotai.com/assets/teapotsmile.png"
return "Ok I can't generate images, but you could easily hook up an image gen model to this tool call. Check out this image I did generate for you: https://teapotai.com/assets/teapotsmile.png"
### Tool Creation
DEFAULT_TOOLS = [
TeapotTool(
name="websearch",
description="Execute web searches with pagination and filtering",
schema=BraveWebSearch,
fn=brave_search_context
),
TeapotTool(
name="letter_counter",
description="Can count how many times a letter occurs in a word.",
schema=CountNumberLetter,
fn=count_number_letters
),
TeapotTool(
name="calculator",
description="Can perform calculations on numbers using addition, subtraction, multiplication, and division.",
schema=Calculator,
fn=lambda expression: evaluate_expression(expression.expression),
),
TeapotTool(
name="generate_image",
description="Can generate an image for a user based on a prompt",
schema=ImageGen,
fn=generate_image,
directly_return_result=True
),
TeapotTool(
name="weather",
description="Can pull today's weather information for any city.",
schema=Weather,
fn=get_weather
)
]
# ========= CONFIG =========
CONFIG = {
# "OneTrainer": TeapotAI(
# documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=361556791&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(),
# settings=TeapotAISettings(rag_num_results=7)
# ),
"Teapot AI": TeapotAI(
model = AutoModelForSeq2SeqLM.from_pretrained(
"teapotai/teapotllm",
revision="5aa6f84b5bd59da85552d55cc00efb702869cbf8",
),
documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=1617599323&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(),
settings=TeapotAISettings(rag_num_results=3, log_level="debug"),
tools=DEFAULT_TOOLS
),
}
# ========= DISCORD CLIENT =========
intents = discord.Intents.default()
intents.messages = True
client = discord.Client(intents=intents)
async def handle_teapot_inference(server_name, user_input):
teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"])
print(f"Using Teapot instance for server: {server_name}")
# Running query in a separate thread to avoid blocking the event loop
# response = await asyncio.to_thread(teapot_instance.query, query=user_input, context=brave_search_context(user_input))
response = await asyncio.to_thread(teapot_instance.query, query=user_input, system_prompt="""You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization. You can use tools such as a web search, a calculator and an image generator to assist users.""")
return response
async def debug_teapot_inference(server_name, user_input):
teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"])
print(f"Using Teapot instance for server: {server_name}")
# Running query in a separate thread to avoid blocking the event loop
search_result = brave_search_context(user_input)
rag_results = teapot_instance.rag(query=user_input)
return "\n\n".join(rag_results), search_result
@client.event
async def on_ready():
print(f'Logged in as {client.user}')
@client.event
async def on_message(message):
if message.author == client.user:
return
# Check if the message mentions the bot
mentioned = f'<@{client.user.id}>' in message.content
# Check if the message is a reply to the bot
replied_to_bot = False
previous_message = ""
if message.reference:
replied_message = await message.channel.fetch_message(message.reference.message_id)
if replied_message.author == client.user:
replied_to_bot = True
previous_message = "agent: "+replied_message.content+"\n"
# If not mentioned and not replying to the bot, ignore
if not (mentioned or replied_to_bot):
return
server_name = message.guild.name if message.guild else "Teapot AI"
print(server_name, message.author, message.content)
async with message.channel.typing():
cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip()
full_context = previous_message + cleaned_message
response = await handle_teapot_inference(server_name, full_context)
await message.reply(response)
@client.event
async def on_reaction_add(reaction, user):
if user == client.user:
return
if str(reaction.emoji) not in ["❓", "❔"]:
return
message = reaction.message
# Make sure it's a bot message that was a reply
if message.author != client.user or not message.reference:
return
# Fetch the original message that this bot message replied to
cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip()
original_message = await message.channel.fetch_message(message.reference.message_id)
user_input = original_message.content.strip()
server_name = message.guild.name if message.guild else "Teapot AI"
# Create a thread or use existing one
thread = message.thread
if thread is None:
thread = await message.create_thread(name=f"Debug Thread: '{cleaned_message[0:30]}...'", auto_archive_duration=60)
rag_result, search_result = await debug_teapot_inference(server_name, user_input)
debug_response = "## RAG:\n```"+discord.utils.escape_markdown(rag_result)[-900:]+"```\n\n## Search:\n```"+discord.utils.escape_markdown(search_result)[-900:]+"```"
await thread.send(debug_response)
# ========= STREAMLIT =========
@st.cache_resource
def discord_loop():
st.session_state["initialized"] = True
client.run(DISCORD_TOKEN)
st.write("418 I'm a teapot")
return
discord_loop()