Spaces:
Running
Running
File size: 7,851 Bytes
f38ab88 7aef8f2 c850ce2 7aef8f2 c850ce2 7aef8f2 c850ce2 7aef8f2 f38ab88 7aef8f2 f38ab88 cffaee2 7aef8f2 cffaee2 7aef8f2 cffaee2 c850ce2 cffaee2 c850ce2 f38ab88 c850ce2 f38ab88 c850ce2 f38ab88 c850ce2 f38ab88 c850ce2 7aef8f2 c850ce2 f38ab88 c850ce2 f38ab88 c850ce2 7aef8f2 c850ce2 7aef8f2 c850ce2 7aef8f2 c850ce2 cffaee2 c850ce2 7aef8f2 cffaee2 c850ce2 cffaee2 c850ce2 7aef8f2 c850ce2 cffaee2 c850ce2 7aef8f2 c850ce2 7aef8f2 c850ce2 7aef8f2 c850ce2 7aef8f2 c850ce2 7aef8f2 c850ce2 7aef8f2 c850ce2 7aef8f2 f38ab88 cffaee2 c850ce2 cffaee2 c850ce2 cffaee2 f38ab88 cffaee2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import gradio as gr
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig, pipeline # <-- Changed import here
import re
import os
import torch
import threading
import time
# --- Model Configuration ---
MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3"
N_CTX = 2048
MAX_TOKENS = 500
TEMPERATURE = 0.7
TOP_P = 0.9
STOP_SEQUENCES = ["USER:", "\n\n"]
# --- Safety Configuration ---
print("Loading safety model (unitary/toxic-bert)...")
try:
# Using the directly imported pipeline function
safety_classifier = pipeline(
"text-classification",
model="unitary/toxic-bert",
framework="pt"
)
print("Safety model loaded successfully.")
except Exception as e:
print(f"Error loading safety model: {e}")
exit(1)
TOXICITY_THRESHOLD = 0.9
def is_text_safe(text: str) -> tuple[bool, str | None]:
if not text.strip():
return True, None
try:
results = safety_classifier(text)
if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
return False, results[0]['label']
return True, None
except Exception as e:
print(f"Error during safety check: {e}")
return False, "safety_check_failed"
# --- Main Model Loading (using Transformers) ---
print(f"Loading tokenizer for {MODEL_REPO_ID}...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
print("Tokenizer loaded.")
except Exception as e:
print(f"Error loading tokenizer: {e}")
print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.")
exit(1)
print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...")
try:
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32)
model.eval()
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.")
exit(1)
# Configure generation for streaming
generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID)
generation_config.max_new_tokens = MAX_TOKENS
generation_config.temperature = TEMPERATURE
generation_config.top_p = TOP_P
generation_config.do_sample = True
generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
if generation_config.pad_token_id == -1:
generation_config.pad_token_id = 0
# --- Custom Streamer for Gradio and Safety Check ---
class GradioSafetyStreamer(TextIteratorStreamer):
def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs):
super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs)
self.safety_checker_fn = safety_checker_fn
self.toxicity_threshold = toxicity_threshold
self.current_sentence_buffer = ""
self.output_queue = []
self.sentence_regex = re.compile(r'[.!?]\s*')
self.text_done = threading.Event()
def on_finalized_text(self, text: str, stream_end: bool = False):
self.current_sentence_buffer += text
sentences = self.sentence_regex.split(self.current_sentence_buffer)
sentences_to_process = []
if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None:
sentences_to_process = sentences[:-1]
self.current_sentence_buffer = sentences[-1]
else:
sentences_to_process = sentences
self.current_sentence_buffer = ""
for sentence in sentences_to_process:
if not sentence.strip(): continue
is_safe, detected_label = self.safety_checker_fn(sentence)
if not is_safe:
print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})")
self.output_queue.append("[Content removed due to safety guidelines]")
self.output_queue.append("__STOP_GENERATION__")
return
else:
self.output_queue.append(sentence)
if stream_end:
if self.current_sentence_buffer.strip():
is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer)
if not is_safe:
self.output_queue.append("[Content removed due to safety guidelines]")
else:
self.output_queue.append(self.current_sentence_buffer)
self.current_sentence_buffer = ""
self.text_done.set()
def __iter__(self):
while True:
if self.output_queue:
item = self.output_queue.pop(0)
if item == "__STOP_GENERATION__":
raise StopIteration
yield item
elif self.text_done.is_set():
raise StopIteration
else:
time.sleep(0.01)
# --- Inference Function with Safety and Streaming ---
def generate_word_by_word_with_safety(prompt_text: str):
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD)
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"generation_config": generation_config,
"do_sample": True,
"temperature": TEMPERATURE,
"top_p": TOP_P,
"max_new_tokens": MAX_TOKENS,
"eos_token_id": generation_config.eos_token_id,
"pad_token_id": generation_config.pad_token_id,
}
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
full_generated_text = ""
try:
for new_sentence_or_chunk in streamer:
full_generated_text += new_sentence_or_chunk
yield full_generated_text
except StopIteration:
pass
except Exception as e:
print(f"Error during streaming: {e}")
yield full_generated_text + f"\n\n[Error during streaming: {e}]"
finally:
thread.join()
# --- Gradio Blocks Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter)
Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model.
**⚠️ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.**
All generated sentences are checked for safety using an AI filter; unsafe content will be replaced.
"""
)
with gr.Row():
user_prompt = gr.Textbox(
lines=5,
label="Enter your prompt here:",
placeholder="e.g., Explain the concept of quantum entanglement in simple terms.",
scale=4
)
generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6)
send_button = gr.Button("Send", variant="primary")
send_button.click(
fn=generate_word_by_word_with_safety,
inputs=user_prompt,
outputs=generated_text,
api_name="predict",
)
if __name__ == "__main__":
print("Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=7860)
|