|
import os |
|
import time |
|
import gc |
|
import threading |
|
from datetime import datetime |
|
import gradio as gr |
|
import torch |
|
from transformers import pipeline, TextIteratorStreamer |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
import spaces |
|
|
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
BitsAndBytesConfig |
|
) |
|
|
|
|
|
|
|
|
|
|
|
cancel_event = threading.Event() |
|
|
|
|
|
|
|
|
|
MODELS = { |
|
"bodrunov-t-lite-lora-16": {"repo_id": "daviondk7131/bodrunov-t-lite-lora-16", "description": "С. Д. Бодрунов (T-lite)", "reward_repo_id": "daviondk7131/bodrunov-reward-model", "author": "bodrunov", "base_model": "t-tech/T-lite-it-1.0"}, |
|
"shakespeare-deepseek-lora-16": {"repo_id": "daviondk7131/shakespeare-deepseek-lora-16", "description": "У. Шекспир (Deepseek)", "reward_repo_id": "daviondk7131/shakespeare-reward-model", "author": "Shakespeare", "base_model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}, |
|
"chekhov-t-lite-lora-16": {"repo_id": "daviondk7131/chekhov-t-lite-lora-16", "description": "А. П. Чехов (T-lite)", "reward_repo_id": "daviondk7131/chekhov-reward-model", "author": "chekhov_ru", "base_model": "t-tech/T-lite-it-1.0"}, |
|
"tolstoy-t-lite-lora-16": {"repo_id": "daviondk7131/tolstoy-t-lite-lora-16", "description": "Л. Н. Толстой (T-lite)", "reward_repo_id": "daviondk7131/tolstoy-reward-model", "author": "tolstoy_ru", "base_model": "t-tech/T-lite-it-1.0"}, |
|
"dostoevsky-t-lite-lora-16": {"repo_id": "daviondk7131/dostoevsky-t-lite-lora-16", "description": "Ф. М. Достоевский (T-lite)", "reward_repo_id": "daviondk7131/dostoevsky-reward-model", "author": "dostoevsky_ru", "base_model": "t-tech/T-lite-it-1.0"}, |
|
"dostoevsky-yagpt-lora-16": {"repo_id": "daviondk7131/dostoevsky-yagpt-lora-16", "description": "Ф. М. Достоевский (YaGPT)", "reward_repo_id": "daviondk7131/dostoevsky-reward-model", "author": "dostoevsky_ru", "base_model": "yandex/YandexGPT-5-Lite-8B-instruct"}, |
|
"tolstoy-yagpt-lora-16": {"repo_id": "daviondk7131/tolstoy-yagpt-lora-16", "description": "Л. Н. Толстой (YaGPT)", "reward_repo_id": "daviondk7131/tolstoy-reward-model", "author": "tolstoy_ru", "base_model": "yandex/YandexGPT-5-Lite-8B-instruct"}, |
|
} |
|
|
|
CACHE = { |
|
"model_name": None, |
|
"model": None, |
|
"tokenizer": None, |
|
"reward_model": None, |
|
} |
|
|
|
def get_model_name(full_selection): |
|
return full_selection.split(" - ")[0] |
|
|
|
|
|
|
|
def user_input(user_message, history): |
|
return "", history + [(user_message, None)] |
|
|
|
|
|
class RewardModel(object): |
|
def __init__(self, model_name): |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
self.reward_model = AutoModelForSequenceClassification.from_pretrained(model_name, device_map=self.device).to('cuda') |
|
self.reward_tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base") |
|
|
|
|
|
def score(self, text): |
|
inputs = self.reward_tokenizer(text, truncation=True, return_tensors='pt').to(self.device) |
|
with torch.no_grad(): |
|
value = self.reward_model(**inputs).logits[0, 0].item() |
|
|
|
return value |
|
|
|
|
|
|
|
STYLE_TEMPLATE_PROMPT = """Below is an instruction describing the task, combined with input data that provides further context. Write a response that completes the request accordingly. |
|
|
|
### Instruction: |
|
Write down the text from the input data in the style of the author {}. |
|
|
|
### Input data: |
|
{} |
|
|
|
### Answer: |
|
{}""" |
|
|
|
def generate( |
|
model, |
|
tokenizer, |
|
author: str, |
|
text: str, |
|
temperature: float = 0.7, |
|
top_p: float = 0.9, |
|
top_k: int = 50, |
|
repetition_penalty: float = 1.1, |
|
do_sample: bool = True, |
|
**kwargs |
|
) -> str: |
|
input_text = STYLE_TEMPLATE_PROMPT.format(author, text, "") |
|
inputs = tokenizer(input_text, return_tensors="pt").to('cuda') |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=2048, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=do_sample, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
**kwargs |
|
) |
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
if generated_text.startswith(input_text): |
|
generated_text = generated_text[len(input_text):].strip() |
|
|
|
return generated_text |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
def bot_response(history, model_selection, max_tokens, temperature, top_k, top_p, repetition_penalty): |
|
""" |
|
Generate AI response to user input |
|
""" |
|
cancel_event.clear() |
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = get_model_name(model_selection) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
""" |
|
Load and cache a transformers pipeline for text generation. |
|
Tries bfloat16, falls back to float16 or float32 if unsupported. |
|
""" |
|
|
|
load_kwargs = { |
|
"pretrained_model_name_or_path": MODELS[model_name]["repo_id"], |
|
"device_map": "auto", |
|
"torch_dtype": torch.float16, |
|
"trust_remote_code": True |
|
} |
|
|
|
if CACHE["model_name"] == model_name: |
|
tokenizer = CACHE["tokenizer"] |
|
model = CACHE["model"] |
|
reward_model = CACHE["reward_model"] |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained(MODELS[model_name]["base_model"]) |
|
model = AutoModelForCausalLM.from_pretrained(**load_kwargs).to("cuda") |
|
reward_model = RewardModel(model_name=MODELS[model_name]["reward_repo_id"]) |
|
CACHE["model_name"] = model_name |
|
CACHE["tokenizer"] = tokenizer |
|
CACHE["model"] = model |
|
CACHE["reward_model"] = reward_model |
|
|
|
|
|
author = MODELS[model_name]["author"] |
|
|
|
user_message = history[-1][0] |
|
|
|
results = [] |
|
for i in range(3): |
|
results.append(generate(model, tokenizer, author, user_message, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)) |
|
|
|
response = max(results, key=reward_model.score) |
|
|
|
|
|
history[-1] = (user_message, response) |
|
return history |
|
except Exception as e: |
|
history[-1] = (user_message, f"Error: {e}") |
|
return history |
|
finally: |
|
gc.collect() |
|
|
|
|
|
|
|
|
|
|
|
def clear_chat(): |
|
return [] |
|
|
|
|
|
css = """ |
|
.gradio-container { |
|
background-color: #f5f7fb !important; |
|
} |
|
.qwen-header { |
|
background: linear-gradient(90deg, #0099FF, #0066CC); |
|
padding: 20px; |
|
border-radius: 10px; |
|
margin-bottom: 20px; |
|
text-align: center; |
|
color: white; |
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); |
|
} |
|
.qwen-container { |
|
border-radius: 10px; |
|
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); |
|
background: white; |
|
padding: 20px; |
|
margin-bottom: 20px; |
|
} |
|
.controls-container { |
|
background: #f0f4fa; |
|
border-radius: 10px; |
|
padding: 15px; |
|
margin-bottom: 15px; |
|
} |
|
.model-select { |
|
border: 2px solid #0099FF !important; |
|
border-radius: 8px !important; |
|
} |
|
.button-primary { |
|
background-color: #0099FF !important; |
|
color: white !important; |
|
} |
|
.button-secondary { |
|
background-color: #6c757d !important; |
|
color: white !important; |
|
} |
|
.footer { |
|
text-align: center; |
|
margin-top: 20px; |
|
font-size: 0.8em; |
|
color: #666; |
|
} |
|
""" |
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Chat", css=css) as demo: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
with gr.Group(elem_classes="qwen-container"): |
|
model_dd = gr.Dropdown( |
|
label="Select Model", |
|
choices=[f"{k} - {v['description']}" for k, v in MODELS.items()], |
|
value=f"{list(MODELS.keys())[0]} - {MODELS[list(MODELS.keys())[0]]['description']}", |
|
elem_classes="model-select" |
|
) |
|
|
|
with gr.Group(elem_classes="controls-container"): |
|
gr.Markdown("### Generation Parameters") |
|
with gr.Row(): |
|
max_tok = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens") |
|
with gr.Row(): |
|
temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") |
|
p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P") |
|
with gr.Row(): |
|
k = gr.Slider(1, 100, value=40, step=1, label="Top-K") |
|
rp = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty") |
|
|
|
clear_btn = gr.Button("Clear Chat", elem_classes="button-secondary") |
|
|
|
with gr.Column(scale=7): |
|
chatbot = gr.Chatbot() |
|
with gr.Row(): |
|
txt = gr.Textbox( |
|
show_label=False, |
|
placeholder="Type your message here...", |
|
lines=2 |
|
) |
|
submit_btn = gr.Button("Send", variant="primary", elem_classes="button-primary") |
|
|
|
gr.HTML(""" |
|
<div class="footer"> |
|
<p>Interface powered by Gradio and ZeroGPU.</p> |
|
</div> |
|
""") |
|
|
|
|
|
submit_btn.click( |
|
user_input, |
|
inputs=[txt, chatbot], |
|
outputs=[txt, chatbot], |
|
queue=False |
|
).then( |
|
bot_response, |
|
inputs=[chatbot, model_dd, max_tok, temp, k, p, rp], |
|
outputs=chatbot, |
|
api_name="generate" |
|
) |
|
|
|
txt.submit( |
|
user_input, |
|
inputs=[txt, chatbot], |
|
outputs=[txt, chatbot], |
|
queue=False |
|
).then( |
|
bot_response, |
|
inputs=[chatbot, model_dd, max_tok, temp, k, p, rp], |
|
outputs=chatbot, |
|
api_name="generate" |
|
) |
|
|
|
clear_btn.click( |
|
clear_chat, |
|
outputs=[chatbot], |
|
queue=False |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|