Spaces:
Paused
Paused
File size: 2,609 Bytes
be189c3 913f152 d20382c 913f152 d20382c 3fa7074 913f152 650d7d1 717c596 b1c819e f20a05d d20382c fb5acbf d20382c 6dd4d01 d20382c 913f152 f6d85fe 6dd4d01 913f152 d20382c 5451932 d20382c be189c3 d20382c fb5acbf d20382c fb5acbf 913f152 fb5acbf d20382c e8da9e0 d20382c 31a4c7e 5451932 d20382c 5451932 d20382c 5451932 d20382c 5451932 d20382c 5451932 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
import torch
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
# ---------------- CONFIG ----------------
REPO_ID = "goonsai-com/civitaiprompts"
SUBFOLDER = "gemma3-1B-goonsai-nsfw-100k"
MODEL_NAME = "Qwen3-1.7B-CivitAI"
# ---------------- LOGGING ----------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
logger.info("Starting Gradio chatbot...")
# ---------------- LOAD MODEL ----------------
logger.info(f"Loading tokenizer from {REPO_ID}/{SUBFOLDER}")
tokenizer = AutoTokenizer.from_pretrained(REPO_ID, subfolder=SUBFOLDER, trust_remote_code=True)
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
logger.info(f"Loading model with dtype {dtype}")
model = AutoModelForCausalLM.from_pretrained(
REPO_ID,
subfolder=SUBFOLDER,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True
)
logger.info("Model loaded successfully.")
# ---------------- CHAT FUNCTION ----------------
def chat_fn(message):
logger.info(f"Received message: {message}")
# Build prompt directly from user input
full_text = f"User: {message}\nAssistant:"
logger.info(f"Full prompt for generation:\n{full_text}")
start_time = time.time()
# Tokenize input
inputs = tokenizer([full_text], return_tensors="pt", truncation=True, max_length=1024).to(model.device)
logger.info("Tokenized input.")
# Generate response
logger.info("Generating response...")
reply_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9
)
response = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]
assistant_reply = response.split("Assistant:")[-1].strip()
logger.info(f"Assistant reply: {assistant_reply}")
logger.info(f"Generation time: {time.time() - start_time:.2f}s")
return assistant_reply
# ---------------- GRADIO BLOCKS UI ----------------
with gr.Blocks() as demo:
gr.Markdown(f"# 🤖 {MODEL_NAME} (Stateless)")
with gr.Row():
with gr.Column():
message = gr.Textbox(label="Type your message...", placeholder="Hello!")
send_btn = gr.Button("Send")
with gr.Column():
output = gr.Textbox(label="Assistant Response", lines=10)
# Connect button
send_btn.click(chat_fn, inputs=[message], outputs=[output])
message.submit(chat_fn, inputs=[message], outputs=[output])
logger.info("Launching Gradio app...")
demo.launch()
|