Spaces:
Sleeping
Sleeping
Krishna Prakash
Merge branches 'main' and 'main' of https://huggingface.co/spaces/Krishna086/SQL_Practice_Platform
14db70f
from fastapi import FastAPI, HTTPException, Header, Request | |
from fastapi.responses import JSONResponse, HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
from loguru import logger | |
import sqlite3 | |
import sqlparse | |
import os | |
import uuid | |
import json # Added for json.load() | |
from typing import Dict, Any | |
from app.database import create_session_db, close_session_db | |
from app.schemas import RunQueryRequest, ValidateQueryRequest | |
app = FastAPI() | |
app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static") | |
async def custom_exception_handler(request: Request, exc: Exception): | |
return JSONResponse(status_code=500, content={"detail": str(exc)}) | |
sessions: Dict[str, dict] = {} | |
BASE_DIR = os.path.dirname(__file__) | |
def load_questions(domain: str): | |
file_path = os.path.join(BASE_DIR, "questions", f"{domain}.json") | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f"Question file not found: {file_path}") | |
with open(file_path, "r") as f: | |
return json.load(f) # Replaced eval with json.load() | |
def load_schema_sql(domain: str): | |
file_path = os.path.join(BASE_DIR, "schemas", f"{domain}.sql") | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f"Schema file not found: {file_path}") | |
with open(file_path, "r") as f: | |
return f.read() | |
def is_safe_query(sql: str) -> bool: | |
parsed = sqlparse.parse(sql.lower())[0] | |
return str(parsed).lower().strip().startswith("select") and all(kw not in str(parsed).lower() for kw in ["drop", "attach", "detach", "pragma", "insert", "update", "delete"]) | |
def extract_tables(sql: str) -> list: | |
tables = set() | |
tokens = sql.replace("\n", " ").lower().split() | |
in_subquery = in_openquery = in_values = False | |
for i, token in enumerate(tokens): | |
if token == "(" and not in_subquery and not in_values: | |
in_values = i > 0 and tokens[i - 1] == "values" | |
in_subquery = not in_values | |
if token == ")" and (in_subquery or in_values): | |
if in_values and i + 1 < len(tokens) and tokens[i + 1] == "as": in_values = False | |
elif in_subquery: in_subquery = False | |
if token == "openquery" and i + 1 < len(tokens) and tokens[i + 1] == "(": in_openquery = True | |
if token == ")" and in_openquery: in_openquery = False | |
if in_openquery: continue | |
if token in ["from", "join", "update", "delete", "insert", "into", "using", "apply", "pivot", "table"]: | |
next_token = tokens[i + 1].replace(",);", "") if i + 1 < len(tokens) else "" | |
if next_token and next_token not in ["select", "where", "on", "order", "group", "having", "as", "("]: | |
if i + 2 < len(tokens) and tokens[i + 2] == "as": next_token = next_token | |
elif next_token not in ["left", "right", "inner", "outer", "cross", "full"]: tables.add(next_token) | |
i += 1 | |
elif token == "merge" and i + 1 < len(tokens) and tokens[i + 1] == "into": | |
next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else "" | |
if next_token and next_token not in ["using", "select", "where"]: tables.add(next_token) | |
i += 2 | |
while i + 1 < len(tokens) and tokens[i + 1] != "using": i += 1 | |
if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["select", "where"]: tables.add(next_token) | |
elif token == "select" and i + 1 < len(tokens) and tokens[i + 1] == "into": | |
next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else "" | |
if next_token and next_token not in ["from", "select"]: tables.add(next_token) | |
i += 2 | |
while i + 1 < len(tokens) and tokens[i + 1] != "from": i += 1 | |
if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["where", "join"]: tables.add(next_token) | |
elif token == "with": | |
while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1 | |
if i + 2 < len(tokens) and tokens[i + 2] == "(": | |
bracket_count = 1 | |
subquery_start = i + 2 | |
i += 2 | |
while i < len(tokens) and bracket_count > 0: | |
if tokens[i] == "(": bracket_count += 1 | |
elif tokens[i] == ")": bracket_count -= 1 | |
i += 1 | |
if bracket_count == 0 and i > subquery_start: | |
subquery = " ".join(tokens[subquery_start:i - 1]) | |
tables.update(t for t in extract_tables(subquery) if t not in tables) | |
elif token == "values" and i + 1 < len(tokens) and tokens[i + 1] == "(": | |
while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1 | |
if i + 2 < len(tokens) and (alias := tokens[i + 2].replace(",);", "")): tables.add(alias) | |
elif token in ["exists", "in"]: | |
subquery_start = i + 1 | |
while i + 1 < len(tokens) and tokens[i + 1] != ")": i += 1 | |
if i > subquery_start: | |
subquery = " ".join(tokens[subquery_start:i + 1]) | |
tables.update(t for t in extract_tables(subquery) if t not in tables) | |
return list(tables) | |
async def create_session(): | |
session_id = str(uuid.uuid4()) | |
sessions[session_id] = {"conn": create_session_db(), "domain": None} | |
return {"session_id": session_id} | |
async def get_databases(): | |
questions_dir = os.path.join(BASE_DIR, "questions") | |
logger.debug(f"Checking databases in directory: {questions_dir}") | |
if not os.path.exists(questions_dir): | |
logger.error(f"Questions directory not found: {questions_dir}") | |
return {"databases": []} | |
databases = [f.replace(".json", "") for f in os.listdir(questions_dir) if f.endswith(".json")] | |
logger.debug(f"Found databases: {databases}") | |
return {"databases": databases} | |
async def get_databases(): | |
questions_dir = os.path.join(BASE_DIR, "questions") | |
logger.debug(f"Checking databases in directory: {questions_dir}") | |
if not os.path.exists(questions_dir): | |
logger.error(f"Questions directory not found: {questions_dir}") | |
return {"databases": []} | |
databases = [f.replace(".json", "") for f in os.listdir(questions_dir) if f.endswith(".json")] | |
logger.debug(f"Found databases: {databases}") | |
return {"databases": databases} | |
async def load_schema(domain: str, session_id: str = Header(...)): | |
logger.debug(f"Loading schema for domain: {domain}, session_id: {session_id}") | |
if session_id not in sessions: | |
logger.error(f"Invalid session: {session_id}") | |
raise HTTPException(status_code=401, detail="Invalid session") | |
sessions[session_id] = {"conn": create_session_db(), "domain": domain} | |
try: | |
schema_sql = load_schema_sql(domain) | |
logger.debug(f"Schema SQL loaded for {domain}") | |
sessions[session_id]["conn"].executescript(schema_sql) | |
sessions[session_id]["conn"].commit() | |
logger.info(f"Schema loaded successfully for {domain}") | |
except FileNotFoundError as e: | |
logger.error(f"Schema file not found: {str(e)}") | |
close_session_db(sessions[session_id]["conn"]) | |
del sessions[session_id] | |
raise HTTPException(status_code=500, detail=str(e)) | |
except sqlite3.Error as e: | |
logger.error(f"Database error: {str(e)}") | |
close_session_db(sessions[session_id]["conn"]) | |
del sessions[session_id] | |
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") | |
return {"message": f"Database {domain} loaded"} | |
async def get_schema(domain: str, session_id: str = Header(...)): | |
if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded") | |
conn = sessions[session_id]["conn"] | |
cursor = conn.cursor() | |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
return {"schema": {table: [{"name": row["name"], "type": row["type"]} for row in conn.execute(f"PRAGMA table_info({table});")] for table in [row["name"] for row in cursor.fetchall()]}} | |
async def get_sample_data(domain: str, session_id: str = Header(...)): | |
if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded") | |
conn = sessions[session_id]["conn"] | |
cursor = conn.cursor() | |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
return {"sample_data": {table: {"columns": [desc[0] for desc in conn.execute(f"SELECT * FROM {table} LIMIT 5").description], "rows": [dict(row) for row in conn.execute(f"SELECT * FROM {table} LIMIT 5")]} for table in [row["name"] for row in cursor.fetchall()]}} | |
async def run_query(request: RunQueryRequest, session_id: str = Header(...)): | |
if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded") | |
if not is_safe_query(request.query): raise HTTPException(status_code=400, detail="Only SELECT queries are allowed") | |
conn = sessions[session_id]["conn"] | |
cursor = conn.cursor() | |
cursor.execute(request.query) | |
if cursor.description: | |
columns = [desc[0] for desc in cursor.description] | |
return {"columns": columns, "rows": [dict(zip(columns, row)) for row in cursor.fetchall()]} | |
return {"message": "Query executed successfully (no results)"} | |
async def get_questions(domain: str, difficulty: str = None): | |
questions = load_questions(domain) | |
if difficulty: questions = [q for q in questions if q["difficulty"].lower() == difficulty.lower()] | |
return [{"id": q["id"], "title": q["title"], "difficulty": q["difficulty"], "description": q["description"], "hint": q["hint"], "expected_sql": q["expected_sql"]} for q in questions] | |
async def validate_query(request: ValidateQueryRequest, session_id: str = Header(...)): | |
if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded") | |
conn = sessions[session_id]["conn"] | |
cursor = conn.cursor() | |
cursor.execute(request.user_query) | |
user_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else [] | |
cursor.execute(request.expected_query) | |
expected_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else [] | |
return {"valid": user_result == expected_result, "error": "Results do not match." if user_result != expected_result else ""} | |
async def cleanup(): | |
for session_id in list(sessions): close_session_db(sessions[session_id]["conn"]); del sessions[session_id] | |
async def serve_frontend(): | |
file_path = os.path.join(BASE_DIR, "static", "index.html") | |
if not os.path.exists(file_path): raise HTTPException(status_code=500, detail=f"Frontend file not found: {file_path}") | |
with open(file_path, "r") as f: return f.read() |