Spaces:
Sleeping
Sleeping
File size: 11,483 Bytes
e7cf806 f68a38d e7cf806 f68a38d e7cf806 f68a38d e7cf806 f68a38d e7cf806 f68a38d e7cf806 f68a38d e7cf806 f68a38d e7cf806 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
from fastapi import FastAPI, HTTPException, Header, Request
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from loguru import logger
import sqlite3
import sqlparse
import os
import uuid
import json # Added for json.load()
from typing import Dict, Any
from app.database import create_session_db, close_session_db
from app.schemas import RunQueryRequest, ValidateQueryRequest
app = FastAPI()
app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static")
@app.exception_handler(Exception)
async def custom_exception_handler(request: Request, exc: Exception):
return JSONResponse(status_code=500, content={"detail": str(exc)})
sessions: Dict[str, dict] = {}
BASE_DIR = os.path.dirname(__file__)
def load_questions(domain: str):
file_path = os.path.join(BASE_DIR, "questions", f"{domain}.json")
if not os.path.exists(file_path):
raise FileNotFoundError(f"Question file not found: {file_path}")
with open(file_path, "r") as f:
return json.load(f) # Replaced eval with json.load()
def load_schema_sql(domain: str):
file_path = os.path.join(BASE_DIR, "schemas", f"{domain}.sql")
if not os.path.exists(file_path):
raise FileNotFoundError(f"Schema file not found: {file_path}")
with open(file_path, "r") as f:
return f.read()
def is_safe_query(sql: str) -> bool:
parsed = sqlparse.parse(sql.lower())[0]
return str(parsed).lower().strip().startswith("select") and all(kw not in str(parsed).lower() for kw in ["drop", "attach", "detach", "pragma", "insert", "update", "delete"])
def extract_tables(sql: str) -> list:
tables = set()
tokens = sql.replace("\n", " ").lower().split()
in_subquery = in_openquery = in_values = False
for i, token in enumerate(tokens):
if token == "(" and not in_subquery and not in_values:
in_values = i > 0 and tokens[i - 1] == "values"
in_subquery = not in_values
if token == ")" and (in_subquery or in_values):
if in_values and i + 1 < len(tokens) and tokens[i + 1] == "as": in_values = False
elif in_subquery: in_subquery = False
if token == "openquery" and i + 1 < len(tokens) and tokens[i + 1] == "(": in_openquery = True
if token == ")" and in_openquery: in_openquery = False
if in_openquery: continue
if token in ["from", "join", "update", "delete", "insert", "into", "using", "apply", "pivot", "table"]:
next_token = tokens[i + 1].replace(",);", "") if i + 1 < len(tokens) else ""
if next_token and next_token not in ["select", "where", "on", "order", "group", "having", "as", "("]:
if i + 2 < len(tokens) and tokens[i + 2] == "as": next_token = next_token
elif next_token not in ["left", "right", "inner", "outer", "cross", "full"]: tables.add(next_token)
i += 1
elif token == "merge" and i + 1 < len(tokens) and tokens[i + 1] == "into":
next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else ""
if next_token and next_token not in ["using", "select", "where"]: tables.add(next_token)
i += 2
while i + 1 < len(tokens) and tokens[i + 1] != "using": i += 1
if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["select", "where"]: tables.add(next_token)
elif token == "select" and i + 1 < len(tokens) and tokens[i + 1] == "into":
next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else ""
if next_token and next_token not in ["from", "select"]: tables.add(next_token)
i += 2
while i + 1 < len(tokens) and tokens[i + 1] != "from": i += 1
if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["where", "join"]: tables.add(next_token)
elif token == "with":
while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1
if i + 2 < len(tokens) and tokens[i + 2] == "(":
bracket_count = 1
subquery_start = i + 2
i += 2
while i < len(tokens) and bracket_count > 0:
if tokens[i] == "(": bracket_count += 1
elif tokens[i] == ")": bracket_count -= 1
i += 1
if bracket_count == 0 and i > subquery_start:
subquery = " ".join(tokens[subquery_start:i - 1])
tables.update(t for t in extract_tables(subquery) if t not in tables)
elif token == "values" and i + 1 < len(tokens) and tokens[i + 1] == "(":
while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1
if i + 2 < len(tokens) and (alias := tokens[i + 2].replace(",);", "")): tables.add(alias)
elif token in ["exists", "in"]:
subquery_start = i + 1
while i + 1 < len(tokens) and tokens[i + 1] != ")": i += 1
if i > subquery_start:
subquery = " ".join(tokens[subquery_start:i + 1])
tables.update(t for t in extract_tables(subquery) if t not in tables)
return list(tables)
@app.post("/api/session")
async def create_session():
session_id = str(uuid.uuid4())
sessions[session_id] = {"conn": create_session_db(), "domain": None}
return {"session_id": session_id}
@app.get("/api/databases")
async def get_databases():
questions_dir = os.path.join(BASE_DIR, "questions")
logger.debug(f"Checking databases in directory: {questions_dir}")
if not os.path.exists(questions_dir):
logger.error(f"Questions directory not found: {questions_dir}")
return {"databases": []}
databases = [f.replace(".json", "") for f in os.listdir(questions_dir) if f.endswith(".json")]
logger.debug(f"Found databases: {databases}")
return {"databases": databases}
@app.get("/api/databases")
async def get_databases():
questions_dir = os.path.join(BASE_DIR, "questions")
logger.debug(f"Checking databases in directory: {questions_dir}")
if not os.path.exists(questions_dir):
logger.error(f"Questions directory not found: {questions_dir}")
return {"databases": []}
databases = [f.replace(".json", "") for f in os.listdir(questions_dir) if f.endswith(".json")]
logger.debug(f"Found databases: {databases}")
return {"databases": databases}
@app.post("/api/load-schema/{domain}")
async def load_schema(domain: str, session_id: str = Header(...)):
logger.debug(f"Loading schema for domain: {domain}, session_id: {session_id}")
if session_id not in sessions:
logger.error(f"Invalid session: {session_id}")
raise HTTPException(status_code=401, detail="Invalid session")
sessions[session_id] = {"conn": create_session_db(), "domain": domain}
try:
schema_sql = load_schema_sql(domain)
logger.debug(f"Schema SQL loaded for {domain}")
sessions[session_id]["conn"].executescript(schema_sql)
sessions[session_id]["conn"].commit()
logger.info(f"Schema loaded successfully for {domain}")
except FileNotFoundError as e:
logger.error(f"Schema file not found: {str(e)}")
close_session_db(sessions[session_id]["conn"])
del sessions[session_id]
raise HTTPException(status_code=500, detail=str(e))
except sqlite3.Error as e:
logger.error(f"Database error: {str(e)}")
close_session_db(sessions[session_id]["conn"])
del sessions[session_id]
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
return {"message": f"Database {domain} loaded"}
@app.get("/api/schema/{domain}")
async def get_schema(domain: str, session_id: str = Header(...)):
if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded")
conn = sessions[session_id]["conn"]
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
return {"schema": {table: [{"name": row["name"], "type": row["type"]} for row in conn.execute(f"PRAGMA table_info({table});")] for table in [row["name"] for row in cursor.fetchall()]}}
@app.get("/api/sample-data/{domain}")
async def get_sample_data(domain: str, session_id: str = Header(...)):
if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded")
conn = sessions[session_id]["conn"]
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
return {"sample_data": {table: {"columns": [desc[0] for desc in conn.execute(f"SELECT * FROM {table} LIMIT 5").description], "rows": [dict(row) for row in conn.execute(f"SELECT * FROM {table} LIMIT 5")]} for table in [row["name"] for row in cursor.fetchall()]}}
@app.post("/api/run-query")
async def run_query(request: RunQueryRequest, session_id: str = Header(...)):
if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded")
if not is_safe_query(request.query): raise HTTPException(status_code=400, detail="Only SELECT queries are allowed")
conn = sessions[session_id]["conn"]
cursor = conn.cursor()
cursor.execute(request.query)
if cursor.description:
columns = [desc[0] for desc in cursor.description]
return {"columns": columns, "rows": [dict(zip(columns, row)) for row in cursor.fetchall()]}
return {"message": "Query executed successfully (no results)"}
@app.get("/api/questions/{domain}")
async def get_questions(domain: str, difficulty: str = None):
questions = load_questions(domain)
if difficulty: questions = [q for q in questions if q["difficulty"].lower() == difficulty.lower()]
return [{"id": q["id"], "title": q["title"], "difficulty": q["difficulty"], "description": q["description"], "hint": q["hint"], "expected_sql": q["expected_sql"]} for q in questions]
@app.post("/api/validate")
async def validate_query(request: ValidateQueryRequest, session_id: str = Header(...)):
if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded")
conn = sessions[session_id]["conn"]
cursor = conn.cursor()
cursor.execute(request.user_query)
user_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else []
cursor.execute(request.expected_query)
expected_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else []
return {"valid": user_result == expected_result, "error": "Results do not match." if user_result != expected_result else ""}
@app.on_event("shutdown")
async def cleanup():
for session_id in list(sessions): close_session_db(sessions[session_id]["conn"]); del sessions[session_id]
@app.get("/", response_class=HTMLResponse)
async def serve_frontend():
file_path = os.path.join(BASE_DIR, "static", "index.html")
if not os.path.exists(file_path): raise HTTPException(status_code=500, detail=f"Frontend file not found: {file_path}")
with open(file_path, "r") as f: return f.read() |