boning123 commited on
Commit
c850ce2
·
verified ·
1 Parent(s): 496d826

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -84
app.py CHANGED
@@ -1,21 +1,23 @@
1
  import gradio as gr
2
- from llama_cpp import Llama
3
- from huggingface_hub import hf_hub_download
4
- from transformers import pipeline
5
- import re # For sentence splitting
6
-
7
- # --- Model Configuration (same as before) ---
8
- MODEL_REPO_ID = "mradermacher/Sam-reason-v3-GGUF"
9
- MODEL_FILENAME = "Sam-reason-v3.Q4_K_M.gguf"
10
- N_CTX = 2048
 
 
 
 
11
  MAX_TOKENS = 500
12
  TEMPERATURE = 0.7
13
  TOP_P = 0.9
14
  STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these
15
 
16
  # --- Safety Configuration ---
17
- # Initialize the toxicity classifier pipeline
18
- # This model identifies if text is 'toxic' with a confidence score.
19
  print("Loading safety model (unitary/toxic-bert)...")
20
  try:
21
  safety_classifier = pipeline(
@@ -26,11 +28,8 @@ try:
26
  print("Safety model loaded successfully.")
27
  except Exception as e:
28
  print(f"Error loading safety model: {e}")
29
- # Consider handling this error more gracefully, e.g., run without safety if model fails to load
30
  exit(1)
31
 
32
- # Threshold for flagging content as unsafe (0.0 to 1.0)
33
- # A higher threshold means it's stricter (less likely to flag non-toxic content).
34
  TOXICITY_THRESHOLD = 0.9
35
 
36
  def is_text_safe(text: str) -> tuple[bool, str | None]:
@@ -39,14 +38,10 @@ def is_text_safe(text: str) -> tuple[bool, str | None]:
39
  Returns (True, None) if safe, or (False, detected_label) if unsafe.
40
  """
41
  if not text.strip():
42
- return True, None # Empty strings are safe
43
 
44
  try:
45
- # Classify the text. The model typically returns [{'label': 'toxic', 'score': X.XX}]
46
- # or [{'label': 'nontoxic', 'score': X.XX}] depending on thresholding in the model.
47
- # For unitary/toxic-bert, 'toxic' is the positive label.
48
  results = safety_classifier(text)
49
-
50
  if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
51
  print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
52
  return False, results[0]['label']
@@ -56,89 +51,165 @@ def is_text_safe(text: str) -> tuple[bool, str | None]:
56
  except Exception as e:
57
  print(f"Error during safety check: {e}")
58
  # If the safety check fails, consider it unsafe by default or log and let it pass.
59
- # For a robust solution, you might want to re-raise or yield an error message.
60
  return False, "safety_check_failed"
61
 
62
 
63
- # --- Main Model Loading (same as before) ---
64
- print(f"Downloading model: {MODEL_FILENAME} from {MODEL_REPO_ID}...")
65
  try:
66
- model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
67
- print(f"Model downloaded to: {model_path}")
 
68
  except Exception as e:
69
- print(f"Error downloading model: {e}")
 
70
  exit(1)
71
 
72
- print("Initializing Llama model (this may take a moment)...")
73
  try:
74
- llm = Llama(
75
- model_path=model_path,
76
- n_gpu_layers=0, # Force CPU usage
77
- n_ctx=N_CTX,
78
- verbose=False
79
- )
80
- print("Llama model initialized successfully.")
81
  except Exception as e:
82
- print(f"Error initializing Llama model: {e}")
 
83
  exit(1)
84
 
85
- # --- Inference Function with Safety ---
86
- def generate_word_by_word_with_safety(prompt_text: str):
87
- """
88
- Generates text word by word, checking each sentence for safety before yielding.
89
- """
90
- formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
91
- current_sentence_buffer = ""
92
- full_output_so_far = ""
93
-
94
- # Stream tokens from the main LLM
95
- token_stream = llm.create_completion(
96
- formatted_prompt,
97
- max_tokens=MAX_TOKENS,
98
- stop=STOP_SEQUENCES,
99
- stream=True,
100
- temperature=TEMPERATURE,
101
- top_p=TOP_P,
102
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- for chunk in token_stream:
105
- token = chunk["choices"][0]["text"]
106
- current_sentence_buffer += token
107
- full_output_so_far += token # Keep track of full output for comprehensive check if needed
108
 
109
- # Simple sentence detection (look for common sentence endings)
110
- if re.search(r'[.!?]\s*$', current_sentence_buffer) or len(current_sentence_buffer) > 100: # Max sentence length fallback
111
- is_safe, detected_label = is_text_safe(current_sentence_buffer)
112
  if not is_safe:
113
- print(f"Safety check failed for sentence: '{current_sentence_buffer.strip()}' (Detected: {detected_label})")
114
- yield "[Content removed due to safety guidelines]" # Replace unsafe content
115
- current_sentence_buffer = "" # Clear buffer for next tokens
116
- # Optionally: return here to stop further generation if first unsafe content is found.
117
- # If you return here, uncomment the `return` statement below.
118
- # return
119
  else:
120
- yield current_sentence_buffer # Yield the safe sentence
121
- current_sentence_buffer = "" # Clear buffer for next sentence
122
-
123
- # After the loop, check and yield any remaining text in the buffer
124
- if current_sentence_buffer.strip():
125
- is_safe, detected_label = is_text_safe(current_sentence_buffer)
126
- if not is_safe:
127
- print(f"Safety check failed for remaining text: '{current_sentence_buffer.strip()}' (Detected: {detected_label})")
128
- yield "[Content removed due to safety guidelines]"
129
- else:
130
- yield current_sentence_buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
 
133
  # --- Gradio Blocks Interface ---
134
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
135
  gr.Markdown(
136
  """
137
- # SmilyAI: Sam-reason-v3-GGUF Word-by-Word Inference (CPU with Safety Filter)
138
- Enter a prompt and get a word-by-word response from the Sam-reason-v3-GGUF model.
139
- **Please note:** All generated sentences are checked for safety using an AI filter.
140
- Potentially unsafe content will be replaced with `[Content removed due to safety guidelines]`.
141
- Running on Hugging Face Spaces' free CPU tier.
142
  """
143
  )
144
 
@@ -153,15 +224,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
153
 
154
  send_button = gr.Button("Send", variant="primary")
155
 
156
- # Connect the button click to the inference function with safety check
157
  send_button.click(
158
- fn=generate_word_by_word_with_safety, # Use the new safety-enabled function
159
  inputs=user_prompt,
160
  outputs=generated_text,
161
  api_name="predict",
162
  )
163
 
164
- # Launch the Gradio application
165
  if __name__ == "__main__":
166
  print("Launching Gradio app...")
167
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
+ from huggingface_hub import hf_hub_download # Still useful if model is private and needs custom token
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig
4
+ from transformers.pipelines import pipeline
5
+ import re
6
+ import os
7
+ import torch # Required for transformers models
8
+ import threading
9
+ import time # For short sleeps in streamer
10
+
11
+ # --- Model Configuration ---
12
+ # Your SmilyAI model ID on Hugging Face Hub
13
+ MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3"
14
+ N_CTX = 2048 # Context window for the model (applies more to LLMs)
15
  MAX_TOKENS = 500
16
  TEMPERATURE = 0.7
17
  TOP_P = 0.9
18
  STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these
19
 
20
  # --- Safety Configuration ---
 
 
21
  print("Loading safety model (unitary/toxic-bert)...")
22
  try:
23
  safety_classifier = pipeline(
 
28
  print("Safety model loaded successfully.")
29
  except Exception as e:
30
  print(f"Error loading safety model: {e}")
 
31
  exit(1)
32
 
 
 
33
  TOXICITY_THRESHOLD = 0.9
34
 
35
  def is_text_safe(text: str) -> tuple[bool, str | None]:
 
38
  Returns (True, None) if safe, or (False, detected_label) if unsafe.
39
  """
40
  if not text.strip():
41
+ return True, None
42
 
43
  try:
 
 
 
44
  results = safety_classifier(text)
 
45
  if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
46
  print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
47
  return False, results[0]['label']
 
51
  except Exception as e:
52
  print(f"Error during safety check: {e}")
53
  # If the safety check fails, consider it unsafe by default or log and let it pass.
 
54
  return False, "safety_check_failed"
55
 
56
 
57
+ # --- Main Model Loading (using Transformers) ---
58
+ print(f"Loading tokenizer for {MODEL_REPO_ID}...")
59
  try:
60
+ # AutoTokenizer fetches the correct tokenizer for the model
61
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
62
+ print("Tokenizer loaded.")
63
  except Exception as e:
64
+ print(f"Error loading tokenizer: {e}")
65
+ print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.")
66
  exit(1)
67
 
68
+ print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...")
69
  try:
70
+ # AutoModelForCausalLM loads the language model.
71
+ # device_map="cpu" ensures all model layers are loaded onto the CPU.
72
+ # torch_dtype=torch.float32 is standard for CPU; float16 can save memory but might not be faster on all CPUs.
73
+ model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32)
74
+ model.eval() # Set model to evaluation mode for inference
75
+ print("Model loaded successfully.")
 
76
  except Exception as e:
77
+ print(f"Error loading model: {e}")
78
+ print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.")
79
  exit(1)
80
 
81
+ # Configure generation for streaming
82
+ # Use GenerationConfig from the model for default parameters, then override as needed.
83
+ generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID)
84
+ generation_config.max_new_tokens = MAX_TOKENS
85
+ generation_config.temperature = TEMPERATURE
86
+ generation_config.top_p = TOP_P
87
+ generation_config.do_sample = True # Enable sampling for temperature/top_p
88
+ # Set EOS and PAD token IDs for proper generation stopping and padding
89
+ generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
90
+ generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
91
+ # Fallback for pad_token_id if not explicitly set
92
+ if generation_config.pad_token_id == -1:
93
+ generation_config.pad_token_id = 0 # Fallback to 0, though not ideal for all models
94
+
95
+ # --- Custom Streamer for Gradio and Safety Check ---
96
+ class GradioSafetyStreamer(TextIteratorStreamer):
97
+ def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs):
98
+ super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs)
99
+ self.safety_checker_fn = safety_checker_fn
100
+ self.toxicity_threshold = toxicity_threshold
101
+ self.current_sentence_buffer = ""
102
+ self.output_queue = [] # Queue to store safety-checked sentences to be yielded by Gradio
103
+ self.sentence_regex = re.compile(r'[.!?]\s*') # Regex for sentence end, simple version
104
+ self.text_done = threading.Event() # Event to signal when internal text processing is complete
105
+
106
+ def on_finalized_text(self, text: str, stream_end: bool = False):
107
+ # This method is called by the superclass when a decoded token chunk is ready.
108
+ self.current_sentence_buffer += text
109
+
110
+ # Split buffer into sentences. Keep the last part in buffer if it's incomplete.
111
+ sentences = self.sentence_regex.split(self.current_sentence_buffer)
112
+
113
+ sentences_to_process = []
114
+ if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None:
115
+ # If not end of stream and last part is not a complete sentence, buffer it for next time
116
+ sentences_to_process = sentences[:-1]
117
+ self.current_sentence_buffer = sentences[-1]
118
+ else:
119
+ # Otherwise, process all segments and clear buffer
120
+ sentences_to_process = sentences
121
+ self.current_sentence_buffer = ""
122
 
123
+ for sentence in sentences_to_process:
124
+ if not sentence.strip(): continue # Skip empty strings from splitting
 
 
125
 
126
+ is_safe, detected_label = self.safety_checker_fn(sentence)
 
 
127
  if not is_safe:
128
+ print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})")
129
+ self.output_queue.append("[Content removed due to safety guidelines]")
130
+ self.output_queue.append("__STOP_GENERATION__") # Special signal to stop LLM generation
131
+ return # Stop processing further sentences from this chunk if unsafe
132
+
 
133
  else:
134
+ self.output_queue.append(sentence)
135
+
136
+ if stream_end:
137
+ # If stream ends and there's leftover text in buffer, process it
138
+ if self.current_sentence_buffer.strip():
139
+ is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer)
140
+ if not is_safe:
141
+ self.output_queue.append("[Content removed due to safety guidelines]")
142
+ else:
143
+ self.output_queue.append(self.current_sentence_buffer)
144
+ self.current_sentence_buffer = "" # Clear after final check
145
+ self.text_done.set() # Signal that all text processing is complete
146
+
147
+ def __iter__(self):
148
+ # This method allows Gradio to iterate over the safety-checked output.
149
+ while True:
150
+ if self.output_queue:
151
+ item = self.output_queue.pop(0)
152
+ if item == "__STOP_GENERATION__":
153
+ # Signal to the outer Gradio loop to stop yielding.
154
+ raise StopIteration
155
+ yield item
156
+ elif self.text_done.is_set(): # Check if internal generation and safety processing is truly finished
157
+ raise StopIteration # End of generation and safety check
158
+ else:
159
+ time.sleep(0.01) # Small sleep to prevent busy-waiting while waiting for new tokens
160
+
161
+
162
+ # --- Inference Function with Safety and Streaming ---
163
+ def generate_word_by_word_with_safety(prompt_text: str):
164
+ formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
165
+ # Encode input on the model's device (CPU)
166
+ input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
167
+
168
+ # Initialize the custom streamer
169
+ streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD)
170
+
171
+ # Use a separate thread for model generation because model.generate is a blocking call.
172
+ # This allows the streamer to continuously fill its queue while Gradio yields.
173
+ generate_kwargs = {
174
+ "input_ids": input_ids,
175
+ "streamer": streamer,
176
+ "generation_config": generation_config,
177
+ # Explicitly pass these for clarity, even if in generation_config
178
+ "do_sample": True,
179
+ "temperature": TEMPERATURE,
180
+ "top_p": TOP_P,
181
+ "max_new_tokens": MAX_TOKENS,
182
+ "eos_token_id": generation_config.eos_token_id,
183
+ "pad_token_id": generation_config.pad_token_id,
184
+ }
185
+
186
+ # Start generation in a separate thread
187
+ thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
188
+ thread.start()
189
+
190
+ # Yield tokens from the streamer's output queue for Gradio to display progressively
191
+ full_generated_text = ""
192
+ try:
193
+ for new_sentence_or_chunk in streamer:
194
+ full_generated_text += new_sentence_or_chunk
195
+ yield full_generated_text # Gradio expects accumulated string for streaming display
196
+ except StopIteration:
197
+ pass # Streamer signaled end
198
+ except Exception as e:
199
+ print(f"Error during streaming: {e}")
200
+ yield full_generated_text + f"\n\n[Error during streaming: {e}]" # Show error in output
201
+ finally:
202
+ thread.join() # Ensure the generation thread finishes gracefully
203
 
204
 
205
  # --- Gradio Blocks Interface ---
206
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
207
  gr.Markdown(
208
  """
209
+ # SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter)
210
+ Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model.
211
+ **⚠️ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.**
212
+ All generated sentences are checked for safety using an AI filter; unsafe content will be replaced.
 
213
  """
214
  )
215
 
 
224
 
225
  send_button = gr.Button("Send", variant="primary")
226
 
 
227
  send_button.click(
228
+ fn=generate_word_by_word_with_safety,
229
  inputs=user_prompt,
230
  outputs=generated_text,
231
  api_name="predict",
232
  )
233
 
 
234
  if __name__ == "__main__":
235
  print("Launching Gradio app...")
236
  demo.launch(server_name="0.0.0.0", server_port=7860)