|
import os |
|
import torch |
|
import spaces |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from peft import PeftModel |
|
from peft.utils import get_peft_model_state_dict |
|
|
|
|
|
iBASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.1" |
|
ADAPTER_REPOS = { |
|
"witty": "ai1-test/mixtral-lora-witty", |
|
"charming": "ai1-test/mixtral-lora-charming", |
|
"sarcastic": "ai1-test/mixtral-lora-sarcastic", |
|
"neutral": "ai1-test/mixtral-lora-neutral", |
|
} |
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
def _get_auth_kwargs(): |
|
return {"use_auth_token": HF_TOKEN} if HF_TOKEN else {} |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(iBASE_MODEL, **_get_auth_kwargs()) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
_delta_cache = {} |
|
|
|
|
|
current_weights = {"witty": 0.0, "charming": 0.0, "sarcastic": 0.0, "neutral": 1.0} |
|
|
|
def apply_weights(witty, charming, sarcastic, neutral): |
|
vals = [0.0 if v is None else max(float(v), 0.0) for v in (witty, charming, sarcastic, neutral)] |
|
total = sum(vals) |
|
if total <= 0.0: |
|
normalised = [0.0, 0.0, 0.0, 1.0] |
|
else: |
|
normalised = [v / total for v in vals] |
|
current_weights["witty"], current_weights["charming"], current_weights["sarcastic"], current_weights["neutral"] = normalised |
|
return "Trait weights updated." |
|
|
|
def load_delta(trait): |
|
if trait in _delta_cache: |
|
return _delta_cache[trait] |
|
repo_id = ADAPTER_REPOS[trait] |
|
base = AutoModelForCausalLM.from_pretrained(iBASE_MODEL, device_map="cpu", **_get_auth_kwargs()) |
|
lora_model = PeftModel.from_pretrained(base, repo_id) |
|
delta = get_peft_model_state_dict(lora_model) |
|
_delta_cache[trait] = delta |
|
return delta |
|
|
|
@spaces.GPU |
|
def generate_response(prompt): |
|
weights = current_weights.copy() |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
iBASE_MODEL, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
**_get_auth_kwargs(), |
|
) |
|
merged_delta = {} |
|
for trait, weight in weights.items(): |
|
delta = load_delta(trait) |
|
for key, tensor in delta.items(): |
|
merged_delta[key] = merged_delta.get(key, 0) + tensor * weight |
|
base_model.load_state_dict(merged_delta, strict=False) |
|
input_ids = tokenizer(prompt, return_tensors="pt").to(base_model.device) |
|
with torch.no_grad(): |
|
output_ids = base_model.generate(**input_ids, max_new_tokens=128, do_sample=False) |
|
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
if response.startswith(prompt): |
|
response = response[len(prompt):].lstrip("\n\r ") |
|
response = response.split("\n")[0].strip() |
|
return response |
|
|
|
|
|
def handle_generate(prompt, history): |
|
if history is None: |
|
history = [] |
|
reply = generate_response(prompt) |
|
new_history = history + [(prompt, reply)] |
|
log_lines = [] |
|
for user_msg, bot_msg in new_history: |
|
log_lines.append(f"**You:** {user_msg}\n**Bot:** {bot_msg}") |
|
formatted_history = "\n\n".join(log_lines) |
|
return reply, new_history, formatted_history |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# AIVA: Customized Personality Trait Chatbot") |
|
gr.Markdown( |
|
"Enter a prompt and adjust the weights to blend witty, charming, sarcastic and neutral traits.\n" |
|
"Set your desired weights once, then generate multiple responses without re‑entering them." |
|
) |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown("## Trait Weights") |
|
witty_input = gr.Number(value=current_weights["witty"], minimum=0.0, maximum=1.0, label="Witty", precision=2) |
|
charming_input = gr.Number(value=current_weights["charming"], minimum=0.0, maximum=1.0, label="Charming", precision=2) |
|
sarcastic_input= gr.Number(value=current_weights["sarcastic"],minimum=0.0, maximum=1.0, label="Sarcastic", precision=2) |
|
neutral_input = gr.Number(value=current_weights["neutral"], minimum=0.0, maximum=1.0, label="Neutral", precision=2) |
|
apply_button = gr.Button("Apply trait weights") |
|
status_message = gr.Markdown("") |
|
|
|
|
|
with gr.Row(): |
|
witty_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=0, label="Witty") |
|
charming_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=0, label="Charming") |
|
sarcastic_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=0, label="Sarcastic") |
|
neutral_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=1, label="Neutral") |
|
|
|
update_btn = gr.Button("Update Weights") |
|
output = gr.Textbox(label="Status") |
|
|
|
update_btn.click( |
|
fn=apply_weights, |
|
inputs=[witty_slider, charming_slider, sarcastic_slider, neutral_slider], |
|
outputs=output, |
|
) |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown("## Chat") |
|
output_box = gr.Textbox(label="Model response", lines=6) |
|
chat_history_box = gr.Markdown("") |
|
prompt_input = gr.Textbox(label="Your prompt", lines=2) |
|
generate_button = gr.Button("Generate") |
|
|
|
|
|
|
|
conversation_state = gr.State([]) |
|
|
|
apply_button.click( |
|
apply_weights, |
|
inputs=[witty_input, charming_input, sarcastic_input, neutral_input], |
|
outputs=status_message, |
|
) |
|
generate_button.click( |
|
handle_generate, |
|
inputs=[prompt_input, conversation_state], |
|
outputs=[output_box, conversation_state, chat_history_box], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |