Spaces:
Sleeping
Sleeping
File size: 4,900 Bytes
90537f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import uuid
from typing import Callable
from ..database import db_manager
class SessionMiddleware(BaseHTTPMiddleware):
"""Middleware to handle session-based database management"""
def __init__(self, app, require_database: bool = True):
super().__init__(app)
self.require_database = require_database
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Skip validation for OPTIONS requests (CORS preflight)
if request.method == "OPTIONS":
response = await call_next(request)
return response
# Get or generate session ID
session_id = request.headers.get('x-session-id')
if not session_id:
session_id = str(uuid.uuid4())
# Add session ID to request state
request.state.session_id = session_id
# Check if this is a database-related endpoint
path = request.url.path
is_database_endpoint = (
path.startswith('/settings/') or
path.startswith('/customer/api/') or
path.startswith('/chef/') or
path.startswith('/admin/') or
path.startswith('/analytics/') or
path.startswith('/tables/') or
path.startswith('/feedback/') or
path.startswith('/loyalty/') or
path.startswith('/selection-offers/')
)
# Skip session validation for certain endpoints
skip_validation_endpoints = [
'/settings/databases',
'/settings/hotels',
'/settings/switch-database',
'/settings/switch-hotel',
'/settings/current-database',
'/settings/current-hotel'
]
# Skip validation for admin and chef routes - they handle their own database selection
skip_validation_paths = [
'/admin/',
'/chef/'
]
# Check if path should skip validation
should_skip_path = any(path.startswith(skip_path) for skip_path in skip_validation_paths)
should_validate = (
is_database_endpoint and
path not in skip_validation_endpoints and
not should_skip_path and
self.require_database
)
if should_validate:
# Check if session has a valid hotel context
current_hotel_id = db_manager.get_current_hotel_id(session_id)
if not current_hotel_id:
# Check if there's stored hotel credentials in headers
stored_hotel_name = request.headers.get('x-hotel-name')
stored_password = request.headers.get('x-hotel-password')
if stored_hotel_name and stored_password:
# Try to verify and set hotel context
try:
# Authenticate hotel using the database manager
hotel_id = db_manager.authenticate_hotel(stored_hotel_name, stored_password)
if hotel_id:
# Valid credentials, set hotel context
db_manager.set_hotel_context(session_id, hotel_id)
else:
# Invalid credentials
return JSONResponse(
status_code=401,
content={
"detail": "Invalid hotel credentials",
"error_code": "HOTEL_AUTH_FAILED"
}
)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"detail": f"Hotel authentication failed: {str(e)}",
"error_code": "HOTEL_VERIFICATION_ERROR"
}
)
else:
# No hotel selected
return JSONResponse(
status_code=400,
content={
"detail": "No hotel selected. Please select a hotel first.",
"error_code": "HOTEL_NOT_SELECTED"
}
)
# Process the request
response = await call_next(request)
# Add session ID to response headers
response.headers["x-session-id"] = session_id
return response
def get_session_id(request: Request) -> str:
"""Helper function to get session ID from request"""
return getattr(request.state, 'session_id', str(uuid.uuid4()))
|