File size: 7,218 Bytes
9382e09
00aed3f
5051379
8789930
8be2014
 
 
5051379
b521f74
c4051f3
8be2014
9382e09
 
 
 
8be2014
5051379
c53aed4
 
 
 
 
8789930
c53aed4
8be2014
 
5051379
b521f74
8be2014
b521f74
c9a54dd
8be2014
 
 
3db3df5
8be2014
 
 
 
 
3db3df5
8789930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c53aed4
 
 
 
 
 
8be2014
8789930
8be2014
 
 
c9a54dd
 
 
b7a61e5
c9a54dd
8be2014
8789930
 
 
 
 
 
 
b521f74
 
5d776f2
 
 
 
 
 
 
 
 
 
 
8789930
b521f74
8789930
b521f74
 
 
5051379
5d776f2
 
8be2014
8789930
 
8be2014
2edad2b
8789930
2edad2b
 
5d776f2
 
 
 
8789930
5d776f2
8789930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6bbb64
b521f74
2edad2b
 
8789930
2edad2b
8789930
 
 
5051379
8789930
 
 
 
 
5051379
 
5d776f2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from peft.utils import get_peft_model_state_dict

# Base model and adapters
iBASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.1"
ADAPTER_REPOS = {
    "witty": "ai1-test/mixtral-lora-witty",
    "charming": "ai1-test/mixtral-lora-charming",
    "sarcastic": "ai1-test/mixtral-lora-sarcastic",
    "neutral": "ai1-test/mixtral-lora-neutral",
}

HF_TOKEN = os.environ.get("HF_TOKEN")

def _get_auth_kwargs():
    return {"use_auth_token": HF_TOKEN} if HF_TOKEN else {}

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(iBASE_MODEL, **_get_auth_kwargs())
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# LoRA state cache
_delta_cache = {}

def load_delta(trait):
    if trait in _delta_cache:
        return _delta_cache[trait]
    repo_id = ADAPTER_REPOS[trait]
    base = AutoModelForCausalLM.from_pretrained(iBASE_MODEL, device_map="cpu", **_get_auth_kwargs())
    lora_model = PeftModel.from_pretrained(base, repo_id)
    delta = get_peft_model_state_dict(lora_model)
    _delta_cache[trait] = delta
    return delta

@spaces.GPU
def generate_response(prompt, weights):
    """
    Generate a response using a merged delta based on per-trait weights.
    weights: dict with keys "witty", "charming", "sarcastic", "neutral" and values in [0,1], summing to 1 (normalized).
    If weights is None or invalid, a sane default is used.
    """
    if weights is None:
        weights = {"witty": 0.0, "charming": 0.0, "sarcastic": 0.0, "neutral": 1.0}
    # Normalize/validate
    w_vals = [
        0.0 if v is None else max(float(v), 0.0)
        for v in (weights.get("witty"), weights.get("charming"), weights.get("sarcastic"), weights.get("neutral"))
    ]
    total = sum(w_vals)
    if total <= 0.0:
        normalised = [0.0, 0.0, 0.0, 1.0]
    else:
        normalised = [v / total for v in w_vals]

    weight_map = {
        "witty": normalised[0],
        "charming": normalised[1],
        "sarcastic": normalised[2],
        "neutral": normalised[3],
    }

    # Merge deltas
    merged_delta = {}
    for trait, weight in weight_map.items():
        delta = load_delta(trait)
        for key, tensor in delta.items():
            merged_delta[key] = merged_delta.get(key, 0) + tensor * weight

    # Load base model fresh and apply merged_delta
    base_model = AutoModelForCausalLM.from_pretrained(
        iBASE_MODEL,
        torch_dtype=torch.float16,
        device_map="auto",
        **_get_auth_kwargs(),
    )
    base_model.load_state_dict(merged_delta, strict=False)

    input_ids = tokenizer(prompt, return_tensors="pt").to(base_model.device)
    with torch.no_grad():
        output_ids = base_model.generate(**input_ids, max_new_tokens=128, do_sample=False)
    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    if response.startswith(prompt):
        response = response[len(prompt):].lstrip("\n\r ")
    response = response.split("\n")[0].strip()
    return response

def format_history(history):
    log_lines = []
    for user_msg, bot_msg in history:
        log_lines.append(f"**You:** {user_msg}\n**Bot:** {bot_msg}")
    return "\n\n".join(log_lines)

def handle_generate(prompt, history, weights_state):
    if history is None:
        history = []
    # weights_state may be a dict (per-session weights) or None
    if isinstance(weights_state, dict):
        w = weights_state
    else:
        # Fallback default if Gradio passes None or an unexpected type
        w = {
            "witty": 0.0,
            "charming": 0.0,
            "sarcastic": 0.0,
            "neutral": 1.0,
        }
    reply = generate_response(prompt, w)
    new_history = history + [(prompt, reply)]
    formatted_history = format_history(new_history)
    return reply, new_history, formatted_history

# Build the Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# AIVA: Customized Personality Trait Chatbot (slider-based per-session)")

    gr.Markdown(
        "Use sliders to blend witty, charming, sarcastic and neutral traits. "
        "The weights are per-session; you can generate multiple responses in the same session."
    )

    # Per-session weight state
    with gr.Column():
        gr.Markdown("## Trait Weights")
        witty_slider     = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.00, label="Witty")
        charming_slider  = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.00, label="Charming")
        sarcastic_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.00, label="Sarcastic")
        neutral_slider   = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.00, label="Neutral")

        # Per-session state: a dict of normalized weights
        weights_state = gr.State({
            "witty": 0.0,
            "charming": 0.0,
            "sarcastic": 0.0,
            "neutral": 1.0
        })

        # Update weights_state on slider changes
        def update_weights_from_sliders(witty, charming, sarcastic, neutral, state):
            vals = [
                max(0.0, float(witty)),
                max(0.0, float(charming)),
                max(0.0, float(sarcastic)),
                max(0.0, float(neutral)),
            ]
            total = sum(vals)
            if total <= 0.0:
                normalised = [0.0, 0.0, 0.0, 1.0]
            else:
                normalised = [v / total for v in vals]
            state.value = {
                "witty": normalised[0],
                "charming": normalised[1],
                "sarcastic": normalised[2],
                "neutral": normalised[3],
            }
            return state

        witty_slider.change(
            fn=update_weights_from_sliders,
            inputs=[witty_slider, charming_slider, sarcastic_slider, neutral_slider, weights_state],
            outputs=weights_state
        )
        charming_slider.change(
            fn=update_weights_from_sliders,
            inputs=[witty_slider, charming_slider, sarcastic_slider, neutral_slider, weights_state],
            outputs=weights_state
        )
        sarcastic_slider.change(
            fn=update_weights_from_sliders,
            inputs=[witty_slider, charming_slider, sarcastic_slider, neutral_slider, weights_state],
            outputs=weights_state
        )
        neutral_slider.change(
            fn=update_weights_from_sliders,
            inputs=[witty_slider, charming_slider, sarcastic_slider, neutral_slider, weights_state],
            outputs=weights_state
        )

    # Chat area
    with gr.Column():
        gr.Markdown("## Chat")
        prompt_input = gr.Textbox(label="Your prompt", lines=2)
        generate_button = gr.Button("Generate")
        output_box = gr.Textbox(label="Model response", lines=6)
        chat_history_box = gr.Markdown("")
        conversation_state = gr.State([])

        generate_button.click(
            fn=handle_generate,
            inputs=[prompt_input, conversation_state, weights_state],
            outputs=[output_box, conversation_state, chat_history_box],
        )

if __name__ == "__main__":
    demo.launch()