Spaces:
Sleeping
Sleeping
import gradio as gr # type: ignore | |
import os # type: ignore | |
import numpy as np #type: ignore | |
from dotenv import load_dotenv | |
from transformers import AutoTokenizer # type: ignore | |
from sentence_transformers import SentenceTransformer # type: ignore | |
from huggingface_hub import InferenceClient, login, HfApi, whoami # type: ignore | |
from gradio.components import ChatMessage # type: ignore | |
from typing import List, TypedDict | |
import json | |
from datetime import datetime | |
import time | |
import uuid | |
import tempfile | |
import shutil | |
import hashlib # Added for password hashing | |
class Message(TypedDict): | |
role: str | |
content: str | |
if os.path.exists('.env'): | |
load_dotenv() | |
hf_token = os.getenv("HF_TOKEN") | |
""" | |
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference | |
""" | |
client = InferenceClient("https://xk54gqdcp97za8n6.us-east-1.aws.endpoints.huggingface.cloud") | |
model = SentenceTransformer('all-MiniLM-L6-v2') # You can choose other models depending on your needs | |
MAX_HISTORY_LENGTH = 5000 # Keep the last 10 exchanges | |
MAX_TOKENS = 128000 # Token limit for your model (check your model's max tokens) | |
EMBEDDING_DIM = 384 # Dimension of embeddings, specific to the model you use (e.g., for 'all-MiniLM-L6-v2', it's 384) | |
login(token=hf_token) | |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Nemo-Base-2407") # Ersätt med din egen modell om det behövs | |
def load_persona(): | |
try: | |
with open("profile.md", "r", encoding="utf-8") as profile_file: | |
profile_content = profile_file.read() | |
with open("instructions.md", "r", encoding="utf-8") as instructions_file: | |
instructions_content = instructions_file.read() | |
# Combine profile and instructions with blank lines in between | |
content = profile_content + "\n\n" + instructions_content | |
return content | |
except FileNotFoundError as e: | |
print(f"Warning: File not found: {e.filename}. Using default persona.") | |
return """Act and roleplay as a literal horse.""" | |
# Preloaded conversation state (initial history) | |
system_message: List[Message] = [Message(role="system", content=load_persona())] | |
# Add this after the existing imports | |
CHAT_HISTORY_DIR = "/data/chat_history" | |
os.makedirs(CHAT_HISTORY_DIR, exist_ok=True) | |
# Generate a unique session ID when the app starts | |
SESSION_ID = f"chat_session_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" | |
# Create a temporary directory for file operations | |
TEMP_DIR = tempfile.mkdtemp() | |
# Global authentication state | |
is_user_authenticated = False | |
# Hashed password - this is a hash of "password123" using SHA-256 | |
HASHED_PASSWORD = "f75778f7425be4db0369d09af37a6c2b9a83dea0e53e7bd57412e4b060e607f7" | |
def get_session_file(): | |
"""Get the current session's chat history file path.""" | |
return os.path.join(CHAT_HISTORY_DIR, f"{SESSION_ID}.json") | |
def save_chat_history(history: List[Message]): | |
"""Save chat history to the session file.""" | |
filename = get_session_file() | |
# Convert history to a serializable format | |
serializable_history = [ | |
{"role": msg["role"], "content": msg["content"]} | |
for msg in history | |
] | |
# Create a backup of the previous file if it exists | |
if os.path.exists(filename): | |
backup_filename = f"{filename}.bak" | |
os.replace(filename, backup_filename) | |
try: | |
with open(filename, "w", encoding="utf-8") as f: | |
json.dump(serializable_history, f, ensure_ascii=False, indent=2) | |
except Exception as e: | |
# If saving fails, restore from backup | |
if os.path.exists(f"{filename}.bak"): | |
os.replace(f"{filename}.bak", filename) | |
raise e | |
def load_chat_history(session_id: str) -> List[Message]: | |
"""Load chat history from a file.""" | |
filename = os.path.join(CHAT_HISTORY_DIR, f"chat_history_{session_id}.json") | |
if os.path.exists(filename): | |
with open(filename, "r", encoding="utf-8") as f: | |
history = json.load(f) | |
return [Message(**msg) for msg in history] | |
return [] | |
# Add these constants at the top with your other constants | |
DATA_FOLDER = "/data/chat_history" | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
def is_authenticated(): | |
"""Check if the user is authenticated.""" | |
# Use the global authentication state | |
return is_user_authenticated | |
def list_chat_history_files(): | |
"""List all chat history files in the data folder.""" | |
if not is_authenticated(): | |
return [] | |
try: | |
files = [f for f in os.listdir(DATA_FOLDER) if f.endswith('.json')] | |
# Sort files by modification time, newest first | |
files.sort(key=lambda x: os.path.getmtime(os.path.join(DATA_FOLDER, x)), reverse=True) | |
return files | |
except Exception: | |
return [] | |
def download_chat_history(filename): | |
"""Download a specific chat history file.""" | |
if not is_authenticated(): | |
return None | |
if not filename: # Handle case when no file is selected | |
return None | |
source_path = os.path.join(DATA_FOLDER, filename) | |
if os.path.exists(source_path): | |
# Copy the file to temp directory first | |
temp_path = os.path.join(TEMP_DIR, filename) | |
shutil.copy2(source_path, temp_path) | |
return temp_path | |
return None | |
def verify_password(password): | |
"""Verify password and update authentication state.""" | |
global is_user_authenticated | |
# Hash the provided password using SHA-256 | |
hashed_input = hashlib.sha256(password.encode()).hexdigest() | |
# Compare with the stored hash | |
if hashed_input == HASHED_PASSWORD: | |
is_user_authenticated = True | |
return "**Status:** ✅ Authenticated" | |
else: | |
is_user_authenticated = False | |
return "**Status:** ❌ Authentication failed" | |
def logout(): | |
"""Log the user out by resetting the authentication state.""" | |
global is_user_authenticated | |
is_user_authenticated = False | |
return "**Status:** ❌ Not authenticated", gr.update(visible=False) | |
# Create a Gradio interface | |
with gr.Blocks() as iface: | |
# Your existing chat interface components | |
chatbot_output = gr.Chatbot(label="Chat History", type="messages") | |
chatbot_input = gr.Textbox(placeholder="Type your message here...", label="Your Message") | |
def update_file_list(): | |
"""Update the file list and return the updated dropdown.""" | |
return gr.update(choices=list_chat_history_files()) | |
# Authentication status indicator | |
# Create a lock icon button that expands into a password field | |
with gr.Group() as auth_group: | |
auth_lock_btn = gr.Button("🔒", elem_id="auth_lock_btn", scale=0) | |
with gr.Group("Enter Password", visible=False, elem_id="auth_accordion") as auth_accordion: | |
auth_password = gr.Textbox( | |
type="password", | |
placeholder="Enter your password", | |
label="Password", | |
elem_id="auth_password", | |
visible=not is_authenticated() | |
) | |
auth_submit = gr.Button("Login", elem_id="auth_submit", visible=not is_authenticated()) | |
auth_status = gr.Markdown( | |
value="**Status:** " + ("✅ Authenticated" if is_authenticated() else "❌ Not authenticated"), | |
elem_id="auth_status" | |
) | |
auth_logout = gr.Button("Logout", elem_id="auth_logout", visible=is_authenticated()) | |
# Toggle visibility of password field when lock button is clicked | |
auth_lock_btn.click( | |
fn=lambda: gr.update(visible=True), | |
outputs=[auth_accordion] | |
) | |
# Add download section (only visible when authenticated) | |
with gr.Group(visible=is_authenticated()) as download_section: | |
gr.Markdown("### Download Chat History") | |
file_list = gr.Dropdown( | |
choices=list_chat_history_files(), | |
label="Select a chat history file", | |
interactive=True, | |
allow_custom_value=False | |
) | |
download_button = gr.Button("Download Chat History") | |
download_output = gr.File(label="Downloaded file") | |
# Update file list when refresh button is clicked | |
refresh_button = gr.Button("Refresh File List") | |
refresh_button.click( | |
fn=update_file_list, | |
outputs=file_list | |
) | |
# Handle password submission | |
auth_submit.click( | |
fn=verify_password, | |
inputs=[auth_password], | |
outputs=[auth_status] | |
) | |
# Clear password field after submission | |
auth_submit.click( | |
fn=lambda: "", | |
outputs=[auth_password] | |
) | |
# Update download section visibility based on authentication status | |
auth_submit.click( | |
fn=lambda: gr.update(visible=is_authenticated()), | |
outputs=[download_section] | |
) | |
# Handle logout button | |
auth_logout.click( | |
fn=logout, | |
outputs=[auth_status, download_section] | |
) | |
# Update logout button visibility based on authentication status | |
auth_submit.click( | |
fn=lambda: gr.update(visible=is_authenticated()), | |
outputs=[auth_logout] | |
) | |
auth_submit.click( | |
fn=update_file_list, | |
outputs=file_list | |
) | |
# Update file list when dropdown selection changes | |
file_list.select( | |
fn=update_file_list, | |
outputs=file_list | |
) | |
def generate_embeddings(messages: List[str]): | |
"""Generate embeddings for the list of messages.""" | |
embeddings = model.encode(messages, show_progress_bar=False) | |
return embeddings | |
def summarize_conversation(conversation: List[Message]): | |
"""Summarize conversation history into a single embedding.""" | |
# Extract the text content from the conversation | |
messages = [msg['content'] for msg in conversation] | |
# Generate embeddings for the entire conversation | |
conversation_embeddings = generate_embeddings(messages) | |
# Return the average of all embeddings (this is a simple approach for compacting) | |
#compact_representation = np.mean(conversation_embeddings, axis=0) | |
#return compact_representation | |
return conversation_embeddings | |
def count_tokens(messages: List[str]) -> int: | |
"""Beräkna det totala antalet tokens i konversationen.""" | |
return sum(len(tokenizer.encode(message)) for message in messages) | |
def get_chat_completion(system_message, history, retry_attempt=0, max_retries=3): | |
"""Get chat completion from the model with retry logic for 503 errors.""" | |
try: | |
# Common parameters | |
params = { | |
"model": "openerotica/writing-roleplay-20k-context-nemo-12b-v1.0-gguf", | |
"messages": [*system_message, *history], | |
"stream_options": {"enabled": True}, | |
"stream": True, | |
"frequency_penalty": 1.0, | |
"max_tokens": 2048, | |
"n": 1, | |
"presence_penalty": 1.0, | |
"temperature": 1.0, | |
"top_p": 1.0 | |
} | |
return client.chat_completion(**params) | |
except Exception as e: | |
if hasattr(e.response, 'status_code') and "503" in str(e.response.status_code): | |
if retry_attempt < max_retries: | |
message = f"Agent is asleep, waking up... Trying again in 3 minutes... (Attempt {retry_attempt + 1}/{max_retries})" | |
gr.Warning(message, duration=180) | |
time.sleep(180) | |
gr.Info("Retrying...") | |
return get_chat_completion(system_message, history, retry_attempt + 1, max_retries) | |
else: | |
gr.Error(f"Max retries ({max_retries}) reached. Giving up.") | |
return None | |
else: | |
gr.Error(f"Error getting chat completion: {e}") | |
if retry_attempt < max_retries: | |
gr.Warning(f"Retrying after error... (Attempt {retry_attempt + 1}/{max_retries})", duration=10) | |
time.sleep(10) # Wait a bit before retrying after an error | |
return get_chat_completion(system_message, history, retry_attempt + 1, max_retries) | |
return None | |
def user(user_message, history: List[Message]): | |
new_history = history + [Message(role="user", content=user_message)] | |
save_chat_history(new_history) | |
return "", new_history | |
def bot(history: list): | |
#compact_history = summarize_conversation(preloaded_history) | |
#compact_history = preloaded_history[-MAX_HISTORY_LENGTH:] | |
#conversation = [msg["content"] for msg in compact_history] | |
session_conversation = [msg["content"] for msg in history] | |
system_context = [msg["content"] for msg in system_message] | |
total_tokens = count_tokens(session_conversation) + count_tokens(system_context) | |
#total_tokens = count_tokens(conversation) + session_tokens | |
print(f"Total tokens: {total_tokens}") | |
# Kolla om tokenräkningen överskrider gränsen (igen) | |
if total_tokens > MAX_TOKENS: | |
print("Token limit exceeded. Truncating history.") | |
while (count_tokens([msg["content"] for msg in history]) + total_tokens) > MAX_TOKENS: | |
history.pop(0) # Ta bort det äldsta meddelandet | |
response = get_chat_completion(system_message, history) | |
if response: | |
# Initialize bot_message | |
bot_message = "" | |
history.append(Message(role="assistant", content="")) | |
for chunk in response: | |
# Debugging: Log the received chunk | |
if 'choices' in chunk and chunk['choices']: | |
choice = chunk['choices'][0] | |
if choice.get('delta') and choice['delta'].get('content'): | |
# Append the new content to bot_message | |
bot_message += choice['delta']['content'] | |
history[-1]['content'] = bot_message | |
yield history | |
save_chat_history(history) | |
# Add download functionality | |
download_button.click( | |
fn=download_chat_history, | |
inputs=file_list, | |
outputs=download_output | |
) | |
chatbot_input.submit(user, [chatbot_input, chatbot_output], [chatbot_input, chatbot_output], queue=False).then( | |
bot, chatbot_output, chatbot_output | |
) | |
if __name__ == "__main__": | |
iface.launch( | |
allowed_paths=[DATA_FOLDER, TEMP_DIR] # Add both the data folder and temp directory to allowed paths | |
) |