import gradio as gr from huggingface_hub import hf_hub_download # Still useful if model is private and needs custom token from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig from transformers.pipelines import pipeline import re import os import torch # Required for transformers models import threading import time # For short sleeps in streamer # --- Model Configuration --- # Your SmilyAI model ID on Hugging Face Hub MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3" N_CTX = 2048 # Context window for the model (applies more to LLMs) MAX_TOKENS = 500 TEMPERATURE = 0.7 TOP_P = 0.9 STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these # --- Safety Configuration --- print("Loading safety model (unitary/toxic-bert)...") try: safety_classifier = pipeline( "text-classification", model="unitary/toxic-bert", framework="pt" # Use PyTorch backend ) print("Safety model loaded successfully.") except Exception as e: print(f"Error loading safety model: {e}") exit(1) TOXICITY_THRESHOLD = 0.9 def is_text_safe(text: str) -> tuple[bool, str | None]: """ Checks if the given text contains unsafe content using the safety classifier. Returns (True, None) if safe, or (False, detected_label) if unsafe. """ if not text.strip(): return True, None try: results = safety_classifier(text) if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD: print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})") return False, results[0]['label'] return True, None except Exception as e: print(f"Error during safety check: {e}") # If the safety check fails, consider it unsafe by default or log and let it pass. return False, "safety_check_failed" # --- Main Model Loading (using Transformers) --- print(f"Loading tokenizer for {MODEL_REPO_ID}...") try: # AutoTokenizer fetches the correct tokenizer for the model tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID) print("Tokenizer loaded.") except Exception as e: print(f"Error loading tokenizer: {e}") print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.") exit(1) print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...") try: # AutoModelForCausalLM loads the language model. # device_map="cpu" ensures all model layers are loaded onto the CPU. # torch_dtype=torch.float32 is standard for CPU; float16 can save memory but might not be faster on all CPUs. model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32) model.eval() # Set model to evaluation mode for inference print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.") exit(1) # Configure generation for streaming # Use GenerationConfig from the model for default parameters, then override as needed. generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID) generation_config.max_new_tokens = MAX_TOKENS generation_config.temperature = TEMPERATURE generation_config.top_p = TOP_P generation_config.do_sample = True # Enable sampling for temperature/top_p # Set EOS and PAD token IDs for proper generation stopping and padding generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1 generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1 # Fallback for pad_token_id if not explicitly set if generation_config.pad_token_id == -1: generation_config.pad_token_id = 0 # Fallback to 0, though not ideal for all models # --- Custom Streamer for Gradio and Safety Check --- class GradioSafetyStreamer(TextIteratorStreamer): def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs): super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs) self.safety_checker_fn = safety_checker_fn self.toxicity_threshold = toxicity_threshold self.current_sentence_buffer = "" self.output_queue = [] # Queue to store safety-checked sentences to be yielded by Gradio self.sentence_regex = re.compile(r'[.!?]\s*') # Regex for sentence end, simple version self.text_done = threading.Event() # Event to signal when internal text processing is complete def on_finalized_text(self, text: str, stream_end: bool = False): # This method is called by the superclass when a decoded token chunk is ready. self.current_sentence_buffer += text # Split buffer into sentences. Keep the last part in buffer if it's incomplete. sentences = self.sentence_regex.split(self.current_sentence_buffer) sentences_to_process = [] if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None: # If not end of stream and last part is not a complete sentence, buffer it for next time sentences_to_process = sentences[:-1] self.current_sentence_buffer = sentences[-1] else: # Otherwise, process all segments and clear buffer sentences_to_process = sentences self.current_sentence_buffer = "" for sentence in sentences_to_process: if not sentence.strip(): continue # Skip empty strings from splitting is_safe, detected_label = self.safety_checker_fn(sentence) if not is_safe: print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})") self.output_queue.append("[Content removed due to safety guidelines]") self.output_queue.append("__STOP_GENERATION__") # Special signal to stop LLM generation return # Stop processing further sentences from this chunk if unsafe else: self.output_queue.append(sentence) if stream_end: # If stream ends and there's leftover text in buffer, process it if self.current_sentence_buffer.strip(): is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer) if not is_safe: self.output_queue.append("[Content removed due to safety guidelines]") else: self.output_queue.append(self.current_sentence_buffer) self.current_sentence_buffer = "" # Clear after final check self.text_done.set() # Signal that all text processing is complete def __iter__(self): # This method allows Gradio to iterate over the safety-checked output. while True: if self.output_queue: item = self.output_queue.pop(0) if item == "__STOP_GENERATION__": # Signal to the outer Gradio loop to stop yielding. raise StopIteration yield item elif self.text_done.is_set(): # Check if internal generation and safety processing is truly finished raise StopIteration # End of generation and safety check else: time.sleep(0.01) # Small sleep to prevent busy-waiting while waiting for new tokens # --- Inference Function with Safety and Streaming --- def generate_word_by_word_with_safety(prompt_text: str): formatted_prompt = f"USER: {prompt_text}\nASSISTANT:" # Encode input on the model's device (CPU) input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device) # Initialize the custom streamer streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD) # Use a separate thread for model generation because model.generate is a blocking call. # This allows the streamer to continuously fill its queue while Gradio yields. generate_kwargs = { "input_ids": input_ids, "streamer": streamer, "generation_config": generation_config, # Explicitly pass these for clarity, even if in generation_config "do_sample": True, "temperature": TEMPERATURE, "top_p": TOP_P, "max_new_tokens": MAX_TOKENS, "eos_token_id": generation_config.eos_token_id, "pad_token_id": generation_config.pad_token_id, } # Start generation in a separate thread thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) thread.start() # Yield tokens from the streamer's output queue for Gradio to display progressively full_generated_text = "" try: for new_sentence_or_chunk in streamer: full_generated_text += new_sentence_or_chunk yield full_generated_text # Gradio expects accumulated string for streaming display except StopIteration: pass # Streamer signaled end except Exception as e: print(f"Error during streaming: {e}") yield full_generated_text + f"\n\n[Error during streaming: {e}]" # Show error in output finally: thread.join() # Ensure the generation thread finishes gracefully # --- Gradio Blocks Interface --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter) Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model. **⚠️ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.** All generated sentences are checked for safety using an AI filter; unsafe content will be replaced. """ ) with gr.Row(): user_prompt = gr.Textbox( lines=5, label="Enter your prompt here:", placeholder="e.g., Explain the concept of quantum entanglement in simple terms.", scale=4 ) generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6) send_button = gr.Button("Send", variant="primary") send_button.click( fn=generate_word_by_word_with_safety, inputs=user_prompt, outputs=generated_text, api_name="predict", ) if __name__ == "__main__": print("Launching Gradio app...") demo.launch(server_name="0.0.0.0", server_port=7860)