Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Utility script to clear user sessions and verify GPT memory isolation. | |
""" | |
import asyncio | |
import logging | |
from typing import Optional | |
from sqlalchemy import select, delete, and_ | |
from sqlalchemy.ext.asyncio import AsyncSession | |
from database import get_async_db | |
from models import GPTModeSession, ConversationMemory, CrossModuleMemory, User | |
logger = logging.getLogger(__name__) | |
async def clear_user_sessions(user_id: str, db: AsyncSession) -> bool: | |
"""Clear all sessions and memory for a specific user.""" | |
try: | |
logger.info(f"Clearing all sessions for user {user_id}") | |
# Clear GPT mode sessions | |
await db.execute( | |
delete(GPTModeSession).where(GPTModeSession.user_id == user_id) | |
) | |
# Clear conversation memory | |
await db.execute( | |
delete(ConversationMemory).where(ConversationMemory.user_id == user_id) | |
) | |
# Clear cross-module memory | |
await db.execute( | |
delete(CrossModuleMemory).where(CrossModuleMemory.user_id == user_id) | |
) | |
await db.commit() | |
logger.info(f"β Successfully cleared all sessions for user {user_id}") | |
return True | |
except Exception as e: | |
logger.error(f"β Error clearing sessions for user {user_id}: {e}") | |
await db.rollback() | |
return False | |
async def clear_all_sessions(db: AsyncSession) -> bool: | |
"""Clear all sessions for all users (nuclear option).""" | |
try: | |
logger.info("Clearing all sessions for all users") | |
# Clear all GPT mode sessions | |
await db.execute(delete(GPTModeSession)) | |
# Clear all conversation memory | |
await db.execute(delete(ConversationMemory)) | |
# Clear all cross-module memory | |
await db.execute(delete(CrossModuleMemory)) | |
await db.commit() | |
logger.info("β Successfully cleared all sessions for all users") | |
return True | |
except Exception as e: | |
logger.error(f"β Error clearing all sessions: {e}") | |
await db.rollback() | |
return False | |
async def verify_user_isolation(db: AsyncSession) -> dict: | |
"""Verify that user isolation is working properly.""" | |
try: | |
logger.info("Verifying user isolation...") | |
# Get all users | |
users_result = await db.execute(select(User)) | |
users = users_result.scalars().all() | |
isolation_report = { | |
"total_users": len(users), | |
"user_sessions": {}, | |
"isolation_status": "β PASSED" | |
} | |
for user in users: | |
# Count sessions per user | |
sessions_result = await db.execute( | |
select(GPTModeSession).where(GPTModeSession.user_id == user.id) | |
) | |
sessions = sessions_result.scalars().all() | |
# Count memory per user | |
memory_result = await db.execute( | |
select(ConversationMemory).where(ConversationMemory.user_id == user.id) | |
) | |
memory = memory_result.scalars().all() | |
# Count cross-module memory per user | |
cross_memory_result = await db.execute( | |
select(CrossModuleMemory).where(CrossModuleMemory.user_id == user.id) | |
) | |
cross_memory = cross_memory_result.scalars().all() | |
isolation_report["user_sessions"][user.email] = { | |
"user_id": user.id, | |
"sessions_count": len(sessions), | |
"memory_count": len(memory), | |
"cross_memory_count": len(cross_memory), | |
"sessions": [ | |
{ | |
"session_id": session.id, | |
"mode_name": session.mode_name, | |
"project_id": session.project_id, | |
"current_question": session.current_question | |
} | |
for session in sessions | |
] | |
} | |
# Check for any orphaned records (records without user_id) | |
orphaned_sessions = await db.execute( | |
select(GPTModeSession).where(GPTModeSession.user_id.is_(None)) | |
) | |
orphaned_sessions = orphaned_sessions.scalars().all() | |
orphaned_memory = await db.execute( | |
select(ConversationMemory).where(ConversationMemory.user_id.is_(None)) | |
) | |
orphaned_memory = orphaned_memory.scalars().all() | |
if orphaned_sessions or orphaned_memory: | |
isolation_report["isolation_status"] = "β FAILED - Orphaned records found" | |
isolation_report["orphaned_records"] = { | |
"sessions": len(orphaned_sessions), | |
"memory": len(orphaned_memory) | |
} | |
logger.info(f"Isolation verification completed: {isolation_report['isolation_status']}") | |
return isolation_report | |
except Exception as e: | |
logger.error(f"β Error verifying user isolation: {e}") | |
return {"error": str(e), "isolation_status": "β FAILED"} | |
async def verify_module_isolation(user_id: str, project_id: str, db: AsyncSession) -> dict: | |
"""Verify that modules are properly isolated for a specific user and project.""" | |
try: | |
logger.info(f"Verifying module isolation for user {user_id}, project {project_id}") | |
# Get all sessions for this user and project | |
sessions_result = await db.execute( | |
select(GPTModeSession).where( | |
and_( | |
GPTModeSession.user_id == user_id, | |
GPTModeSession.project_id == project_id | |
) | |
) | |
) | |
sessions = sessions_result.scalars().all() | |
# Get all memory for this user and project | |
memory_result = await db.execute( | |
select(ConversationMemory).where( | |
and_( | |
ConversationMemory.user_id == user_id, | |
ConversationMemory.project_id == project_id | |
) | |
) | |
) | |
memory = memory_result.scalars().all() | |
module_report = { | |
"user_id": user_id, | |
"project_id": project_id, | |
"sessions_by_module": {}, | |
"memory_by_module": {}, | |
"isolation_status": "β PASSED" | |
} | |
# Group sessions by module | |
for session in sessions: | |
if session.mode_name not in module_report["sessions_by_module"]: | |
module_report["sessions_by_module"][session.mode_name] = [] | |
module_report["sessions_by_module"][session.mode_name].append({ | |
"session_id": session.id, | |
"current_question": session.current_question, | |
"answers_count": len(session.answers) | |
}) | |
# Group memory by module | |
for mem in memory: | |
if mem.module_id not in module_report["memory_by_module"]: | |
module_report["memory_by_module"][mem.module_id] = [] | |
module_report["memory_by_module"][mem.module_id].append({ | |
"memory_id": mem.id, | |
"session_id": mem.session_id, | |
"conversation_history_count": len(mem.conversation_history) | |
}) | |
logger.info(f"Module isolation verification completed: {module_report['isolation_status']}") | |
return module_report | |
except Exception as e: | |
logger.error(f"β Error verifying module isolation: {e}") | |
return {"error": str(e), "isolation_status": "β FAILED"} | |
async def list_users_with_sessions(db: AsyncSession) -> list: | |
"""List all users and their session counts.""" | |
try: | |
users_result = await db.execute(select(User)) | |
users = users_result.scalars().all() | |
user_list = [] | |
for user in users: | |
# Count sessions | |
sessions_result = await db.execute( | |
select(GPTModeSession).where(GPTModeSession.user_id == user.id) | |
) | |
sessions = sessions_result.scalars().all() | |
# Count memory | |
memory_result = await db.execute( | |
select(ConversationMemory).where(ConversationMemory.user_id == user.id) | |
) | |
memory = memory_result.scalars().all() | |
user_list.append({ | |
"user_id": user.id, | |
"email": user.email, | |
"full_name": user.full_name, | |
"sessions_count": len(sessions), | |
"memory_count": len(memory), | |
"created_at": user.created_at | |
}) | |
return user_list | |
except Exception as e: | |
logger.error(f"β Error listing users: {e}") | |
return [] | |
async def main(): | |
"""Main function for session management.""" | |
import argparse | |
parser = argparse.ArgumentParser(description="GPT Session Management Utility") | |
parser.add_argument("--action", choices=["clear-user", "clear-all", "verify", "list"], required=True) | |
parser.add_argument("--user-id", help="User ID for specific actions") | |
parser.add_argument("--project-id", help="Project ID for module verification") | |
parser.add_argument("--email", help="User email for finding user ID") | |
args = parser.parse_args() | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
async for db in get_async_db(): | |
try: | |
if args.action == "clear-user": | |
if not args.user_id and not args.email: | |
logger.error("β Please provide either --user-id or --email") | |
return | |
user_id = args.user_id | |
if args.email: | |
# Find user by email | |
user_result = await db.execute(select(User).where(User.email == args.email)) | |
user = user_result.scalar_one_or_none() | |
if not user: | |
logger.error(f"β User with email {args.email} not found") | |
return | |
user_id = user.id | |
success = await clear_user_sessions(user_id, db) | |
if success: | |
logger.info(f"β Successfully cleared sessions for user {user_id}") | |
else: | |
logger.error(f"β Failed to clear sessions for user {user_id}") | |
elif args.action == "clear-all": | |
confirm = input("β οΈ This will clear ALL sessions for ALL users. Are you sure? (yes/no): ") | |
if confirm.lower() == "yes": | |
success = await clear_all_sessions(db) | |
if success: | |
logger.info("β Successfully cleared all sessions") | |
else: | |
logger.error("β Failed to clear all sessions") | |
else: | |
logger.info("Operation cancelled") | |
elif args.action == "verify": | |
if args.user_id and args.project_id: | |
report = await verify_module_isolation(args.user_id, args.project_id, db) | |
else: | |
report = await verify_user_isolation(db) | |
print("\n" + "="*50) | |
print("ISOLATION VERIFICATION REPORT") | |
print("="*50) | |
print(f"Status: {report.get('isolation_status', 'UNKNOWN')}") | |
if "user_sessions" in report: | |
print(f"\nTotal Users: {report['total_users']}") | |
for email, data in report["user_sessions"].items(): | |
print(f"\nUser: {email}") | |
print(f" Sessions: {data['sessions_count']}") | |
print(f" Memory: {data['memory_count']}") | |
print(f" Cross-Memory: {data['cross_memory_count']}") | |
if "orphaned_records" in report: | |
print(f"\nβ οΈ Orphaned Records Found:") | |
print(f" Sessions: {report['orphaned_records']['sessions']}") | |
print(f" Memory: {report['orphaned_records']['memory']}") | |
elif args.action == "list": | |
users = await list_users_with_sessions(db) | |
print("\n" + "="*80) | |
print("USERS AND SESSION COUNTS") | |
print("="*80) | |
for user in users: | |
print(f"\nUser: {user['full_name']} ({user['email']})") | |
print(f" User ID: {user['user_id']}") | |
print(f" Sessions: {user['sessions_count']}") | |
print(f" Memory: {user['memory_count']}") | |
print(f" Created: {user['created_at']}") | |
except Exception as e: | |
logger.error(f"β Error in main: {e}") | |
finally: | |
break | |
if __name__ == "__main__": | |
asyncio.run(main()) |