Spaces:
Paused
Paused
import torch | |
import logging | |
import time | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# ---------------- CONFIG ---------------- | |
MODEL_ID = "goonsai-com/civitaiprompts" | |
MODEL_VARIANT = "Q4_K_M" # This is the HF tag for the quantized model | |
MODEL_NAME = "CivitAI-Prompts-Q4_K_M" | |
# ---------------- LOGGING ---------------- | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
logger = logging.getLogger(__name__) | |
logger.info("Starting Gradio chatbot...") | |
# ---------------- LOAD MODEL ---------------- | |
logger.info(f"Loading tokenizer from {MODEL_ID} (revision={MODEL_VARIANT})") | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_ID, | |
revision=MODEL_VARIANT, | |
trust_remote_code=True | |
) | |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
logger.info(f"Loading model with dtype {dtype}") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
revision=MODEL_VARIANT, | |
torch_dtype=dtype, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
logger.info("Model loaded successfully.") | |
# ---------------- CHAT FUNCTION ---------------- | |
def chat_fn(message): | |
logger.info(f"Received message: {message}") | |
# Build prompt | |
full_text = f"User: {message}\nAssistant:" | |
logger.info(f"Full prompt for generation:\n{full_text}") | |
start_time = time.time() | |
# Tokenize input | |
inputs = tokenizer([full_text], return_tensors="pt", truncation=True, max_length=1024).to(model.device) | |
logger.info("Tokenized input.") | |
# Generate response | |
logger.info("Generating response...") | |
reply_ids = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
response = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0] | |
assistant_reply = response.split("Assistant:")[-1].strip() | |
logger.info(f"Assistant reply: {assistant_reply}") | |
logger.info(f"Generation time: {time.time() - start_time:.2f}s") | |
return assistant_reply | |
# ---------------- GRADIO BLOCKS UI ---------------- | |
with gr.Blocks() as demo: | |
gr.Markdown(f"# 🤖 {MODEL_NAME} (Stateless)") | |
with gr.Row(): | |
with gr.Column(): | |
message = gr.Textbox(label="Type your message...", placeholder="Hello!") | |
send_btn = gr.Button("Send") | |
with gr.Column(): | |
output = gr.Textbox(label="Assistant Response", lines=10) | |
send_btn.click(chat_fn, inputs=[message], outputs=[output]) | |
message.submit(chat_fn, inputs=[message], outputs=[output]) | |
logger.info("Launching Gradio app...") | |
demo.launch() |