Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,7 @@ from transformers import (
|
|
8 |
import torch
|
9 |
import gradio as gr
|
10 |
import re
|
11 |
-
from duckduckgo_search import DDGS
|
12 |
|
13 |
# Dictionary of available models
|
14 |
AVAILABLE_MODELS = {
|
@@ -24,6 +24,7 @@ tokenizer = None
|
|
24 |
model = None
|
25 |
|
26 |
# Initialize tokenizer and model globally for the first run
|
|
|
27 |
print(f"Initializing model: {DEFAULT_MODEL}...")
|
28 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
29 |
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype="auto", device_map="auto")
|
@@ -99,18 +100,13 @@ def load_model(selected_model_name):
|
|
99 |
return f"✅ Loaded model: {selected_model_name}"
|
100 |
|
101 |
def respond(message, chat_history, use_reasoning):
|
102 |
-
# Start by appending the user's message and an empty string for the bot's response
|
103 |
-
# This creates the new chat bubble for the user's input.
|
104 |
-
chat_history.append([message, ""])
|
105 |
-
yield chat_history, chat_history # Yield immediately to show the user's message
|
106 |
-
|
107 |
# Gradio's chat_history is a list of [user_message, bot_message] pairs.
|
108 |
# We need to convert it to the format expected by the model's chat template.
|
109 |
messages_for_template = [{"role": "system", "content": SYSTEM_PROMPT}]
|
110 |
-
for user_msg, bot_msg in chat_history
|
111 |
-
if user_msg:
|
112 |
messages_for_template.append({"role": "user", "content": user_msg})
|
113 |
-
if bot_msg:
|
114 |
messages_for_template.append({"role": "assistant", "content": bot_msg})
|
115 |
|
116 |
# Add the current user message with the appropriate prefix
|
@@ -122,10 +118,17 @@ def respond(message, chat_history, use_reasoning):
|
|
122 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
123 |
|
124 |
# --- First Generation Pass (for potential search query) ---
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
126 |
with torch.no_grad():
|
127 |
-
#
|
128 |
-
for
|
|
|
|
|
129 |
input_ids=inputs["input_ids"],
|
130 |
attention_mask=inputs.get("attention_mask"),
|
131 |
max_new_tokens=gen_kwargs["max_new_tokens"],
|
@@ -137,67 +140,57 @@ def respond(message, chat_history, use_reasoning):
|
|
137 |
eos_token_id=gen_kwargs["eos_token_id"],
|
138 |
pad_token_id=gen_kwargs["pad_token_id"],
|
139 |
stopping_criteria=stop_criteria,
|
140 |
-
|
141 |
-
|
142 |
-
# We don't use streamer directly here for Gradio output,
|
143 |
-
# but for internal token generation and decoding.
|
144 |
-
).sequences.tolist():
|
145 |
-
# Decode the newly generated token
|
146 |
-
new_token_text = tokenizer.decode(output_ids[len(inputs["input_ids"][0]):], skip_special_tokens=True)
|
147 |
-
generated_text_buffer += new_token_text
|
148 |
-
|
149 |
-
# Update the last assistant message in chat_history
|
150 |
-
chat_history[-1][1] = generated_text_buffer
|
151 |
-
yield chat_history, chat_history # Yield to update the Gradio UI
|
152 |
-
|
153 |
-
# Check for early stop condition if the full response isn't needed yet
|
154 |
-
if "<search>" in generated_text_buffer and "</search>" in generated_text_buffer:
|
155 |
-
break # Stop generating if a full search tag is found
|
156 |
-
|
157 |
-
# After the first pass, check for search query in the full generated buffer
|
158 |
-
model_response_content_first_pass = generated_text_buffer.strip()
|
159 |
-
|
160 |
-
# Try to extract content after 'content:' if in thinking mode
|
161 |
-
if use_reasoning:
|
162 |
-
try:
|
163 |
-
thinking_content_start_index = model_response_content_first_pass.find("thinking content:")
|
164 |
-
if thinking_content_start_index != -1:
|
165 |
-
content_start_index = model_response_content_first_pass.rindex("content:", thinking_content_start_index)
|
166 |
-
response_to_process = model_response_content_first_pass[content_start_index + len("content:"):].strip()
|
167 |
-
else:
|
168 |
-
response_to_process = model_response_content_first_pass
|
169 |
-
except ValueError:
|
170 |
-
response_to_process = model_response_content_first_pass
|
171 |
-
else:
|
172 |
-
response_to_process = model_response_content_first_pass
|
173 |
-
|
174 |
-
search_match = re.search(r'<search>(.*?)</search>', response_to_process, re.DOTALL)
|
175 |
-
|
176 |
-
if search_match:
|
177 |
-
search_query = search_match.group(1).strip()
|
178 |
|
179 |
-
#
|
180 |
-
|
181 |
-
yield chat_history, chat_history
|
182 |
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
-
#
|
186 |
-
|
187 |
-
messages_for_template.append({"role": "assistant", "content": response_to_process}) # Model's initial response with search tag
|
188 |
-
messages_for_template.append({"role": "tool_response", "content": f"Search results for \"{search_query}\": {search_results}"})
|
189 |
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
yield chat_history, chat_history
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
for output_ids_search in model.generate(
|
201 |
input_ids=inputs_with_search["input_ids"],
|
202 |
attention_mask=inputs_with_search.get("attention_mask"),
|
203 |
max_new_tokens=gen_kwargs["max_new_tokens"],
|
@@ -209,64 +202,50 @@ def respond(message, chat_history, use_reasoning):
|
|
209 |
eos_token_id=gen_kwargs["eos_token_id"],
|
210 |
pad_token_id=gen_kwargs["pad_token_id"],
|
211 |
stopping_criteria=stop_criteria,
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
if thinking_content_start_index != -1:
|
229 |
-
content_start_index = final_model_response_content.rindex("content:", thinking_content_start_index)
|
230 |
-
final_response_to_display = final_model_response_content[content_start_index + len("content:"):].strip()
|
231 |
-
else:
|
232 |
final_response_to_display = final_model_response_content
|
233 |
-
|
234 |
final_response_to_display = final_model_response_content
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
direct_response_to_display = model_response_content_first_pass[content_start_index + len("content:"):].strip()
|
254 |
-
else:
|
255 |
-
direct_response_to_display = model_response_content_first_pass
|
256 |
-
except ValueError:
|
257 |
-
direct_response_to_display = model_response_content_first_pass
|
258 |
-
else:
|
259 |
-
direct_response_to_display = model_response_content_first_pass
|
260 |
-
|
261 |
-
# Ensure the last message in chat_history is the fully generated one.
|
262 |
-
chat_history[-1][1] = direct_response_to_display
|
263 |
-
yield chat_history, chat_history
|
264 |
|
265 |
-
|
266 |
-
|
|
|
|
|
267 |
return chat_history, chat_history
|
268 |
|
269 |
-
|
270 |
with gr.Blocks() as demo:
|
271 |
gr.Markdown("## 🤖 Sam - SmilyAI Assistant")
|
272 |
gr.Markdown("Chat with **Sam**, an AI assistant built by [SmilyAI Labs](https://smily.ai). Toggle reasoning mode or choose a model below.")
|
|
|
8 |
import torch
|
9 |
import gradio as gr
|
10 |
import re
|
11 |
+
from duckduckgo_search import DDGS # Import DuckDuckGo Search
|
12 |
|
13 |
# Dictionary of available models
|
14 |
AVAILABLE_MODELS = {
|
|
|
24 |
model = None
|
25 |
|
26 |
# Initialize tokenizer and model globally for the first run
|
27 |
+
# This ensures they are loaded when the script starts
|
28 |
print(f"Initializing model: {DEFAULT_MODEL}...")
|
29 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
30 |
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype="auto", device_map="auto")
|
|
|
100 |
return f"✅ Loaded model: {selected_model_name}"
|
101 |
|
102 |
def respond(message, chat_history, use_reasoning):
|
|
|
|
|
|
|
|
|
|
|
103 |
# Gradio's chat_history is a list of [user_message, bot_message] pairs.
|
104 |
# We need to convert it to the format expected by the model's chat template.
|
105 |
messages_for_template = [{"role": "system", "content": SYSTEM_PROMPT}]
|
106 |
+
for user_msg, bot_msg in chat_history:
|
107 |
+
if user_msg: # Only add if not empty
|
108 |
messages_for_template.append({"role": "user", "content": user_msg})
|
109 |
+
if bot_msg: # Only add if not empty
|
110 |
messages_for_template.append({"role": "assistant", "content": bot_msg})
|
111 |
|
112 |
# Add the current user message with the appropriate prefix
|
|
|
118 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
119 |
|
120 |
# --- First Generation Pass (for potential search query) ---
|
121 |
+
full_response_parts = []
|
122 |
+
current_chat_history_for_yield = list(chat_history) # Create a copy for yielding
|
123 |
+
|
124 |
+
# Ensure streamer is correctly set up for this generation
|
125 |
+
# For Gradio streaming, we need to manually yield tokens.
|
126 |
+
# TextStreamer is for console output; here we'll collect and yield.
|
127 |
with torch.no_grad():
|
128 |
+
# Use an iterable for token-by-token generation for Gradio
|
129 |
+
# This is a common pattern for streaming outputs in Gradio
|
130 |
+
# We will collect the full response to check for search tags.
|
131 |
+
generated_ids = model.generate(
|
132 |
input_ids=inputs["input_ids"],
|
133 |
attention_mask=inputs.get("attention_mask"),
|
134 |
max_new_tokens=gen_kwargs["max_new_tokens"],
|
|
|
140 |
eos_token_id=gen_kwargs["eos_token_id"],
|
141 |
pad_token_id=gen_kwargs["pad_token_id"],
|
142 |
stopping_criteria=stop_criteria,
|
143 |
+
streamer=TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # For console debugging
|
144 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
+
# Decode the full generated output
|
147 |
+
full_generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=False)
|
|
|
148 |
|
149 |
+
# The actual response from the model (excluding the prompt)
|
150 |
+
model_response_content = full_generated_text[len(prompt):].strip()
|
151 |
+
|
152 |
+
# Try to extract content after 'content:' if in thinking mode
|
153 |
+
if use_reasoning:
|
154 |
+
try:
|
155 |
+
# Find the last occurrence of "content:" after "thinking content:"
|
156 |
+
thinking_content_start_index = model_response_content.find("thinking content:")
|
157 |
+
if thinking_content_start_index != -1:
|
158 |
+
content_start_index = model_response_content.rindex("content:", thinking_content_start_index)
|
159 |
+
response_to_display = model_response_content[content_start_index + len("content:"):].strip()
|
160 |
+
else:
|
161 |
+
response_to_display = model_response_content
|
162 |
+
except ValueError:
|
163 |
+
response_to_display = model_response_content
|
164 |
+
else:
|
165 |
+
response_to_display = model_response_content
|
166 |
|
167 |
+
# Check for search query in the model's response
|
168 |
+
search_match = re.search(r'<search>(.*?)</search>', response_to_display, re.DOTALL)
|
|
|
|
|
169 |
|
170 |
+
if search_match:
|
171 |
+
search_query = search_match.group(1).strip()
|
172 |
+
|
173 |
+
# Update history with the model's initial response containing the search query
|
174 |
+
# For display purposes in Gradio, we'll show the search request, then results, then final answer.
|
175 |
+
current_chat_history_for_yield.append([message, f"SAM: Detecting search query: '{search_query}'...\nSearching the web..."])
|
176 |
+
yield current_chat_history_for_yield, current_chat_history_for_yield # Yield interim state
|
177 |
+
|
178 |
+
search_results = search_duckduckgo(search_query)
|
179 |
+
|
180 |
+
# Now, prepare for the second generation pass with search results
|
181 |
+
# Append model's response (with search tag) and tool response to history
|
182 |
+
messages_for_template.append({"role": "assistant", "content": response_to_display}) # Model's initial response with search tag
|
183 |
+
messages_for_template.append({"role": "tool_response", "content": f"Search results for \"{search_query}\": {search_results}"})
|
184 |
+
|
185 |
+
prompt_with_search_results = tokenizer.apply_chat_template(messages_for_template, tokenize=False, add_generation_prompt=True)
|
186 |
+
inputs_with_search = tokenizer([prompt_with_search_results], return_tensors="pt").to(model.device)
|
187 |
|
188 |
+
current_chat_history_for_yield[-1][1] += f"\nSearch results for \"{search_query}\": {search_results}\nSAM: Thinking with results..."
|
189 |
+
yield current_chat_history_for_yield, current_chat_history_for_yield # Yield interim state
|
|
|
190 |
|
191 |
+
# --- Second Generation Pass (with search results) ---
|
192 |
+
final_response_parts = []
|
193 |
+
final_generated_ids = model.generate(
|
|
|
194 |
input_ids=inputs_with_search["input_ids"],
|
195 |
attention_mask=inputs_with_search.get("attention_mask"),
|
196 |
max_new_tokens=gen_kwargs["max_new_tokens"],
|
|
|
202 |
eos_token_id=gen_kwargs["eos_token_id"],
|
203 |
pad_token_id=gen_kwargs["pad_token_id"],
|
204 |
stopping_criteria=stop_criteria,
|
205 |
+
streamer=TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # For console debugging
|
206 |
+
)
|
207 |
+
|
208 |
+
final_full_generated_text = tokenizer.decode(final_generated_ids[0], skip_special_tokens=False)
|
209 |
+
final_model_response_content = final_full_generated_text[len(prompt_with_search_results):].strip()
|
210 |
+
|
211 |
+
# Extract content after 'content:' for the final response
|
212 |
+
if use_reasoning:
|
213 |
+
try:
|
214 |
+
thinking_content_start_index = final_model_response_content.find("thinking content:")
|
215 |
+
if thinking_content_start_index != -1:
|
216 |
+
content_start_index = final_model_response_content.rindex("content:", thinking_content_start_index)
|
217 |
+
final_response_to_display = final_model_response_content[content_start_index + len("content:"):].strip()
|
218 |
+
else:
|
219 |
+
final_response_to_display = final_model_response_content
|
220 |
+
except ValueError:
|
|
|
|
|
|
|
|
|
221 |
final_response_to_display = final_model_response_content
|
222 |
+
else:
|
223 |
final_response_to_display = final_model_response_content
|
224 |
+
|
225 |
+
# Yield token by token for the final response
|
226 |
+
for char in final_response_to_display:
|
227 |
+
final_response_parts.append(char)
|
228 |
+
# Update the last message in chat_history for streaming in Gradio
|
229 |
+
current_chat_history_for_yield[-1][1] = "SAM: Thinking with results...\n" + "".join(final_response_parts)
|
230 |
+
yield current_chat_history_for_yield, current_chat_history_for_yield
|
231 |
+
|
232 |
+
# After streaming, update the actual chat_history for the next turn
|
233 |
+
chat_history.append((message, final_response_to_display))
|
234 |
+
|
235 |
+
else: # No search query detected in the first pass
|
236 |
+
# Yield token by token for the direct response
|
237 |
+
for char in response_to_display:
|
238 |
+
full_response_parts.append(char)
|
239 |
+
# Update the last message in chat_history for streaming in Gradio
|
240 |
+
current_chat_history_for_yield.append([message, "".join(full_response_parts)])
|
241 |
+
yield current_chat_history_for_yield, current_chat_history_for_yield
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
+
# After streaming, update the actual chat_history for the next turn
|
244 |
+
chat_history.append((message, response_to_display))
|
245 |
+
|
246 |
+
# Return the final chat history for the state
|
247 |
return chat_history, chat_history
|
248 |
|
|
|
249 |
with gr.Blocks() as demo:
|
250 |
gr.Markdown("## 🤖 Sam - SmilyAI Assistant")
|
251 |
gr.Markdown("Chat with **Sam**, an AI assistant built by [SmilyAI Labs](https://smily.ai). Toggle reasoning mode or choose a model below.")
|