zakerytclarke commited on
Commit
66ad49d
·
verified ·
1 Parent(s): a83ab40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -42
app.py CHANGED
@@ -15,26 +15,27 @@ from pydantic import BaseModel
15
  from typing import List, Optional
16
  from tqdm import tqdm
17
  import re
18
- import os
19
 
20
  st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide")
21
- tokenizer = None
22
- model = None
23
- model_name = "teapotai/teapotllm"
24
- tokenizer = AutoTokenizer.from_pretrained(model_name)
25
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
26
 
 
 
 
 
 
 
 
 
27
 
28
  def log_time(func):
29
  async def wrapper(*args, **kwargs):
30
  start_time = time.time()
31
- result = await func(*args, **kwargs) # Make it awaitable
32
  end_time = time.time()
33
  print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
34
  return result
35
  return wrapper
36
 
37
-
38
  API_KEY = os.environ.get("brave_api_key")
39
 
40
  @log_time
@@ -53,13 +54,11 @@ async def brave_search(query, count=1):
53
  print(f"Error: {response.status}, {await response.text()}")
54
  return []
55
 
56
-
57
  @traceable
58
  @log_time
59
  async def query_teapot(prompt, context, user_input):
60
  input_text = prompt + "\n" + context + "\n" + user_input
61
  print(input_text)
62
- start_time = time.time()
63
 
64
  inputs = tokenizer(input_text, return_tensors="pt")
65
  input_length = inputs["input_ids"].shape[1]
@@ -67,30 +66,21 @@ async def query_teapot(prompt, context, user_input):
67
  output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512)
68
 
69
  output_text = tokenizer.decode(output[0], skip_special_tokens=True)
70
- total_length = output.shape[1] # Includes both input and output tokens
71
- output_length = total_length - input_length # Extract output token count
72
-
73
- end_time = time.time()
74
-
75
- elapsed_time = end_time - start_time
76
- tokens_per_second = total_length / elapsed_time if elapsed_time > 0 else float("inf")
77
 
78
  return output_text
79
 
80
-
81
  @log_time
82
  async def handle_chat(user_input):
83
- search_start_time = time.time()
84
  results = await brave_search(user_input)
85
- search_end_time = time.time()
86
 
87
  documents = [desc.replace('<strong>', '').replace('</strong>', '') for _, desc, _ in results]
88
-
89
  context = "\n".join(documents)
 
90
  prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization."""
91
- generation_start_time = time.time()
92
  response = await query_teapot(prompt, context, user_input)
93
- generation_end_time = time.time()
94
 
95
  debug_info = f"""
96
  Prompt:
@@ -99,57 +89,51 @@ Prompt:
99
  Context:
100
  {context}
101
 
102
- Search time: {search_end_time - search_start_time:.2f} seconds
103
- Generation time: {generation_end_time - generation_start_time:.2f} seconds
104
  Response: {response}
