boning123 commited on
Commit
f040238
·
verified ·
1 Parent(s): da33fb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -120
app.py CHANGED
@@ -8,7 +8,7 @@ from transformers import (
8
  import torch
9
  import gradio as gr
10
  import re
11
- from duckduckgo_search import DDGS
12
 
13
  # Dictionary of available models
14
  AVAILABLE_MODELS = {
@@ -24,6 +24,7 @@ tokenizer = None
24
  model = None
25
 
26
  # Initialize tokenizer and model globally for the first run
 
27
  print(f"Initializing model: {DEFAULT_MODEL}...")
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
29
  model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype="auto", device_map="auto")
@@ -99,18 +100,13 @@ def load_model(selected_model_name):
99
  return f"✅ Loaded model: {selected_model_name}"
100
 
101
  def respond(message, chat_history, use_reasoning):
102
- # Start by appending the user's message and an empty string for the bot's response
103
- # This creates the new chat bubble for the user's input.
104
- chat_history.append([message, ""])
105
- yield chat_history, chat_history # Yield immediately to show the user's message
106
-
107
  # Gradio's chat_history is a list of [user_message, bot_message] pairs.
108
  # We need to convert it to the format expected by the model's chat template.
109
  messages_for_template = [{"role": "system", "content": SYSTEM_PROMPT}]
110
- for user_msg, bot_msg in chat_history[:-1]: # Exclude the current, incomplete turn
111
- if user_msg:
112
  messages_for_template.append({"role": "user", "content": user_msg})
113
- if bot_msg:
114
  messages_for_template.append({"role": "assistant", "content": bot_msg})
115
 
116
  # Add the current user message with the appropriate prefix
@@ -122,10 +118,17 @@ def respond(message, chat_history, use_reasoning):
122
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
123
 
124
  # --- First Generation Pass (for potential search query) ---
125
- generated_text_buffer = ""
 
 
 
 
 
126
  with torch.no_grad():
127
- # Iterate over tokens from the model
128
- for output_ids in model.generate(
 
 
129
  input_ids=inputs["input_ids"],
130
  attention_mask=inputs.get("attention_mask"),
131
  max_new_tokens=gen_kwargs["max_new_tokens"],
@@ -137,67 +140,57 @@ def respond(message, chat_history, use_reasoning):
137
  eos_token_id=gen_kwargs["eos_token_id"],
138
  pad_token_id=gen_kwargs["pad_token_id"],
139
  stopping_criteria=stop_criteria,
140
- return_dict_in_generate=True,
141
- output_scores=True,
142
- # We don't use streamer directly here for Gradio output,
143
- # but for internal token generation and decoding.
144
- ).sequences.tolist():
145
- # Decode the newly generated token
146
- new_token_text = tokenizer.decode(output_ids[len(inputs["input_ids"][0]):], skip_special_tokens=True)
147
- generated_text_buffer += new_token_text
148
-
149
- # Update the last assistant message in chat_history
150
- chat_history[-1][1] = generated_text_buffer
151
- yield chat_history, chat_history # Yield to update the Gradio UI
152
-
153
- # Check for early stop condition if the full response isn't needed yet
154
- if "<search>" in generated_text_buffer and "</search>" in generated_text_buffer:
155
- break # Stop generating if a full search tag is found
156
-
157
- # After the first pass, check for search query in the full generated buffer
158
- model_response_content_first_pass = generated_text_buffer.strip()
159
-
160
- # Try to extract content after 'content:' if in thinking mode
161
- if use_reasoning:
162
- try:
163
- thinking_content_start_index = model_response_content_first_pass.find("thinking content:")
164
- if thinking_content_start_index != -1:
165
- content_start_index = model_response_content_first_pass.rindex("content:", thinking_content_start_index)
166
- response_to_process = model_response_content_first_pass[content_start_index + len("content:"):].strip()
167
- else:
168
- response_to_process = model_response_content_first_pass
169
- except ValueError:
170
- response_to_process = model_response_content_first_pass
171
- else:
172
- response_to_process = model_response_content_first_pass
173
-
174
- search_match = re.search(r'<search>(.*?)</search>', response_to_process, re.DOTALL)
175
-
176
- if search_match:
177
- search_query = search_match.group(1).strip()
178
 
179
- # Update the current bubble to show the search action
180
- chat_history[-1][1] = f"SAM: Detecting search query: '{search_query}'...\nSearching the web..."
181
- yield chat_history, chat_history
182
 
183
- search_results = search_duckduckgo(search_query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
- # Now, prepare for the second generation pass with search results
186
- # Append model's response (with search tag) and tool response to history for the model's context
187
- messages_for_template.append({"role": "assistant", "content": response_to_process}) # Model's initial response with search tag
188
- messages_for_template.append({"role": "tool_response", "content": f"Search results for \"{search_query}\": {search_results}"})
189
 
190
- prompt_with_search_results = tokenizer.apply_chat_template(messages_for_template, tokenize=False, add_generation_prompt=True)
191
- inputs_with_search = tokenizer([prompt_with_search_results], return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
- # Update the current bubble again to indicate thinking with results
194
- chat_history[-1][1] = f"SAM: Search results for \"{search_query}\": {search_results}\nSAM: Thinking with results..."
195
- yield chat_history, chat_history
196
 
197
- # --- Second Generation Pass (with search results) ---
198
- final_generated_text_buffer = ""
199
- with torch.no_grad():
200
- for output_ids_search in model.generate(
201
  input_ids=inputs_with_search["input_ids"],
202
  attention_mask=inputs_with_search.get("attention_mask"),
203
  max_new_tokens=gen_kwargs["max_new_tokens"],
@@ -209,64 +202,50 @@ def respond(message, chat_history, use_reasoning):
209
  eos_token_id=gen_kwargs["eos_token_id"],
210
  pad_token_id=gen_kwargs["pad_token_id"],
211
  stopping_criteria=stop_criteria,
212
- return_dict_in_generate=True,
213
- output_scores=True,
214
- ).sequences.tolist():
215
- new_token_text_search = tokenizer.decode(output_ids_search[len(inputs_with_search["input_ids"][0]):], skip_special_tokens=True)
216
- final_generated_text_buffer += new_token_text_search
217
-
218
- # Update the *same* last message in chat_history
219
- chat_history[-1][1] = f"SAM: Search results for \"{search_query}\": {search_results}\nSAM: Thinking with results...\n" + final_generated_text_buffer
220
- yield chat_history, chat_history
221
-
222
- final_model_response_content = final_generated_text_buffer.strip()
223
-
224
- # Extract content after 'content:' for the final response
225
- if use_reasoning:
226
- try:
227
- thinking_content_start_index = final_model_response_content.find("thinking content:")
228
- if thinking_content_start_index != -1:
229
- content_start_index = final_model_response_content.rindex("content:", thinking_content_start_index)
230
- final_response_to_display = final_model_response_content[content_start_index + len("content:"):].strip()
231
- else:
232
  final_response_to_display = final_model_response_content
233
- except ValueError:
234
  final_response_to_display = final_model_response_content
235
- else:
236
- final_response_to_display = final_model_response_content
237
-
238
- # Final update to the chat history for the completed response
239
- chat_history[-1][1] = final_response_to_display
240
- yield chat_history, chat_history
241
-
242
- else: # No search query detected in the first pass
243
- # The generated_text_buffer already holds the full response from the first pass
244
- # This part handles the case where no search was needed.
245
- # Ensure the last message in chat_history is the fully generated one.
246
-
247
- # Extract content after 'content:' for the direct response
248
- if use_reasoning:
249
- try:
250
- thinking_content_start_index = model_response_content_first_pass.find("thinking content:")
251
- if thinking_content_start_index != -1:
252
- content_start_index = model_response_content_first_pass.rindex("content:", thinking_content_start_index)
253
- direct_response_to_display = model_response_content_first_pass[content_start_index + len("content:"):].strip()
254
- else:
255
- direct_response_to_display = model_response_content_first_pass
256
- except ValueError:
257
- direct_response_to_display = model_response_content_first_pass
258
- else:
259
- direct_response_to_display = model_response_content_first_pass
260
-
261
- # Ensure the last message in chat_history is the fully generated one.
262
- chat_history[-1][1] = direct_response_to_display
263
- yield chat_history, chat_history
264
 
265
- # The final `yield` at the end of the function ensures the state is updated
266
- # for the next turn in Gradio.
 
 
267
  return chat_history, chat_history
268
 
269
-
270
  with gr.Blocks() as demo:
271
  gr.Markdown("## 🤖 Sam - SmilyAI Assistant")
272
  gr.Markdown("Chat with **Sam**, an AI assistant built by [SmilyAI Labs](https://smily.ai). Toggle reasoning mode or choose a model below.")
 
8
  import torch
9
  import gradio as gr
10
  import re
11
+ from duckduckgo_search import DDGS # Import DuckDuckGo Search
12
 
13
  # Dictionary of available models
14
  AVAILABLE_MODELS = {
 
24
  model = None
25
 
26
  # Initialize tokenizer and model globally for the first run
27
+ # This ensures they are loaded when the script starts
28
  print(f"Initializing model: {DEFAULT_MODEL}...")
29
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
30
  model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype="auto", device_map="auto")
 
100
  return f"✅ Loaded model: {selected_model_name}"
101
 
102
  def respond(message, chat_history, use_reasoning):
 
 
 
 
 
103
  # Gradio's chat_history is a list of [user_message, bot_message] pairs.
104
  # We need to convert it to the format expected by the model's chat template.
105
  messages_for_template = [{"role": "system", "content": SYSTEM_PROMPT}]
106
+ for user_msg, bot_msg in chat_history:
107
+ if user_msg: # Only add if not empty
108
  messages_for_template.append({"role": "user", "content": user_msg})
109
+ if bot_msg: # Only add if not empty
110
  messages_for_template.append({"role": "assistant", "content": bot_msg})
111
 
112
  # Add the current user message with the appropriate prefix
 
118
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
119
 
120
  # --- First Generation Pass (for potential search query) ---
121
+ full_response_parts = []
122
+ current_chat_history_for_yield = list(chat_history) # Create a copy for yielding
123
+
124
+ # Ensure streamer is correctly set up for this generation
125
+ # For Gradio streaming, we need to manually yield tokens.
126
+ # TextStreamer is for console output; here we'll collect and yield.
127
  with torch.no_grad():
128
+ # Use an iterable for token-by-token generation for Gradio
129
+ # This is a common pattern for streaming outputs in Gradio
130
+ # We will collect the full response to check for search tags.
131
+ generated_ids = model.generate(
132
  input_ids=inputs["input_ids"],
133
  attention_mask=inputs.get("attention_mask"),
134
  max_new_tokens=gen_kwargs["max_new_tokens"],
 
140
  eos_token_id=gen_kwargs["eos_token_id"],
141
  pad_token_id=gen_kwargs["pad_token_id"],
142
  stopping_criteria=stop_criteria,
143
+ streamer=TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # For console debugging
144
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ # Decode the full generated output
147
+ full_generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=False)
 
148
 
149
+ # The actual response from the model (excluding the prompt)
150
+ model_response_content = full_generated_text[len(prompt):].strip()
151
+
152
+ # Try to extract content after 'content:' if in thinking mode
153
+ if use_reasoning:
154
+ try:
155
+ # Find the last occurrence of "content:" after "thinking content:"
156
+ thinking_content_start_index = model_response_content.find("thinking content:")
157
+ if thinking_content_start_index != -1:
158
+ content_start_index = model_response_content.rindex("content:", thinking_content_start_index)
159
+ response_to_display = model_response_content[content_start_index + len("content:"):].strip()
160
+ else:
161
+ response_to_display = model_response_content
162
+ except ValueError:
163
+ response_to_display = model_response_content
164
+ else:
165
+ response_to_display = model_response_content
166
 
167
+ # Check for search query in the model's response
168
+ search_match = re.search(r'<search>(.*?)</search>', response_to_display, re.DOTALL)
 
 
169
 
170
+ if search_match:
171
+ search_query = search_match.group(1).strip()
172
+
173
+ # Update history with the model's initial response containing the search query
174
+ # For display purposes in Gradio, we'll show the search request, then results, then final answer.
175
+ current_chat_history_for_yield.append([message, f"SAM: Detecting search query: '{search_query}'...\nSearching the web..."])
176
+ yield current_chat_history_for_yield, current_chat_history_for_yield # Yield interim state
177
+
178
+ search_results = search_duckduckgo(search_query)
179
+
180
+ # Now, prepare for the second generation pass with search results
181
+ # Append model's response (with search tag) and tool response to history
182
+ messages_for_template.append({"role": "assistant", "content": response_to_display}) # Model's initial response with search tag
183
+ messages_for_template.append({"role": "tool_response", "content": f"Search results for \"{search_query}\": {search_results}"})
184
+
185
+ prompt_with_search_results = tokenizer.apply_chat_template(messages_for_template, tokenize=False, add_generation_prompt=True)
186
+ inputs_with_search = tokenizer([prompt_with_search_results], return_tensors="pt").to(model.device)
187
 
188
+ current_chat_history_for_yield[-1][1] += f"\nSearch results for \"{search_query}\": {search_results}\nSAM: Thinking with results..."
189
+ yield current_chat_history_for_yield, current_chat_history_for_yield # Yield interim state
 
190
 
191
+ # --- Second Generation Pass (with search results) ---
192
+ final_response_parts = []
193
+ final_generated_ids = model.generate(
 
194
  input_ids=inputs_with_search["input_ids"],
195
  attention_mask=inputs_with_search.get("attention_mask"),
196
  max_new_tokens=gen_kwargs["max_new_tokens"],
 
202
  eos_token_id=gen_kwargs["eos_token_id"],
203
  pad_token_id=gen_kwargs["pad_token_id"],
204
  stopping_criteria=stop_criteria,
205
+ streamer=TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # For console debugging
206
+ )
207
+
208
+ final_full_generated_text = tokenizer.decode(final_generated_ids[0], skip_special_tokens=False)
209
+ final_model_response_content = final_full_generated_text[len(prompt_with_search_results):].strip()
210
+
211
+ # Extract content after 'content:' for the final response
212
+ if use_reasoning:
213
+ try:
214
+ thinking_content_start_index = final_model_response_content.find("thinking content:")
215
+ if thinking_content_start_index != -1:
216
+ content_start_index = final_model_response_content.rindex("content:", thinking_content_start_index)
217
+ final_response_to_display = final_model_response_content[content_start_index + len("content:"):].strip()
218
+ else:
219
+ final_response_to_display = final_model_response_content
220
+ except ValueError:
 
 
 
 
221
  final_response_to_display = final_model_response_content
222
+ else:
223
  final_response_to_display = final_model_response_content
224
+
225
+ # Yield token by token for the final response
226
+ for char in final_response_to_display:
227
+ final_response_parts.append(char)
228
+ # Update the last message in chat_history for streaming in Gradio
229
+ current_chat_history_for_yield[-1][1] = "SAM: Thinking with results...\n" + "".join(final_response_parts)
230
+ yield current_chat_history_for_yield, current_chat_history_for_yield
231
+
232
+ # After streaming, update the actual chat_history for the next turn
233
+ chat_history.append((message, final_response_to_display))
234
+
235
+ else: # No search query detected in the first pass
236
+ # Yield token by token for the direct response
237
+ for char in response_to_display:
238
+ full_response_parts.append(char)
239
+ # Update the last message in chat_history for streaming in Gradio
240
+ current_chat_history_for_yield.append([message, "".join(full_response_parts)])
241
+ yield current_chat_history_for_yield, current_chat_history_for_yield
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ # After streaming, update the actual chat_history for the next turn
244
+ chat_history.append((message, response_to_display))
245
+
246
+ # Return the final chat history for the state
247
  return chat_history, chat_history
248
 
 
249
  with gr.Blocks() as demo:
250
  gr.Markdown("## 🤖 Sam - SmilyAI Assistant")
251
  gr.Markdown("Chat with **Sam**, an AI assistant built by [SmilyAI Labs](https://smily.ai). Toggle reasoning mode or choose a model below.")