Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
from typing import List, Optional | |
import uvicorn | |
import logging | |
import time | |
import os | |
import asyncio | |
from contextlib import asynccontextmanager | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Global model instance | |
model = None | |
model_loading = False | |
model_load_error = None | |
async def lifespan(app: FastAPI): | |
# Startup | |
global model, model_loading, model_load_error | |
logger.info("Starting Text-to-SQL API...") | |
# Start model loading in background | |
model_loading = True | |
model_load_error = None | |
try: | |
# Import here to avoid startup delays | |
from model_utils import get_model | |
# Set a timeout for model loading (5 minutes) | |
try: | |
# Run model loading in a thread to avoid blocking | |
import concurrent.futures | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future = executor.submit(get_model) | |
model = future.result(timeout=300) # 5 minute timeout | |
logger.info("Model loaded successfully!") | |
except concurrent.futures.TimeoutError: | |
logger.error("Model loading timed out after 5 minutes") | |
model_load_error = "Model loading timed out" | |
except Exception as e: | |
logger.error(f"Failed to load model: {str(e)}") | |
model_load_error = str(e) | |
except Exception as e: | |
logger.error(f"Failed to import model_utils: {str(e)}") | |
model_load_error = f"Import error: {str(e)}" | |
finally: | |
model_loading = False | |
yield | |
# Shutdown | |
logger.info("Shutting down Text-to-SQL API...") | |
# Create FastAPI app | |
app = FastAPI( | |
title="Text-to-SQL API", | |
description="API for converting natural language questions to SQL queries", | |
version="1.0.0", | |
lifespan=lifespan | |
) | |
# Pydantic models for request/response | |
class SQLRequest(BaseModel): | |
question: str | |
table_headers: List[str] | |
class SQLResponse(BaseModel): | |
question: str | |
table_headers: List[str] | |
sql_query: str | |
processing_time: float | |
class BatchRequest(BaseModel): | |
queries: List[SQLRequest] | |
class BatchResponse(BaseModel): | |
results: List[SQLResponse] | |
total_queries: int | |
successful_queries: int | |
class HealthResponse(BaseModel): | |
status: str | |
model_loaded: bool | |
model_loading: bool | |
model_error: Optional[str] = None | |
timestamp: float | |
async def root(): | |
"""Serve the main HTML interface""" | |
try: | |
with open("index.html", "r", encoding="utf-8") as f: | |
return HTMLResponse(content=f.read()) | |
except FileNotFoundError: | |
return HTMLResponse(content=""" | |
<html> | |
<body> | |
<h1>Text-to-SQL API</h1> | |
<p>index.html not found. Please ensure the file exists in the same directory.</p> | |
</body> | |
</html> | |
""") | |
async def api_info(): | |
"""API information endpoint""" | |
return { | |
"message": "Text-to-SQL API", | |
"version": "1.0.0", | |
"endpoints": { | |
"/": "GET - Web interface", | |
"/api": "GET - API information", | |
"/predict": "POST - Generate SQL from single question", | |
"/batch": "POST - Generate SQL from multiple questions", | |
"/health": "GET - Health check", | |
"/docs": "GET - API documentation" | |
} | |
} | |
async def predict_sql(request: SQLRequest): | |
""" | |
Generate SQL query from a natural language question | |
Args: | |
request: SQLRequest containing question and table headers | |
Returns: | |
SQLResponse with generated SQL query | |
""" | |
global model, model_loading, model_load_error | |
if model_loading: | |
raise HTTPException(status_code=503, detail="Model is still loading, please try again in a few minutes") | |
if model is None: | |
error_msg = model_load_error or "Model not loaded" | |
raise HTTPException(status_code=503, detail=f"Model not available: {error_msg}") | |
start_time = time.time() | |
try: | |
sql_query = model.predict(request.question, request.table_headers) | |
processing_time = time.time() - start_time | |
return SQLResponse( | |
question=request.question, | |
table_headers=request.table_headers, | |
sql_query=sql_query, | |
processing_time=processing_time | |
) | |
except Exception as e: | |
logger.error(f"Error generating SQL: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error generating SQL: {str(e)}") | |
async def batch_predict(request: BatchRequest): | |
""" | |
Generate SQL queries from multiple questions | |
Args: | |
request: BatchRequest containing list of questions and table headers | |
Returns: | |
BatchResponse with generated SQL queries | |
""" | |
global model, model_loading, model_load_error | |
if model_loading: | |
raise HTTPException(status_code=503, detail="Model is still loading, please try again in a few minutes") | |
if model is None: | |
error_msg = model_load_error or "Model not loaded" | |
raise HTTPException(status_code=503, detail=f"Model not available: {error_msg}") | |
start_time = time.time() | |
try: | |
# Convert to format expected by model | |
queries = [ | |
{"question": q.question, "table_headers": q.table_headers} | |
for q in request.queries | |
] | |
# Get predictions | |
results = model.batch_predict(queries) | |
# Convert to response format | |
sql_responses = [] | |
successful_count = 0 | |
for i, result in enumerate(results): | |
if result['status'] == 'success': | |
successful_count += 1 | |
sql_responses.append(SQLResponse( | |
question=result['question'], | |
table_headers=result['table_headers'], | |
sql_query=result['sql'], | |
processing_time=time.time() - start_time | |
)) | |
else: | |
# For failed queries, return error in SQL field | |
sql_responses.append(SQLResponse( | |
question=result['question'], | |
table_headers=result['table_headers'], | |
sql_query=f"ERROR: {result.get('error', 'Unknown error')}", | |
processing_time=time.time() - start_time | |
)) | |
return BatchResponse( | |
results=sql_responses, | |
total_queries=len(request.queries), | |
successful_queries=successful_count | |
) | |
except Exception as e: | |
logger.error(f"Error in batch prediction: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error in batch prediction: {str(e)}") | |
async def health_check(): | |
""" | |
Health check endpoint | |
Returns: | |
HealthResponse with service status | |
""" | |
global model, model_loading, model_load_error | |
model_loaded = model is not None and model.health_check() | |
if model_loaded: | |
status = "healthy" | |
elif model_loading: | |
status = "loading" | |
else: | |
status = "unhealthy" | |
return HealthResponse( | |
status=status, | |
model_loaded=model_loaded, | |
model_loading=model_loading, | |
model_error=model_load_error, | |
timestamp=time.time() | |
) | |
async def get_example(): | |
"""Get example usage""" | |
return { | |
"example_request": { | |
"question": "How many employees are older than 30?", | |
"table_headers": ["id", "name", "age", "department", "salary"] | |
}, | |
"example_response": { | |
"question": "How many employees are older than 30?", | |
"table_headers": ["id", "name", "age", "department", "salary"], | |
"sql_query": "SELECT COUNT(*) FROM table WHERE age > 30", | |
"processing_time": 0.5 | |
} | |
} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) |