105
  """
106
-
107
  return response, debug_info
108
 
109
-
110
  st.write("418 I'm a teapot")
111
 
112
  DISCORD_TOKEN = os.environ.get("discord_key")
113
 
 
 
 
 
114
  # Create an instance of Intents and enable the required ones
115
- intents = discord.Intents.default() # Default intents enable basic functionality
116
- intents.messages = True # Enable message-related events
117
 
118
  # Create an instance of a client with the intents
119
  client = discord.Client(intents=intents)
120
 
121
- # Event when the bot has connected to the server
122
  @client.event
123
  async def on_ready():
124
  print(f'Logged in as {client.user}')
125
 
126
- # Event when a message is received
127
  @client.event
128
  async def on_message(message):
129
- # Check if the message is from the bot itself to prevent a loop
130
  if message.author == client.user:
131
  return
132
 
133
- # Exit the function if the bot is not mentioned
134
  if f'<@{client.user.id}>' not in message.content:
135
  return
136
-
137
- print(message.content)
138
 
 
139
  is_debug = "debug:" in message.content
 
140
  async with message.channel.typing():
141
- cleaned_message=message.content.replace("debug:", "").replace(f'<@{client.user.id}>',"")
142
  response, debug_info = await handle_chat(cleaned_message)
143
  print(response)
144
  sent_message = await message.reply(response)
145
 
146
- # Create a thread from the sent message
147
  if is_debug:
148
  thread = await sent_message.create_thread(name=f"""Debug Thread: '{cleaned_message}'""", auto_archive_duration=60)
149
-
150
- # Send a message in the created thread
151
  await thread.send(debug_info)
152
 
153
-
154
- # Run the bot with your token
155
- client.run(DISCORD_TOKEN)
 
 
15
  from typing import List, Optional
16
  from tqdm import tqdm
17
  import re
 
18
 
19
  st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide")
 
 
 
 
 
20
 
21
+ # Load model only once
22
+ if "tokenizer" not in st.session_state:
23
+ model_name = "teapotai/teapotllm"
24
+ st.session_state.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ st.session_state.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
26
+
27
+ tokenizer = st.session_state.tokenizer
28
+ model = st.session_state.model
29
 
30
  def log_time(func):
31
  async def wrapper(*args, **kwargs):
32
  start_time = time.time()
33
+ result = await func(*args, **kwargs)
34
  end_time = time.time()
35
  print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
36
  return result
37
  return wrapper
38
 
 
39
  API_KEY = os.environ.get("brave_api_key")
40
 
41
  @log_time
 
54
  print(f"Error: {response.status}, {await response.text()}")
55
  return []
56
 
 
57
  @traceable
58
  @log_time
59
  async def query_teapot(prompt, context, user_input):
60
  input_text = prompt + "\n" + context + "\n" + user_input
61
  print(input_text)
 
62
 
63
  inputs = tokenizer(input_text, return_tensors="pt")
64
  input_length = inputs["input_ids"].shape[1]
 
66
  output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512)
67
 
68
  output_text = tokenizer.decode(output[0], skip_special_tokens=True)
69
+ total_length = output.shape[1]
70
+ output_length = total_length - input_length
 
 
 
 
 
71
 
72
  return output_text
73
 
 
74
  @log_time
75
  async def handle_chat(user_input):
 
76
  results = await brave_search(user_input)
 
77
 
78
  documents = [desc.replace('<strong>', '').replace('</strong>', '') for _, desc, _ in results]
 
79
  context = "\n".join(documents)
80
+
81
  prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization."""
82
+
83
  response = await query_teapot(prompt, context, user_input)
 
84
 
85
  debug_info = f"""
86
  Prompt:
 
89
  Context:
90
  {context}
91
 
 
 
92
  Response: {response}
93
  """
 
94
  return response, debug_info
95
 
 
96
  st.write("418 I'm a teapot")
97
 
98
  DISCORD_TOKEN = os.environ.get("discord_key")
99
 
100
+ # Only start the bot once
101
+ if "discord_bot_started" not in st.session_state:
102
+ st.session_state.discord_bot_started = False
103
+
104
  # Create an instance of Intents and enable the required ones
105
+ intents = discord.Intents.default()
106
+ intents.messages = True
107
 
108
  # Create an instance of a client with the intents
109
  client = discord.Client(intents=intents)
110
 
 
111
  @client.event
112
  async def on_ready():
113
  print(f'Logged in as {client.user}')
114
 
 
115
  @client.event
116
  async def on_message(message):
 
117
  if message.author == client.user:
118
  return
119
 
 
120
  if f'<@{client.user.id}>' not in message.content:
121
  return
 
 
122
 
123
+ print(message.content)
124
  is_debug = "debug:" in message.content
125
+
126
  async with message.channel.typing():
127
+ cleaned_message = message.content.replace("debug:", "").replace(f'<@{client.user.id}>',"")
128
  response, debug_info = await handle_chat(cleaned_message)
129
  print(response)
130
  sent_message = await message.reply(response)
131
 
 
132
  if is_debug:
133
  thread = await sent_message.create_thread(name=f"""Debug Thread: '{cleaned_message}'""", auto_archive_duration=60)
 
 
134
  await thread.send(debug_info)
135
 
136
+ # Start Discord bot only once
137
+ if not st.session_state.discord_bot_started:
138
+ st.session_state.discord_bot_started = True
139
+ asyncio.run(client.start(DISCORD_TOKEN))