Sam-S-3-api / app.py
boning123's picture
Update app.py
cffaee2 verified
raw
history blame
6.65 kB
import gradio as gr
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
from transformers import pipeline
import re # For sentence splitting
# --- Model Configuration (same as before) ---
MODEL_REPO_ID = "mradermacher/Sam-reason-v3-GGUF"
MODEL_FILENAME = "Sam-reason-v3.Q4_K_M.gguf"
N_CTX = 2048
MAX_TOKENS = 500
TEMPERATURE = 0.7
TOP_P = 0.9
STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these
# --- Safety Configuration ---
# Initialize the toxicity classifier pipeline
# This model identifies if text is 'toxic' with a confidence score.
print("Loading safety model (unitary/toxic-bert)...")
try:
safety_classifier = pipeline(
"text-classification",
model="unitary/toxic-bert",
framework="pt" # Use PyTorch backend
)
print("Safety model loaded successfully.")
except Exception as e:
print(f"Error loading safety model: {e}")
# Consider handling this error more gracefully, e.g., run without safety if model fails to load
exit(1)
# Threshold for flagging content as unsafe (0.0 to 1.0)
# A higher threshold means it's stricter (less likely to flag non-toxic content).
TOXICITY_THRESHOLD = 0.9
def is_text_safe(text: str) -> tuple[bool, str | None]:
"""
Checks if the given text contains unsafe content using the safety classifier.
Returns (True, None) if safe, or (False, detected_label) if unsafe.
"""
if not text.strip():
return True, None # Empty strings are safe
try:
# Classify the text. The model typically returns [{'label': 'toxic', 'score': X.XX}]
# or [{'label': 'nontoxic', 'score': X.XX}] depending on thresholding in the model.
# For unitary/toxic-bert, 'toxic' is the positive label.
results = safety_classifier(text)
if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
return False, results[0]['label']
return True, None
except Exception as e:
print(f"Error during safety check: {e}")
# If the safety check fails, consider it unsafe by default or log and let it pass.
# For a robust solution, you might want to re-raise or yield an error message.
return False, "safety_check_failed"
# --- Main Model Loading (same as before) ---
print(f"Downloading model: {MODEL_FILENAME} from {MODEL_REPO_ID}...")
try:
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
print(f"Model downloaded to: {model_path}")
except Exception as e:
print(f"Error downloading model: {e}")
exit(1)
print("Initializing Llama model (this may take a moment)...")
try:
llm = Llama(
model_path=model_path,
n_gpu_layers=0, # Force CPU usage
n_ctx=N_CTX,
verbose=False
)
print("Llama model initialized successfully.")
except Exception as e:
print(f"Error initializing Llama model: {e}")
exit(1)
# --- Inference Function with Safety ---
def generate_word_by_word_with_safety(prompt_text: str):
"""
Generates text word by word, checking each sentence for safety before yielding.
"""
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
current_sentence_buffer = ""
full_output_so_far = ""
# Stream tokens from the main LLM
token_stream = llm.create_completion(
formatted_prompt,
max_tokens=MAX_TOKENS,
stop=STOP_SEQUENCES,
stream=True,
temperature=TEMPERATURE,
top_p=TOP_P,
)
for chunk in token_stream:
token = chunk["choices"][0]["text"]
current_sentence_buffer += token
full_output_so_far += token # Keep track of full output for comprehensive check if needed
# Simple sentence detection (look for common sentence endings)
if re.search(r'[.!?]\s*$', current_sentence_buffer) or len(current_sentence_buffer) > 100: # Max sentence length fallback
is_safe, detected_label = is_text_safe(current_sentence_buffer)
if not is_safe:
print(f"Safety check failed for sentence: '{current_sentence_buffer.strip()}' (Detected: {detected_label})")
yield "[Content removed due to safety guidelines]" # Replace unsafe content
current_sentence_buffer = "" # Clear buffer for next tokens
# Optionally: return here to stop further generation if first unsafe content is found.
# If you return here, uncomment the `return` statement below.
# return
else:
yield current_sentence_buffer # Yield the safe sentence
current_sentence_buffer = "" # Clear buffer for next sentence
# After the loop, check and yield any remaining text in the buffer
if current_sentence_buffer.strip():
is_safe, detected_label = is_text_safe(current_sentence_buffer)
if not is_safe:
print(f"Safety check failed for remaining text: '{current_sentence_buffer.strip()}' (Detected: {detected_label})")
yield "[Content removed due to safety guidelines]"
else:
yield current_sentence_buffer
# --- Gradio Blocks Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# SmilyAI: Sam-reason-v3-GGUF Word-by-Word Inference (CPU with Safety Filter)
Enter a prompt and get a word-by-word response from the Sam-reason-v3-GGUF model.
**Please note:** All generated sentences are checked for safety using an AI filter.
Potentially unsafe content will be replaced with `[Content removed due to safety guidelines]`.
Running on Hugging Face Spaces' free CPU tier.
"""
)
with gr.Row():
user_prompt = gr.Textbox(
lines=5,
label="Enter your prompt here:",
placeholder="e.g., Explain the concept of quantum entanglement in simple terms.",
scale=4
)
generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6)
send_button = gr.Button("Send", variant="primary")
# Connect the button click to the inference function with safety check
send_button.click(
fn=generate_word_by_word_with_safety, # Use the new safety-enabled function
inputs=user_prompt,
outputs=generated_text,
api_name="predict",
)
# Launch the Gradio application
if __name__ == "__main__":
print("Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=7860)