|
import gradio as gr |
|
import torch |
|
import torch._dynamo |
|
import pandas as pd |
|
import datetime |
|
import json |
|
import os |
|
from peft import PeftModel |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from huggingface_hub import login |
|
import spaces |
|
|
|
hf_token = os.getenv("gemma_access") |
|
if hf_token: |
|
login(hf_token) |
|
else: |
|
raise RuntimeError("Missing access token. Add it under Space Settings > Secrets.") |
|
torch._dynamo.config.suppress_errors = True |
|
torch._dynamo.config.disable = True |
|
base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it") |
|
model = PeftModel.from_pretrained(base_model, "Nourivex/noura-2b-it-lora") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
tokenizer = AutoTokenizer.from_pretrained("Nourivex/noura-2b-it-lora") |
|
|
|
""" |
|
This part is for the setting of the character of the chatbot. |
|
""" |
|
def save_character(name, personality, background): |
|
character_data = { |
|
"name": name, |
|
"personality": personality, |
|
"background": background |
|
} |
|
|
|
with open("character_data.json", "w") as f: |
|
json.dump(character_data, f) |
|
return character_data |
|
|
|
def load_character(): |
|
try: |
|
with open("character_data.json", "r") as f: |
|
return json.load(f) |
|
except FileNotFoundError: |
|
return { |
|
"name": "Wanting", |
|
"personality": "friendly and helpful", |
|
"background": "Wanting is a software engineer who loves to code and build new things." |
|
} |
|
|
|
def save_rating(rating, chat_history): |
|
rating_data = { |
|
"rating": rating, |
|
"timestamp": datetime.datetime.now().isoformat(), |
|
"chat_history": chat_history |
|
} |
|
|
|
with open("ratings.json", "a") as f: |
|
json.dump(rating_data, f) |
|
f.write("\n") |
|
return rating_data |
|
|
|
def generate_system_message(character): |
|
return f"""You are roleplaying as {character['name']}. |
|
Your personality: {character['personality']} |
|
Your background: {character['background']} |
|
|
|
Important guidelines: |
|
1. Stay in character at all times |
|
2. Respond naturally and consistently with your personality |
|
3. Use appropriate emotional expressions based on your character |
|
4. Maintain your character's background and experiences in your responses |
|
5. If asked about something your character wouldn't know, respond appropriately in character |
|
6. You can add some reasonable settings to make yourself more vivid |
|
|
|
Now, begin the conversation as {character['name']}.""" |
|
|
|
@spaces.GPU(duration=120) |
|
def run(message, history, system_message): |
|
try: |
|
prompt = "" |
|
if not history: |
|
prompt += f"<|system|>\n{system_message}\n" |
|
|
|
for user_msg, assistant_msg in history: |
|
if user_msg: |
|
prompt += f"<|user|>\n{user_msg}\n" |
|
if assistant_msg: |
|
prompt += f"<|assistant|>\n{assistant_msg}\n" |
|
|
|
prompt += f"<|user|>\n{message}\n<|assistant|>\n" |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device) |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=64, |
|
temperature=0.6, |
|
top_p=0.95, |
|
do_sample=True, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
repetition_penalty=1.1, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=3 |
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
if "<|assistant|>" in response: |
|
response = response.split("<|assistant|>")[-1].strip() |
|
else: |
|
response = response.strip() |
|
response = (response.replace("print(", "") |
|
.replace(")", "") |
|
.replace("<|system|>", "") |
|
.replace("<|user|>", "") |
|
.replace("```", "") |
|
.replace("`", "") |
|
) |
|
response = " ".join(response.split()) |
|
return response |
|
|
|
except Exception as e: |
|
print(f"❌ Error in run function: {str(e)}") |
|
return "I apologize, but I'm having trouble generating a response right now. Please try again." |
|
|
|
def save_session(character, history, rating, feedback, filename="dataset.jsonl"): |
|
if not os.path.exists(filename): |
|
open(filename, "w").close() |
|
|
|
formatted = { |
|
"system": f"You are roleplaying as {character['name']}. Personality: {character['personality']} Background: {character['background']}", |
|
"conversation": [ |
|
{"user": u, "assistant": a} for u, a in history if u and a |
|
], |
|
"rating": rating, |
|
"feedback": feedback |
|
} |
|
|
|
with open(filename, "a") as f: |
|
f.write(json.dumps(formatted) + "\n") |
|
|
|
def preview_character(name, personality, background): |
|
preview_text = f""" |
|
<div class="character-preview"> |
|
<div class="character-card"> |
|
<div class="character-avatar">👤</div> |
|
<h3 class="character-name">{name}</h3> |
|
<div class="character-details"> |
|
<div class="detail-section"> |
|
<span class="detail-label">Personality:</span> |
|
<span class="detail-text">{personality}</span> |
|
</div> |
|
<div class="detail-section"> |
|
<span class="detail-label">Background:</span> |
|
<span class="detail-text">{background}</span> |
|
</div> |
|
</div> |
|
<div class="preview-footer"> |
|
Ready to start chatting with {name}! |
|
</div> |
|
</div> |
|
</div> |
|
""" |
|
return preview_text |
|
|
|
custom_css = """ |
|
/* Modern blue and white color scheme */ |
|
:root { |
|
--primary: #2563eb; |
|
--primary-light: #3b82f6; |
|
--primary-dark: #1d4ed8; |
|
--secondary: #f8fafc; |
|
--accent: #e0f2fe; |
|
--text: #1e293b; |
|
--text-light: #64748b; |
|
--border: #e2e8f0; |
|
--success: #10b981; |
|
} |
|
|
|
/* Global styling */ |
|
.gradio-container { |
|
background: white; |
|
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; |
|
color: var(--text); |
|
line-height: 1.5; |
|
} |
|
|
|
/* Main header */ |
|
#header-title { |
|
text-align: center; |
|
color: var(--primary); |
|
font-size: 2.5rem; |
|
font-weight: 700; |
|
margin: 20px 0 10px 0; |
|
background: linear-gradient(90deg, var(--primary), var(--primary-light)); |
|
-webkit-background-clip: text; |
|
background-clip: text; |
|
-webkit-text-fill-color: transparent; |
|
} |
|
|
|
#subtitle { |
|
text-align: center; |
|
color: var(--text-light); |
|
font-size: 1.1rem; |
|
margin-bottom: 30px; |
|
max-width: 600px; |
|
margin-left: auto; |
|
margin-right: auto; |
|
} |
|
|
|
/* Character creation section */ |
|
.gr-group { |
|
background: white; |
|
border-radius: 12px; |
|
box-shadow: 0 4px 6px rgba(0,0,0,0.05); |
|
border: 1px solid var(--border); |
|
margin: 20px auto; |
|
max-width: 900px; |
|
transition: all 0.2s ease; |
|
} |
|
|
|
.gr-group:hover { |
|
box-shadow: 0 10px 15px rgba(0,0,0,0.1); |
|
} |
|
|
|
.gr-form { |
|
padding: 30px; |
|
} |
|
|
|
/* Input fields */ |
|
.gr-textbox, .gr-textbox textarea { |
|
border: 1px solid var(--border); |
|
border-radius: 8px; |
|
background: white; |
|
font-family: inherit; |
|
padding: 12px 16px; |
|
transition: all 0.2s ease; |
|
margin-bottom: 16px; |
|
} |
|
|
|
.gr-textbox:focus, .gr-textbox textarea:focus { |
|
border-color: var(--primary); |
|
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1); |
|
outline: none; |
|
} |
|
|
|
/* Button styling */ |
|
#preview_btn, #save_btn, #start_btn { |
|
background: var(--primary); |
|
color: white; |
|
border: none; |
|
border-radius: 8px; |
|
padding: 12px 24px; |
|
font-size: 1rem; |
|
font-weight: 500; |
|
cursor: pointer; |
|
transition: all 0.2s ease; |
|
margin: 8px 0; |
|
} |
|
|
|
#preview_btn:hover, #save_btn:hover, #start_btn:hover { |
|
background: var(--primary-dark); |
|
transform: translateY(-1px); |
|
} |
|
|
|
#start_btn { |
|
background: var(--primary-dark); |
|
font-size: 1.1rem; |
|
padding: 14px 28px; |
|
margin-top: 16px; |
|
} |
|
|
|
/* Character preview card */ |
|
.character-preview { |
|
padding: 20px; |
|
} |
|
|
|
.character-card { |
|
background: white; |
|
border: 1px solid var(--border); |
|
border-radius: 12px; |
|
padding: 24px; |
|
box-shadow: 0 4px 6px rgba(0,0,0,0.05); |
|
} |
|
|
|
.character-avatar { |
|
font-size: 3rem; |
|
margin-bottom: 16px; |
|
color: var(--primary); |
|
} |
|
|
|
.character-name { |
|
color: var(--primary); |
|
font-size: 1.5rem; |
|
margin-bottom: 16px; |
|
font-weight: 600; |
|
} |
|
|
|
.detail-section { |
|
margin: 12px 0; |
|
text-align: left; |
|
padding: 12px; |
|
background: var(--accent); |
|
border-radius: 8px; |
|
} |
|
|
|
.detail-label { |
|
font-weight: 600; |
|
color: var(--primary); |
|
display: block; |
|
margin-bottom: 4px; |
|
font-size: 0.9rem; |
|
} |
|
|
|
.detail-text { |
|
color: var(--text); |
|
font-size: 0.95rem; |
|
} |
|
|
|
.preview-footer { |
|
margin-top: 16px; |
|
color: var(--primary); |
|
font-weight: 500; |
|
font-size: 0.95rem; |
|
text-align: center; |
|
padding-top: 12px; |
|
border-top: 1px solid var(--border); |
|
} |
|
|
|
/* Chat interface styling */ |
|
.chat-header { |
|
background: var(--primary); |
|
color: white; |
|
padding: 16px; |
|
border-radius: 12px 12px 0 0; |
|
font-weight: 500; |
|
font-size: 1.1rem; |
|
} |
|
|
|
#chatbox { |
|
background: white; |
|
border: 1px solid var(--border); |
|
border-radius: 0 0 12px 12px; |
|
min-height: 500px; |
|
} |
|
|
|
/* Chat messages */ |
|
.message.user { |
|
background: var(--accent) !important; |
|
border-radius: 12px 12px 0 12px !important; |
|
border: 1px solid var(--border) !important; |
|
color: var(--text) !important; |
|
padding: 12px 16px !important; |
|
} |
|
|
|
.message.bot { |
|
background: white !important; |
|
border-radius: 12px 12px 12px 0 !important; |
|
border: 1px solid var(--border) !important; |
|
color: var(--text) !important; |
|
padding: 12px 16px !important; |
|
} |
|
|
|
/* Finish button */ |
|
#finish_btn { |
|
background: var(--primary); |
|
color: white; |
|
border: none; |
|
border-radius: 8px; |
|
padding: 12px 24px; |
|
font-size: 1rem; |
|
font-weight: 500; |
|
margin-top: 20px; |
|
cursor: pointer; |
|
transition: all 0.2s ease; |
|
} |
|
|
|
#finish_btn:hover { |
|
background: var(--primary-dark); |
|
transform: translateY(-1px); |
|
} |
|
|
|
/* Rating section */ |
|
.rating-header { |
|
text-align: center; |
|
font-size: 1.5rem; |
|
color: var(--primary); |
|
font-weight: 600; |
|
margin-bottom: 24px; |
|
} |
|
|
|
#rating-slider { |
|
margin: 20px 0; |
|
} |
|
|
|
#feedback-box textarea { |
|
border: 1px solid var(--border); |
|
border-radius: 8px; |
|
background: white; |
|
padding: 12px 16px; |
|
min-height: 120px; |
|
} |
|
|
|
#submit-rating-btn { |
|
background: var(--primary); |
|
color: white; |
|
border: none; |
|
border-radius: 8px; |
|
padding: 12px 24px; |
|
font-size: 1rem; |
|
font-weight: 500; |
|
cursor: pointer; |
|
transition: all 0.2s ease; |
|
margin: 16px auto; |
|
display: block; |
|
} |
|
|
|
#submit-rating-btn:hover { |
|
background: var(--primary-dark); |
|
transform: translateY(-1px); |
|
} |
|
|
|
/* Responsive design */ |
|
@media (max-width: 768px) { |
|
#header-title { |
|
font-size: 2rem !important; |
|
} |
|
|
|
.gr-group { |
|
margin: 10px; |
|
} |
|
|
|
.gr-form { |
|
padding: 20px; |
|
} |
|
} |
|
|
|
/* Animations */ |
|
.fade-in { |
|
animation: fadeIn 0.3s ease-in; |
|
} |
|
|
|
@keyframes fadeIn { |
|
from { opacity: 0; transform: translateY(10px); } |
|
to { opacity: 1; transform: translateY(0); } |
|
} |
|
|
|
/* Loading animation */ |
|
@keyframes spin { |
|
0% { transform: rotate(0deg); } |
|
100% { transform: rotate(360deg); } |
|
} |
|
""" |
|
|
|
with gr.Blocks(title="Roleplay Chatbot", css=custom_css) as demo: |
|
state = gr.State(load_character()) |
|
gr.HTML(""" |
|
<div id="header-title">Roleplay AI</div> |
|
<div id="subtitle">Create and chat with custom AI characters. Powered by fine-tuned Gemma 2B model.</div> |
|
""") |
|
|
|
with gr.Group(visible=True, elem_classes=["fade-in"]) as tab_setup: |
|
gr.HTML(""" |
|
<div style="text-align: center; margin-bottom: 30px;"> |
|
<h2 style="color: var(--primary); font-size: 1.8rem; margin-bottom: 8px;"> |
|
Character Settings |
|
</h2> |
|
<p style="color: var(--text-light); margin-top: 0;"> |
|
Define your character's personality and background |
|
</p> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
name = gr.Textbox( |
|
label="Character Name", |
|
placeholder="Enter your character's name", |
|
elem_classes=["character-input"] |
|
) |
|
personality = gr.Textbox( |
|
label="Personality Traits", |
|
placeholder="Describe personality traits (e.g., friendly, witty, mysterious)", |
|
elem_classes=["character-input"] |
|
) |
|
background = gr.Textbox( |
|
label="Background Story", |
|
lines=4, |
|
placeholder="Describe your character's background and history", |
|
elem_classes=["character-input"] |
|
) |
|
with gr.Row(): |
|
btn_preview = gr.Button("Preview Character", elem_id="preview_btn") |
|
btn_save = gr.Button("Save Character", elem_id="save_btn") |
|
|
|
btn_start = gr.Button("Start Chatting", elem_id="start_btn", size="lg") |
|
|
|
with gr.Column(scale=1): |
|
preview = gr.HTML( |
|
""" |
|
<div style="text-align: center; padding: 40px; color: var(--text-light);"> |
|
<div style="font-size: 3rem; margin-bottom: 20px; color: var(--primary);">👤</div> |
|
<p>Preview will appear here</p> |
|
</div> |
|
""", |
|
elem_classes=["preview-placeholder"] |
|
) |
|
|
|
btn_preview.click(preview_character, [name, personality, background], preview) |
|
btn_save.click(save_character, [name, personality, background], state) |
|
|
|
with gr.Group(visible=False, elem_classes=["fade-in"]) as tab_chat: |
|
gr.HTML('<div class="chat-header">Chat with your character</div>') |
|
chatbot = gr.ChatInterface( |
|
run, |
|
additional_inputs=[ |
|
gr.Textbox(value="", label="System message", visible=False), |
|
], |
|
chatbot=gr.Chatbot( |
|
elem_id="chatbox", |
|
height=500, |
|
show_label=False, |
|
container=True |
|
), |
|
) |
|
|
|
btn_finish = gr.Button("Finish Chat", elem_id="finish_btn") |
|
|
|
with gr.Group(visible=False, elem_classes=["fade-in"]) as tab_rating: |
|
gr.HTML(""" |
|
<div style="text-align: center; margin-bottom: 30px;"> |
|
<div class="rating-header">Feedback</div> |
|
<p style="color: var(--text-light);">How was your experience?</p> |
|
</div> |
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
rating = gr.Slider( |
|
minimum=1, |
|
maximum=5, |
|
step=1, |
|
label="Rating (1-5)", |
|
elem_id="rating-slider", |
|
value=3 |
|
) |
|
feedback = gr.Textbox( |
|
label="Additional Feedback (optional)", |
|
placeholder="Share your thoughts about the experience...", |
|
lines=4, |
|
elem_id="feedback-box" |
|
) |
|
btn_submit_rating = gr.Button("Submit Feedback", elem_id="submit-rating-btn") |
|
|
|
gr.HTML(""" |
|
<div style="text-align: center; margin-top: 20px; color: var(--text-light);"> |
|
Thank you for your feedback! |
|
</div> |
|
""") |
|
|
|
|
|
def start_chat_with_character(character): |
|
system_msg = generate_system_message(character) |
|
return (system_msg, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)) |
|
|
|
btn_start.click( |
|
fn=start_chat_with_character, |
|
inputs=state, |
|
outputs=[chatbot.additional_inputs[0], tab_setup, tab_chat, tab_rating] |
|
) |
|
btn_finish.click( |
|
lambda: [gr.update(visible=False), gr.update(visible=False),gr.update(visible=True)], |
|
inputs=[], |
|
outputs=[tab_setup, tab_chat, tab_rating] |
|
) |
|
btn_submit_rating.click( |
|
fn=lambda r, f, h, c:( |
|
save_rating(r, h), |
|
save_session(generate_system_message(c), h, r, f), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False) |
|
), |
|
inputs=[rating, feedback, chatbot.chatbot_state, state], |
|
outputs=[gr.State(), tab_setup, tab_chat, tab_rating] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |