boning123 commited on
Commit
cffaee2
·
verified ·
1 Parent(s): a392d62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -68
app.py CHANGED
@@ -1,106 +1,168 @@
1
  import gradio as gr
2
  from llama_cpp import Llama
3
  from huggingface_hub import hf_hub_download
4
- import os
 
5
 
6
- # --- Model Configuration ---
7
- # The Hugging Face model repository ID
8
  MODEL_REPO_ID = "mradermacher/Sam-reason-v3-GGUF"
9
- # The specific GGUF filename within that repository
10
  MODEL_FILENAME = "Sam-reason-v3.Q4_K_M.gguf"
11
- # Maximum context window for the model (how much text it can 'remember')
12
- # Adjust this based on your needs and available memory.
13
  N_CTX = 2048
14
- # Maximum number of tokens the model will generate in a single response
15
  MAX_TOKENS = 500
16
- # Temperature for generation: higher values (e.g., 0.8-1.0) make output more random,
17
- # lower values (e.g., 0.2-0.5) make it more focused.
18
  TEMPERATURE = 0.7
19
- # Top-p sampling: controls diversity. Lower values focus on more probable tokens.
20
  TOP_P = 0.9
21
- # Stop sequences: the model will stop generating when it encounters any of these strings.
22
- # This prevents it from generating further turns or excessive boilerplate.
23
- STOP_SEQUENCES = ["USER:", "\n\n"]
24
 
25
- # --- Model Loading ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  print(f"Downloading model: {MODEL_FILENAME} from {MODEL_REPO_ID}...")
27
  try:
28
- # Download the GGUF model file from Hugging Face Hub
29
  model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
30
  print(f"Model downloaded to: {model_path}")
31
  except Exception as e:
32
  print(f"Error downloading model: {e}")
33
- # Exit or handle the error appropriately if the model can't be downloaded
34
  exit(1)
35
 
36
  print("Initializing Llama model (this may take a moment)...")
37
  try:
38
- # Initialize the Llama model
39
- # n_gpu_layers=0 ensures the model runs entirely on the CPU,
40
- # which is necessary for the free tier on Hugging Face Spaces.
41
  llm = Llama(
42
  model_path=model_path,
43
  n_gpu_layers=0, # Force CPU usage
44
- n_ctx=N_CTX, # Set context window size
45
- verbose=False # Suppress llama_cpp verbose output
46
  )
47
  print("Llama model initialized successfully.")
48
  except Exception as e:
49
  print(f"Error initializing Llama model: {e}")
50
  exit(1)
51
 
52
- # --- Inference Function ---
53
- def generate_word_by_word(prompt_text: str):
54
  """
55
- Generates text from the LLM word by word (or token by token) and yields the output.
56
- This provides a streaming experience in the Gradio UI and for API calls.
57
  """
58
- # Define the prompt template. This model does not specify a strict chat format,
59
- # so a simple instruction-following format is used.
60
  formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- print(f"Starting generation for prompt: '{prompt_text[:50]}...'")
63
- output_tokens = []
64
- try:
65
- # Use the create_completion method with stream=True for token-by-token generation
66
- for chunk in llm.create_completion(
67
- formatted_prompt,
68
- max_tokens=MAX_TOKENS,
69
- stop=STOP_SEQUENCES,
70
- stream=True,
71
- temperature=TEMPERATURE,
72
- top_p=TOP_P,
73
- ):
74
- token = chunk["choices"][0]["text"]
75
- output_tokens.append(token)
76
- # Yield the accumulated text to update the UI/API response in real-time
77
- yield "".join(output_tokens)
78
- except Exception as e:
79
- print(f"Error during text generation: {e}")
80
- yield f"An error occurred during generation: {e}"
81
-
82
- # --- Gradio Interface ---
83
- # Create the Gradio Interface for the web UI and API endpoint
84
- iface = gr.Interface(
85
- fn=generate_word_by_word,
86
- inputs=gr.Textbox(
87
- lines=5,
88
- label="Enter your prompt here:",
89
- placeholder="e.g., Explain the concept of quantum entanglement in simple terms."
90
- ),
91
- outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
92
- title="SmilyAI: Sam-reason-v3-GGUF Word-by-Word Inference (CPU)",
93
- description=(
94
- "Enter a prompt and get a word-by-word response from the "
95
- "Sam-reason-v3-GGUF model, running on Hugging Face Spaces' free CPU tier. "
96
- "The response will stream as it's generated."
97
- ),
98
- live=True, # Enable live streaming updates in the UI
99
- api_name="predict", # Expose this function as a REST API endpoint
100
- theme=gr.themes.Soft(), # A modern, soft theme for better aesthetics
101
- )
102
 
103
  # Launch the Gradio application
104
  if __name__ == "__main__":
105
  print("Launching Gradio app...")
106
- iface.launch(server_name="0.0.0.0", server_port=7860) # Standard ports for HF Spaces
 
 
1
  import gradio as gr
2
  from llama_cpp import Llama
3
  from huggingface_hub import hf_hub_download
4
+ from transformers import pipeline
5
+ import re # For sentence splitting
6
 
7
+ # --- Model Configuration (same as before) ---
 
8
  MODEL_REPO_ID = "mradermacher/Sam-reason-v3-GGUF"
 
9
  MODEL_FILENAME = "Sam-reason-v3.Q4_K_M.gguf"
 
 
10
  N_CTX = 2048
 
11
  MAX_TOKENS = 500
 
 
12
  TEMPERATURE = 0.7
 
13
  TOP_P = 0.9
14
+ STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these
 
 
15
 
16
+ # --- Safety Configuration ---
17
+ # Initialize the toxicity classifier pipeline
18
+ # This model identifies if text is 'toxic' with a confidence score.
19
+ print("Loading safety model (unitary/toxic-bert)...")
20
+ try:
21
+ safety_classifier = pipeline(
22
+ "text-classification",
23
+ model="unitary/toxic-bert",
24
+ framework="pt" # Use PyTorch backend
25
+ )
26
+ print("Safety model loaded successfully.")
27
+ except Exception as e:
28
+ print(f"Error loading safety model: {e}")
29
+ # Consider handling this error more gracefully, e.g., run without safety if model fails to load
30
+ exit(1)
31
+
32
+ # Threshold for flagging content as unsafe (0.0 to 1.0)
33
+ # A higher threshold means it's stricter (less likely to flag non-toxic content).
34
+ TOXICITY_THRESHOLD = 0.9
35
+
36
+ def is_text_safe(text: str) -> tuple[bool, str | None]:
37
+ """
38
+ Checks if the given text contains unsafe content using the safety classifier.
39
+ Returns (True, None) if safe, or (False, detected_label) if unsafe.
40
+ """
41
+ if not text.strip():
42
+ return True, None # Empty strings are safe
43
+
44
+ try:
45
+ # Classify the text. The model typically returns [{'label': 'toxic', 'score': X.XX}]
46
+ # or [{'label': 'nontoxic', 'score': X.XX}] depending on thresholding in the model.
47
+ # For unitary/toxic-bert, 'toxic' is the positive label.
48
+ results = safety_classifier(text)
49
+
50
+ if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
51
+ print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
52
+ return False, results[0]['label']
53
+
54
+ return True, None
55
+
56
+ except Exception as e:
57
+ print(f"Error during safety check: {e}")
58
+ # If the safety check fails, consider it unsafe by default or log and let it pass.
59
+ # For a robust solution, you might want to re-raise or yield an error message.
60
+ return False, "safety_check_failed"
61
+
62
+
63
+ # --- Main Model Loading (same as before) ---
64
  print(f"Downloading model: {MODEL_FILENAME} from {MODEL_REPO_ID}...")
65
  try:
 
66
  model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
67
  print(f"Model downloaded to: {model_path}")
68
  except Exception as e:
69
  print(f"Error downloading model: {e}")
 
70
  exit(1)
71
 
72
  print("Initializing Llama model (this may take a moment)...")
73
  try:
 
 
 
74
  llm = Llama(
75
  model_path=model_path,
76
  n_gpu_layers=0, # Force CPU usage
77
+ n_ctx=N_CTX,
78
+ verbose=False
79
  )
80
  print("Llama model initialized successfully.")
81
  except Exception as e:
82
  print(f"Error initializing Llama model: {e}")
83
  exit(1)
84
 
85
+ # --- Inference Function with Safety ---
86
+ def generate_word_by_word_with_safety(prompt_text: str):
87
  """
88
+ Generates text word by word, checking each sentence for safety before yielding.
 
89
  """
 
 
90
  formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
91
+ current_sentence_buffer = ""
92
+ full_output_so_far = ""
93
+
94
+ # Stream tokens from the main LLM
95
+ token_stream = llm.create_completion(
96
+ formatted_prompt,
97
+ max_tokens=MAX_TOKENS,
98
+ stop=STOP_SEQUENCES,
99
+ stream=True,
100
+ temperature=TEMPERATURE,
101
+ top_p=TOP_P,
102
+ )
103
+
104
+ for chunk in token_stream:
105
+ token = chunk["choices"][0]["text"]
106
+ current_sentence_buffer += token
107
+ full_output_so_far += token # Keep track of full output for comprehensive check if needed
108
+
109
+ # Simple sentence detection (look for common sentence endings)
110
+ if re.search(r'[.!?]\s*$', current_sentence_buffer) or len(current_sentence_buffer) > 100: # Max sentence length fallback
111
+ is_safe, detected_label = is_text_safe(current_sentence_buffer)
112
+ if not is_safe:
113
+ print(f"Safety check failed for sentence: '{current_sentence_buffer.strip()}' (Detected: {detected_label})")
114
+ yield "[Content removed due to safety guidelines]" # Replace unsafe content
115
+ current_sentence_buffer = "" # Clear buffer for next tokens
116
+ # Optionally: return here to stop further generation if first unsafe content is found.
117
+ # If you return here, uncomment the `return` statement below.
118
+ # return
119
+ else:
120
+ yield current_sentence_buffer # Yield the safe sentence
121
+ current_sentence_buffer = "" # Clear buffer for next sentence
122
+
123
+ # After the loop, check and yield any remaining text in the buffer
124
+ if current_sentence_buffer.strip():
125
+ is_safe, detected_label = is_text_safe(current_sentence_buffer)
126
+ if not is_safe:
127
+ print(f"Safety check failed for remaining text: '{current_sentence_buffer.strip()}' (Detected: {detected_label})")
128
+ yield "[Content removed due to safety guidelines]"
129
+ else:
130
+ yield current_sentence_buffer
131
 
132
+
133
+ # --- Gradio Blocks Interface ---
134
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
135
+ gr.Markdown(
136
+ """
137
+ # SmilyAI: Sam-reason-v3-GGUF Word-by-Word Inference (CPU with Safety Filter)
138
+ Enter a prompt and get a word-by-word response from the Sam-reason-v3-GGUF model.
139
+ **Please note:** All generated sentences are checked for safety using an AI filter.
140
+ Potentially unsafe content will be replaced with `[Content removed due to safety guidelines]`.
141
+ Running on Hugging Face Spaces' free CPU tier.
142
+ """
143
+ )
144
+
145
+ with gr.Row():
146
+ user_prompt = gr.Textbox(
147
+ lines=5,
148
+ label="Enter your prompt here:",
149
+ placeholder="e.g., Explain the concept of quantum entanglement in simple terms.",
150
+ scale=4
151
+ )
152
+ generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6)
153
+
154
+ send_button = gr.Button("Send", variant="primary")
155
+
156
+ # Connect the button click to the inference function with safety check
157
+ send_button.click(
158
+ fn=generate_word_by_word_with_safety, # Use the new safety-enabled function
159
+ inputs=user_prompt,
160
+ outputs=generated_text,
161
+ api_name="predict",
162
+ )
 
 
 
 
 
 
 
 
 
163
 
164
  # Launch the Gradio application
165
  if __name__ == "__main__":
166
  print("Launching Gradio app...")
167
+ demo.launch(server_name="0.0.0.0", server_port=7860)
168
+