chat / app.py
Erik Hallros
Updated chat history
cba0a61
import gradio as gr # type: ignore
import os # type: ignore
import numpy as np #type: ignore
from dotenv import load_dotenv
from transformers import AutoTokenizer # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from huggingface_hub import InferenceClient, login, HfApi, whoami # type: ignore
from gradio.components import ChatMessage # type: ignore
from typing import List, TypedDict
import json
from datetime import datetime
import time
import uuid
import tempfile
import shutil
import hashlib # Added for password hashing
class Message(TypedDict):
role: str
content: str
if os.path.exists('.env'):
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
client = InferenceClient("https://xk54gqdcp97za8n6.us-east-1.aws.endpoints.huggingface.cloud")
model = SentenceTransformer('all-MiniLM-L6-v2') # You can choose other models depending on your needs
MAX_HISTORY_LENGTH = 5000 # Keep the last 10 exchanges
MAX_TOKENS = 128000 # Token limit for your model (check your model's max tokens)
EMBEDDING_DIM = 384 # Dimension of embeddings, specific to the model you use (e.g., for 'all-MiniLM-L6-v2', it's 384)
login(token=hf_token)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Nemo-Base-2407") # Ersätt med din egen modell om det behövs
def load_persona():
try:
with open("profile.md", "r", encoding="utf-8") as profile_file:
profile_content = profile_file.read()
with open("instructions.md", "r", encoding="utf-8") as instructions_file:
instructions_content = instructions_file.read()
# Combine profile and instructions with blank lines in between
content = profile_content + "\n\n" + instructions_content
return content
except FileNotFoundError as e:
print(f"Warning: File not found: {e.filename}. Using default persona.")
return """Act and roleplay as a literal horse."""
# Preloaded conversation state (initial history)
system_message: List[Message] = [Message(role="system", content=load_persona())]
# Add this after the existing imports
CHAT_HISTORY_DIR = "/data/chat_history"
os.makedirs(CHAT_HISTORY_DIR, exist_ok=True)
# Generate a unique session ID when the app starts
SESSION_ID = f"chat_session_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
# Create a temporary directory for file operations
TEMP_DIR = tempfile.mkdtemp()
# Global authentication state
is_user_authenticated = False
# Hashed password - this is a hash of "password123" using SHA-256
HASHED_PASSWORD = "f75778f7425be4db0369d09af37a6c2b9a83dea0e53e7bd57412e4b060e607f7"
def get_session_file():
"""Get the current session's chat history file path."""
return os.path.join(CHAT_HISTORY_DIR, f"{SESSION_ID}.json")
def save_chat_history(history: List[Message]):
"""Save chat history to the session file."""
filename = get_session_file()
# Convert history to a serializable format
serializable_history = [
{"role": msg["role"], "content": msg["content"]}
for msg in history
]
# Create a backup of the previous file if it exists
if os.path.exists(filename):
backup_filename = f"{filename}.bak"
os.replace(filename, backup_filename)
try:
with open(filename, "w", encoding="utf-8") as f:
json.dump(serializable_history, f, ensure_ascii=False, indent=2)
except Exception as e:
# If saving fails, restore from backup
if os.path.exists(f"{filename}.bak"):
os.replace(f"{filename}.bak", filename)
raise e
def load_chat_history(session_id: str) -> List[Message]:
"""Load chat history from a file."""
filename = os.path.join(CHAT_HISTORY_DIR, f"chat_history_{session_id}.json")
if os.path.exists(filename):
with open(filename, "r", encoding="utf-8") as f:
history = json.load(f)
return [Message(**msg) for msg in history]
return []
# Add these constants at the top with your other constants
DATA_FOLDER = "/data/chat_history"
HF_TOKEN = os.getenv("HF_TOKEN")
def is_authenticated():
"""Check if the user is authenticated."""
# Use the global authentication state
return is_user_authenticated
def list_chat_history_files():
"""List all chat history files in the data folder."""
if not is_authenticated():
return []
try:
files = [f for f in os.listdir(DATA_FOLDER) if f.endswith('.json')]
# Sort files by modification time, newest first
files.sort(key=lambda x: os.path.getmtime(os.path.join(DATA_FOLDER, x)), reverse=True)
return files
except Exception:
return []
def download_chat_history(filename):
"""Download a specific chat history file."""
if not is_authenticated():
return None
if not filename: # Handle case when no file is selected
return None
source_path = os.path.join(DATA_FOLDER, filename)
if os.path.exists(source_path):
# Copy the file to temp directory first
temp_path = os.path.join(TEMP_DIR, filename)
shutil.copy2(source_path, temp_path)
return temp_path
return None
def verify_password(password):
"""Verify password and update authentication state."""
global is_user_authenticated
# Hash the provided password using SHA-256
hashed_input = hashlib.sha256(password.encode()).hexdigest()
# Compare with the stored hash
if hashed_input == HASHED_PASSWORD:
is_user_authenticated = True
return "**Status:** ✅ Authenticated"
else:
is_user_authenticated = False
return "**Status:** ❌ Authentication failed"
def logout():
"""Log the user out by resetting the authentication state."""
global is_user_authenticated
is_user_authenticated = False
return "**Status:** ❌ Not authenticated", gr.update(visible=False)
# Create a Gradio interface
with gr.Blocks() as iface:
# Your existing chat interface components
chatbot_output = gr.Chatbot(label="Chat History", type="messages")
chatbot_input = gr.Textbox(placeholder="Type your message here...", label="Your Message")
def update_file_list():
"""Update the file list and return the updated dropdown."""
return gr.update(choices=list_chat_history_files())
# Authentication status indicator
# Create a lock icon button that expands into a password field
with gr.Group() as auth_group:
auth_lock_btn = gr.Button("🔒", elem_id="auth_lock_btn", scale=0)
with gr.Group("Enter Password", visible=False, elem_id="auth_accordion") as auth_accordion:
auth_password = gr.Textbox(
type="password",
placeholder="Enter your password",
label="Password",
elem_id="auth_password",
visible=not is_authenticated()
)
auth_submit = gr.Button("Login", elem_id="auth_submit", visible=not is_authenticated())
auth_status = gr.Markdown(
value="**Status:** " + ("✅ Authenticated" if is_authenticated() else "❌ Not authenticated"),
elem_id="auth_status"
)
auth_logout = gr.Button("Logout", elem_id="auth_logout", visible=is_authenticated())
# Toggle visibility of password field when lock button is clicked
auth_lock_btn.click(
fn=lambda: gr.update(visible=True),
outputs=[auth_accordion]
)
# Add download section (only visible when authenticated)
with gr.Group(visible=is_authenticated()) as download_section:
gr.Markdown("### Download Chat History")
file_list = gr.Dropdown(
choices=list_chat_history_files(),
label="Select a chat history file",
interactive=True,
allow_custom_value=False
)
download_button = gr.Button("Download Chat History")
download_output = gr.File(label="Downloaded file")
# Update file list when refresh button is clicked
refresh_button = gr.Button("Refresh File List")
refresh_button.click(
fn=update_file_list,
outputs=file_list
)
# Handle password submission
auth_submit.click(
fn=verify_password,
inputs=[auth_password],
outputs=[auth_status]
)
# Clear password field after submission
auth_submit.click(
fn=lambda: "",
outputs=[auth_password]
)
# Update download section visibility based on authentication status
auth_submit.click(
fn=lambda: gr.update(visible=is_authenticated()),
outputs=[download_section]
)
# Handle logout button
auth_logout.click(
fn=logout,
outputs=[auth_status, download_section]
)
# Update logout button visibility based on authentication status
auth_submit.click(
fn=lambda: gr.update(visible=is_authenticated()),
outputs=[auth_logout]
)
auth_submit.click(
fn=update_file_list,
outputs=file_list
)
# Update file list when dropdown selection changes
file_list.select(
fn=update_file_list,
outputs=file_list
)
def generate_embeddings(messages: List[str]):
"""Generate embeddings for the list of messages."""
embeddings = model.encode(messages, show_progress_bar=False)
return embeddings
def summarize_conversation(conversation: List[Message]):
"""Summarize conversation history into a single embedding."""
# Extract the text content from the conversation
messages = [msg['content'] for msg in conversation]
# Generate embeddings for the entire conversation
conversation_embeddings = generate_embeddings(messages)
# Return the average of all embeddings (this is a simple approach for compacting)
#compact_representation = np.mean(conversation_embeddings, axis=0)
#return compact_representation
return conversation_embeddings
def count_tokens(messages: List[str]) -> int:
"""Beräkna det totala antalet tokens i konversationen."""
return sum(len(tokenizer.encode(message)) for message in messages)
def get_chat_completion(system_message, history, retry_attempt=0, max_retries=3):
"""Get chat completion from the model with retry logic for 503 errors."""
try:
# Common parameters
params = {
"model": "openerotica/writing-roleplay-20k-context-nemo-12b-v1.0-gguf",
"messages": [*system_message, *history],
"stream_options": {"enabled": True},
"stream": True,
"frequency_penalty": 1.0,
"max_tokens": 2048,
"n": 1,
"presence_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0
}
return client.chat_completion(**params)
except Exception as e:
if hasattr(e.response, 'status_code') and "503" in str(e.response.status_code):
if retry_attempt < max_retries:
message = f"Agent is asleep, waking up... Trying again in 3 minutes... (Attempt {retry_attempt + 1}/{max_retries})"
gr.Warning(message, duration=180)
time.sleep(180)
gr.Info("Retrying...")
return get_chat_completion(system_message, history, retry_attempt + 1, max_retries)
else:
gr.Error(f"Max retries ({max_retries}) reached. Giving up.")
return None
else:
gr.Error(f"Error getting chat completion: {e}")
if retry_attempt < max_retries:
gr.Warning(f"Retrying after error... (Attempt {retry_attempt + 1}/{max_retries})", duration=10)
time.sleep(10) # Wait a bit before retrying after an error
return get_chat_completion(system_message, history, retry_attempt + 1, max_retries)
return None
def user(user_message, history: List[Message]):
new_history = history + [Message(role="user", content=user_message)]
save_chat_history(new_history)
return "", new_history
def bot(history: list):
#compact_history = summarize_conversation(preloaded_history)
#compact_history = preloaded_history[-MAX_HISTORY_LENGTH:]
#conversation = [msg["content"] for msg in compact_history]
session_conversation = [msg["content"] for msg in history]
system_context = [msg["content"] for msg in system_message]
total_tokens = count_tokens(session_conversation) + count_tokens(system_context)
#total_tokens = count_tokens(conversation) + session_tokens
print(f"Total tokens: {total_tokens}")
# Kolla om tokenräkningen överskrider gränsen (igen)
if total_tokens > MAX_TOKENS:
print("Token limit exceeded. Truncating history.")
while (count_tokens([msg["content"] for msg in history]) + total_tokens) > MAX_TOKENS:
history.pop(0) # Ta bort det äldsta meddelandet
response = get_chat_completion(system_message, history)
if response:
# Initialize bot_message
bot_message = ""
history.append(Message(role="assistant", content=""))
for chunk in response:
# Debugging: Log the received chunk
if 'choices' in chunk and chunk['choices']:
choice = chunk['choices'][0]
if choice.get('delta') and choice['delta'].get('content'):
# Append the new content to bot_message
bot_message += choice['delta']['content']
history[-1]['content'] = bot_message
yield history
save_chat_history(history)
# Add download functionality
download_button.click(
fn=download_chat_history,
inputs=file_list,
outputs=download_output
)
chatbot_input.submit(user, [chatbot_input, chatbot_output], [chatbot_input, chatbot_output], queue=False).then(
bot, chatbot_output, chatbot_output
)
if __name__ == "__main__":
iface.launch(
allowed_paths=[DATA_FOLDER, TEMP_DIR] # Add both the data folder and temp directory to allowed paths
)