from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextStreamer, StoppingCriteriaList, StoppingCriteria ) import torch import gradio as gr import re from duckduckgo_search import DDGS # Import DuckDuckGo Search # Dictionary of available models AVAILABLE_MODELS = { "Sam-reason-A1": "Smilyai-labs/Sam-reason-A1" } # Load default model DEFAULT_MODEL = "Sam-reason-A1" MODEL_PATH = AVAILABLE_MODELS[DEFAULT_MODEL] # Global variables for tokenizer and model (will be updated by load_model) tokenizer = None model = None # Initialize tokenizer and model globally for the first run # This ensures they are loaded when the script starts print(f"Initializing model: {DEFAULT_MODEL}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype="auto", device_map="auto") print("Model initialized.") # Generation parameters gen_kwargs = { "max_new_tokens": 512, "temperature": 0.7, "top_k": 50, "top_p": 0.95, "repetition_penalty": 1.2, "do_sample": True, "eos_token_id": tokenizer.eos_token_id, "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id, } # Prompt prefixes REASON_PROMPT = "/think\n" NO_REASON_PROMPT = "/no_thinking\n" # System prompt SYSTEM_PROMPT = ( "You are Sam, an AI assistant by SmilyAI. " "When you need to search for information, output your search query enclosed in tags, " "like this: your search query here. " "The search results will then be provided to you. " "Never use the tags unless you specifically need to perform a search for real-time data. " "Do not use the search tags in your reasoning or conversation, only for the search query itself." ) # Stopping criteria class class StopOnTokens(StoppingCriteria): def __init__(self, stop_token_ids): self.stop_token_ids = stop_token_ids def __call__(self, input_ids, scores): return any(t.item() in self.stop_token_ids for t in input_ids[0][-1:]) # Note: streamer and stop_criteria need to be defined AFTER tokenizer is loaded # They will be recreated in `load_model` if the tokenizer changes. streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) stop_criteria = StoppingCriteriaList([ StopOnTokens([tokenizer.eos_token_id]) ]) def search_duckduckgo(query): """ Performs a DuckDuckGo search and returns the top result's snippet. """ try: results = DDGS().text(keywords=query, max_results=1) if results: return results[0]['body'] else: return "No search results found." except Exception as e: return f"An error occurred during search: {e}" def load_model(selected_model_name): global model, tokenizer, streamer, stop_criteria model_path = AVAILABLE_MODELS[selected_model_name] print(f"Loading model: {selected_model_name} from {model_path}...") tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto") # Re-initialize streamer and stop_criteria with the new tokenizer streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) stop_criteria = StoppingCriteriaList([ StopOnTokens([tokenizer.eos_token_id]) ]) print(f"Successfully loaded model: {selected_model_name}") return f"✅ Loaded model: {selected_model_name}" def respond(message, chat_history, use_reasoning): # Gradio's chat_history is a list of [user_message, bot_message] pairs. # We need to convert it to the format expected by the model's chat template. messages_for_template = [{"role": "system", "content": SYSTEM_PROMPT}] for user_msg, bot_msg in chat_history: if user_msg: # Only add if not empty messages_for_template.append({"role": "user", "content": user_msg}) if bot_msg: # Only add if not empty messages_for_template.append({"role": "assistant", "content": bot_msg}) # Add the current user message with the appropriate prefix current_user_message_content = (REASON_PROMPT if use_reasoning else NO_REASON_PROMPT) + message messages_for_template.append({"role": "user", "content": current_user_message_content}) # Apply chat template prompt = tokenizer.apply_chat_template(messages_for_template, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # --- First Generation Pass (for potential search query) --- full_response_parts = [] current_chat_history_for_yield = list(chat_history) # Create a copy for yielding # Ensure streamer is correctly set up for this generation # For Gradio streaming, we need to manually yield tokens. # TextStreamer is for console output; here we'll collect and yield. with torch.no_grad(): # Use an iterable for token-by-token generation for Gradio # This is a common pattern for streaming outputs in Gradio # We will collect the full response to check for search tags. generated_ids = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"), max_new_tokens=gen_kwargs["max_new_tokens"], temperature=gen_kwargs["temperature"], top_k=gen_kwargs["top_k"], top_p=gen_kwargs["top_p"], repetition_penalty=gen_kwargs["repetition_penalty"], do_sample=gen_kwargs["do_sample"], eos_token_id=gen_kwargs["eos_token_id"], pad_token_id=gen_kwargs["pad_token_id"], stopping_criteria=stop_criteria, streamer=TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # For console debugging ) # Decode the full generated output full_generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=False) # The actual response from the model (excluding the prompt) model_response_content = full_generated_text[len(prompt):].strip() # Try to extract content after 'content:' if in thinking mode if use_reasoning: try: # Find the last occurrence of "content:" after "thinking content:" thinking_content_start_index = model_response_content.find("thinking content:") if thinking_content_start_index != -1: content_start_index = model_response_content.rindex("content:", thinking_content_start_index) response_to_display = model_response_content[content_start_index + len("content:"):].strip() else: response_to_display = model_response_content except ValueError: response_to_display = model_response_content else: response_to_display = model_response_content # Check for search query in the model's response search_match = re.search(r'(.*?)', response_to_display, re.DOTALL) if search_match: search_query = search_match.group(1).strip() # Update history with the model's initial response containing the search query # For display purposes in Gradio, we'll show the search request, then results, then final answer. current_chat_history_for_yield.append([message, f"SAM: Detecting search query: '{search_query}'...\nSearching the web..."]) yield current_chat_history_for_yield, current_chat_history_for_yield # Yield interim state search_results = search_duckduckgo(search_query) # Now, prepare for the second generation pass with search results # Append model's response (with search tag) and tool response to history messages_for_template.append({"role": "assistant", "content": response_to_display}) # Model's initial response with search tag messages_for_template.append({"role": "tool_response", "content": f"Search results for \"{search_query}\": {search_results}"}) prompt_with_search_results = tokenizer.apply_chat_template(messages_for_template, tokenize=False, add_generation_prompt=True) inputs_with_search = tokenizer([prompt_with_search_results], return_tensors="pt").to(model.device) current_chat_history_for_yield[-1][1] += f"\nSearch results for \"{search_query}\": {search_results}\nSAM: Thinking with results..." yield current_chat_history_for_yield, current_chat_history_for_yield # Yield interim state # --- Second Generation Pass (with search results) --- final_response_parts = [] final_generated_ids = model.generate( input_ids=inputs_with_search["input_ids"], attention_mask=inputs_with_search.get("attention_mask"), max_new_tokens=gen_kwargs["max_new_tokens"], temperature=gen_kwargs["temperature"], top_k=gen_kwargs["top_k"], top_p=gen_kwargs["top_p"], repetition_penalty=gen_kwargs["repetition_penalty"], do_sample=gen_kwargs["do_sample"], eos_token_id=gen_kwargs["eos_token_id"], pad_token_id=gen_kwargs["pad_token_id"], stopping_criteria=stop_criteria, streamer=TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # For console debugging ) final_full_generated_text = tokenizer.decode(final_generated_ids[0], skip_special_tokens=False) final_model_response_content = final_full_generated_text[len(prompt_with_search_results):].strip() # Extract content after 'content:' for the final response if use_reasoning: try: thinking_content_start_index = final_model_response_content.find("thinking content:") if thinking_content_start_index != -1: content_start_index = final_model_response_content.rindex("content:", thinking_content_start_index) final_response_to_display = final_model_response_content[content_start_index + len("content:"):].strip() else: final_response_to_display = final_model_response_content except ValueError: final_response_to_display = final_model_response_content else: final_response_to_display = final_model_response_content # Yield token by token for the final response for char in final_response_to_display: final_response_parts.append(char) # Update the last message in chat_history for streaming in Gradio current_chat_history_for_yield[-1][1] = "SAM: Thinking with results...\n" + "".join(final_response_parts) yield current_chat_history_for_yield, current_chat_history_for_yield # After streaming, update the actual chat_history for the next turn chat_history.append((message, final_response_to_display)) else: # No search query detected in the first pass # Yield token by token for the direct response for char in response_to_display: full_response_parts.append(char) # Update the last message in chat_history for streaming in Gradio current_chat_history_for_yield.append([message, "".join(full_response_parts)]) yield current_chat_history_for_yield, current_chat_history_for_yield # After streaming, update the actual chat_history for the next turn chat_history.append((message, response_to_display)) # Return the final chat history for the state return chat_history, chat_history with gr.Blocks() as demo: gr.Markdown("## 🤖 Sam - SmilyAI Assistant") gr.Markdown("Chat with **Sam**, an AI assistant built by [SmilyAI Labs](http://smilyai.rf.gd). Toggle reasoning mode or choose a model below.") with gr.Row(): model_selector = gr.Dropdown( choices=list(AVAILABLE_MODELS.keys()), value=DEFAULT_MODEL, label="Select Model", interactive=True ) load_status = gr.Textbox(label="Status", value=f"Loaded: {DEFAULT_MODEL}", interactive=False) chatbot = gr.Chatbot() msg = gr.Textbox(label="Your Message") with gr.Row(): clear = gr.Button("Clear Chat") use_reasoning = gr.Checkbox(label="Enable Reasoning Mode (/think)", value=False) state = gr.State([]) # Stores the full chat history # Load model when selector changes model_selector.change(fn=load_model, inputs=model_selector, outputs=load_status) # Use .submit to trigger the response function # `msg` is the user input, `state` is the chat history, `use_reasoning` is the checkbox value. # The `respond` function will yield updates to `chatbot` and `state`. msg.submit(respond, [msg, state, use_reasoning], [chatbot, state]) clear.click(lambda: ([], []), None, [chatbot, state]) demo.launch()