zakerytclarke commited on
Commit
00915d5
·
verified ·
1 Parent(s): 66ad49d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -27
app.py CHANGED
@@ -15,27 +15,26 @@ from pydantic import BaseModel
15
  from typing import List, Optional
16
  from tqdm import tqdm
17
  import re
 
18
 
19
  st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide")
 
 
 
 
 
20
 
21
- # Load model only once
22
- if "tokenizer" not in st.session_state:
23
- model_name = "teapotai/teapotllm"
24
- st.session_state.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
- st.session_state.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
26
-
27
- tokenizer = st.session_state.tokenizer
28
- model = st.session_state.model
29
 
30
  def log_time(func):
31
  async def wrapper(*args, **kwargs):
32
  start_time = time.time()
33
- result = await func(*args, **kwargs)
34
  end_time = time.time()
35
  print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
36
  return result
37
  return wrapper
38
 
 
39
  API_KEY = os.environ.get("brave_api_key")
40
 
41
  @log_time
@@ -54,33 +53,45 @@ async def brave_search(query, count=1):
54
  print(f"Error: {response.status}, {await response.text()}")
55
  return []
56
 
 
57
  @traceable
58
  @log_time
59
  async def query_teapot(prompt, context, user_input):
60
  input_text = prompt + "\n" + context + "\n" + user_input
61
  print(input_text)
 
62
 
63
  inputs = tokenizer(input_text, return_tensors="pt")
64
  input_length = inputs["input_ids"].shape[1]
65
 
66
- output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512)
 
67
 
68
  output_text = tokenizer.decode(output[0], skip_special_tokens=True)
69
- total_length = output.shape[1]
70
- output_length = total_length - input_length
 
 
 
 
 
71
 
72
  return output_text
73
 
 
74
  @log_time
75
  async def handle_chat(user_input):
 
76
  results = await brave_search(user_input)
 
77
 
78
  documents = [desc.replace('<strong>', '').replace('</strong>', '') for _, desc, _ in results]
 
79
  context = "\n".join(documents)
80
-
81
  prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization."""
82
-
83
  response = await query_teapot(prompt, context, user_input)
 
84
 
85
  debug_info = f"""
86
  Prompt:
@@ -89,51 +100,57 @@ Prompt:
89
  Context:
90
  {context}
91
 
 
 
92
  Response: {response}
