import asyncio import logging from typing import Dict, Any from fastapi import HTTPException, UploadFile, status, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from config import Config from .rag_pipeline import route_and_process_query, add_document_to_rag, check_system_health from .document_handler import extract_text_from_file # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) security = HTTPBearer() # Supported file types SUPPORTED_CONTENT_TYPES = Config.RAG_SUPPORTED_CONTENT_TYPES MAX_FILE_SIZE = Config.RAG_MAX_FILE_SIZE MAX_QUERY_LENGTH = Config.RAG_MAX_QUERY_LENGTH async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): """Verify Bearer token from Authorization header.""" token = credentials.credentials expected_token = Config.SECRET_TOKEN if not expected_token: logger.error("MY_SECRET_TOKEN not configured") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Server configuration error" ) if token != expected_token: logger.warning(f"Invalid token attempt: {token[:10]}...") raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid or expired token" ) return token async def handle_rag_query(query: str) -> Dict[str, Any]: """Handle an incoming query by routing it and getting the appropriate answer.""" # Input validation if not query or not query.strip(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Query cannot be empty" ) if len(query) > MAX_QUERY_LENGTH: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Query too long. Please limit to {MAX_QUERY_LENGTH} characters." ) try: logger.info(f"Processing query: {query[:50]}...") # Process query in thread pool response = await asyncio.to_thread(route_and_process_query, query) logger.info(f"Query processed successfully. Route: {response.get('route', 'Unknown')}") return response except Exception as e: logger.error(f"Error processing query: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error processing your query. Please try again." ) async def handle_document_upload(file: UploadFile) -> Dict[str, str]: """Handle uploading a document to the RAG's vector store.""" # File validation if not file.filename: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="No file provided" ) if file.content_type not in SUPPORTED_CONTENT_TYPES: raise HTTPException( status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, detail=f"Unsupported file type: {file.content_type}. " f"Supported types: {', '.join(SUPPORTED_CONTENT_TYPES)}" ) # Check file size contents = await file.read() if len(contents) > MAX_FILE_SIZE: raise HTTPException( status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail=f"File too large. Maximum size: {MAX_FILE_SIZE / (1024*1024):.1f}MB" ) # Reset file pointer await file.seek(0) try: logger.info(f"Processing file upload: {file.filename}") # Extract text from file text = await extract_text_from_file(file) if not text or not text.strip(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="The file appears to be empty or could not be read." ) if len(text) < 50: # Too short to be meaningful raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="The extracted text is too short to be meaningful." ) # Add to RAG system success = await asyncio.to_thread( add_document_to_rag, text, { "source": file.filename, "content_type": file.content_type, "size": len(contents) } ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to add document to the knowledge base" ) logger.info(f"Successfully processed file: {file.filename}") return { "message": f"Successfully uploaded and processed '{file.filename}'. " f"It is now available for querying.", "filename": file.filename, "text_length": len(text), "content_type": file.content_type } except HTTPException: raise except Exception as e: logger.error(f"Error processing file {file.filename}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error processing the file. Please try again." ) async def handle_health_check() -> Dict[str, Any]: """Handle health check requests.""" try: health_status = await asyncio.to_thread(check_system_health) if health_status["status"] == "unhealthy": raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Service is currently unhealthy" ) return health_status except HTTPException: raise except Exception as e: logger.error(f"Health check failed: {e}") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Health check failed" )