broadfield-dev's picture
Update app.py
56ef105 verified
import os
import gradio as gr
import torch
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
# --- Configuration and Model Loading ---
# Set your Hugging Face token (useful for Spaces)
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
try:
login(token=HF_TOKEN)
print("Successfully logged in to Hugging Face Hub.")
except Exception as e:
print(f"Error logging in to Hugging Face Hub: {e}")
else:
print("HF_TOKEN not set, proceeding without login.")
MODEL_ID = "google/gemma-3-270m-it"
bot_im = "https://huggingface.co/spaces/idzkha/Geo-Chat-Bert/resolve/main/bot.png"
user_im = "https://huggingface.co/spaces/idzkha/Geo-Chat-Bert/resolve/main/user.png"
try:
print(f"Loading tokenizer: {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Loading model: {MODEL_ID} for CPU...")
# Load the model with the default float32 precision
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
# No need for device_map or specific dtypes for CPU
)
# --- CPU OPTIMIZATION ---
print("Optimizing model for CPU inference with quantization...")
# 1. Set the model to evaluation mode
model.eval()
# 2. Apply dynamic quantization to the linear layers
# This converts float32 weights to int8, making it much faster on CPU
model_quantized = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
print("Model successfully quantized and optimized for CPU.")
except Exception as e:
print(f"Error loading or quantizing model: {e}")
exit()
# --- Gradio UI and Logic ---
def generate_response(message, history, system_prompt, max_new_tokens, temperature, top_p):
"""
Generates a streaming response from the CPU-quantized model.
"""
conversation = []
if system_prompt and system_prompt.strip():
conversation.append({"role": "system", "content": system_prompt})
for user_msg, model_msg in history:
conversation.append({"role": "user", "content": user_msg})
if model_msg is not None:
conversation.append({"role": "assistant", "content": model_msg})
conversation.append({"role": "user", "content": message})
inputs = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
) # Inputs will be on CPU by default
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"temperature": float(temperature),
"top_p": float(top_p),
"do_sample": True,
}
# Wrapper function to run generation within a no_grad context
def generation_thread_target(**kwargs):
# Use torch.no_grad() for inference to save memory and computations
with torch.no_grad():
model_quantized.generate(**kwargs)
# Run generation in a separate thread
thread = Thread(target=generation_thread_target, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
# Build the Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f"""# Gradio Chat Demo (CPU Optimized) with {MODEL_ID}
Duplicate this space for private CPU/GPU (faster)""")
chatbot = gr.Chatbot(label="Chat History", height=500, avatar_images=(user_im, bot_im))
msg = gr.Textbox(
label="Your Message",
placeholder="Type your message here and press Enter...",
)
with gr.Accordion("Model Parameters", open=False):
system_prompt = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
max_new_tokens = gr.Slider(minimum=1, maximum=32000, value=2048, step=1, label="Max New Tokens")
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (Nucleus Sampling)")
clear = gr.Button("Clear Chat History")
def user_and_generate(user_message, history, system_prompt, max_new_tokens, temperature, top_p):
history.append([user_message, ""])
stream = generate_response(user_message, history[:-3], system_prompt, max_new_tokens, temperature, top_p)
for new_text in stream:
history[-1][1] = new_text
yield history, ""
msg.submit(
user_and_generate,
[msg, chatbot, system_prompt, max_new_tokens, temperature, top_p],
[chatbot, msg]
)
clear.click(lambda: [], None, chatbot, queue=False)
# Launch the demo
demo.queue().launch(debug=True, share=True)