Sam-S-3-api / app.py
boning123's picture
Update app.py
7aef8f2 verified
import gradio as gr
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig, pipeline # <-- Changed import here
import re
import os
import torch
import threading
import time
# --- Model Configuration ---
MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3"
N_CTX = 2048
MAX_TOKENS = 500
TEMPERATURE = 0.7
TOP_P = 0.9
STOP_SEQUENCES = ["USER:", "\n\n"]
# --- Safety Configuration ---
print("Loading safety model (unitary/toxic-bert)...")
try:
# Using the directly imported pipeline function
safety_classifier = pipeline(
"text-classification",
model="unitary/toxic-bert",
framework="pt"
)
print("Safety model loaded successfully.")
except Exception as e:
print(f"Error loading safety model: {e}")
exit(1)
TOXICITY_THRESHOLD = 0.9
def is_text_safe(text: str) -> tuple[bool, str | None]:
if not text.strip():
return True, None
try:
results = safety_classifier(text)
if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
return False, results[0]['label']
return True, None
except Exception as e:
print(f"Error during safety check: {e}")
return False, "safety_check_failed"
# --- Main Model Loading (using Transformers) ---
print(f"Loading tokenizer for {MODEL_REPO_ID}...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
print("Tokenizer loaded.")
except Exception as e:
print(f"Error loading tokenizer: {e}")
print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.")
exit(1)
print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...")
try:
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32)
model.eval()
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.")
exit(1)
# Configure generation for streaming
generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID)
generation_config.max_new_tokens = MAX_TOKENS
generation_config.temperature = TEMPERATURE
generation_config.top_p = TOP_P
generation_config.do_sample = True
generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
if generation_config.pad_token_id == -1:
generation_config.pad_token_id = 0
# --- Custom Streamer for Gradio and Safety Check ---
class GradioSafetyStreamer(TextIteratorStreamer):
def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs):
super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs)
self.safety_checker_fn = safety_checker_fn
self.toxicity_threshold = toxicity_threshold
self.current_sentence_buffer = ""
self.output_queue = []
self.sentence_regex = re.compile(r'[.!?]\s*')
self.text_done = threading.Event()
def on_finalized_text(self, text: str, stream_end: bool = False):
self.current_sentence_buffer += text
sentences = self.sentence_regex.split(self.current_sentence_buffer)
sentences_to_process = []
if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None:
sentences_to_process = sentences[:-1]
self.current_sentence_buffer = sentences[-1]
else:
sentences_to_process = sentences
self.current_sentence_buffer = ""
for sentence in sentences_to_process:
if not sentence.strip(): continue
is_safe, detected_label = self.safety_checker_fn(sentence)
if not is_safe:
print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})")
self.output_queue.append("[Content removed due to safety guidelines]")
self.output_queue.append("__STOP_GENERATION__")
return
else:
self.output_queue.append(sentence)
if stream_end:
if self.current_sentence_buffer.strip():
is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer)
if not is_safe:
self.output_queue.append("[Content removed due to safety guidelines]")
else:
self.output_queue.append(self.current_sentence_buffer)
self.current_sentence_buffer = ""
self.text_done.set()
def __iter__(self):
while True:
if self.output_queue:
item = self.output_queue.pop(0)
if item == "__STOP_GENERATION__":
raise StopIteration
yield item
elif self.text_done.is_set():
raise StopIteration
else:
time.sleep(0.01)
# --- Inference Function with Safety and Streaming ---
def generate_word_by_word_with_safety(prompt_text: str):
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD)
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"generation_config": generation_config,
"do_sample": True,
"temperature": TEMPERATURE,
"top_p": TOP_P,
"max_new_tokens": MAX_TOKENS,
"eos_token_id": generation_config.eos_token_id,
"pad_token_id": generation_config.pad_token_id,
}
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
full_generated_text = ""
try:
for new_sentence_or_chunk in streamer:
full_generated_text += new_sentence_or_chunk
yield full_generated_text
except StopIteration:
pass
except Exception as e:
print(f"Error during streaming: {e}")
yield full_generated_text + f"\n\n[Error during streaming: {e}]"
finally:
thread.join()
# --- Gradio Blocks Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter)
Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model.
**⚠️ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.**
All generated sentences are checked for safety using an AI filter; unsafe content will be replaced.
"""
)
with gr.Row():
user_prompt = gr.Textbox(
lines=5,
label="Enter your prompt here:",
placeholder="e.g., Explain the concept of quantum entanglement in simple terms.",
scale=4
)
generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6)
send_button = gr.Button("Send", variant="primary")
send_button.click(
fn=generate_word_by_word_with_safety,
inputs=user_prompt,
outputs=generated_text,
api_name="predict",
)
if __name__ == "__main__":
print("Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=7860)