assitantchatbot / clear_user_sessions.py
aghaai's picture
πŸ”’ Add GPT memory isolation with user-specific sessions and module separation
896d096
#!/usr/bin/env python3
"""
Utility script to clear user sessions and verify GPT memory isolation.
"""
import asyncio
import logging
from typing import Optional
from sqlalchemy import select, delete, and_
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_async_db
from models import GPTModeSession, ConversationMemory, CrossModuleMemory, User
logger = logging.getLogger(__name__)
async def clear_user_sessions(user_id: str, db: AsyncSession) -> bool:
"""Clear all sessions and memory for a specific user."""
try:
logger.info(f"Clearing all sessions for user {user_id}")
# Clear GPT mode sessions
await db.execute(
delete(GPTModeSession).where(GPTModeSession.user_id == user_id)
)
# Clear conversation memory
await db.execute(
delete(ConversationMemory).where(ConversationMemory.user_id == user_id)
)
# Clear cross-module memory
await db.execute(
delete(CrossModuleMemory).where(CrossModuleMemory.user_id == user_id)
)
await db.commit()
logger.info(f"βœ… Successfully cleared all sessions for user {user_id}")
return True
except Exception as e:
logger.error(f"❌ Error clearing sessions for user {user_id}: {e}")
await db.rollback()
return False
async def clear_all_sessions(db: AsyncSession) -> bool:
"""Clear all sessions for all users (nuclear option)."""
try:
logger.info("Clearing all sessions for all users")
# Clear all GPT mode sessions
await db.execute(delete(GPTModeSession))
# Clear all conversation memory
await db.execute(delete(ConversationMemory))
# Clear all cross-module memory
await db.execute(delete(CrossModuleMemory))
await db.commit()
logger.info("βœ… Successfully cleared all sessions for all users")
return True
except Exception as e:
logger.error(f"❌ Error clearing all sessions: {e}")
await db.rollback()
return False
async def verify_user_isolation(db: AsyncSession) -> dict:
"""Verify that user isolation is working properly."""
try:
logger.info("Verifying user isolation...")
# Get all users
users_result = await db.execute(select(User))
users = users_result.scalars().all()
isolation_report = {
"total_users": len(users),
"user_sessions": {},
"isolation_status": "βœ… PASSED"
}
for user in users:
# Count sessions per user
sessions_result = await db.execute(
select(GPTModeSession).where(GPTModeSession.user_id == user.id)
)
sessions = sessions_result.scalars().all()
# Count memory per user
memory_result = await db.execute(
select(ConversationMemory).where(ConversationMemory.user_id == user.id)
)
memory = memory_result.scalars().all()
# Count cross-module memory per user
cross_memory_result = await db.execute(
select(CrossModuleMemory).where(CrossModuleMemory.user_id == user.id)
)
cross_memory = cross_memory_result.scalars().all()
isolation_report["user_sessions"][user.email] = {
"user_id": user.id,
"sessions_count": len(sessions),
"memory_count": len(memory),
"cross_memory_count": len(cross_memory),
"sessions": [
{
"session_id": session.id,
"mode_name": session.mode_name,
"project_id": session.project_id,
"current_question": session.current_question
}
for session in sessions
]
}
# Check for any orphaned records (records without user_id)
orphaned_sessions = await db.execute(
select(GPTModeSession).where(GPTModeSession.user_id.is_(None))
)
orphaned_sessions = orphaned_sessions.scalars().all()
orphaned_memory = await db.execute(
select(ConversationMemory).where(ConversationMemory.user_id.is_(None))
)
orphaned_memory = orphaned_memory.scalars().all()
if orphaned_sessions or orphaned_memory:
isolation_report["isolation_status"] = "❌ FAILED - Orphaned records found"
isolation_report["orphaned_records"] = {
"sessions": len(orphaned_sessions),
"memory": len(orphaned_memory)
}
logger.info(f"Isolation verification completed: {isolation_report['isolation_status']}")
return isolation_report
except Exception as e:
logger.error(f"❌ Error verifying user isolation: {e}")
return {"error": str(e), "isolation_status": "❌ FAILED"}
async def verify_module_isolation(user_id: str, project_id: str, db: AsyncSession) -> dict:
"""Verify that modules are properly isolated for a specific user and project."""
try:
logger.info(f"Verifying module isolation for user {user_id}, project {project_id}")
# Get all sessions for this user and project
sessions_result = await db.execute(
select(GPTModeSession).where(
and_(
GPTModeSession.user_id == user_id,
GPTModeSession.project_id == project_id
)
)
)
sessions = sessions_result.scalars().all()
# Get all memory for this user and project
memory_result = await db.execute(
select(ConversationMemory).where(
and_(
ConversationMemory.user_id == user_id,
ConversationMemory.project_id == project_id
)
)
)
memory = memory_result.scalars().all()
module_report = {
"user_id": user_id,
"project_id": project_id,
"sessions_by_module": {},
"memory_by_module": {},
"isolation_status": "βœ… PASSED"
}
# Group sessions by module
for session in sessions:
if session.mode_name not in module_report["sessions_by_module"]:
module_report["sessions_by_module"][session.mode_name] = []
module_report["sessions_by_module"][session.mode_name].append({
"session_id": session.id,
"current_question": session.current_question,
"answers_count": len(session.answers)
})
# Group memory by module
for mem in memory:
if mem.module_id not in module_report["memory_by_module"]:
module_report["memory_by_module"][mem.module_id] = []
module_report["memory_by_module"][mem.module_id].append({
"memory_id": mem.id,
"session_id": mem.session_id,
"conversation_history_count": len(mem.conversation_history)
})
logger.info(f"Module isolation verification completed: {module_report['isolation_status']}")
return module_report
except Exception as e:
logger.error(f"❌ Error verifying module isolation: {e}")
return {"error": str(e), "isolation_status": "❌ FAILED"}
async def list_users_with_sessions(db: AsyncSession) -> list:
"""List all users and their session counts."""
try:
users_result = await db.execute(select(User))
users = users_result.scalars().all()
user_list = []
for user in users:
# Count sessions
sessions_result = await db.execute(
select(GPTModeSession).where(GPTModeSession.user_id == user.id)
)
sessions = sessions_result.scalars().all()
# Count memory
memory_result = await db.execute(
select(ConversationMemory).where(ConversationMemory.user_id == user.id)
)
memory = memory_result.scalars().all()
user_list.append({
"user_id": user.id,
"email": user.email,
"full_name": user.full_name,
"sessions_count": len(sessions),
"memory_count": len(memory),
"created_at": user.created_at
})
return user_list
except Exception as e:
logger.error(f"❌ Error listing users: {e}")
return []
async def main():
"""Main function for session management."""
import argparse
parser = argparse.ArgumentParser(description="GPT Session Management Utility")
parser.add_argument("--action", choices=["clear-user", "clear-all", "verify", "list"], required=True)
parser.add_argument("--user-id", help="User ID for specific actions")
parser.add_argument("--project-id", help="Project ID for module verification")
parser.add_argument("--email", help="User email for finding user ID")
args = parser.parse_args()
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
async for db in get_async_db():
try:
if args.action == "clear-user":
if not args.user_id and not args.email:
logger.error("❌ Please provide either --user-id or --email")
return
user_id = args.user_id
if args.email:
# Find user by email
user_result = await db.execute(select(User).where(User.email == args.email))
user = user_result.scalar_one_or_none()
if not user:
logger.error(f"❌ User with email {args.email} not found")
return
user_id = user.id
success = await clear_user_sessions(user_id, db)
if success:
logger.info(f"βœ… Successfully cleared sessions for user {user_id}")
else:
logger.error(f"❌ Failed to clear sessions for user {user_id}")
elif args.action == "clear-all":
confirm = input("⚠️ This will clear ALL sessions for ALL users. Are you sure? (yes/no): ")
if confirm.lower() == "yes":
success = await clear_all_sessions(db)
if success:
logger.info("βœ… Successfully cleared all sessions")
else:
logger.error("❌ Failed to clear all sessions")
else:
logger.info("Operation cancelled")
elif args.action == "verify":
if args.user_id and args.project_id:
report = await verify_module_isolation(args.user_id, args.project_id, db)
else:
report = await verify_user_isolation(db)
print("\n" + "="*50)
print("ISOLATION VERIFICATION REPORT")
print("="*50)
print(f"Status: {report.get('isolation_status', 'UNKNOWN')}")
if "user_sessions" in report:
print(f"\nTotal Users: {report['total_users']}")
for email, data in report["user_sessions"].items():
print(f"\nUser: {email}")
print(f" Sessions: {data['sessions_count']}")
print(f" Memory: {data['memory_count']}")
print(f" Cross-Memory: {data['cross_memory_count']}")
if "orphaned_records" in report:
print(f"\n⚠️ Orphaned Records Found:")
print(f" Sessions: {report['orphaned_records']['sessions']}")
print(f" Memory: {report['orphaned_records']['memory']}")
elif args.action == "list":
users = await list_users_with_sessions(db)
print("\n" + "="*80)
print("USERS AND SESSION COUNTS")
print("="*80)
for user in users:
print(f"\nUser: {user['full_name']} ({user['email']})")
print(f" User ID: {user['user_id']}")
print(f" Sessions: {user['sessions_count']}")
print(f" Memory: {user['memory_count']}")
print(f" Created: {user['created_at']}")
except Exception as e:
logger.error(f"❌ Error in main: {e}")
finally:
break
if __name__ == "__main__":
asyncio.run(main())