import os import torch import spaces import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel from peft.utils import get_peft_model_state_dict # Base model and adapters iBASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.1" ADAPTER_REPOS = { "witty": "ai1-test/mixtral-lora-witty", "charming": "ai1-test/mixtral-lora-charming", "sarcastic": "ai1-test/mixtral-lora-sarcastic", "neutral": "ai1-test/mixtral-lora-neutral", } HF_TOKEN = os.environ.get("HF_TOKEN") def _get_auth_kwargs(): return {"use_auth_token": HF_TOKEN} if HF_TOKEN else {} # Tokeniser tokenizer = AutoTokenizer.from_pretrained(iBASE_MODEL, **_get_auth_kwargs()) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # LoRA state cache _delta_cache = {} # Persistent trait weights current_weights = {"witty": 0.0, "charming": 0.0, "sarcastic": 0.0, "neutral": 1.0} def apply_weights(witty, charming, sarcastic, neutral): vals = [0.0 if v is None else max(float(v), 0.0) for v in (witty, charming, sarcastic, neutral)] total = sum(vals) if total <= 0.0: normalised = [0.0, 0.0, 0.0, 1.0] else: normalised = [v / total for v in vals] current_weights["witty"], current_weights["charming"], current_weights["sarcastic"], current_weights["neutral"] = normalised return "Trait weights updated." def load_delta(trait): if trait in _delta_cache: return _delta_cache[trait] repo_id = ADAPTER_REPOS[trait] base = AutoModelForCausalLM.from_pretrained(iBASE_MODEL, device_map="cpu", **_get_auth_kwargs()) lora_model = PeftModel.from_pretrained(base, repo_id) delta = get_peft_model_state_dict(lora_model) _delta_cache[trait] = delta return delta @spaces.GPU def generate_response(prompt): weights = current_weights.copy() base_model = AutoModelForCausalLM.from_pretrained( iBASE_MODEL, torch_dtype=torch.float16, device_map="auto", **_get_auth_kwargs(), ) merged_delta = {} for trait, weight in weights.items(): delta = load_delta(trait) for key, tensor in delta.items(): merged_delta[key] = merged_delta.get(key, 0) + tensor * weight base_model.load_state_dict(merged_delta, strict=False) input_ids = tokenizer(prompt, return_tensors="pt").to(base_model.device) with torch.no_grad(): output_ids = base_model.generate(**input_ids, max_new_tokens=128, do_sample=False) response = tokenizer.decode(output_ids[0], skip_special_tokens=True) if response.startswith(prompt): response = response[len(prompt):].lstrip("\n\r ") response = response.split("\n")[0].strip() return response # Chat handler to maintain history def handle_generate(prompt, history): if history is None: history = [] reply = generate_response(prompt) new_history = history + [(prompt, reply)] log_lines = [] for user_msg, bot_msg in new_history: log_lines.append(f"**You:** {user_msg}\n**Bot:** {bot_msg}") formatted_history = "\n\n".join(log_lines) return reply, new_history, formatted_history # Build the Gradio UI with gr.Blocks() as demo: gr.Markdown("# AIVA: Customized Personality Trait Chatbot") gr.Markdown( "Enter a prompt and adjust the weights to blend witty, charming, sarcastic and neutral traits.\n" "Set your desired weights once, then generate multiple responses without re‑entering them." ) # Trait weights with gr.Column(): gr.Markdown("## Trait Weights") witty_input = gr.Number(value=current_weights["witty"], minimum=0.0, maximum=1.0, label="Witty", precision=2) charming_input = gr.Number(value=current_weights["charming"], minimum=0.0, maximum=1.0, label="Charming", precision=2) sarcastic_input= gr.Number(value=current_weights["sarcastic"],minimum=0.0, maximum=1.0, label="Sarcastic", precision=2) neutral_input = gr.Number(value=current_weights["neutral"], minimum=0.0, maximum=1.0, label="Neutral", precision=2) apply_button = gr.Button("Apply trait weights") status_message = gr.Markdown("") # Chat area with gr.Column(): gr.Markdown("## Chat") prompt_input = gr.Textbox(label="Your prompt", lines=2) generate_button = gr.Button("Generate") output_box = gr.Textbox(label="Model response", lines=6) chat_history_box = gr.Markdown("") # Conversation state to hold the history conversation_state = gr.State([]) apply_button.click( apply_weights, inputs=[witty_input, charming_input, sarcastic_input, neutral_input], outputs=status_message, ) generate_button.click( handle_generate, inputs=[prompt_input, conversation_state], outputs=[output_box, conversation_state, chat_history_box], ) if __name__ == "__main__": demo.launch()