File size: 2,609 Bytes
be189c3
913f152
d20382c
913f152
d20382c
3fa7074
913f152
650d7d1
717c596
b1c819e
f20a05d
d20382c
 
 
 
 
 
 
fb5acbf
d20382c
6dd4d01
d20382c
913f152
f6d85fe
 
6dd4d01
913f152
 
 
d20382c
5451932
 
d20382c
 
 
 
 
 
be189c3
d20382c
 
fb5acbf
d20382c
 
 
 
fb5acbf
913f152
 
 
 
 
 
fb5acbf
 
d20382c
 
e8da9e0
d20382c
31a4c7e
5451932
 
d20382c
5451932
 
 
 
 
 
d20382c
5451932
d20382c
 
 
5451932
d20382c
5451932
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import torch
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
import time

# ---------------- CONFIG ----------------
REPO_ID = "goonsai-com/civitaiprompts"
SUBFOLDER = "gemma3-1B-goonsai-nsfw-100k"
MODEL_NAME = "Qwen3-1.7B-CivitAI"

# ---------------- LOGGING ----------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
logger.info("Starting Gradio chatbot...")

# ---------------- LOAD MODEL ----------------
logger.info(f"Loading tokenizer from {REPO_ID}/{SUBFOLDER}")
tokenizer = AutoTokenizer.from_pretrained(REPO_ID, subfolder=SUBFOLDER, trust_remote_code=True)

dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
logger.info(f"Loading model with dtype {dtype}")
model = AutoModelForCausalLM.from_pretrained(
    REPO_ID,
    subfolder=SUBFOLDER,
    torch_dtype=dtype,
    device_map="auto",
    trust_remote_code=True
)
logger.info("Model loaded successfully.")

# ---------------- CHAT FUNCTION ----------------
def chat_fn(message):
    logger.info(f"Received message: {message}")
    
    # Build prompt directly from user input
    full_text = f"User: {message}\nAssistant:"
    logger.info(f"Full prompt for generation:\n{full_text}")

    start_time = time.time()
    # Tokenize input
    inputs = tokenizer([full_text], return_tensors="pt", truncation=True, max_length=1024).to(model.device)
    logger.info("Tokenized input.")

    # Generate response
    logger.info("Generating response...")
    reply_ids = model.generate(
        **inputs,
        max_new_tokens=512,
        do_sample=True,
        temperature=0.7,
        top_p=0.9
    )
    response = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]
    assistant_reply = response.split("Assistant:")[-1].strip()
    logger.info(f"Assistant reply: {assistant_reply}")
    logger.info(f"Generation time: {time.time() - start_time:.2f}s")

    return assistant_reply

# ---------------- GRADIO BLOCKS UI ----------------
with gr.Blocks() as demo:
    gr.Markdown(f"# 🤖 {MODEL_NAME} (Stateless)")

    with gr.Row():
        with gr.Column():
            message = gr.Textbox(label="Type your message...", placeholder="Hello!")
            send_btn = gr.Button("Send")
        with gr.Column():
            output = gr.Textbox(label="Assistant Response", lines=10)

    # Connect button
    send_btn.click(chat_fn, inputs=[message], outputs=[output])
    message.submit(chat_fn, inputs=[message], outputs=[output])

logger.info("Launching Gradio app...")
demo.launch()