import gradio as gr from llama_cpp import Llama from huggingface_hub import hf_hub_download from transformers import pipeline import re # For sentence splitting # --- Model Configuration (same as before) --- MODEL_REPO_ID = "mradermacher/Sam-reason-v3-GGUF" MODEL_FILENAME = "Sam-reason-v3.Q4_K_M.gguf" N_CTX = 2048 MAX_TOKENS = 500 TEMPERATURE = 0.7 TOP_P = 0.9 STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these # --- Safety Configuration --- # Initialize the toxicity classifier pipeline # This model identifies if text is 'toxic' with a confidence score. print("Loading safety model (unitary/toxic-bert)...") try: safety_classifier = pipeline( "text-classification", model="unitary/toxic-bert", framework="pt" # Use PyTorch backend ) print("Safety model loaded successfully.") except Exception as e: print(f"Error loading safety model: {e}") # Consider handling this error more gracefully, e.g., run without safety if model fails to load exit(1) # Threshold for flagging content as unsafe (0.0 to 1.0) # A higher threshold means it's stricter (less likely to flag non-toxic content). TOXICITY_THRESHOLD = 0.9 def is_text_safe(text: str) -> tuple[bool, str | None]: """ Checks if the given text contains unsafe content using the safety classifier. Returns (True, None) if safe, or (False, detected_label) if unsafe. """ if not text.strip(): return True, None # Empty strings are safe try: # Classify the text. The model typically returns [{'label': 'toxic', 'score': X.XX}] # or [{'label': 'nontoxic', 'score': X.XX}] depending on thresholding in the model. # For unitary/toxic-bert, 'toxic' is the positive label. results = safety_classifier(text) if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD: print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})") return False, results[0]['label'] return True, None except Exception as e: print(f"Error during safety check: {e}") # If the safety check fails, consider it unsafe by default or log and let it pass. # For a robust solution, you might want to re-raise or yield an error message. return False, "safety_check_failed" # --- Main Model Loading (same as before) --- print(f"Downloading model: {MODEL_FILENAME} from {MODEL_REPO_ID}...") try: model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME) print(f"Model downloaded to: {model_path}") except Exception as e: print(f"Error downloading model: {e}") exit(1) print("Initializing Llama model (this may take a moment)...") try: llm = Llama( model_path=model_path, n_gpu_layers=0, # Force CPU usage n_ctx=N_CTX, verbose=False ) print("Llama model initialized successfully.") except Exception as e: print(f"Error initializing Llama model: {e}") exit(1) # --- Inference Function with Safety --- def generate_word_by_word_with_safety(prompt_text: str): """ Generates text word by word, checking each sentence for safety before yielding. """ formatted_prompt = f"USER: {prompt_text}\nASSISTANT:" current_sentence_buffer = "" full_output_so_far = "" # Stream tokens from the main LLM token_stream = llm.create_completion( formatted_prompt, max_tokens=MAX_TOKENS, stop=STOP_SEQUENCES, stream=True, temperature=TEMPERATURE, top_p=TOP_P, ) for chunk in token_stream: token = chunk["choices"][0]["text"] current_sentence_buffer += token full_output_so_far += token # Keep track of full output for comprehensive check if needed # Simple sentence detection (look for common sentence endings) if re.search(r'[.!?]\s*$', current_sentence_buffer) or len(current_sentence_buffer) > 100: # Max sentence length fallback is_safe, detected_label = is_text_safe(current_sentence_buffer) if not is_safe: print(f"Safety check failed for sentence: '{current_sentence_buffer.strip()}' (Detected: {detected_label})") yield "[Content removed due to safety guidelines]" # Replace unsafe content current_sentence_buffer = "" # Clear buffer for next tokens # Optionally: return here to stop further generation if first unsafe content is found. # If you return here, uncomment the `return` statement below. # return else: yield current_sentence_buffer # Yield the safe sentence current_sentence_buffer = "" # Clear buffer for next sentence # After the loop, check and yield any remaining text in the buffer if current_sentence_buffer.strip(): is_safe, detected_label = is_text_safe(current_sentence_buffer) if not is_safe: print(f"Safety check failed for remaining text: '{current_sentence_buffer.strip()}' (Detected: {detected_label})") yield "[Content removed due to safety guidelines]" else: yield current_sentence_buffer # --- Gradio Blocks Interface --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # SmilyAI: Sam-reason-v3-GGUF Word-by-Word Inference (CPU with Safety Filter) Enter a prompt and get a word-by-word response from the Sam-reason-v3-GGUF model. **Please note:** All generated sentences are checked for safety using an AI filter. Potentially unsafe content will be replaced with `[Content removed due to safety guidelines]`. Running on Hugging Face Spaces' free CPU tier. """ ) with gr.Row(): user_prompt = gr.Textbox( lines=5, label="Enter your prompt here:", placeholder="e.g., Explain the concept of quantum entanglement in simple terms.", scale=4 ) generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6) send_button = gr.Button("Send", variant="primary") # Connect the button click to the inference function with safety check send_button.click( fn=generate_word_by_word_with_safety, # Use the new safety-enabled function inputs=user_prompt, outputs=generated_text, api_name="predict", ) # Launch the Gradio application if __name__ == "__main__": print("Launching Gradio app...") demo.launch(server_name="0.0.0.0", server_port=7860)