|
import os |
|
import gradio as gr |
|
import torch |
|
from huggingface_hub import login |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
from threading import Thread |
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
if HF_TOKEN: |
|
try: |
|
login(token=HF_TOKEN) |
|
print("Successfully logged in to Hugging Face Hub.") |
|
except Exception as e: |
|
print(f"Error logging in to Hugging Face Hub: {e}") |
|
else: |
|
print("HF_TOKEN not set, proceeding without login.") |
|
|
|
|
|
MODEL_ID = "google/gemma-3-270m-it" |
|
bot_im = "https://huggingface.co/spaces/idzkha/Geo-Chat-Bert/resolve/main/bot.png" |
|
user_im = "https://huggingface.co/spaces/idzkha/Geo-Chat-Bert/resolve/main/user.png" |
|
|
|
try: |
|
print(f"Loading tokenizer: {MODEL_ID}...") |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
|
print(f"Loading model: {MODEL_ID} for CPU...") |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
|
|
) |
|
|
|
|
|
print("Optimizing model for CPU inference with quantization...") |
|
|
|
model.eval() |
|
|
|
|
|
|
|
model_quantized = torch.quantization.quantize_dynamic( |
|
model, {torch.nn.Linear}, dtype=torch.qint8 |
|
) |
|
|
|
print("Model successfully quantized and optimized for CPU.") |
|
|
|
except Exception as e: |
|
print(f"Error loading or quantizing model: {e}") |
|
exit() |
|
|
|
|
|
|
|
def generate_response(message, history, system_prompt, max_new_tokens, temperature, top_p): |
|
""" |
|
Generates a streaming response from the CPU-quantized model. |
|
""" |
|
conversation = [] |
|
if system_prompt and system_prompt.strip(): |
|
conversation.append({"role": "system", "content": system_prompt}) |
|
|
|
for user_msg, model_msg in history: |
|
conversation.append({"role": "user", "content": user_msg}) |
|
if model_msg is not None: |
|
conversation.append({"role": "assistant", "content": model_msg}) |
|
|
|
conversation.append({"role": "user", "content": message}) |
|
|
|
inputs = tokenizer.apply_chat_template( |
|
conversation, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt", |
|
) |
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
generation_kwargs = { |
|
**inputs, |
|
"streamer": streamer, |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": float(temperature), |
|
"top_p": float(top_p), |
|
"do_sample": True, |
|
} |
|
|
|
|
|
def generation_thread_target(**kwargs): |
|
|
|
with torch.no_grad(): |
|
model_quantized.generate(**kwargs) |
|
|
|
|
|
thread = Thread(target=generation_thread_target, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
response = "" |
|
for new_text in streamer: |
|
response += new_text |
|
yield response |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(f"""# Gradio Chat Demo (CPU Optimized) with {MODEL_ID} |
|
Duplicate this space for private CPU/GPU (faster)""") |
|
|
|
chatbot = gr.Chatbot(label="Chat History", height=500, avatar_images=(user_im, bot_im)) |
|
|
|
msg = gr.Textbox( |
|
label="Your Message", |
|
placeholder="Type your message here and press Enter...", |
|
) |
|
|
|
with gr.Accordion("Model Parameters", open=False): |
|
system_prompt = gr.Textbox(label="System Prompt", value="You are a helpful assistant.") |
|
max_new_tokens = gr.Slider(minimum=1, maximum=32000, value=2048, step=1, label="Max New Tokens") |
|
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Temperature") |
|
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (Nucleus Sampling)") |
|
|
|
clear = gr.Button("Clear Chat History") |
|
|
|
def user_and_generate(user_message, history, system_prompt, max_new_tokens, temperature, top_p): |
|
history.append([user_message, ""]) |
|
|
|
stream = generate_response(user_message, history[:-3], system_prompt, max_new_tokens, temperature, top_p) |
|
|
|
for new_text in stream: |
|
history[-1][1] = new_text |
|
yield history, "" |
|
|
|
msg.submit( |
|
user_and_generate, |
|
[msg, chatbot, system_prompt, max_new_tokens, temperature, top_p], |
|
[chatbot, msg] |
|
) |
|
|
|
clear.click(lambda: [], None, chatbot, queue=False) |
|
|
|
|
|
demo.queue().launch(debug=True, share=True) |