AIVA-Demo / app.py
ai1-test's picture
Update app.py
d6bbb64 verified
import os
import torch
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from peft.utils import get_peft_model_state_dict
# Base model and adapters
iBASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.1"
ADAPTER_REPOS = {
"witty": "ai1-test/mixtral-lora-witty",
"charming": "ai1-test/mixtral-lora-charming",
"sarcastic": "ai1-test/mixtral-lora-sarcastic",
"neutral": "ai1-test/mixtral-lora-neutral",
}
HF_TOKEN = os.environ.get("HF_TOKEN")
def _get_auth_kwargs():
return {"use_auth_token": HF_TOKEN} if HF_TOKEN else {}
# Tokeniser
tokenizer = AutoTokenizer.from_pretrained(iBASE_MODEL, **_get_auth_kwargs())
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# LoRA state cache
_delta_cache = {}
# Persistent trait weights
current_weights = {"witty": 0.0, "charming": 0.0, "sarcastic": 0.0, "neutral": 1.0}
def apply_weights(witty, charming, sarcastic, neutral):
vals = [0.0 if v is None else max(float(v), 0.0) for v in (witty, charming, sarcastic, neutral)]
total = sum(vals)
if total <= 0.0:
normalised = [0.0, 0.0, 0.0, 1.0]
else:
normalised = [v / total for v in vals]
current_weights["witty"], current_weights["charming"], current_weights["sarcastic"], current_weights["neutral"] = normalised
return "Trait weights updated."
def load_delta(trait):
if trait in _delta_cache:
return _delta_cache[trait]
repo_id = ADAPTER_REPOS[trait]
base = AutoModelForCausalLM.from_pretrained(iBASE_MODEL, device_map="cpu", **_get_auth_kwargs())
lora_model = PeftModel.from_pretrained(base, repo_id)
delta = get_peft_model_state_dict(lora_model)
_delta_cache[trait] = delta
return delta
@spaces.GPU
def generate_response(prompt):
weights = current_weights.copy()
base_model = AutoModelForCausalLM.from_pretrained(
iBASE_MODEL,
torch_dtype=torch.float16,
device_map="auto",
**_get_auth_kwargs(),
)
merged_delta = {}
for trait, weight in weights.items():
delta = load_delta(trait)
for key, tensor in delta.items():
merged_delta[key] = merged_delta.get(key, 0) + tensor * weight
base_model.load_state_dict(merged_delta, strict=False)
input_ids = tokenizer(prompt, return_tensors="pt").to(base_model.device)
with torch.no_grad():
output_ids = base_model.generate(**input_ids, max_new_tokens=128, do_sample=False)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
if response.startswith(prompt):
response = response[len(prompt):].lstrip("\n\r ")
response = response.split("\n")[0].strip()
return response
# Chat handler to maintain history
def handle_generate(prompt, history):
if history is None:
history = []
reply = generate_response(prompt)
new_history = history + [(prompt, reply)]
log_lines = []
for user_msg, bot_msg in new_history:
log_lines.append(f"**You:** {user_msg}\n**Bot:** {bot_msg}")
formatted_history = "\n\n".join(log_lines)
return reply, new_history, formatted_history
# Build the Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# AIVA: Customized Personality Trait Chatbot")
gr.Markdown(
"Enter a prompt and adjust the weights to blend witty, charming, sarcastic and neutral traits.\n"
"Set your desired weights once, then generate multiple responses without re‑entering them."
)
# Trait weights
with gr.Column():
gr.Markdown("## Trait Weights")
witty_input = gr.Number(value=current_weights["witty"], minimum=0.0, maximum=1.0, label="Witty", precision=2)
charming_input = gr.Number(value=current_weights["charming"], minimum=0.0, maximum=1.0, label="Charming", precision=2)
sarcastic_input= gr.Number(value=current_weights["sarcastic"],minimum=0.0, maximum=1.0, label="Sarcastic", precision=2)
neutral_input = gr.Number(value=current_weights["neutral"], minimum=0.0, maximum=1.0, label="Neutral", precision=2)
apply_button = gr.Button("Apply trait weights")
status_message = gr.Markdown("")
# Gradio UI
with gr.Row():
witty_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=0, label="Witty")
charming_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=0, label="Charming")
sarcastic_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=0, label="Sarcastic")
neutral_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=1, label="Neutral")
update_btn = gr.Button("Update Weights")
output = gr.Textbox(label="Status")
update_btn.click(
fn=apply_weights,
inputs=[witty_slider, charming_slider, sarcastic_slider, neutral_slider],
outputs=output,
)
# Chat area
with gr.Column():
gr.Markdown("## Chat")
output_box = gr.Textbox(label="Model response", lines=6)
chat_history_box = gr.Markdown("")
prompt_input = gr.Textbox(label="Your prompt", lines=2)
generate_button = gr.Button("Generate")
# Conversation state to hold the history
conversation_state = gr.State([])
apply_button.click(
apply_weights,
inputs=[witty_input, charming_input, sarcastic_input, neutral_input],
outputs=status_message,
)
generate_button.click(
handle_generate,
inputs=[prompt_input, conversation_state],
outputs=[output_box, conversation_state, chat_history_box],
)
if __name__ == "__main__":
demo.launch()