File size: 7,851 Bytes
f38ab88
7aef8f2
 
c850ce2
 
7aef8f2
c850ce2
7aef8f2
c850ce2
 
 
7aef8f2
f38ab88
 
 
7aef8f2
f38ab88
cffaee2
 
 
7aef8f2
cffaee2
 
 
7aef8f2
cffaee2
 
 
 
 
 
 
 
 
 
c850ce2
cffaee2
 
 
 
 
 
 
 
 
 
 
 
 
 
c850ce2
 
f38ab88
c850ce2
 
f38ab88
c850ce2
 
f38ab88
 
c850ce2
f38ab88
c850ce2
7aef8f2
c850ce2
f38ab88
c850ce2
 
f38ab88
 
c850ce2
 
 
 
 
7aef8f2
c850ce2
 
 
7aef8f2
c850ce2
 
 
 
 
 
 
 
7aef8f2
 
 
c850ce2
 
 
 
 
 
 
 
 
 
 
 
 
cffaee2
c850ce2
7aef8f2
cffaee2
c850ce2
cffaee2
c850ce2
 
7aef8f2
 
c850ce2
cffaee2
c850ce2
 
 
 
 
 
 
 
 
7aef8f2
 
c850ce2
 
 
 
 
 
 
 
7aef8f2
 
c850ce2
7aef8f2
c850ce2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7aef8f2
c850ce2
7aef8f2
c850ce2
 
7aef8f2
c850ce2
7aef8f2
f38ab88
cffaee2
 
 
 
 
c850ce2
 
 
 
cffaee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c850ce2
cffaee2
 
 
 
f38ab88
 
 
cffaee2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import gradio as gr
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig, pipeline # <-- Changed import here
import re
import os
import torch
import threading
import time

# --- Model Configuration ---
MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3"
N_CTX = 2048
MAX_TOKENS = 500
TEMPERATURE = 0.7
TOP_P = 0.9
STOP_SEQUENCES = ["USER:", "\n\n"]

# --- Safety Configuration ---
print("Loading safety model (unitary/toxic-bert)...")
try:
    # Using the directly imported pipeline function
    safety_classifier = pipeline(
        "text-classification",
        model="unitary/toxic-bert",
        framework="pt"
    )
    print("Safety model loaded successfully.")
except Exception as e:
    print(f"Error loading safety model: {e}")
    exit(1)

TOXICITY_THRESHOLD = 0.9

def is_text_safe(text: str) -> tuple[bool, str | None]:
    if not text.strip():
        return True, None

    try:
        results = safety_classifier(text)
        if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
            print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
            return False, results[0]['label']
        
        return True, None

    except Exception as e:
        print(f"Error during safety check: {e}")
        return False, "safety_check_failed"


# --- Main Model Loading (using Transformers) ---
print(f"Loading tokenizer for {MODEL_REPO_ID}...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
    print("Tokenizer loaded.")
except Exception as e:
    print(f"Error loading tokenizer: {e}")
    print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.")
    exit(1)

print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...")
try:
    model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32)
    model.eval()
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
    print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.")
    exit(1)

# Configure generation for streaming
generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID)
generation_config.max_new_tokens = MAX_TOKENS
generation_config.temperature = TEMPERATURE
generation_config.top_p = TOP_P
generation_config.do_sample = True
generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
if generation_config.pad_token_id == -1:
    generation_config.pad_token_id = 0

# --- Custom Streamer for Gradio and Safety Check ---
class GradioSafetyStreamer(TextIteratorStreamer):
    def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs):
        super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs)
        self.safety_checker_fn = safety_checker_fn
        self.toxicity_threshold = toxicity_threshold
        self.current_sentence_buffer = ""
        self.output_queue = []
        self.sentence_regex = re.compile(r'[.!?]\s*')
        self.text_done = threading.Event()

    def on_finalized_text(self, text: str, stream_end: bool = False):
        self.current_sentence_buffer += text
        
        sentences = self.sentence_regex.split(self.current_sentence_buffer)
        
        sentences_to_process = []
        if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None:
            sentences_to_process = sentences[:-1]
            self.current_sentence_buffer = sentences[-1]
        else:
            sentences_to_process = sentences
            self.current_sentence_buffer = ""

        for sentence in sentences_to_process:
            if not sentence.strip(): continue

            is_safe, detected_label = self.safety_checker_fn(sentence)
            if not is_safe:
                print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})")
                self.output_queue.append("[Content removed due to safety guidelines]")
                self.output_queue.append("__STOP_GENERATION__")
                return

            else:
                self.output_queue.append(sentence)

        if stream_end:
            if self.current_sentence_buffer.strip():
                is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer)
                if not is_safe:
                    self.output_queue.append("[Content removed due to safety guidelines]")
                else:
                    self.output_queue.append(self.current_sentence_buffer)
                self.current_sentence_buffer = ""
            self.text_done.set()

    def __iter__(self):
        while True:
            if self.output_queue:
                item = self.output_queue.pop(0)
                if item == "__STOP_GENERATION__":
                    raise StopIteration
                yield item
            elif self.text_done.is_set():
                raise StopIteration
            else:
                time.sleep(0.01)


# --- Inference Function with Safety and Streaming ---
def generate_word_by_word_with_safety(prompt_text: str):
    formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
    input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)

    streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD)

    generate_kwargs = {
        "input_ids": input_ids,
        "streamer": streamer,
        "generation_config": generation_config,
        "do_sample": True,
        "temperature": TEMPERATURE,
        "top_p": TOP_P,
        "max_new_tokens": MAX_TOKENS,
        "eos_token_id": generation_config.eos_token_id,
        "pad_token_id": generation_config.pad_token_id,
    }
    
    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    full_generated_text = ""
    try:
        for new_sentence_or_chunk in streamer:
            full_generated_text += new_sentence_or_chunk
            yield full_generated_text
    except StopIteration:
        pass
    except Exception as e:
        print(f"Error during streaming: {e}")
        yield full_generated_text + f"\n\n[Error during streaming: {e}]"
    finally:
        thread.join()


# --- Gradio Blocks Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter)
        Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model.
        **⚠️ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.**
        All generated sentences are checked for safety using an AI filter; unsafe content will be replaced.
        """
    )

    with gr.Row():
        user_prompt = gr.Textbox(
            lines=5,
            label="Enter your prompt here:",
            placeholder="e.g., Explain the concept of quantum entanglement in simple terms.",
            scale=4
        )
        generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6)

    send_button = gr.Button("Send", variant="primary")

    send_button.click(
        fn=generate_word_by_word_with_safety,
        inputs=user_prompt,
        outputs=generated_text,
        api_name="predict",
    )

if __name__ == "__main__":
    print("Launching Gradio app...")
    demo.launch(server_name="0.0.0.0", server_port=7860)