File size: 4,900 Bytes
90537f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import uuid
from typing import Callable
from ..database import db_manager


class SessionMiddleware(BaseHTTPMiddleware):
    """Middleware to handle session-based database management"""
    
    def __init__(self, app, require_database: bool = True):
        super().__init__(app)
        self.require_database = require_database
        
    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        # Skip validation for OPTIONS requests (CORS preflight)
        if request.method == "OPTIONS":
            response = await call_next(request)
            return response

        # Get or generate session ID
        session_id = request.headers.get('x-session-id')
        if not session_id:
            session_id = str(uuid.uuid4())

        # Add session ID to request state
        request.state.session_id = session_id

        # Check if this is a database-related endpoint
        path = request.url.path
        is_database_endpoint = (
            path.startswith('/settings/') or
            path.startswith('/customer/api/') or
            path.startswith('/chef/') or
            path.startswith('/admin/') or
            path.startswith('/analytics/') or
            path.startswith('/tables/') or
            path.startswith('/feedback/') or
            path.startswith('/loyalty/') or
            path.startswith('/selection-offers/')
        )

        # Skip session validation for certain endpoints
        skip_validation_endpoints = [
            '/settings/databases',
            '/settings/hotels',
            '/settings/switch-database',
            '/settings/switch-hotel',
            '/settings/current-database',
            '/settings/current-hotel'
        ]

        # Skip validation for admin and chef routes - they handle their own database selection
        skip_validation_paths = [
            '/admin/',
            '/chef/'
        ]

        # Check if path should skip validation
        should_skip_path = any(path.startswith(skip_path) for skip_path in skip_validation_paths)

        should_validate = (
            is_database_endpoint and
            path not in skip_validation_endpoints and
            not should_skip_path and
            self.require_database
        )
        
        if should_validate:
            # Check if session has a valid hotel context
            current_hotel_id = db_manager.get_current_hotel_id(session_id)
            if not current_hotel_id:
                # Check if there's stored hotel credentials in headers
                stored_hotel_name = request.headers.get('x-hotel-name')
                stored_password = request.headers.get('x-hotel-password')

                if stored_hotel_name and stored_password:
                    # Try to verify and set hotel context
                    try:
                        # Authenticate hotel using the database manager
                        hotel_id = db_manager.authenticate_hotel(stored_hotel_name, stored_password)

                        if hotel_id:
                            # Valid credentials, set hotel context
                            db_manager.set_hotel_context(session_id, hotel_id)
                        else:
                            # Invalid credentials
                            return JSONResponse(
                                status_code=401,
                                content={
                                    "detail": "Invalid hotel credentials",
                                    "error_code": "HOTEL_AUTH_FAILED"
                                }
                            )
                    except Exception as e:
                        return JSONResponse(
                            status_code=500,
                            content={
                                "detail": f"Hotel authentication failed: {str(e)}",
                                "error_code": "HOTEL_VERIFICATION_ERROR"
                            }
                        )
                else:
                    # No hotel selected
                    return JSONResponse(
                        status_code=400,
                        content={
                            "detail": "No hotel selected. Please select a hotel first.",
                            "error_code": "HOTEL_NOT_SELECTED"
                        }
                    )
        
        # Process the request
        response = await call_next(request)
        
        # Add session ID to response headers
        response.headers["x-session-id"] = session_id
        
        return response


def get_session_id(request: Request) -> str:
    """Helper function to get session ID from request"""
    return getattr(request.state, 'session_id', str(uuid.uuid4()))