Spaces:
Running
Running
import gradio as gr | |
from llama_cpp import Llama | |
from huggingface_hub import hf_hub_download | |
from transformers import pipeline | |
import re # For sentence splitting | |
# --- Model Configuration (same as before) --- | |
MODEL_REPO_ID = "mradermacher/Sam-reason-v3-GGUF" | |
MODEL_FILENAME = "Sam-reason-v3.Q4_K_M.gguf" | |
N_CTX = 2048 | |
MAX_TOKENS = 500 | |
TEMPERATURE = 0.7 | |
TOP_P = 0.9 | |
STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these | |
# --- Safety Configuration --- | |
# Initialize the toxicity classifier pipeline | |
# This model identifies if text is 'toxic' with a confidence score. | |
print("Loading safety model (unitary/toxic-bert)...") | |
try: | |
safety_classifier = pipeline( | |
"text-classification", | |
model="unitary/toxic-bert", | |
framework="pt" # Use PyTorch backend | |
) | |
print("Safety model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading safety model: {e}") | |
# Consider handling this error more gracefully, e.g., run without safety if model fails to load | |
exit(1) | |
# Threshold for flagging content as unsafe (0.0 to 1.0) | |
# A higher threshold means it's stricter (less likely to flag non-toxic content). | |
TOXICITY_THRESHOLD = 0.9 | |
def is_text_safe(text: str) -> tuple[bool, str | None]: | |
""" | |
Checks if the given text contains unsafe content using the safety classifier. | |
Returns (True, None) if safe, or (False, detected_label) if unsafe. | |
""" | |
if not text.strip(): | |
return True, None # Empty strings are safe | |
try: | |
# Classify the text. The model typically returns [{'label': 'toxic', 'score': X.XX}] | |
# or [{'label': 'nontoxic', 'score': X.XX}] depending on thresholding in the model. | |
# For unitary/toxic-bert, 'toxic' is the positive label. | |
results = safety_classifier(text) | |
if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD: | |
print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})") | |
return False, results[0]['label'] | |
return True, None | |
except Exception as e: | |
print(f"Error during safety check: {e}") | |
# If the safety check fails, consider it unsafe by default or log and let it pass. | |
# For a robust solution, you might want to re-raise or yield an error message. | |
return False, "safety_check_failed" | |
# --- Main Model Loading (same as before) --- | |
print(f"Downloading model: {MODEL_FILENAME} from {MODEL_REPO_ID}...") | |
try: | |
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME) | |
print(f"Model downloaded to: {model_path}") | |
except Exception as e: | |
print(f"Error downloading model: {e}") | |
exit(1) | |
print("Initializing Llama model (this may take a moment)...") | |
try: | |
llm = Llama( | |
model_path=model_path, | |
n_gpu_layers=0, # Force CPU usage | |
n_ctx=N_CTX, | |
verbose=False | |
) | |
print("Llama model initialized successfully.") | |
except Exception as e: | |
print(f"Error initializing Llama model: {e}") | |
exit(1) | |
# --- Inference Function with Safety --- | |
def generate_word_by_word_with_safety(prompt_text: str): | |
""" | |
Generates text word by word, checking each sentence for safety before yielding. | |
""" | |
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:" | |
current_sentence_buffer = "" | |
full_output_so_far = "" | |
# Stream tokens from the main LLM | |
token_stream = llm.create_completion( | |
formatted_prompt, | |
max_tokens=MAX_TOKENS, | |
stop=STOP_SEQUENCES, | |
stream=True, | |
temperature=TEMPERATURE, | |
top_p=TOP_P, | |
) | |
for chunk in token_stream: | |
token = chunk["choices"][0]["text"] | |
current_sentence_buffer += token | |
full_output_so_far += token # Keep track of full output for comprehensive check if needed | |
# Simple sentence detection (look for common sentence endings) | |
if re.search(r'[.!?]\s*$', current_sentence_buffer) or len(current_sentence_buffer) > 100: # Max sentence length fallback | |
is_safe, detected_label = is_text_safe(current_sentence_buffer) | |
if not is_safe: | |
print(f"Safety check failed for sentence: '{current_sentence_buffer.strip()}' (Detected: {detected_label})") | |
yield "[Content removed due to safety guidelines]" # Replace unsafe content | |
current_sentence_buffer = "" # Clear buffer for next tokens | |
# Optionally: return here to stop further generation if first unsafe content is found. | |
# If you return here, uncomment the `return` statement below. | |
# return | |
else: | |
yield current_sentence_buffer # Yield the safe sentence | |
current_sentence_buffer = "" # Clear buffer for next sentence | |
# After the loop, check and yield any remaining text in the buffer | |
if current_sentence_buffer.strip(): | |
is_safe, detected_label = is_text_safe(current_sentence_buffer) | |
if not is_safe: | |
print(f"Safety check failed for remaining text: '{current_sentence_buffer.strip()}' (Detected: {detected_label})") | |
yield "[Content removed due to safety guidelines]" | |
else: | |
yield current_sentence_buffer | |
# --- Gradio Blocks Interface --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# SmilyAI: Sam-reason-v3-GGUF Word-by-Word Inference (CPU with Safety Filter) | |
Enter a prompt and get a word-by-word response from the Sam-reason-v3-GGUF model. | |
**Please note:** All generated sentences are checked for safety using an AI filter. | |
Potentially unsafe content will be replaced with `[Content removed due to safety guidelines]`. | |
Running on Hugging Face Spaces' free CPU tier. | |
""" | |
) | |
with gr.Row(): | |
user_prompt = gr.Textbox( | |
lines=5, | |
label="Enter your prompt here:", | |
placeholder="e.g., Explain the concept of quantum entanglement in simple terms.", | |
scale=4 | |
) | |
generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6) | |
send_button = gr.Button("Send", variant="primary") | |
# Connect the button click to the inference function with safety check | |
send_button.click( | |
fn=generate_word_by_word_with_safety, # Use the new safety-enabled function | |
inputs=user_prompt, | |
outputs=generated_text, | |
api_name="predict", | |
) | |
# Launch the Gradio application | |
if __name__ == "__main__": | |
print("Launching Gradio app...") | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |