import streamlit as st import os import aiohttp import asyncio import discord import pandas as pd import requests from teapotai import TeapotAI, TeapotAISettings from pydantic import BaseModel, Field st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide") DISCORD_TOKEN = os.environ.get("discord_key") # ======= API KEYS ======= BRAVE_API_KEY = os.environ.get("brave_api_key") WEATHER_API_KEY = os.environ.get("weather_api_key") # ======== TOOLS =========== import requests from typing import Optional from teapotai import TeapotTool import re import math import pandas as pd from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, logging ### SEARCH TOOL class BraveWebSearch(BaseModel): search_query: str = Field(..., description="the search string to answer the question") def brave_search_context(query, count=3): url = "https://api.search.brave.com/res/v1/web/search" headers = {"Accept": "application/json", "X-Subscription-Token": BRAVE_API_KEY} params = {"q": query, "count": count} response = requests.get(url, headers=headers, params=params) if response.status_code == 200: results = response.json().get("web", {}).get("results", []) return "\n\n".join([res["title"]+"\n"+res["url"]+"\n"+res["description"].replace("","").replace("","") for res in results]) else: print(f"Error: {response.status_code}, {response.text}") return "" ### CALCULATOR TOOL import builtins def evaluate_expression(expr) -> str: """ Evaluate a simple algebraic expression string safely. Supports +, -, *, /, **, and parentheses. Retries evaluation after stripping non-numeric/non-operator characters if needed. """ allowed_names = {k: getattr(builtins, k) for k in ("abs", "round")} allowed_names.update({k: getattr(math, k) for k in ("sqrt", "pow")}) def safe_eval(expression): return eval(expression, {"__builtins__": None}, allowed_names) try: result = safe_eval(expr) return f"{expr} = {result}" except Exception as e: print(f"Initial evaluation failed: {e}") # Strip out any characters that are not numbers, parentheses, or valid operators cleaned_expr = re.sub(r"[^0-9\.\+\-\*/\*\*\(\) ]", "", expr) try: result = safe_eval(cleaned_expr) return f"{cleaned_expr} = {result}" except Exception as e2: print(f"Retry also failed: {e2}") return "Sorry, I am unable to calculate that." class Calculator(BaseModel): expression: str = Field(..., description="mathematical expression") ### Weather Tool def get_weather(weather): city_name = weather.city_name # OpenWeatherMap API endpoint url = f'https://api.openweathermap.org/data/2.5/weather?appid={WEATHER_API_KEY}&units=imperial&q={city_name}' # Send GET request to the OpenWeatherMap API response = requests.get(url) # Check if the request was successful if response.status_code == 200: data = response.json() # Extract relevant weather information city = data['name'] temperature = round(data['main']['temp']) weather_description = data['weather'][0]['description'] # Print or return the results return f"The weather in {city} is {weather_description} with a temperature of {temperature}°F." else: print(response.status_code) return "City not found or there was an error with the request." class Weather(BaseModel): city_name: str = Field(..., description="The name of the city to pull the weather for") ### Stupid Question Tool class CountNumberLetter(BaseModel): word: str = Field(..., description="the word to count the number of letters in") letter: str = Field(..., description="the letter to count the occurences of") def count_number_letters(obj): letter = obj.letter.lower() expression = obj.word.lower() if letter == "None": return f"There are {len(obj.word)} letters in '{expression}'" count = len([l for l in expression if l == letter]) if count == 1: return f"There is 1 '{letter}' in '{expression}'" return f"There are {count} '{letter}'s in '{expression}'" ### Image Gen Tool class ImageGen(BaseModel): prompt: str = Field(..., description="The prompt to use to generate the image") def generate_image(prompt): if "teapot" in prompt.prompt.lower(): return "I generated an image of a teapot for you: https://teapotai.com/assets/teapotsmile.png" return "Ok I can't generate images, but you could easily hook up an image gen model to this tool call. Check out this image I did generate for you: https://teapotai.com/assets/teapotsmile.png" ### Tool Creation DEFAULT_TOOLS = [ TeapotTool( name="websearch", description="Execute web searches with pagination and filtering", schema=BraveWebSearch, fn=brave_search_context ), TeapotTool( name="letter_counter", description="Can count how many times a letter occurs in a word.", schema=CountNumberLetter, fn=count_number_letters ), TeapotTool( name="calculator", description="Can perform calculations on numbers using addition, subtraction, multiplication, and division.", schema=Calculator, fn=lambda expression: evaluate_expression(expression.expression), ), TeapotTool( name="generate_image", description="Can generate an image for a user based on a prompt", schema=ImageGen, fn=generate_image, directly_return_result=True ), TeapotTool( name="weather", description="Can pull today's weather information for any city.", schema=Weather, fn=get_weather ) ] # ========= CONFIG ========= CONFIG = { # "OneTrainer": TeapotAI( # documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=361556791&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(), # settings=TeapotAISettings(rag_num_results=7) # ), "Teapot AI": TeapotAI( model = AutoModelForSeq2SeqLM.from_pretrained( "teapotai/teapotllm", revision="5aa6f84b5bd59da85552d55cc00efb702869cbf8", ), documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=1617599323&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(), settings=TeapotAISettings(rag_num_results=3, log_level="debug"), tools=DEFAULT_TOOLS ), } # ========= DISCORD CLIENT ========= intents = discord.Intents.default() intents.messages = True client = discord.Client(intents=intents) async def handle_teapot_inference(server_name, user_input): teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"]) print(f"Using Teapot instance for server: {server_name}") # Running query in a separate thread to avoid blocking the event loop # response = await asyncio.to_thread(teapot_instance.query, query=user_input, context=brave_search_context(user_input)) response = await asyncio.to_thread(teapot_instance.query, query=user_input, system_prompt="""You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization. You can use tools such as a web search, a calculator and an image generator to assist users.""") return response async def debug_teapot_inference(server_name, user_input): teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"]) print(f"Using Teapot instance for server: {server_name}") # Running query in a separate thread to avoid blocking the event loop search_result = brave_search_context(user_input) rag_results = teapot_instance.rag(query=user_input) return "\n\n".join(rag_results), search_result @client.event async def on_ready(): print(f'Logged in as {client.user}') @client.event async def on_message(message): if message.author == client.user: return # Check if the message mentions the bot mentioned = f'<@{client.user.id}>' in message.content # Check if the message is a reply to the bot replied_to_bot = False previous_message = "" if message.reference: replied_message = await message.channel.fetch_message(message.reference.message_id) if replied_message.author == client.user: replied_to_bot = True previous_message = "agent: "+replied_message.content+"\n" # If not mentioned and not replying to the bot, ignore if not (mentioned or replied_to_bot): return server_name = message.guild.name if message.guild else "Teapot AI" print(server_name, message.author, message.content) async with message.channel.typing(): cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip() full_context = previous_message + cleaned_message response = await handle_teapot_inference(server_name, full_context) await message.reply(response) @client.event async def on_reaction_add(reaction, user): if user == client.user: return if str(reaction.emoji) not in ["❓", "❔"]: return message = reaction.message # Make sure it's a bot message that was a reply if message.author != client.user or not message.reference: return # Fetch the original message that this bot message replied to cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip() original_message = await message.channel.fetch_message(message.reference.message_id) user_input = original_message.content.strip() server_name = message.guild.name if message.guild else "Teapot AI" # Create a thread or use existing one thread = message.thread if thread is None: thread = await message.create_thread(name=f"Debug Thread: '{cleaned_message[0:30]}...'", auto_archive_duration=60) rag_result, search_result = await debug_teapot_inference(server_name, user_input) debug_response = "## RAG:\n```"+discord.utils.escape_markdown(rag_result)[-900:]+"```\n\n## Search:\n```"+discord.utils.escape_markdown(search_result)[-900:]+"```" await thread.send(debug_response) # ========= STREAMLIT ========= @st.cache_resource def discord_loop(): st.session_state["initialized"] = True client.run(DISCORD_TOKEN) st.write("418 I'm a teapot") return discord_loop()