93
  """
 
94
  return response, debug_info
95
 
 
96
  st.write("418 I'm a teapot")
97
 
98
  DISCORD_TOKEN = os.environ.get("discord_key")
99
 
100
- # Only start the bot once
101
- if "discord_bot_started" not in st.session_state:
102
- st.session_state.discord_bot_started = False
103
-
104
  # Create an instance of Intents and enable the required ones
105
- intents = discord.Intents.default()
106
- intents.messages = True
107
 
108
  # Create an instance of a client with the intents
109
  client = discord.Client(intents=intents)
110
 
 
111
  @client.event
112
  async def on_ready():
113
  print(f'Logged in as {client.user}')
114
 
 
115
  @client.event
116
  async def on_message(message):
 
117
  if message.author == client.user:
118
  return
119
 
 
120
  if f'<@{client.user.id}>' not in message.content:
121
  return
122
-
123
  print(message.content)
 
124
  is_debug = "debug:" in message.content
125
-
126
  async with message.channel.typing():
127
- cleaned_message = message.content.replace("debug:", "").replace(f'<@{client.user.id}>',"")
128
  response, debug_info = await handle_chat(cleaned_message)
129
  print(response)
130
  sent_message = await message.reply(response)
131
 
 
132
  if is_debug:
133
  thread = await sent_message.create_thread(name=f"""Debug Thread: '{cleaned_message}'""", auto_archive_duration=60)
 
 
134
  await thread.send(debug_info)
135
 
136
- # Start Discord bot only once
137
- if not st.session_state.discord_bot_started:
138
- st.session_state.discord_bot_started = True
139
- asyncio.run(client.start(DISCORD_TOKEN))
 
15
  from typing import List, Optional
16
  from tqdm import tqdm
17
  import re
18
+ import os
19
 
20
  st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide")
21
+ tokenizer = None
22
+ model = None
23
+ model_name = "teapotai/teapotllm"
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
26
 
 
 
 
 
 
 
 
 
27
 
28
  def log_time(func):
29
  async def wrapper(*args, **kwargs):
30
  start_time = time.time()
31
+ result = await func(*args, **kwargs) # Make it awaitable
32
  end_time = time.time()
33
  print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
34
  return result
35
  return wrapper
36
 
37
+
38
  API_KEY = os.environ.get("brave_api_key")
39
 
40
  @log_time
 
53
  print(f"Error: {response.status}, {await response.text()}")
54
  return []
55
 
56
+
57
  @traceable
58
  @log_time
59
  async def query_teapot(prompt, context, user_input):
60
  input_text = prompt + "\n" + context + "\n" + user_input
61
  print(input_text)
62
+ start_time = time.time()
63
 
64
  inputs = tokenizer(input_text, return_tensors="pt")
65
  input_length = inputs["input_ids"].shape[1]
66
 
67
+ # output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512)
68
+ output = model.generate(inputs, max_new_tokens=512)
69
 
70
  output_text = tokenizer.decode(output[0], skip_special_tokens=True)
71
+ total_length = output.shape[1] # Includes both input and output tokens
72
+ output_length = total_length - input_length # Extract output token count
73
+
74
+ end_time = time.time()
75
+
76
+ elapsed_time = end_time - start_time
77
+ tokens_per_second = total_length / elapsed_time if elapsed_time > 0 else float("inf")
78
 
79
  return output_text
80
 
81
+
82
  @log_time
83
  async def handle_chat(user_input):
84
+ search_start_time = time.time()
85
  results = await brave_search(user_input)
86
+ search_end_time = time.time()
87
 
88
  documents = [desc.replace('<strong>', '').replace('</strong>', '') for _, desc, _ in results]
89
+
90
  context = "\n".join(documents)
 
91
  prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization."""
92
+ generation_start_time = time.time()
93
  response = await query_teapot(prompt, context, user_input)
94
+ generation_end_time = time.time()
95
 
96
  debug_info = f"""
97
  Prompt:
 
100
  Context:
101
  {context}
102
 
103
+ Search time: {search_end_time - search_start_time:.2f} seconds
104
+ Generation time: {generation_end_time - generation_start_time:.2f} seconds
105
  Response: {response}
106
  """
107
+
108
  return response, debug_info
109
 
110
+
111
  st.write("418 I'm a teapot")
112
 
113
  DISCORD_TOKEN = os.environ.get("discord_key")
114
 
 
 
 
 
115
  # Create an instance of Intents and enable the required ones
116
+ intents = discord.Intents.default() # Default intents enable basic functionality
117
+ intents.messages = True # Enable message-related events
118
 
119
  # Create an instance of a client with the intents
120
  client = discord.Client(intents=intents)
121
 
122
+ # Event when the bot has connected to the server
123
  @client.event
124
  async def on_ready():
125
  print(f'Logged in as {client.user}')
126
 
127
+ # Event when a message is received
128
  @client.event
129
  async def on_message(message):
130
+ # Check if the message is from the bot itself to prevent a loop
131
  if message.author == client.user:
132
  return
133
 
134
+ # Exit the function if the bot is not mentioned
135
  if f'<@{client.user.id}>' not in message.content:
136
  return
137
+
138
  print(message.content)
139
+
140
  is_debug = "debug:" in message.content
 
141
  async with message.channel.typing():
142
+ cleaned_message=message.content.replace("debug:", "").replace(f'<@{client.user.id}>',"")
143
  response, debug_info = await handle_chat(cleaned_message)
144
  print(response)
145
  sent_message = await message.reply(response)
146
 
147
+ # Create a thread from the sent message
148
  if is_debug:
149
  thread = await sent_message.create_thread(name=f"""Debug Thread: '{cleaned_message}'""", auto_archive_duration=60)
150
+
151
+ # Send a message in the created thread
152
  await thread.send(debug_info)
153
 
154
+
155
+ # Run the bot with your token
156
+ client.run(DISCORD_TOKEN)