File size: 2,633 Bytes
b83ea83
3233665
 
 
 
b83ea83
3233665
b83ea83
3233665
 
 
 
 
 
 
b83ea83
3233665
 
 
 
 
 
 
 
 
 
b83ea83
3233665
 
 
 
 
b83ea83
3233665
 
 
 
 
 
 
 
 
b83ea83
3233665
 
 
 
 
 
 
 
b83ea83
3233665
 
b83ea83
3233665
b83ea83
3233665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b83ea83
3233665
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import logging
import time
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

# ---------------- CONFIG ----------------
MODEL_ID = "goonsai-com/civitaiprompts"
MODEL_VARIANT = "Q4_K_M"  # This is the HF tag for the quantized model
MODEL_NAME = "CivitAI-Prompts-Q4_K_M"

# ---------------- LOGGING ----------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
logger.info("Starting Gradio chatbot...")

# ---------------- LOAD MODEL ----------------
logger.info(f"Loading tokenizer from {MODEL_ID} (revision={MODEL_VARIANT})")
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    revision=MODEL_VARIANT,
    trust_remote_code=True
)

dtype = torch.float16 if torch.cuda.is_available() else torch.float32
logger.info(f"Loading model with dtype {dtype}")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    revision=MODEL_VARIANT,
    torch_dtype=dtype,
    device_map="auto",
    trust_remote_code=True
)
logger.info("Model loaded successfully.")

# ---------------- CHAT FUNCTION ----------------
def chat_fn(message):
    logger.info(f"Received message: {message}")
    
    # Build prompt
    full_text = f"User: {message}\nAssistant:"
    logger.info(f"Full prompt for generation:\n{full_text}")

    start_time = time.time()
    # Tokenize input
    inputs = tokenizer([full_text], return_tensors="pt", truncation=True, max_length=1024).to(model.device)
    logger.info("Tokenized input.")

    # Generate response
    logger.info("Generating response...")
    reply_ids = model.generate(
        **inputs,
        max_new_tokens=512,
        do_sample=True,
        temperature=0.7,
        top_p=0.9
    )
    response = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]
    assistant_reply = response.split("Assistant:")[-1].strip()
    logger.info(f"Assistant reply: {assistant_reply}")
    logger.info(f"Generation time: {time.time() - start_time:.2f}s")

    return assistant_reply

# ---------------- GRADIO BLOCKS UI ----------------
with gr.Blocks() as demo:
    gr.Markdown(f"# 🤖 {MODEL_NAME} (Stateless)")

    with gr.Row():
        with gr.Column():
            message = gr.Textbox(label="Type your message...", placeholder="Hello!")
            send_btn = gr.Button("Send")
        with gr.Column():
            output = gr.Textbox(label="Assistant Response", lines=10)

    send_btn.click(chat_fn, inputs=[message], outputs=[output])
    message.submit(chat_fn, inputs=[message], outputs=[output])

logger.info("Launching Gradio app...")
demo.launch()