Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig, pipeline # <-- Changed import here | |
import re | |
import os | |
import torch | |
import threading | |
import time | |
# --- Model Configuration --- | |
MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3" | |
N_CTX = 2048 | |
MAX_TOKENS = 500 | |
TEMPERATURE = 0.7 | |
TOP_P = 0.9 | |
STOP_SEQUENCES = ["USER:", "\n\n"] | |
# --- Safety Configuration --- | |
print("Loading safety model (unitary/toxic-bert)...") | |
try: | |
# Using the directly imported pipeline function | |
safety_classifier = pipeline( | |
"text-classification", | |
model="unitary/toxic-bert", | |
framework="pt" | |
) | |
print("Safety model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading safety model: {e}") | |
exit(1) | |
TOXICITY_THRESHOLD = 0.9 | |
def is_text_safe(text: str) -> tuple[bool, str | None]: | |
if not text.strip(): | |
return True, None | |
try: | |
results = safety_classifier(text) | |
if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD: | |
print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})") | |
return False, results[0]['label'] | |
return True, None | |
except Exception as e: | |
print(f"Error during safety check: {e}") | |
return False, "safety_check_failed" | |
# --- Main Model Loading (using Transformers) --- | |
print(f"Loading tokenizer for {MODEL_REPO_ID}...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID) | |
print("Tokenizer loaded.") | |
except Exception as e: | |
print(f"Error loading tokenizer: {e}") | |
print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.") | |
exit(1) | |
print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...") | |
try: | |
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32) | |
model.eval() | |
print("Model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.") | |
exit(1) | |
# Configure generation for streaming | |
generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID) | |
generation_config.max_new_tokens = MAX_TOKENS | |
generation_config.temperature = TEMPERATURE | |
generation_config.top_p = TOP_P | |
generation_config.do_sample = True | |
generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1 | |
generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1 | |
if generation_config.pad_token_id == -1: | |
generation_config.pad_token_id = 0 | |
# --- Custom Streamer for Gradio and Safety Check --- | |
class GradioSafetyStreamer(TextIteratorStreamer): | |
def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs): | |
super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs) | |
self.safety_checker_fn = safety_checker_fn | |
self.toxicity_threshold = toxicity_threshold | |
self.current_sentence_buffer = "" | |
self.output_queue = [] | |
self.sentence_regex = re.compile(r'[.!?]\s*') | |
self.text_done = threading.Event() | |
def on_finalized_text(self, text: str, stream_end: bool = False): | |
self.current_sentence_buffer += text | |
sentences = self.sentence_regex.split(self.current_sentence_buffer) | |
sentences_to_process = [] | |
if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None: | |
sentences_to_process = sentences[:-1] | |
self.current_sentence_buffer = sentences[-1] | |
else: | |
sentences_to_process = sentences | |
self.current_sentence_buffer = "" | |
for sentence in sentences_to_process: | |
if not sentence.strip(): continue | |
is_safe, detected_label = self.safety_checker_fn(sentence) | |
if not is_safe: | |
print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})") | |
self.output_queue.append("[Content removed due to safety guidelines]") | |
self.output_queue.append("__STOP_GENERATION__") | |
return | |
else: | |
self.output_queue.append(sentence) | |
if stream_end: | |
if self.current_sentence_buffer.strip(): | |
is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer) | |
if not is_safe: | |
self.output_queue.append("[Content removed due to safety guidelines]") | |
else: | |
self.output_queue.append(self.current_sentence_buffer) | |
self.current_sentence_buffer = "" | |
self.text_done.set() | |
def __iter__(self): | |
while True: | |
if self.output_queue: | |
item = self.output_queue.pop(0) | |
if item == "__STOP_GENERATION__": | |
raise StopIteration | |
yield item | |
elif self.text_done.is_set(): | |
raise StopIteration | |
else: | |
time.sleep(0.01) | |
# --- Inference Function with Safety and Streaming --- | |
def generate_word_by_word_with_safety(prompt_text: str): | |
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:" | |
input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device) | |
streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD) | |
generate_kwargs = { | |
"input_ids": input_ids, | |
"streamer": streamer, | |
"generation_config": generation_config, | |
"do_sample": True, | |
"temperature": TEMPERATURE, | |
"top_p": TOP_P, | |
"max_new_tokens": MAX_TOKENS, | |
"eos_token_id": generation_config.eos_token_id, | |
"pad_token_id": generation_config.pad_token_id, | |
} | |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
full_generated_text = "" | |
try: | |
for new_sentence_or_chunk in streamer: | |
full_generated_text += new_sentence_or_chunk | |
yield full_generated_text | |
except StopIteration: | |
pass | |
except Exception as e: | |
print(f"Error during streaming: {e}") | |
yield full_generated_text + f"\n\n[Error during streaming: {e}]" | |
finally: | |
thread.join() | |
# --- Gradio Blocks Interface --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter) | |
Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model. | |
**⚠️ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.** | |
All generated sentences are checked for safety using an AI filter; unsafe content will be replaced. | |
""" | |
) | |
with gr.Row(): | |
user_prompt = gr.Textbox( | |
lines=5, | |
label="Enter your prompt here:", | |
placeholder="e.g., Explain the concept of quantum entanglement in simple terms.", | |
scale=4 | |
) | |
generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6) | |
send_button = gr.Button("Send", variant="primary") | |
send_button.click( | |
fn=generate_word_by_word_with_safety, | |
inputs=user_prompt, | |
outputs=generated_text, | |
api_name="predict", | |
) | |
if __name__ == "__main__": | |
print("Launching Gradio app...") | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |