Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,106 +1,168 @@
|
|
1 |
import gradio as gr
|
2 |
from llama_cpp import Llama
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
-
import
|
|
|
5 |
|
6 |
-
# --- Model Configuration ---
|
7 |
-
# The Hugging Face model repository ID
|
8 |
MODEL_REPO_ID = "mradermacher/Sam-reason-v3-GGUF"
|
9 |
-
# The specific GGUF filename within that repository
|
10 |
MODEL_FILENAME = "Sam-reason-v3.Q4_K_M.gguf"
|
11 |
-
# Maximum context window for the model (how much text it can 'remember')
|
12 |
-
# Adjust this based on your needs and available memory.
|
13 |
N_CTX = 2048
|
14 |
-
# Maximum number of tokens the model will generate in a single response
|
15 |
MAX_TOKENS = 500
|
16 |
-
# Temperature for generation: higher values (e.g., 0.8-1.0) make output more random,
|
17 |
-
# lower values (e.g., 0.2-0.5) make it more focused.
|
18 |
TEMPERATURE = 0.7
|
19 |
-
# Top-p sampling: controls diversity. Lower values focus on more probable tokens.
|
20 |
TOP_P = 0.9
|
21 |
-
|
22 |
-
# This prevents it from generating further turns or excessive boilerplate.
|
23 |
-
STOP_SEQUENCES = ["USER:", "\n\n"]
|
24 |
|
25 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
print(f"Downloading model: {MODEL_FILENAME} from {MODEL_REPO_ID}...")
|
27 |
try:
|
28 |
-
# Download the GGUF model file from Hugging Face Hub
|
29 |
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
|
30 |
print(f"Model downloaded to: {model_path}")
|
31 |
except Exception as e:
|
32 |
print(f"Error downloading model: {e}")
|
33 |
-
# Exit or handle the error appropriately if the model can't be downloaded
|
34 |
exit(1)
|
35 |
|
36 |
print("Initializing Llama model (this may take a moment)...")
|
37 |
try:
|
38 |
-
# Initialize the Llama model
|
39 |
-
# n_gpu_layers=0 ensures the model runs entirely on the CPU,
|
40 |
-
# which is necessary for the free tier on Hugging Face Spaces.
|
41 |
llm = Llama(
|
42 |
model_path=model_path,
|
43 |
n_gpu_layers=0, # Force CPU usage
|
44 |
-
n_ctx=N_CTX,
|
45 |
-
verbose=False
|
46 |
)
|
47 |
print("Llama model initialized successfully.")
|
48 |
except Exception as e:
|
49 |
print(f"Error initializing Llama model: {e}")
|
50 |
exit(1)
|
51 |
|
52 |
-
# --- Inference Function ---
|
53 |
-
def
|
54 |
"""
|
55 |
-
Generates text
|
56 |
-
This provides a streaming experience in the Gradio UI and for API calls.
|
57 |
"""
|
58 |
-
# Define the prompt template. This model does not specify a strict chat format,
|
59 |
-
# so a simple instruction-following format is used.
|
60 |
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
description=(
|
94 |
-
"Enter a prompt and get a word-by-word response from the "
|
95 |
-
"Sam-reason-v3-GGUF model, running on Hugging Face Spaces' free CPU tier. "
|
96 |
-
"The response will stream as it's generated."
|
97 |
-
),
|
98 |
-
live=True, # Enable live streaming updates in the UI
|
99 |
-
api_name="predict", # Expose this function as a REST API endpoint
|
100 |
-
theme=gr.themes.Soft(), # A modern, soft theme for better aesthetics
|
101 |
-
)
|
102 |
|
103 |
# Launch the Gradio application
|
104 |
if __name__ == "__main__":
|
105 |
print("Launching Gradio app...")
|
106 |
-
|
|
|
|
1 |
import gradio as gr
|
2 |
from llama_cpp import Llama
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
+
from transformers import pipeline
|
5 |
+
import re # For sentence splitting
|
6 |
|
7 |
+
# --- Model Configuration (same as before) ---
|
|
|
8 |
MODEL_REPO_ID = "mradermacher/Sam-reason-v3-GGUF"
|
|
|
9 |
MODEL_FILENAME = "Sam-reason-v3.Q4_K_M.gguf"
|
|
|
|
|
10 |
N_CTX = 2048
|
|
|
11 |
MAX_TOKENS = 500
|
|
|
|
|
12 |
TEMPERATURE = 0.7
|
|
|
13 |
TOP_P = 0.9
|
14 |
+
STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these
|
|
|
|
|
15 |
|
16 |
+
# --- Safety Configuration ---
|
17 |
+
# Initialize the toxicity classifier pipeline
|
18 |
+
# This model identifies if text is 'toxic' with a confidence score.
|
19 |
+
print("Loading safety model (unitary/toxic-bert)...")
|
20 |
+
try:
|
21 |
+
safety_classifier = pipeline(
|
22 |
+
"text-classification",
|
23 |
+
model="unitary/toxic-bert",
|
24 |
+
framework="pt" # Use PyTorch backend
|
25 |
+
)
|
26 |
+
print("Safety model loaded successfully.")
|
27 |
+
except Exception as e:
|
28 |
+
print(f"Error loading safety model: {e}")
|
29 |
+
# Consider handling this error more gracefully, e.g., run without safety if model fails to load
|
30 |
+
exit(1)
|
31 |
+
|
32 |
+
# Threshold for flagging content as unsafe (0.0 to 1.0)
|
33 |
+
# A higher threshold means it's stricter (less likely to flag non-toxic content).
|
34 |
+
TOXICITY_THRESHOLD = 0.9
|
35 |
+
|
36 |
+
def is_text_safe(text: str) -> tuple[bool, str | None]:
|
37 |
+
"""
|
38 |
+
Checks if the given text contains unsafe content using the safety classifier.
|
39 |
+
Returns (True, None) if safe, or (False, detected_label) if unsafe.
|
40 |
+
"""
|
41 |
+
if not text.strip():
|
42 |
+
return True, None # Empty strings are safe
|
43 |
+
|
44 |
+
try:
|
45 |
+
# Classify the text. The model typically returns [{'label': 'toxic', 'score': X.XX}]
|
46 |
+
# or [{'label': 'nontoxic', 'score': X.XX}] depending on thresholding in the model.
|
47 |
+
# For unitary/toxic-bert, 'toxic' is the positive label.
|
48 |
+
results = safety_classifier(text)
|
49 |
+
|
50 |
+
if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
|
51 |
+
print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
|
52 |
+
return False, results[0]['label']
|
53 |
+
|
54 |
+
return True, None
|
55 |
+
|
56 |
+
except Exception as e:
|
57 |
+
print(f"Error during safety check: {e}")
|
58 |
+
# If the safety check fails, consider it unsafe by default or log and let it pass.
|
59 |
+
# For a robust solution, you might want to re-raise or yield an error message.
|
60 |
+
return False, "safety_check_failed"
|
61 |
+
|
62 |
+
|
63 |
+
# --- Main Model Loading (same as before) ---
|
64 |
print(f"Downloading model: {MODEL_FILENAME} from {MODEL_REPO_ID}...")
|
65 |
try:
|
|
|
66 |
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
|
67 |
print(f"Model downloaded to: {model_path}")
|
68 |
except Exception as e:
|
69 |
print(f"Error downloading model: {e}")
|
|
|
70 |
exit(1)
|
71 |
|
72 |
print("Initializing Llama model (this may take a moment)...")
|
73 |
try:
|
|
|
|
|
|
|
74 |
llm = Llama(
|
75 |
model_path=model_path,
|
76 |
n_gpu_layers=0, # Force CPU usage
|
77 |
+
n_ctx=N_CTX,
|
78 |
+
verbose=False
|
79 |
)
|
80 |
print("Llama model initialized successfully.")
|
81 |
except Exception as e:
|
82 |
print(f"Error initializing Llama model: {e}")
|
83 |
exit(1)
|
84 |
|
85 |
+
# --- Inference Function with Safety ---
|
86 |
+
def generate_word_by_word_with_safety(prompt_text: str):
|
87 |
"""
|
88 |
+
Generates text word by word, checking each sentence for safety before yielding.
|
|
|
89 |
"""
|
|
|
|
|
90 |
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
|
91 |
+
current_sentence_buffer = ""
|
92 |
+
full_output_so_far = ""
|
93 |
+
|
94 |
+
# Stream tokens from the main LLM
|
95 |
+
token_stream = llm.create_completion(
|
96 |
+
formatted_prompt,
|
97 |
+
max_tokens=MAX_TOKENS,
|
98 |
+
stop=STOP_SEQUENCES,
|
99 |
+
stream=True,
|
100 |
+
temperature=TEMPERATURE,
|
101 |
+
top_p=TOP_P,
|
102 |
+
)
|
103 |
+
|
104 |
+
for chunk in token_stream:
|
105 |
+
token = chunk["choices"][0]["text"]
|
106 |
+
current_sentence_buffer += token
|
107 |
+
full_output_so_far += token # Keep track of full output for comprehensive check if needed
|
108 |
+
|
109 |
+
# Simple sentence detection (look for common sentence endings)
|
110 |
+
if re.search(r'[.!?]\s*$', current_sentence_buffer) or len(current_sentence_buffer) > 100: # Max sentence length fallback
|
111 |
+
is_safe, detected_label = is_text_safe(current_sentence_buffer)
|
112 |
+
if not is_safe:
|
113 |
+
print(f"Safety check failed for sentence: '{current_sentence_buffer.strip()}' (Detected: {detected_label})")
|
114 |
+
yield "[Content removed due to safety guidelines]" # Replace unsafe content
|
115 |
+
current_sentence_buffer = "" # Clear buffer for next tokens
|
116 |
+
# Optionally: return here to stop further generation if first unsafe content is found.
|
117 |
+
# If you return here, uncomment the `return` statement below.
|
118 |
+
# return
|
119 |
+
else:
|
120 |
+
yield current_sentence_buffer # Yield the safe sentence
|
121 |
+
current_sentence_buffer = "" # Clear buffer for next sentence
|
122 |
+
|
123 |
+
# After the loop, check and yield any remaining text in the buffer
|
124 |
+
if current_sentence_buffer.strip():
|
125 |
+
is_safe, detected_label = is_text_safe(current_sentence_buffer)
|
126 |
+
if not is_safe:
|
127 |
+
print(f"Safety check failed for remaining text: '{current_sentence_buffer.strip()}' (Detected: {detected_label})")
|
128 |
+
yield "[Content removed due to safety guidelines]"
|
129 |
+
else:
|
130 |
+
yield current_sentence_buffer
|
131 |
|
132 |
+
|
133 |
+
# --- Gradio Blocks Interface ---
|
134 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
135 |
+
gr.Markdown(
|
136 |
+
"""
|
137 |
+
# SmilyAI: Sam-reason-v3-GGUF Word-by-Word Inference (CPU with Safety Filter)
|
138 |
+
Enter a prompt and get a word-by-word response from the Sam-reason-v3-GGUF model.
|
139 |
+
**Please note:** All generated sentences are checked for safety using an AI filter.
|
140 |
+
Potentially unsafe content will be replaced with `[Content removed due to safety guidelines]`.
|
141 |
+
Running on Hugging Face Spaces' free CPU tier.
|
142 |
+
"""
|
143 |
+
)
|
144 |
+
|
145 |
+
with gr.Row():
|
146 |
+
user_prompt = gr.Textbox(
|
147 |
+
lines=5,
|
148 |
+
label="Enter your prompt here:",
|
149 |
+
placeholder="e.g., Explain the concept of quantum entanglement in simple terms.",
|
150 |
+
scale=4
|
151 |
+
)
|
152 |
+
generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6)
|
153 |
+
|
154 |
+
send_button = gr.Button("Send", variant="primary")
|
155 |
+
|
156 |
+
# Connect the button click to the inference function with safety check
|
157 |
+
send_button.click(
|
158 |
+
fn=generate_word_by_word_with_safety, # Use the new safety-enabled function
|
159 |
+
inputs=user_prompt,
|
160 |
+
outputs=generated_text,
|
161 |
+
api_name="predict",
|
162 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
# Launch the Gradio application
|
165 |
if __name__ == "__main__":
|
166 |
print("Launching Gradio app...")
|
167 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
168 |
+
|