daviondk
fix cache again
eb612a0
import os
import time
import gc
import threading
from datetime import datetime
import gradio as gr
import torch
from transformers import pipeline, TextIteratorStreamer
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import spaces # Import spaces early to enable ZeroGPU support
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig
)
# ------------------------------
# Global Cancellation Event
# ------------------------------
cancel_event = threading.Event()
# ------------------------------
# Qwen3 Model Definitions
# ------------------------------
MODELS = {
"bodrunov-t-lite-lora-16": {"repo_id": "daviondk7131/bodrunov-t-lite-lora-16", "description": "С. Д. Бодрунов (T-lite)", "reward_repo_id": "daviondk7131/bodrunov-reward-model", "author": "bodrunov", "base_model": "t-tech/T-lite-it-1.0"},
"shakespeare-deepseek-lora-16": {"repo_id": "daviondk7131/shakespeare-deepseek-lora-16", "description": "У. Шекспир (Deepseek)", "reward_repo_id": "daviondk7131/shakespeare-reward-model", "author": "Shakespeare", "base_model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"},
"chekhov-t-lite-lora-16": {"repo_id": "daviondk7131/chekhov-t-lite-lora-16", "description": "А. П. Чехов (T-lite)", "reward_repo_id": "daviondk7131/chekhov-reward-model", "author": "chekhov_ru", "base_model": "t-tech/T-lite-it-1.0"},
"tolstoy-t-lite-lora-16": {"repo_id": "daviondk7131/tolstoy-t-lite-lora-16", "description": "Л. Н. Толстой (T-lite)", "reward_repo_id": "daviondk7131/tolstoy-reward-model", "author": "tolstoy_ru", "base_model": "t-tech/T-lite-it-1.0"},
"dostoevsky-t-lite-lora-16": {"repo_id": "daviondk7131/dostoevsky-t-lite-lora-16", "description": "Ф. М. Достоевский (T-lite)", "reward_repo_id": "daviondk7131/dostoevsky-reward-model", "author": "dostoevsky_ru", "base_model": "t-tech/T-lite-it-1.0"},
"dostoevsky-yagpt-lora-16": {"repo_id": "daviondk7131/dostoevsky-yagpt-lora-16", "description": "Ф. М. Достоевский (YaGPT)", "reward_repo_id": "daviondk7131/dostoevsky-reward-model", "author": "dostoevsky_ru", "base_model": "yandex/YandexGPT-5-Lite-8B-instruct"},
"tolstoy-yagpt-lora-16": {"repo_id": "daviondk7131/tolstoy-yagpt-lora-16", "description": "Л. Н. Толстой (YaGPT)", "reward_repo_id": "daviondk7131/tolstoy-reward-model", "author": "tolstoy_ru", "base_model": "yandex/YandexGPT-5-Lite-8B-instruct"},
}
CACHE = {
"model_name": None,
"model": None,
"tokenizer": None,
"reward_model": None,
}
# Function to get just the model name from the dropdown selection
def get_model_name(full_selection):
return full_selection.split(" - ")[0]
# User input handling function
def user_input(user_message, history):
return "", history + [(user_message, None)]
class RewardModel(object):
def __init__(self, model_name):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.reward_model = AutoModelForSequenceClassification.from_pretrained(model_name, device_map=self.device).to('cuda')
self.reward_tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
def score(self, text):
inputs = self.reward_tokenizer(text, truncation=True, return_tensors='pt').to(self.device)
with torch.no_grad():
value = self.reward_model(**inputs).logits[0, 0].item()
return value
STYLE_TEMPLATE_PROMPT = """Below is an instruction describing the task, combined with input data that provides further context. Write a response that completes the request accordingly.
### Instruction:
Write down the text from the input data in the style of the author {}.
### Input data:
{}
### Answer:
{}"""
def generate(
model,
tokenizer,
author: str,
text: str,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.1,
do_sample: bool = True,
**kwargs
) -> str:
input_text = STYLE_TEMPLATE_PROMPT.format(author, text, "")
inputs = tokenizer(input_text, return_tensors="pt").to('cuda')
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=2048,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
**kwargs
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
if generated_text.startswith(input_text):
generated_text = generated_text[len(input_text):].strip()
return generated_text
@spaces.GPU(duration=60)
def bot_response(history, model_selection, max_tokens, temperature, top_k, top_p, repetition_penalty):
"""
Generate AI response to user input
"""
cancel_event.clear()
# Extract the latest user message
#user_message = history[-1][0]
#history_without_last = history[:-1]
# Get model name from selection
model_name = get_model_name(model_selection)
# Format the conversation
#conversation = format_conversation(history_without_last, system_prompt)
#conversation += "User: " + user_message + "\nAssistant: "
try:
"""
Load and cache a transformers pipeline for text generation.
Tries bfloat16, falls back to float16 or float32 if unsupported.
"""
load_kwargs = {
"pretrained_model_name_or_path": MODELS[model_name]["repo_id"],
"device_map": "auto",
"torch_dtype": torch.float16,
"trust_remote_code": True
}
if CACHE["model_name"] == model_name:
tokenizer = CACHE["tokenizer"]
model = CACHE["model"]
reward_model = CACHE["reward_model"]
else:
tokenizer = AutoTokenizer.from_pretrained(MODELS[model_name]["base_model"])
model = AutoModelForCausalLM.from_pretrained(**load_kwargs).to("cuda")
reward_model = RewardModel(model_name=MODELS[model_name]["reward_repo_id"])
CACHE["model_name"] = model_name
CACHE["tokenizer"] = tokenizer
CACHE["model"] = model
CACHE["reward_model"] = reward_model
author = MODELS[model_name]["author"]
#pipe = load_pipeline(model_name)
user_message = history[-1][0]
results = []
for i in range(3):
results.append(generate(model, tokenizer, author, user_message, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty))
response = max(results, key=reward_model.score)
# Update the last message pair with the response
history[-1] = (user_message, response)
return history
except Exception as e:
history[-1] = (user_message, f"Error: {e}")
return history
finally:
gc.collect()
#def get_default_system_prompt():
# today = datetime.now().strftime('%Y-%m-%d')
# return f"""You are Qwen3, a helpful and friendly AI assistat. Be concise, accurate, and helpful in your responses."""
def clear_chat():
return []
# CSS for improved visual style
css = """
.gradio-container {
background-color: #f5f7fb !important;
}
.qwen-header {
background: linear-gradient(90deg, #0099FF, #0066CC);
padding: 20px;
border-radius: 10px;
margin-bottom: 20px;
text-align: center;
color: white;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.qwen-container {
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
background: white;
padding: 20px;
margin-bottom: 20px;
}
.controls-container {
background: #f0f4fa;
border-radius: 10px;
padding: 15px;
margin-bottom: 15px;
}
.model-select {
border: 2px solid #0099FF !important;
border-radius: 8px !important;
}
.button-primary {
background-color: #0099FF !important;
color: white !important;
}
.button-secondary {
background-color: #6c757d !important;
color: white !important;
}
.footer {
text-align: center;
margin-top: 20px;
font-size: 0.8em;
color: #666;
}
"""
# ------------------------------
# Gradio UI
# ------------------------------
with gr.Blocks(title="Chat", css=css) as demo:
#gr.HTML("""
#<div class="qwen-header">
# <h1>Style transfer chat</h1>
# <p>-----------------------</p>
#</div>
#""")
with gr.Row():
with gr.Column(scale=3):
with gr.Group(elem_classes="qwen-container"):
model_dd = gr.Dropdown(
label="Select Model",
choices=[f"{k} - {v['description']}" for k, v in MODELS.items()],
value=f"{list(MODELS.keys())[0]} - {MODELS[list(MODELS.keys())[0]]['description']}",
elem_classes="model-select"
)
with gr.Group(elem_classes="controls-container"):
gr.Markdown("### Generation Parameters")
with gr.Row():
max_tok = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens")
with gr.Row():
temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
with gr.Row():
k = gr.Slider(1, 100, value=40, step=1, label="Top-K")
rp = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
clear_btn = gr.Button("Clear Chat", elem_classes="button-secondary")
with gr.Column(scale=7):
chatbot = gr.Chatbot()
with gr.Row():
txt = gr.Textbox(
show_label=False,
placeholder="Type your message here...",
lines=2
)
submit_btn = gr.Button("Send", variant="primary", elem_classes="button-primary")
gr.HTML("""
<div class="footer">
<p>Interface powered by Gradio and ZeroGPU.</p>
</div>
""")
# Connect UI elements to functions
submit_btn.click(
user_input,
inputs=[txt, chatbot],
outputs=[txt, chatbot],
queue=False
).then(
bot_response,
inputs=[chatbot, model_dd, max_tok, temp, k, p, rp],
outputs=chatbot,
api_name="generate"
)
txt.submit(
user_input,
inputs=[txt, chatbot],
outputs=[txt, chatbot],
queue=False
).then(
bot_response,
inputs=[chatbot, model_dd, max_tok, temp, k, p, rp],
outputs=chatbot,
api_name="generate"
)
clear_btn.click(
clear_chat,
outputs=[chatbot],
queue=False
)
if __name__ == "__main__":
demo.launch()