Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,21 +1,23 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
-
from
|
4 |
-
from transformers import pipeline
|
5 |
-
import re
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
MAX_TOKENS = 500
|
12 |
TEMPERATURE = 0.7
|
13 |
TOP_P = 0.9
|
14 |
STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these
|
15 |
|
16 |
# --- Safety Configuration ---
|
17 |
-
# Initialize the toxicity classifier pipeline
|
18 |
-
# This model identifies if text is 'toxic' with a confidence score.
|
19 |
print("Loading safety model (unitary/toxic-bert)...")
|
20 |
try:
|
21 |
safety_classifier = pipeline(
|
@@ -26,11 +28,8 @@ try:
|
|
26 |
print("Safety model loaded successfully.")
|
27 |
except Exception as e:
|
28 |
print(f"Error loading safety model: {e}")
|
29 |
-
# Consider handling this error more gracefully, e.g., run without safety if model fails to load
|
30 |
exit(1)
|
31 |
|
32 |
-
# Threshold for flagging content as unsafe (0.0 to 1.0)
|
33 |
-
# A higher threshold means it's stricter (less likely to flag non-toxic content).
|
34 |
TOXICITY_THRESHOLD = 0.9
|
35 |
|
36 |
def is_text_safe(text: str) -> tuple[bool, str | None]:
|
@@ -39,14 +38,10 @@ def is_text_safe(text: str) -> tuple[bool, str | None]:
|
|
39 |
Returns (True, None) if safe, or (False, detected_label) if unsafe.
|
40 |
"""
|
41 |
if not text.strip():
|
42 |
-
return True, None
|
43 |
|
44 |
try:
|
45 |
-
# Classify the text. The model typically returns [{'label': 'toxic', 'score': X.XX}]
|
46 |
-
# or [{'label': 'nontoxic', 'score': X.XX}] depending on thresholding in the model.
|
47 |
-
# For unitary/toxic-bert, 'toxic' is the positive label.
|
48 |
results = safety_classifier(text)
|
49 |
-
|
50 |
if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
|
51 |
print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
|
52 |
return False, results[0]['label']
|
@@ -56,89 +51,165 @@ def is_text_safe(text: str) -> tuple[bool, str | None]:
|
|
56 |
except Exception as e:
|
57 |
print(f"Error during safety check: {e}")
|
58 |
# If the safety check fails, consider it unsafe by default or log and let it pass.
|
59 |
-
# For a robust solution, you might want to re-raise or yield an error message.
|
60 |
return False, "safety_check_failed"
|
61 |
|
62 |
|
63 |
-
# --- Main Model Loading (
|
64 |
-
print(f"
|
65 |
try:
|
66 |
-
|
67 |
-
|
|
|
68 |
except Exception as e:
|
69 |
-
print(f"Error
|
|
|
70 |
exit(1)
|
71 |
|
72 |
-
print("
|
73 |
try:
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
)
|
80 |
-
print("Llama model initialized successfully.")
|
81 |
except Exception as e:
|
82 |
-
print(f"Error
|
|
|
83 |
exit(1)
|
84 |
|
85 |
-
#
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
current_sentence_buffer += token
|
107 |
-
full_output_so_far += token # Keep track of full output for comprehensive check if needed
|
108 |
|
109 |
-
|
110 |
-
if re.search(r'[.!?]\s*$', current_sentence_buffer) or len(current_sentence_buffer) > 100: # Max sentence length fallback
|
111 |
-
is_safe, detected_label = is_text_safe(current_sentence_buffer)
|
112 |
if not is_safe:
|
113 |
-
print(f"Safety check failed for
|
114 |
-
|
115 |
-
|
116 |
-
#
|
117 |
-
|
118 |
-
# return
|
119 |
else:
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
|
133 |
# --- Gradio Blocks Interface ---
|
134 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
135 |
gr.Markdown(
|
136 |
"""
|
137 |
-
# SmilyAI: Sam-reason-
|
138 |
-
Enter a prompt and get a word-by-word response from the Sam-reason-
|
139 |
-
|
140 |
-
|
141 |
-
Running on Hugging Face Spaces' free CPU tier.
|
142 |
"""
|
143 |
)
|
144 |
|
@@ -153,15 +224,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
153 |
|
154 |
send_button = gr.Button("Send", variant="primary")
|
155 |
|
156 |
-
# Connect the button click to the inference function with safety check
|
157 |
send_button.click(
|
158 |
-
fn=generate_word_by_word_with_safety,
|
159 |
inputs=user_prompt,
|
160 |
outputs=generated_text,
|
161 |
api_name="predict",
|
162 |
)
|
163 |
|
164 |
-
# Launch the Gradio application
|
165 |
if __name__ == "__main__":
|
166 |
print("Launching Gradio app...")
|
167 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
1 |
import gradio as gr
|
2 |
+
from huggingface_hub import hf_hub_download # Still useful if model is private and needs custom token
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig
|
4 |
+
from transformers.pipelines import pipeline
|
5 |
+
import re
|
6 |
+
import os
|
7 |
+
import torch # Required for transformers models
|
8 |
+
import threading
|
9 |
+
import time # For short sleeps in streamer
|
10 |
+
|
11 |
+
# --- Model Configuration ---
|
12 |
+
# Your SmilyAI model ID on Hugging Face Hub
|
13 |
+
MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3"
|
14 |
+
N_CTX = 2048 # Context window for the model (applies more to LLMs)
|
15 |
MAX_TOKENS = 500
|
16 |
TEMPERATURE = 0.7
|
17 |
TOP_P = 0.9
|
18 |
STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these
|
19 |
|
20 |
# --- Safety Configuration ---
|
|
|
|
|
21 |
print("Loading safety model (unitary/toxic-bert)...")
|
22 |
try:
|
23 |
safety_classifier = pipeline(
|
|
|
28 |
print("Safety model loaded successfully.")
|
29 |
except Exception as e:
|
30 |
print(f"Error loading safety model: {e}")
|
|
|
31 |
exit(1)
|
32 |
|
|
|
|
|
33 |
TOXICITY_THRESHOLD = 0.9
|
34 |
|
35 |
def is_text_safe(text: str) -> tuple[bool, str | None]:
|
|
|
38 |
Returns (True, None) if safe, or (False, detected_label) if unsafe.
|
39 |
"""
|
40 |
if not text.strip():
|
41 |
+
return True, None
|
42 |
|
43 |
try:
|
|
|
|
|
|
|
44 |
results = safety_classifier(text)
|
|
|
45 |
if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
|
46 |
print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
|
47 |
return False, results[0]['label']
|
|
|
51 |
except Exception as e:
|
52 |
print(f"Error during safety check: {e}")
|
53 |
# If the safety check fails, consider it unsafe by default or log and let it pass.
|
|
|
54 |
return False, "safety_check_failed"
|
55 |
|
56 |
|
57 |
+
# --- Main Model Loading (using Transformers) ---
|
58 |
+
print(f"Loading tokenizer for {MODEL_REPO_ID}...")
|
59 |
try:
|
60 |
+
# AutoTokenizer fetches the correct tokenizer for the model
|
61 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
|
62 |
+
print("Tokenizer loaded.")
|
63 |
except Exception as e:
|
64 |
+
print(f"Error loading tokenizer: {e}")
|
65 |
+
print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.")
|
66 |
exit(1)
|
67 |
|
68 |
+
print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...")
|
69 |
try:
|
70 |
+
# AutoModelForCausalLM loads the language model.
|
71 |
+
# device_map="cpu" ensures all model layers are loaded onto the CPU.
|
72 |
+
# torch_dtype=torch.float32 is standard for CPU; float16 can save memory but might not be faster on all CPUs.
|
73 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32)
|
74 |
+
model.eval() # Set model to evaluation mode for inference
|
75 |
+
print("Model loaded successfully.")
|
|
|
76 |
except Exception as e:
|
77 |
+
print(f"Error loading model: {e}")
|
78 |
+
print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.")
|
79 |
exit(1)
|
80 |
|
81 |
+
# Configure generation for streaming
|
82 |
+
# Use GenerationConfig from the model for default parameters, then override as needed.
|
83 |
+
generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID)
|
84 |
+
generation_config.max_new_tokens = MAX_TOKENS
|
85 |
+
generation_config.temperature = TEMPERATURE
|
86 |
+
generation_config.top_p = TOP_P
|
87 |
+
generation_config.do_sample = True # Enable sampling for temperature/top_p
|
88 |
+
# Set EOS and PAD token IDs for proper generation stopping and padding
|
89 |
+
generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
|
90 |
+
generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
|
91 |
+
# Fallback for pad_token_id if not explicitly set
|
92 |
+
if generation_config.pad_token_id == -1:
|
93 |
+
generation_config.pad_token_id = 0 # Fallback to 0, though not ideal for all models
|
94 |
+
|
95 |
+
# --- Custom Streamer for Gradio and Safety Check ---
|
96 |
+
class GradioSafetyStreamer(TextIteratorStreamer):
|
97 |
+
def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs):
|
98 |
+
super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs)
|
99 |
+
self.safety_checker_fn = safety_checker_fn
|
100 |
+
self.toxicity_threshold = toxicity_threshold
|
101 |
+
self.current_sentence_buffer = ""
|
102 |
+
self.output_queue = [] # Queue to store safety-checked sentences to be yielded by Gradio
|
103 |
+
self.sentence_regex = re.compile(r'[.!?]\s*') # Regex for sentence end, simple version
|
104 |
+
self.text_done = threading.Event() # Event to signal when internal text processing is complete
|
105 |
+
|
106 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
107 |
+
# This method is called by the superclass when a decoded token chunk is ready.
|
108 |
+
self.current_sentence_buffer += text
|
109 |
+
|
110 |
+
# Split buffer into sentences. Keep the last part in buffer if it's incomplete.
|
111 |
+
sentences = self.sentence_regex.split(self.current_sentence_buffer)
|
112 |
+
|
113 |
+
sentences_to_process = []
|
114 |
+
if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None:
|
115 |
+
# If not end of stream and last part is not a complete sentence, buffer it for next time
|
116 |
+
sentences_to_process = sentences[:-1]
|
117 |
+
self.current_sentence_buffer = sentences[-1]
|
118 |
+
else:
|
119 |
+
# Otherwise, process all segments and clear buffer
|
120 |
+
sentences_to_process = sentences
|
121 |
+
self.current_sentence_buffer = ""
|
122 |
|
123 |
+
for sentence in sentences_to_process:
|
124 |
+
if not sentence.strip(): continue # Skip empty strings from splitting
|
|
|
|
|
125 |
|
126 |
+
is_safe, detected_label = self.safety_checker_fn(sentence)
|
|
|
|
|
127 |
if not is_safe:
|
128 |
+
print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})")
|
129 |
+
self.output_queue.append("[Content removed due to safety guidelines]")
|
130 |
+
self.output_queue.append("__STOP_GENERATION__") # Special signal to stop LLM generation
|
131 |
+
return # Stop processing further sentences from this chunk if unsafe
|
132 |
+
|
|
|
133 |
else:
|
134 |
+
self.output_queue.append(sentence)
|
135 |
+
|
136 |
+
if stream_end:
|
137 |
+
# If stream ends and there's leftover text in buffer, process it
|
138 |
+
if self.current_sentence_buffer.strip():
|
139 |
+
is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer)
|
140 |
+
if not is_safe:
|
141 |
+
self.output_queue.append("[Content removed due to safety guidelines]")
|
142 |
+
else:
|
143 |
+
self.output_queue.append(self.current_sentence_buffer)
|
144 |
+
self.current_sentence_buffer = "" # Clear after final check
|
145 |
+
self.text_done.set() # Signal that all text processing is complete
|
146 |
+
|
147 |
+
def __iter__(self):
|
148 |
+
# This method allows Gradio to iterate over the safety-checked output.
|
149 |
+
while True:
|
150 |
+
if self.output_queue:
|
151 |
+
item = self.output_queue.pop(0)
|
152 |
+
if item == "__STOP_GENERATION__":
|
153 |
+
# Signal to the outer Gradio loop to stop yielding.
|
154 |
+
raise StopIteration
|
155 |
+
yield item
|
156 |
+
elif self.text_done.is_set(): # Check if internal generation and safety processing is truly finished
|
157 |
+
raise StopIteration # End of generation and safety check
|
158 |
+
else:
|
159 |
+
time.sleep(0.01) # Small sleep to prevent busy-waiting while waiting for new tokens
|
160 |
+
|
161 |
+
|
162 |
+
# --- Inference Function with Safety and Streaming ---
|
163 |
+
def generate_word_by_word_with_safety(prompt_text: str):
|
164 |
+
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
|
165 |
+
# Encode input on the model's device (CPU)
|
166 |
+
input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
|
167 |
+
|
168 |
+
# Initialize the custom streamer
|
169 |
+
streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD)
|
170 |
+
|
171 |
+
# Use a separate thread for model generation because model.generate is a blocking call.
|
172 |
+
# This allows the streamer to continuously fill its queue while Gradio yields.
|
173 |
+
generate_kwargs = {
|
174 |
+
"input_ids": input_ids,
|
175 |
+
"streamer": streamer,
|
176 |
+
"generation_config": generation_config,
|
177 |
+
# Explicitly pass these for clarity, even if in generation_config
|
178 |
+
"do_sample": True,
|
179 |
+
"temperature": TEMPERATURE,
|
180 |
+
"top_p": TOP_P,
|
181 |
+
"max_new_tokens": MAX_TOKENS,
|
182 |
+
"eos_token_id": generation_config.eos_token_id,
|
183 |
+
"pad_token_id": generation_config.pad_token_id,
|
184 |
+
}
|
185 |
+
|
186 |
+
# Start generation in a separate thread
|
187 |
+
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
|
188 |
+
thread.start()
|
189 |
+
|
190 |
+
# Yield tokens from the streamer's output queue for Gradio to display progressively
|
191 |
+
full_generated_text = ""
|
192 |
+
try:
|
193 |
+
for new_sentence_or_chunk in streamer:
|
194 |
+
full_generated_text += new_sentence_or_chunk
|
195 |
+
yield full_generated_text # Gradio expects accumulated string for streaming display
|
196 |
+
except StopIteration:
|
197 |
+
pass # Streamer signaled end
|
198 |
+
except Exception as e:
|
199 |
+
print(f"Error during streaming: {e}")
|
200 |
+
yield full_generated_text + f"\n\n[Error during streaming: {e}]" # Show error in output
|
201 |
+
finally:
|
202 |
+
thread.join() # Ensure the generation thread finishes gracefully
|
203 |
|
204 |
|
205 |
# --- Gradio Blocks Interface ---
|
206 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
207 |
gr.Markdown(
|
208 |
"""
|
209 |
+
# SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter)
|
210 |
+
Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model.
|
211 |
+
**⚠️ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.**
|
212 |
+
All generated sentences are checked for safety using an AI filter; unsafe content will be replaced.
|
|
|
213 |
"""
|
214 |
)
|
215 |
|
|
|
224 |
|
225 |
send_button = gr.Button("Send", variant="primary")
|
226 |
|
|
|
227 |
send_button.click(
|
228 |
+
fn=generate_word_by_word_with_safety,
|
229 |
inputs=user_prompt,
|
230 |
outputs=generated_text,
|
231 |
api_name="predict",
|
232 |
)
|
233 |
|
|
|
234 |
if __name__ == "__main__":
|
235 |
print("Launching Gradio app...")
|
236 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|