import gradio as gr from huggingface_hub import hf_hub_download from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig, pipeline # <-- Changed import here import re import os import torch import threading import time # --- Model Configuration --- MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3" N_CTX = 2048 MAX_TOKENS = 500 TEMPERATURE = 0.7 TOP_P = 0.9 STOP_SEQUENCES = ["USER:", "\n\n"] # --- Safety Configuration --- print("Loading safety model (unitary/toxic-bert)...") try: # Using the directly imported pipeline function safety_classifier = pipeline( "text-classification", model="unitary/toxic-bert", framework="pt" ) print("Safety model loaded successfully.") except Exception as e: print(f"Error loading safety model: {e}") exit(1) TOXICITY_THRESHOLD = 0.9 def is_text_safe(text: str) -> tuple[bool, str | None]: if not text.strip(): return True, None try: results = safety_classifier(text) if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD: print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})") return False, results[0]['label'] return True, None except Exception as e: print(f"Error during safety check: {e}") return False, "safety_check_failed" # --- Main Model Loading (using Transformers) --- print(f"Loading tokenizer for {MODEL_REPO_ID}...") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID) print("Tokenizer loaded.") except Exception as e: print(f"Error loading tokenizer: {e}") print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.") exit(1) print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...") try: model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32) model.eval() print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.") exit(1) # Configure generation for streaming generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID) generation_config.max_new_tokens = MAX_TOKENS generation_config.temperature = TEMPERATURE generation_config.top_p = TOP_P generation_config.do_sample = True generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1 generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1 if generation_config.pad_token_id == -1: generation_config.pad_token_id = 0 # --- Custom Streamer for Gradio and Safety Check --- class GradioSafetyStreamer(TextIteratorStreamer): def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs): super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs) self.safety_checker_fn = safety_checker_fn self.toxicity_threshold = toxicity_threshold self.current_sentence_buffer = "" self.output_queue = [] self.sentence_regex = re.compile(r'[.!?]\s*') self.text_done = threading.Event() def on_finalized_text(self, text: str, stream_end: bool = False): self.current_sentence_buffer += text sentences = self.sentence_regex.split(self.current_sentence_buffer) sentences_to_process = [] if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None: sentences_to_process = sentences[:-1] self.current_sentence_buffer = sentences[-1] else: sentences_to_process = sentences self.current_sentence_buffer = "" for sentence in sentences_to_process: if not sentence.strip(): continue is_safe, detected_label = self.safety_checker_fn(sentence) if not is_safe: print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})") self.output_queue.append("[Content removed due to safety guidelines]") self.output_queue.append("__STOP_GENERATION__") return else: self.output_queue.append(sentence) if stream_end: if self.current_sentence_buffer.strip(): is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer) if not is_safe: self.output_queue.append("[Content removed due to safety guidelines]") else: self.output_queue.append(self.current_sentence_buffer) self.current_sentence_buffer = "" self.text_done.set() def __iter__(self): while True: if self.output_queue: item = self.output_queue.pop(0) if item == "__STOP_GENERATION__": raise StopIteration yield item elif self.text_done.is_set(): raise StopIteration else: time.sleep(0.01) # --- Inference Function with Safety and Streaming --- def generate_word_by_word_with_safety(prompt_text: str): formatted_prompt = f"USER: {prompt_text}\nASSISTANT:" input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device) streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD) generate_kwargs = { "input_ids": input_ids, "streamer": streamer, "generation_config": generation_config, "do_sample": True, "temperature": TEMPERATURE, "top_p": TOP_P, "max_new_tokens": MAX_TOKENS, "eos_token_id": generation_config.eos_token_id, "pad_token_id": generation_config.pad_token_id, } thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) thread.start() full_generated_text = "" try: for new_sentence_or_chunk in streamer: full_generated_text += new_sentence_or_chunk yield full_generated_text except StopIteration: pass except Exception as e: print(f"Error during streaming: {e}") yield full_generated_text + f"\n\n[Error during streaming: {e}]" finally: thread.join() # --- Gradio Blocks Interface --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter) Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model. **⚠️ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.** All generated sentences are checked for safety using an AI filter; unsafe content will be replaced. """ ) with gr.Row(): user_prompt = gr.Textbox( lines=5, label="Enter your prompt here:", placeholder="e.g., Explain the concept of quantum entanglement in simple terms.", scale=4 ) generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6) send_button = gr.Button("Send", variant="primary") send_button.click( fn=generate_word_by_word_with_safety, inputs=user_prompt, outputs=generated_text, api_name="predict", ) if __name__ == "__main__": print("Launching Gradio app...") demo.launch(server_name="0.0.0.0", server_port=7860)