Spaces:
Paused
Paused
File size: 2,633 Bytes
b83ea83 3233665 b83ea83 3233665 b83ea83 3233665 b83ea83 3233665 b83ea83 3233665 b83ea83 3233665 b83ea83 3233665 b83ea83 3233665 b83ea83 3233665 b83ea83 3233665 b83ea83 3233665 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
import logging
import time
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# ---------------- CONFIG ----------------
MODEL_ID = "goonsai-com/civitaiprompts"
MODEL_VARIANT = "Q4_K_M" # This is the HF tag for the quantized model
MODEL_NAME = "CivitAI-Prompts-Q4_K_M"
# ---------------- LOGGING ----------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
logger.info("Starting Gradio chatbot...")
# ---------------- LOAD MODEL ----------------
logger.info(f"Loading tokenizer from {MODEL_ID} (revision={MODEL_VARIANT})")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
revision=MODEL_VARIANT,
trust_remote_code=True
)
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
logger.info(f"Loading model with dtype {dtype}")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
revision=MODEL_VARIANT,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True
)
logger.info("Model loaded successfully.")
# ---------------- CHAT FUNCTION ----------------
def chat_fn(message):
logger.info(f"Received message: {message}")
# Build prompt
full_text = f"User: {message}\nAssistant:"
logger.info(f"Full prompt for generation:\n{full_text}")
start_time = time.time()
# Tokenize input
inputs = tokenizer([full_text], return_tensors="pt", truncation=True, max_length=1024).to(model.device)
logger.info("Tokenized input.")
# Generate response
logger.info("Generating response...")
reply_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9
)
response = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]
assistant_reply = response.split("Assistant:")[-1].strip()
logger.info(f"Assistant reply: {assistant_reply}")
logger.info(f"Generation time: {time.time() - start_time:.2f}s")
return assistant_reply
# ---------------- GRADIO BLOCKS UI ----------------
with gr.Blocks() as demo:
gr.Markdown(f"# 🤖 {MODEL_NAME} (Stateless)")
with gr.Row():
with gr.Column():
message = gr.Textbox(label="Type your message...", placeholder="Hello!")
send_btn = gr.Button("Send")
with gr.Column():
output = gr.Textbox(label="Assistant Response", lines=10)
send_btn.click(chat_fn, inputs=[message], outputs=[output])
message.submit(chat_fn, inputs=[message], outputs=[output])
logger.info("Launching Gradio app...")
demo.launch() |