Travel_Assistant / modules /session_manager.py
Eliot0110's picture
improve: session manage
49aff10
import uuid
import time
from typing import Dict, Any
from utils.logger import log
class SessionManager:
def __init__(self):
self.sessions: Dict[str, Dict[str, Any]] = {}
self.session_metadata: Dict[str, Dict[str, Any]] = {} # 元数据追跟
def get_or_create_session(self, session_id: str = None) -> Dict[str, Any]:
"""改进的会话获取/创建逻辑"""
# 如果提供了session_id且存在,直接返回
if session_id and session_id in self.sessions:
# 更新最后访问时间
self.session_metadata[session_id]['last_accessed'] = time.time()
log.info(f"📂 使用现有会话: {session_id}")
log.info(f"📊 会话状态: {self._get_session_summary(session_id)}")
return self.sessions[session_id]
# 创建新会话
new_session_id = str(uuid.uuid4())[:8]
current_time = time.time()
self.sessions[new_session_id] = {
"session_id": new_session_id,
"destination": None,
"duration": None,
"budget": None,
"persona": None,
"stage": "greeting",
"created_at": current_time,
"last_updated": current_time,
}
# 创建元数据
self.session_metadata[new_session_id] = {
"created_at": current_time,
"last_accessed": current_time,
"message_count": 0,
"frontend_chat_id": None, # 可以存储前端对话ID
"persona_type": None,
"completion_status": {
"destination": False,
"duration": False,
"budget": False,
"persona": False
}
}
if session_id:
log.info(f"⚠️ 会话ID {session_id} 不存在,创建新会话: {new_session_id}")
else:
log.info(f"🆕 创建新会话: {new_session_id}")
return self.sessions[new_session_id]
def update_session(self, session_id: str, updates: Dict[str, Any]):
"""增强的会话更新"""
if session_id in self.sessions:
# 更新会话数据
self.sessions[session_id].update(updates)
self.sessions[session_id]["last_updated"] = time.time()
# 更新元数据
if session_id in self.session_metadata:
metadata = self.session_metadata[session_id]
metadata["last_accessed"] = time.time()
metadata["message_count"] += 1
# 更新完成状态
session_data = self.sessions[session_id]
metadata["completion_status"] = {
"destination": session_data.get("destination") is not None,
"duration": session_data.get("duration") is not None,
"budget": session_data.get("budget") is not None,
"persona": session_data.get("persona") is not None
}
# 设置persona类型
if "persona" in updates and updates["persona"]:
metadata["persona_type"] = updates["persona"].get("key")
log.info(f"📝 更新会话 {session_id}: {list(updates.keys())}")
log.info(f"📊 更新后状态: {self._get_session_summary(session_id)}")
else:
log.error(f"❌ 尝试更新不存在的会话: {session_id}")
def _get_session_summary(self, session_id: str) -> str:
"""获取会话摘要"""
if session_id not in self.sessions:
return "会话不存在"
session = self.sessions[session_id]
metadata = self.session_metadata.get(session_id, {})
dest = session.get('destination', {}).get('name', '未设置') if session.get('destination') else '未设置'
duration = f"{session.get('duration', {}).get('days', '?')}天" if session.get('duration') else '未设置'
budget = session.get('budget', {}).get('description', '未设置') if session.get('budget') else '未设置'
persona = session.get('persona', {}).get('name', '未设置') if session.get('persona') else '未设置'
completed = sum(metadata.get('completion_status', {}).values())
return f"目的地:{dest}, 天数:{duration}, 预算:{budget}, 风格:{persona} ({completed}/4完成)"
def get_all_sessions_summary(self) -> Dict[str, Any]:
"""获取所有会话的摘要信息"""
summary = {
"total_sessions": len(self.sessions),
"active_sessions": 0,
"sessions": {}
}
current_time = time.time()
for session_id, session_data in self.sessions.items():
metadata = self.session_metadata.get(session_id, {})
last_accessed = metadata.get('last_accessed', 0)
# 判断是否为活跃会话(30分钟内有访问)
is_active = (current_time - last_accessed) < 1800
if is_active:
summary["active_sessions"] += 1
summary["sessions"][session_id] = {
"summary": self._get_session_summary(session_id),
"created_at": time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime(metadata.get('created_at', 0))),
"last_accessed": time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime(last_accessed)),
"message_count": metadata.get('message_count', 0),
"persona_type": metadata.get('persona_type'),
"completion_status": metadata.get('completion_status', {}),
"is_active": is_active
}
return summary
def cleanup_old_sessions(self, max_age_hours: int = 24):
"""清理旧会话"""
current_time = time.time()
max_age_seconds = max_age_hours * 3600
old_sessions = []
for session_id in list(self.sessions.keys()):
metadata = self.session_metadata.get(session_id, {})
last_accessed = metadata.get('last_accessed', 0)
if (current_time - last_accessed) > max_age_seconds:
old_sessions.append(session_id)
for session_id in old_sessions:
del self.sessions[session_id]
if session_id in self.session_metadata:
del self.session_metadata[session_id]
if old_sessions:
log.info(f"🧹 清理了 {len(old_sessions)} 个旧会话")
return len(old_sessions)
def format_session_info(self, session_state: dict) -> dict:
"""返回详细的会话状态信息"""
session_id = session_state.get('session_id', '')
# 基础信息
info = {
"session_id": session_id,
"created_at": session_state.get('created_at', ''),
"last_updated": session_state.get('last_updated', ''),
}
# 目的地信息
destination = session_state.get('destination')
if destination:
info['destination'] = {
'name': destination.get('name'),
'country': destination.get('country', ''),
'status': 'completed'
}
else:
info['destination'] = {'status': 'pending'}
# 天数信息
duration = session_state.get('duration')
if duration:
info['duration'] = {
'days': duration.get('days'),
'description': duration.get('description', ''),
'status': 'completed'
}
else:
info['duration'] = {'status': 'pending'}
# 预算信息
budget = session_state.get('budget')
if budget:
info['budget'] = {
'type': budget.get('type', ''),
'amount': budget.get('amount', ''),
'currency': budget.get('currency', ''),
'description': budget.get('description', ''),
'status': 'completed'
}
else:
info['budget'] = {'status': 'pending'}
# Persona信息
persona = session_state.get('persona')
if persona:
info['persona'] = {
'key': persona.get('key'),
'name': persona.get('name', ''),
'style': persona.get('style', ''),
'source': persona.get('source', ''),
'status': 'completed'
}
else:
info['persona'] = {'status': 'pending'}
# 完成度统计
completed_fields = sum(1 for field in ['destination', 'duration', 'budget', 'persona']
if info[field]['status'] == 'completed')
info['progress'] = {
'completed': completed_fields,
'total': 4,
'percentage': (completed_fields / 4) * 100
}
# 添加元数据信息
if session_id in self.session_metadata:
metadata = self.session_metadata[session_id]
info['metadata'] = {
'message_count': metadata.get('message_count', 0),
'persona_type': metadata.get('persona_type'),
'is_active': (time.time() - metadata.get('last_accessed', 0)) < 1800
}
return info
def reset(self, session_id: str):
"""删除指定会话"""
if session_id in self.sessions:
del self.sessions[session_id]
if session_id in self.session_metadata:
del self.session_metadata[session_id]
log.info(f"🗑️ 删除会话: {session_id}")
return True
else:
log.warning(f"⚠️ 尝试删除不存在的会话: {session_id}")