File size: 5,579 Bytes
eacbbc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import uuid
import logging
import time
from datetime import datetime, timedelta
from typing import Dict, Optional, Tuple, List
from fastapi import UploadFile
from pathlib import Path

from app.config import settings

logger = logging.getLogger(__name__)

class Session:
    """Object representing a user session"""
    def __init__(self, session_id: str, image_path: str):
        self.session_id = session_id
        self.image_path = image_path
        self.created_at = datetime.now()
        self.last_accessed = datetime.now()
        self.questions = []  # History of questions for this session
    
    def is_expired(self) -> bool:
        """Check if the session has expired"""
        expiry_time = self.last_accessed + timedelta(seconds=settings.MAX_SESSION_AGE)
        return datetime.now() > expiry_time
    
    def update_access_time(self):
        """Update the last accessed time"""
        self.last_accessed = datetime.now()
    
    def add_question(self, question: str, answer: Dict):
        """Add a question and its answer to the session history"""
        self.questions.append({
            "question": question,
            "answer": answer,
            "timestamp": datetime.now().isoformat()
        })
        self.update_access_time()

class SessionService:
    """Service for managing user sessions"""
    
    def __init__(self):
        """Initialize the session service"""
        self.sessions: Dict[str, Session] = {}
        self.ensure_upload_dir()
        
        # Start a background cleanup task
        self._cleanup_sessions()
    
    def ensure_upload_dir(self):
        """Ensure the upload directory exists"""
        os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
    
    def create_session(self, file: UploadFile) -> str:
        """
        Create a new session for the user
        
        Args:
            file (UploadFile): The uploaded image file
            
        Returns:
            str: The session ID
        """
        # Generate a unique session ID
        session_id = str(uuid.uuid4())
        
        # Create a unique filename
        timestamp = int(time.time())
        file_extension = Path(file.filename).suffix
        filename = f"{timestamp}_{session_id}{file_extension}"
        
        # Save the uploaded file
        file_path = os.path.join(settings.UPLOAD_DIR, filename)
        with open(file_path, "wb") as f:
            f.write(file.file.read())
        
        # Create and store the session
        self.sessions[session_id] = Session(session_id, file_path)
        
        logger.info(f"Created new session {session_id} with image {file_path}")
        return session_id
    
    def get_session(self, session_id: str) -> Optional[Session]:
        """
        Get a session by ID
        
        Args:
            session_id (str): The session ID
            
        Returns:
            Optional[Session]: The session, or None if not found or expired
        """
        session = self.sessions.get(session_id)
        
        if session is None:
            return None
        
        if session.is_expired():
            self._remove_session(session_id)
            return None
        
        session.update_access_time()
        return session
    
    def complete_session(self, session_id: str) -> bool:
        """
        Mark a session as complete and remove its resources
        
        Args:
            session_id (str): The session ID
            
        Returns:
            bool: True if successful, False otherwise
        """
        session = self.sessions.get(session_id)
        if not session:
            logger.warning(f"Cannot complete nonexistent session: {session_id}")
            return False
        
        logger.info(f"Completing session {session_id}")
        
        try:
            # Remove the image file but keep session data temporarily for any final operations
            if session.image_path and os.path.exists(session.image_path):
                os.remove(session.image_path)
                logger.info(f"Removed image file for completed session {session.image_path}")
                
                # Set the image path to None to indicate it's been removed
                session.image_path = None
                return True
            return True  # No image to remove or already removed
        except Exception as e:
            logger.error(f"Error removing image file during session completion: {e}")
            return False
    
    def _remove_session(self, session_id: str):
        """
        Remove a session and its associated file
        
        Args:
            session_id (str): The session ID
        """
        session = self.sessions.pop(session_id, None)
        if session:
            try:
                # Remove the image file
                if session.image_path and os.path.exists(session.image_path):
                    os.remove(session.image_path)
                    logger.info(f"Removed session file {session.image_path}")
            except Exception as e:
                logger.error(f"Error removing session file: {e}")
    
    def _cleanup_sessions(self):
        """Clean up expired sessions"""
        expired_sessions = [
            session_id for session_id, session in self.sessions.items()
            if session.is_expired()
        ]
        
        for session_id in expired_sessions:
            self._remove_session(session_id)
        
        if expired_sessions:
            logger.info(f"Cleaned up {len(expired_sessions)} expired sessions")