Krishna Prakash
Merge branches 'main' and 'main' of https://huggingface.co/spaces/Krishna086/SQL_Practice_Platform
14db70f
from fastapi import FastAPI, HTTPException, Header, Request
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from loguru import logger
import sqlite3
import sqlparse
import os
import uuid
import json # Added for json.load()
from typing import Dict, Any
from app.database import create_session_db, close_session_db
from app.schemas import RunQueryRequest, ValidateQueryRequest
app = FastAPI()
app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static")
@app.exception_handler(Exception)
async def custom_exception_handler(request: Request, exc: Exception):
return JSONResponse(status_code=500, content={"detail": str(exc)})
sessions: Dict[str, dict] = {}
BASE_DIR = os.path.dirname(__file__)
def load_questions(domain: str):
file_path = os.path.join(BASE_DIR, "questions", f"{domain}.json")
if not os.path.exists(file_path):
raise FileNotFoundError(f"Question file not found: {file_path}")
with open(file_path, "r") as f:
return json.load(f) # Replaced eval with json.load()
def load_schema_sql(domain: str):
file_path = os.path.join(BASE_DIR, "schemas", f"{domain}.sql")
if not os.path.exists(file_path):
raise FileNotFoundError(f"Schema file not found: {file_path}")
with open(file_path, "r") as f:
return f.read()
def is_safe_query(sql: str) -> bool:
parsed = sqlparse.parse(sql.lower())[0]
return str(parsed).lower().strip().startswith("select") and all(kw not in str(parsed).lower() for kw in ["drop", "attach", "detach", "pragma", "insert", "update", "delete"])
def extract_tables(sql: str) -> list:
tables = set()
tokens = sql.replace("\n", " ").lower().split()
in_subquery = in_openquery = in_values = False
for i, token in enumerate(tokens):
if token == "(" and not in_subquery and not in_values:
in_values = i > 0 and tokens[i - 1] == "values"
in_subquery = not in_values
if token == ")" and (in_subquery or in_values):
if in_values and i + 1 < len(tokens) and tokens[i + 1] == "as": in_values = False
elif in_subquery: in_subquery = False
if token == "openquery" and i + 1 < len(tokens) and tokens[i + 1] == "(": in_openquery = True
if token == ")" and in_openquery: in_openquery = False
if in_openquery: continue
if token in ["from", "join", "update", "delete", "insert", "into", "using", "apply", "pivot", "table"]:
next_token = tokens[i + 1].replace(",);", "") if i + 1 < len(tokens) else ""
if next_token and next_token not in ["select", "where", "on", "order", "group", "having", "as", "("]:
if i + 2 < len(tokens) and tokens[i + 2] == "as": next_token = next_token
elif next_token not in ["left", "right", "inner", "outer", "cross", "full"]: tables.add(next_token)
i += 1
elif token == "merge" and i + 1 < len(tokens) and tokens[i + 1] == "into":
next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else ""
if next_token and next_token not in ["using", "select", "where"]: tables.add(next_token)
i += 2
while i + 1 < len(tokens) and tokens[i + 1] != "using": i += 1
if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["select", "where"]: tables.add(next_token)
elif token == "select" and i + 1 < len(tokens) and tokens[i + 1] == "into":
next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else ""
if next_token and next_token not in ["from", "select"]: tables.add(next_token)
i += 2
while i + 1 < len(tokens) and tokens[i + 1] != "from": i += 1
if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["where", "join"]: tables.add(next_token)
elif token == "with":
while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1
if i + 2 < len(tokens) and tokens[i + 2] == "(":
bracket_count = 1
subquery_start = i + 2
i += 2
while i < len(tokens) and bracket_count > 0:
if tokens[i] == "(": bracket_count += 1
elif tokens[i] == ")": bracket_count -= 1
i += 1
if bracket_count == 0 and i > subquery_start:
subquery = " ".join(tokens[subquery_start:i - 1])
tables.update(t for t in extract_tables(subquery) if t not in tables)
elif token == "values" and i + 1 < len(tokens) and tokens[i + 1] == "(":
while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1
if i + 2 < len(tokens) and (alias := tokens[i + 2].replace(",);", "")): tables.add(alias)
elif token in ["exists", "in"]:
subquery_start = i + 1
while i + 1 < len(tokens) and tokens[i + 1] != ")": i += 1
if i > subquery_start:
subquery = " ".join(tokens[subquery_start:i + 1])
tables.update(t for t in extract_tables(subquery) if t not in tables)
return list(tables)
@app.post("/api/session")
async def create_session():
session_id = str(uuid.uuid4())
sessions[session_id] = {"conn": create_session_db(), "domain": None}
return {"session_id": session_id}
@app.get("/api/databases")
async def get_databases():
questions_dir = os.path.join(BASE_DIR, "questions")
logger.debug(f"Checking databases in directory: {questions_dir}")
if not os.path.exists(questions_dir):
logger.error(f"Questions directory not found: {questions_dir}")
return {"databases": []}
databases = [f.replace(".json", "") for f in os.listdir(questions_dir) if f.endswith(".json")]
logger.debug(f"Found databases: {databases}")
return {"databases": databases}
@app.get("/api/databases")
async def get_databases():
questions_dir = os.path.join(BASE_DIR, "questions")
logger.debug(f"Checking databases in directory: {questions_dir}")
if not os.path.exists(questions_dir):
logger.error(f"Questions directory not found: {questions_dir}")
return {"databases": []}
databases = [f.replace(".json", "") for f in os.listdir(questions_dir) if f.endswith(".json")]
logger.debug(f"Found databases: {databases}")
return {"databases": databases}
@app.post("/api/load-schema/{domain}")
async def load_schema(domain: str, session_id: str = Header(...)):
logger.debug(f"Loading schema for domain: {domain}, session_id: {session_id}")
if session_id not in sessions:
logger.error(f"Invalid session: {session_id}")
raise HTTPException(status_code=401, detail="Invalid session")
sessions[session_id] = {"conn": create_session_db(), "domain": domain}
try:
schema_sql = load_schema_sql(domain)
logger.debug(f"Schema SQL loaded for {domain}")
sessions[session_id]["conn"].executescript(schema_sql)
sessions[session_id]["conn"].commit()
logger.info(f"Schema loaded successfully for {domain}")
except FileNotFoundError as e:
logger.error(f"Schema file not found: {str(e)}")
close_session_db(sessions[session_id]["conn"])
del sessions[session_id]
raise HTTPException(status_code=500, detail=str(e))
except sqlite3.Error as e:
logger.error(f"Database error: {str(e)}")
close_session_db(sessions[session_id]["conn"])
del sessions[session_id]
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
return {"message": f"Database {domain} loaded"}
@app.get("/api/schema/{domain}")
async def get_schema(domain: str, session_id: str = Header(...)):
if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded")
conn = sessions[session_id]["conn"]
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
return {"schema": {table: [{"name": row["name"], "type": row["type"]} for row in conn.execute(f"PRAGMA table_info({table});")] for table in [row["name"] for row in cursor.fetchall()]}}
@app.get("/api/sample-data/{domain}")
async def get_sample_data(domain: str, session_id: str = Header(...)):
if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded")
conn = sessions[session_id]["conn"]
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
return {"sample_data": {table: {"columns": [desc[0] for desc in conn.execute(f"SELECT * FROM {table} LIMIT 5").description], "rows": [dict(row) for row in conn.execute(f"SELECT * FROM {table} LIMIT 5")]} for table in [row["name"] for row in cursor.fetchall()]}}
@app.post("/api/run-query")
async def run_query(request: RunQueryRequest, session_id: str = Header(...)):
if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded")
if not is_safe_query(request.query): raise HTTPException(status_code=400, detail="Only SELECT queries are allowed")
conn = sessions[session_id]["conn"]
cursor = conn.cursor()
cursor.execute(request.query)
if cursor.description:
columns = [desc[0] for desc in cursor.description]
return {"columns": columns, "rows": [dict(zip(columns, row)) for row in cursor.fetchall()]}
return {"message": "Query executed successfully (no results)"}
@app.get("/api/questions/{domain}")
async def get_questions(domain: str, difficulty: str = None):
questions = load_questions(domain)
if difficulty: questions = [q for q in questions if q["difficulty"].lower() == difficulty.lower()]
return [{"id": q["id"], "title": q["title"], "difficulty": q["difficulty"], "description": q["description"], "hint": q["hint"], "expected_sql": q["expected_sql"]} for q in questions]
@app.post("/api/validate")
async def validate_query(request: ValidateQueryRequest, session_id: str = Header(...)):
if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded")
conn = sessions[session_id]["conn"]
cursor = conn.cursor()
cursor.execute(request.user_query)
user_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else []
cursor.execute(request.expected_query)
expected_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else []
return {"valid": user_result == expected_result, "error": "Results do not match." if user_result != expected_result else ""}
@app.on_event("shutdown")
async def cleanup():
for session_id in list(sessions): close_session_db(sessions[session_id]["conn"]); del sessions[session_id]
@app.get("/", response_class=HTMLResponse)
async def serve_frontend():
file_path = os.path.join(BASE_DIR, "static", "index.html")
if not os.path.exists(file_path): raise HTTPException(status_code=500, detail=f"Frontend file not found: {file_path}")
with open(file_path, "r") as f: return f.read()