|
import streamlit as st |
|
import os |
|
import aiohttp |
|
import asyncio |
|
import discord |
|
import pandas as pd |
|
import requests |
|
from teapotai import TeapotAI, TeapotAISettings |
|
from pydantic import BaseModel, Field |
|
|
|
st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide") |
|
|
|
DISCORD_TOKEN = os.environ.get("discord_key") |
|
|
|
|
|
|
|
|
|
BRAVE_API_KEY = os.environ.get("brave_api_key") |
|
WEATHER_API_KEY = os.environ.get("weather_api_key") |
|
|
|
|
|
import requests |
|
from typing import Optional |
|
from teapotai import TeapotTool |
|
import re |
|
import math |
|
import pandas as pd |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, logging |
|
|
|
|
|
class BraveWebSearch(BaseModel): |
|
search_query: str = Field(..., description="the search string to answer the question") |
|
|
|
def brave_search_context(query, count=3): |
|
url = "https://api.search.brave.com/res/v1/web/search" |
|
headers = {"Accept": "application/json", "X-Subscription-Token": BRAVE_API_KEY} |
|
params = {"q": query, "count": count} |
|
|
|
response = requests.get(url, headers=headers, params=params) |
|
|
|
if response.status_code == 200: |
|
results = response.json().get("web", {}).get("results", []) |
|
return "\n\n".join([res["title"]+"\n"+res["url"]+"\n"+res["description"].replace("<strong>","").replace("</strong>","") for res in results]) |
|
else: |
|
print(f"Error: {response.status_code}, {response.text}") |
|
return "" |
|
|
|
|
|
import builtins |
|
|
|
def evaluate_expression(expr) -> str: |
|
""" |
|
Evaluate a simple algebraic expression string safely. |
|
Supports +, -, *, /, **, and parentheses. |
|
Retries evaluation after stripping non-numeric/non-operator characters if needed. |
|
""" |
|
allowed_names = {k: getattr(builtins, k) for k in ("abs", "round")} |
|
allowed_names.update({k: getattr(math, k) for k in ("sqrt", "pow")}) |
|
|
|
def safe_eval(expression): |
|
return eval(expression, {"__builtins__": None}, allowed_names) |
|
|
|
try: |
|
result = safe_eval(expr) |
|
return f"{expr} = {result}" |
|
except Exception as e: |
|
print(f"Initial evaluation failed: {e}") |
|
|
|
cleaned_expr = re.sub(r"[^0-9\.\+\-\*/\*\*\(\) ]", "", expr) |
|
try: |
|
result = safe_eval(cleaned_expr) |
|
return f"{cleaned_expr} = {result}" |
|
except Exception as e2: |
|
print(f"Retry also failed: {e2}") |
|
return "Sorry, I am unable to calculate that." |
|
|
|
|
|
class Calculator(BaseModel): |
|
expression: str = Field(..., description="mathematical expression") |
|
|
|
|
|
|
|
def get_weather(weather): |
|
city_name = weather.city_name |
|
|
|
url = f'https://api.openweathermap.org/data/2.5/weather?appid={WEATHER_API_KEY}&units=imperial&q={city_name}' |
|
|
|
|
|
response = requests.get(url) |
|
|
|
|
|
if response.status_code == 200: |
|
data = response.json() |
|
|
|
|
|
city = data['name'] |
|
temperature = round(data['main']['temp']) |
|
weather_description = data['weather'][0]['description'] |
|
|
|
|
|
return f"The weather in {city} is {weather_description} with a temperature of {temperature}°F." |
|
else: |
|
print(response.status_code) |
|
return "City not found or there was an error with the request." |
|
|
|
class Weather(BaseModel): |
|
city_name: str = Field(..., description="The name of the city to pull the weather for") |
|
|
|
|
|
class CountNumberLetter(BaseModel): |
|
word: str = Field(..., description="the word to count the number of letters in") |
|
letter: str = Field(..., description="the letter to count the occurences of") |
|
|
|
def count_number_letters(obj): |
|
letter = obj.letter.lower() |
|
expression = obj.word.lower() |
|
if letter == "None": |
|
return f"There are {len(obj.word)} letters in '{expression}'" |
|
count = len([l for l in expression if l == letter]) |
|
if count == 1: |
|
return f"There is 1 '{letter}' in '{expression}'" |
|
return f"There are {count} '{letter}'s in '{expression}'" |
|
|
|
|
|
class ImageGen(BaseModel): |
|
prompt: str = Field(..., description="The prompt to use to generate the image") |
|
|
|
def generate_image(prompt): |
|
if "teapot" in prompt.prompt.lower(): |
|
return "I generated an image of a teapot for you: https://teapotai.com/assets/teapotsmile.png" |
|
return "Ok I can't generate images, but you could easily hook up an image gen model to this tool call. Check out this image I did generate for you: https://teapotai.com/assets/teapotsmile.png" |
|
|
|
|
|
DEFAULT_TOOLS = [ |
|
TeapotTool( |
|
name="websearch", |
|
description="Execute web searches with pagination and filtering", |
|
schema=BraveWebSearch, |
|
fn=brave_search_context |
|
), |
|
TeapotTool( |
|
name="letter_counter", |
|
description="Can count how many times a letter occurs in a word.", |
|
schema=CountNumberLetter, |
|
fn=count_number_letters |
|
), |
|
TeapotTool( |
|
name="calculator", |
|
description="Can perform calculations on numbers using addition, subtraction, multiplication, and division.", |
|
schema=Calculator, |
|
fn=lambda expression: evaluate_expression(expression.expression), |
|
), |
|
TeapotTool( |
|
name="generate_image", |
|
description="Can generate an image for a user based on a prompt", |
|
schema=ImageGen, |
|
fn=generate_image, |
|
directly_return_result=True |
|
), |
|
TeapotTool( |
|
name="weather", |
|
description="Can pull today's weather information for any city.", |
|
schema=Weather, |
|
fn=get_weather |
|
) |
|
] |
|
|
|
|
|
CONFIG = { |
|
|
|
|
|
|
|
|
|
"Teapot AI": TeapotAI( |
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
"teapotai/teapotllm", |
|
revision="5aa6f84b5bd59da85552d55cc00efb702869cbf8", |
|
), |
|
documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=1617599323&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(), |
|
settings=TeapotAISettings(rag_num_results=3, log_level="debug"), |
|
tools=DEFAULT_TOOLS |
|
), |
|
} |
|
|
|
|
|
|
|
|
|
intents = discord.Intents.default() |
|
intents.messages = True |
|
client = discord.Client(intents=intents) |
|
|
|
async def handle_teapot_inference(server_name, user_input): |
|
teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"]) |
|
print(f"Using Teapot instance for server: {server_name}") |
|
|
|
|
|
response = await asyncio.to_thread(teapot_instance.query, query=user_input, system_prompt="""You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization. You can use tools such as a web search, a calculator and an image generator to assist users.""") |
|
return response |
|
|
|
async def debug_teapot_inference(server_name, user_input): |
|
teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"]) |
|
print(f"Using Teapot instance for server: {server_name}") |
|
|
|
search_result = brave_search_context(user_input) |
|
rag_results = teapot_instance.rag(query=user_input) |
|
|
|
return "\n\n".join(rag_results), search_result |
|
|
|
@client.event |
|
async def on_ready(): |
|
print(f'Logged in as {client.user}') |
|
|
|
@client.event |
|
async def on_message(message): |
|
if message.author == client.user: |
|
return |
|
|
|
|
|
mentioned = f'<@{client.user.id}>' in message.content |
|
|
|
|
|
replied_to_bot = False |
|
previous_message = "" |
|
if message.reference: |
|
replied_message = await message.channel.fetch_message(message.reference.message_id) |
|
if replied_message.author == client.user: |
|
replied_to_bot = True |
|
previous_message = "agent: "+replied_message.content+"\n" |
|
|
|
|
|
if not (mentioned or replied_to_bot): |
|
return |
|
|
|
server_name = message.guild.name if message.guild else "Teapot AI" |
|
print(server_name, message.author, message.content) |
|
|
|
async with message.channel.typing(): |
|
cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip() |
|
|
|
full_context = previous_message + cleaned_message |
|
|
|
response = await handle_teapot_inference(server_name, full_context) |
|
await message.reply(response) |
|
|
|
|
|
|
|
@client.event |
|
async def on_reaction_add(reaction, user): |
|
if user == client.user: |
|
return |
|
|
|
if str(reaction.emoji) not in ["❓", "❔"]: |
|
return |
|
|
|
message = reaction.message |
|
|
|
|
|
if message.author != client.user or not message.reference: |
|
return |
|
|
|
|
|
cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip() |
|
original_message = await message.channel.fetch_message(message.reference.message_id) |
|
user_input = original_message.content.strip() |
|
|
|
server_name = message.guild.name if message.guild else "Teapot AI" |
|
|
|
|
|
thread = message.thread |
|
if thread is None: |
|
thread = await message.create_thread(name=f"Debug Thread: '{cleaned_message[0:30]}...'", auto_archive_duration=60) |
|
|
|
rag_result, search_result = await debug_teapot_inference(server_name, user_input) |
|
debug_response = "## RAG:\n```"+discord.utils.escape_markdown(rag_result)[-900:]+"```\n\n## Search:\n```"+discord.utils.escape_markdown(search_result)[-900:]+"```" |
|
await thread.send(debug_response) |
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def discord_loop(): |
|
st.session_state["initialized"] = True |
|
client.run(DISCORD_TOKEN) |
|
st.write("418 I'm a teapot") |
|
return |
|
|
|
|
|
discord_loop() |
|
|