from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from typing import List, Optional, Dict, Any import uvicorn import logging import time import os import asyncio from contextlib import asynccontextmanager from pathlib import Path # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global RAG system instance rag_system = None system_loading = False system_load_error = None @asynccontextmanager async def lifespan(app: FastAPI): # Startup global rag_system, system_loading, system_load_error logger.info("Starting Text-to-SQL RAG API with CodeLlama for HF Spaces...") # Start system loading in background system_loading = True system_load_error = None try: # Import here to avoid startup delays from rag_system import VectorStore, SQLRetriever, PromptEngine, SQLGenerator, DataProcessor # Initialize RAG system components logger.info("Initializing RAG system components...") # Initialize vector store logger.info("Initializing vector store...") vector_store = VectorStore() # Initialize SQL retriever logger.info("Initializing SQL retriever...") sql_retriever = SQLRetriever(vector_store) # Initialize prompt engine logger.info("Initializing prompt engine...") prompt_engine = PromptEngine() # Initialize SQL generator (with CodeLlama as primary) logger.info("Initializing SQL generator with CodeLlama...") sql_generator = SQLGenerator(sql_retriever, prompt_engine) # Initialize data processor logger.info("Initializing data processor...") data_processor = DataProcessor() # Create RAG system object rag_system = { "vector_store": vector_store, "sql_retriever": sql_retriever, "prompt_engine": prompt_engine, "sql_generator": sql_generator, "data_processor": data_processor } # Load or create sample data logger.info("Loading sample data...") await load_or_create_sample_data(data_processor, vector_store) logger.info("All RAG system components initialized successfully!") except Exception as e: logger.error(f"Failed to initialize RAG system: {str(e)}") system_load_error = str(e) finally: system_loading = False yield # Shutdown logger.info("Shutting down Text-to-SQL RAG API...") async def load_or_create_sample_data(data_processor, vector_store): """Load existing data or create sample dataset.""" try: # Try to load existing processed data examples = data_processor.load_processed_data() if examples: logger.info(f"Loaded {len(examples)} existing examples") # Add to vector store vector_store.add_examples(examples) else: # Create sample dataset logger.info("Creating sample dataset...") sample_data = data_processor.create_sample_dataset() vector_store.add_examples(sample_data) logger.info(f"Added {len(sample_data)} sample examples to vector store") except Exception as e: logger.warning(f"Could not load sample data: {e}") # Create minimal sample data try: sample_data = data_processor.create_sample_dataset() vector_store.add_examples(sample_data) logger.info(f"Added {len(sample_data)} sample examples to vector store") except Exception as e2: logger.error(f"Failed to create sample data: {e2}") # Create FastAPI app app = FastAPI( title="Text-to-SQL RAG API with CodeLlama", description="Advanced API for converting natural language questions to SQL queries using RAG and CodeLlama", version="2.0.0", lifespan=lifespan ) # Pydantic models for request/response class SQLRequest(BaseModel): question: str table_headers: List[str] class SQLResponse(BaseModel): question: str table_headers: List[str] sql_query: str model_used: str processing_time: float retrieved_examples: List[Dict[str, Any]] status: str class BatchRequest(BaseModel): queries: List[SQLRequest] class BatchResponse(BaseModel): results: List[SQLResponse] total_queries: int successful_queries: int class HealthResponse(BaseModel): status: str system_loaded: bool system_loading: bool system_error: Optional[str] = None model_info: Optional[Dict[str, Any]] = None timestamp: float @app.get("/", response_class=HTMLResponse) async def root(): """Serve the main HTML interface""" try: with open("index.html", "r", encoding="utf-8") as f: return HTMLResponse(content=f.read()) except FileNotFoundError: return HTMLResponse(content="""

Text-to-SQL RAG API with CodeLlama

Advanced SQL generation using RAG and CodeLlama models

index.html not found. Please ensure the file exists in the same directory.

""") @app.get("/api", response_model=dict) async def api_info(): """API information endpoint""" return { "message": "Text-to-SQL RAG API with CodeLlama", "version": "2.0.0", "features": [ "RAG-enhanced SQL generation", "CodeLlama as primary model", "Vector-based example retrieval", "Advanced prompt engineering" ], "endpoints": { "/": "GET - Web interface", "/api": "GET - API information", "/predict": "POST - Generate SQL from single question", "/batch": "POST - Generate SQL from multiple questions", "/health": "GET - Health check", "/docs": "GET - API documentation" } } @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" global rag_system, system_loading, system_load_error model_info = None if rag_system and "sql_generator" in rag_system: try: model_info = rag_system["sql_generator"].get_model_info() except Exception as e: logger.warning(f"Could not get model info: {e}") return HealthResponse( status="healthy" if rag_system and not system_loading else "unhealthy", system_loaded=rag_system is not None, system_loading=system_loading, system_error=system_load_error, model_info=model_info, timestamp=time.time() ) @app.post("/predict", response_model=SQLResponse) async def predict_sql(request: SQLRequest): """ Generate SQL query from a natural language question using RAG and CodeLlama Args: request: SQLRequest containing question and table headers Returns: SQLResponse with generated SQL query and metadata """ global rag_system, system_loading, system_load_error if system_loading: raise HTTPException(status_code=503, detail="System is still loading, please try again in a few minutes") if rag_system is None: error_msg = system_load_error or "RAG system not loaded" raise HTTPException(status_code=503, detail=f"System not available: {error_msg}") start_time = time.time() try: # Generate SQL using RAG system result = rag_system["sql_generator"].generate_sql( question=request.question, table_headers=request.table_headers ) processing_time = time.time() - start_time return SQLResponse( question=request.question, table_headers=request.table_headers, sql_query=result["sql_query"], model_used=result["model_used"], processing_time=processing_time, retrieved_examples=result["retrieved_examples"], status=result["status"] ) except Exception as e: logger.error(f"Error generating SQL: {str(e)}") raise HTTPException(status_code=500, detail=f"Error generating SQL: {str(e)}") @app.post("/batch", response_model=BatchResponse) async def batch_predict(request: BatchRequest): """ Generate SQL queries from multiple questions using RAG and CodeLlama Args: request: BatchRequest containing list of questions and table headers Returns: BatchResponse with generated SQL queries """ global rag_system, system_loading, system_load_error if system_loading: raise HTTPException(status_code=503, detail="System is still loading, please try again in a few minutes") if rag_system is None: error_msg = system_load_error or "RAG system not loaded" raise HTTPException(status_code=503, detail=f"System not available: {error_msg}") start_time = time.time() try: results = [] successful_count = 0 for query in request.queries: try: result = rag_system["sql_generator"].generate_sql( question=query.question, table_headers=query.table_headers ) sql_response = SQLResponse( question=query.question, table_headers=query.table_headers, sql_query=result["sql_query"], model_used=result["model_used"], processing_time=result["processing_time"], retrieved_examples=result["retrieved_examples"], status=result["status"] ) results.append(sql_response) if result["status"] == "success": successful_count += 1 except Exception as e: logger.error(f"Error processing query '{query.question}': {str(e)}") # Add error response error_response = SQLResponse( question=query.question, table_headers=query.table_headers, sql_query="", model_used="none", processing_time=0.0, retrieved_examples=[], status="error" ) results.append(error_response) total_time = time.time() - start_time return BatchResponse( results=results, total_queries=len(request.queries), successful_queries=successful_count ) except Exception as e: logger.error(f"Error in batch processing: {str(e)}") raise HTTPException(status_code=500, detail=f"Error in batch processing: {str(e)}") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)