Spaces:
Running
Running
# chat_manager.py - Chat Session Management System | |
import json | |
import os | |
import uuid | |
from dataclasses import dataclass, asdict | |
from typing import List, Optional, Dict, Any | |
from datetime import datetime | |
from pathlib import Path | |
class ChatMessage: | |
"""Individual chat message structure""" | |
message_id: str | |
role: str # 'user' or 'assistant' | |
content: str | |
timestamp: str | |
rating: Optional[int] = None # 1 for thumbs up, -1 for thumbs down, None for no rating | |
is_bookmarked: bool = False | |
source_documents: List[str] = None | |
def __post_init__(self): | |
if self.source_documents is None: | |
self.source_documents = [] | |
class ChatSession: | |
"""Chat session data structure""" | |
session_id: str | |
user_id: str | |
title: str | |
created_at: str | |
updated_at: str | |
messages: List[ChatMessage] = None | |
is_archived: bool = False | |
tags: List[str] = None | |
def __post_init__(self): | |
if self.messages is None: | |
self.messages = [] | |
if self.tags is None: | |
self.tags = [] | |
class ChatManager: | |
"""Manages chat sessions and messages""" | |
def __init__(self, data_dir: str): | |
self.data_dir = Path(data_dir) | |
self.data_dir.mkdir(exist_ok=True) | |
self.sessions_file = self.data_dir / "sessions.json" | |
self.ensure_sessions_file() | |
def ensure_sessions_file(self): | |
"""Ensure sessions file exists""" | |
if not self.sessions_file.exists(): | |
with open(self.sessions_file, 'w') as f: | |
json.dump({}, f) | |
def create_session(self, user_id: str, title: str = None) -> str: | |
"""Create a new chat session""" | |
session_id = str(uuid.uuid4()) | |
timestamp = datetime.now().isoformat() | |
if not title: | |
title = f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M')}" | |
session = ChatSession( | |
session_id=session_id, | |
user_id=user_id, | |
title=title, | |
created_at=timestamp, | |
updated_at=timestamp | |
) | |
try: | |
sessions = self.load_all_sessions() | |
sessions[session_id] = asdict(session) | |
with open(self.sessions_file, 'w') as f: | |
json.dump(sessions, f, indent=2) | |
return session_id | |
except Exception as e: | |
raise Exception(f"Failed to create session: {str(e)}") | |
def load_all_sessions(self) -> Dict[str, Dict]: | |
"""Load all sessions from storage""" | |
try: | |
with open(self.sessions_file, 'r') as f: | |
return json.load(f) | |
except (FileNotFoundError, json.JSONDecodeError): | |
return {} | |
def get_session(self, session_id: str) -> Optional[ChatSession]: | |
"""Get chat session by ID""" | |
sessions = self.load_all_sessions() | |
session_data = sessions.get(session_id) | |
if session_data: | |
# Convert message dictionaries back to ChatMessage objects | |
messages = [] | |
for msg_data in session_data.get('messages', []): | |
messages.append(ChatMessage(**msg_data)) | |
session_data['messages'] = messages | |
return ChatSession(**session_data) | |
return None | |
def get_user_sessions(self, user_id: str, include_archived: bool = False) -> List[ChatSession]: | |
"""Get all sessions for a user""" | |
sessions = self.load_all_sessions() | |
user_sessions = [] | |
for session_data in sessions.values(): | |
if session_data.get('user_id') == user_id: | |
if include_archived or not session_data.get('is_archived', False): | |
# Convert message dictionaries back to ChatMessage objects | |
messages = [] | |
for msg_data in session_data.get('messages', []): | |
messages.append(ChatMessage(**msg_data)) | |
session_data['messages'] = messages | |
user_sessions.append(ChatSession(**session_data)) | |
# Sort by updated_at descending | |
user_sessions.sort(key=lambda x: x.updated_at, reverse=True) | |
return user_sessions | |
def add_message(self, session_id: str, role: str, content: str, source_documents: List[str] = None) -> str: | |
"""Add a message to a chat session""" | |
message_id = str(uuid.uuid4()) | |
timestamp = datetime.now().isoformat() | |
message = ChatMessage( | |
message_id=message_id, | |
role=role, | |
content=content, | |
timestamp=timestamp, | |
source_documents=source_documents or [] | |
) | |
try: | |
sessions = self.load_all_sessions() | |
if session_id not in sessions: | |
raise ValueError(f"Session {session_id} not found") | |
# Convert message to dict for storage | |
message_dict = asdict(message) | |
sessions[session_id]['messages'].append(message_dict) | |
sessions[session_id]['updated_at'] = timestamp | |
with open(self.sessions_file, 'w') as f: | |
json.dump(sessions, f, indent=2) | |
return message_id | |
except Exception as e: | |
raise Exception(f"Failed to add message: {str(e)}") | |
def rate_message(self, session_id: str, message_id: str, rating: int) -> bool: | |
"""Rate a message (1 for thumbs up, -1 for thumbs down)""" | |
try: | |
sessions = self.load_all_sessions() | |
if session_id not in sessions: | |
return False | |
for message in sessions[session_id]['messages']: | |
if message['message_id'] == message_id: | |
message['rating'] = rating | |
sessions[session_id]['updated_at'] = datetime.now().isoformat() | |
with open(self.sessions_file, 'w') as f: | |
json.dump(sessions, f, indent=2) | |
return True | |
return False | |
except Exception: | |
return False | |
def bookmark_message(self, session_id: str, message_id: str, is_bookmarked: bool = True) -> bool: | |
"""Bookmark or unbookmark a message""" | |
try: | |
sessions = self.load_all_sessions() | |
if session_id not in sessions: | |
return False | |
for message in sessions[session_id]['messages']: | |
if message['message_id'] == message_id: | |
message['is_bookmarked'] = is_bookmarked | |
sessions[session_id]['updated_at'] = datetime.now().isoformat() | |
with open(self.sessions_file, 'w') as f: | |
json.dump(sessions, f, indent=2) | |
return True | |
return False | |
except Exception: | |
return False | |
def get_bookmarked_messages(self, user_id: str) -> List[Dict[str, Any]]: | |
"""Get all bookmarked messages for a user""" | |
sessions = self.load_all_sessions() | |
bookmarked = [] | |
for session_data in sessions.values(): | |
if session_data.get('user_id') == user_id: | |
for message in session_data.get('messages', []): | |
if message.get('is_bookmarked', False): | |
bookmarked.append({ | |
'session_id': session_data['session_id'], | |
'session_title': session_data['title'], | |
'message': message, | |
'timestamp': message['timestamp'] | |
}) | |
# Sort by timestamp descending | |
bookmarked.sort(key=lambda x: x['timestamp'], reverse=True) | |
return bookmarked | |
def update_session_title(self, session_id: str, title: str) -> bool: | |
"""Update session title""" | |
try: | |
sessions = self.load_all_sessions() | |
if session_id not in sessions: | |
return False | |
sessions[session_id]['title'] = title | |
sessions[session_id]['updated_at'] = datetime.now().isoformat() | |
with open(self.sessions_file, 'w') as f: | |
json.dump(sessions, f, indent=2) | |
return True | |
except Exception: | |
return False | |
def archive_session(self, session_id: str, is_archived: bool = True) -> bool: | |
"""Archive or unarchive a session""" | |
try: | |
sessions = self.load_all_sessions() | |
if session_id not in sessions: | |
return False | |
sessions[session_id]['is_archived'] = is_archived | |
sessions[session_id]['updated_at'] = datetime.now().isoformat() | |
with open(self.sessions_file, 'w') as f: | |
json.dump(sessions, f, indent=2) | |
return True | |
except Exception: | |
return False | |
def delete_session(self, session_id: str) -> bool: | |
"""Delete a chat session""" | |
try: | |
sessions = self.load_all_sessions() | |
if session_id in sessions: | |
del sessions[session_id] | |
with open(self.sessions_file, 'w') as f: | |
json.dump(sessions, f, indent=2) | |
return True | |
return False | |
except Exception: | |
return False | |
def export_chat_history(self, user_id: str, session_id: str = None) -> Dict[str, Any]: | |
"""Export chat history for a user or specific session""" | |
if session_id: | |
session = self.get_session(session_id) | |
if session and session.user_id == user_id: | |
return { | |
'export_type': 'single_session', | |
'session': asdict(session), | |
'exported_at': datetime.now().isoformat() | |
} | |
else: | |
sessions = self.get_user_sessions(user_id, include_archived=True) | |
return { | |
'export_type': 'all_sessions', | |
'sessions': [asdict(session) for session in sessions], | |
'exported_at': datetime.now().isoformat(), | |
'total_sessions': len(sessions) | |
} | |
return {} | |
def get_chat_statistics(self, user_id: str) -> Dict[str, Any]: | |
"""Get chat statistics for a user""" | |
sessions = self.get_user_sessions(user_id, include_archived=True) | |
total_messages = 0 | |
total_user_messages = 0 | |
total_assistant_messages = 0 | |
bookmarked_count = 0 | |
rated_messages = {'positive': 0, 'negative': 0} | |
for session in sessions: | |
total_messages += len(session.messages) | |
for message in session.messages: | |
if message.role == 'user': | |
total_user_messages += 1 | |
else: | |
total_assistant_messages += 1 | |
if message.is_bookmarked: | |
bookmarked_count += 1 | |
if message.rating == 1: | |
rated_messages['positive'] += 1 | |
elif message.rating == -1: | |
rated_messages['negative'] += 1 | |
return { | |
'total_sessions': len(sessions), | |
'total_messages': total_messages, | |
'user_messages': total_user_messages, | |
'assistant_messages': total_assistant_messages, | |
'bookmarked_messages': bookmarked_count, | |
'message_ratings': rated_messages, | |
'average_messages_per_session': total_messages / len(sessions) if sessions else 0 | |
} |