from fastapi import FastAPI, HTTPException, Header, Request from fastapi.responses import JSONResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from loguru import logger import sqlite3 import sqlparse import os import uuid import json # Added for json.load() from typing import Dict, Any from app.database import create_session_db, close_session_db from app.schemas import RunQueryRequest, ValidateQueryRequest app = FastAPI() app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static") @app.exception_handler(Exception) async def custom_exception_handler(request: Request, exc: Exception): return JSONResponse(status_code=500, content={"detail": str(exc)}) sessions: Dict[str, dict] = {} BASE_DIR = os.path.dirname(__file__) def load_questions(domain: str): file_path = os.path.join(BASE_DIR, "questions", f"{domain}.json") if not os.path.exists(file_path): raise FileNotFoundError(f"Question file not found: {file_path}") with open(file_path, "r") as f: return json.load(f) # Replaced eval with json.load() def load_schema_sql(domain: str): file_path = os.path.join(BASE_DIR, "schemas", f"{domain}.sql") if not os.path.exists(file_path): raise FileNotFoundError(f"Schema file not found: {file_path}") with open(file_path, "r") as f: return f.read() def is_safe_query(sql: str) -> bool: parsed = sqlparse.parse(sql.lower())[0] return str(parsed).lower().strip().startswith("select") and all(kw not in str(parsed).lower() for kw in ["drop", "attach", "detach", "pragma", "insert", "update", "delete"]) def extract_tables(sql: str) -> list: tables = set() tokens = sql.replace("\n", " ").lower().split() in_subquery = in_openquery = in_values = False for i, token in enumerate(tokens): if token == "(" and not in_subquery and not in_values: in_values = i > 0 and tokens[i - 1] == "values" in_subquery = not in_values if token == ")" and (in_subquery or in_values): if in_values and i + 1 < len(tokens) and tokens[i + 1] == "as": in_values = False elif in_subquery: in_subquery = False if token == "openquery" and i + 1 < len(tokens) and tokens[i + 1] == "(": in_openquery = True if token == ")" and in_openquery: in_openquery = False if in_openquery: continue if token in ["from", "join", "update", "delete", "insert", "into", "using", "apply", "pivot", "table"]: next_token = tokens[i + 1].replace(",);", "") if i + 1 < len(tokens) else "" if next_token and next_token not in ["select", "where", "on", "order", "group", "having", "as", "("]: if i + 2 < len(tokens) and tokens[i + 2] == "as": next_token = next_token elif next_token not in ["left", "right", "inner", "outer", "cross", "full"]: tables.add(next_token) i += 1 elif token == "merge" and i + 1 < len(tokens) and tokens[i + 1] == "into": next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else "" if next_token and next_token not in ["using", "select", "where"]: tables.add(next_token) i += 2 while i + 1 < len(tokens) and tokens[i + 1] != "using": i += 1 if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["select", "where"]: tables.add(next_token) elif token == "select" and i + 1 < len(tokens) and tokens[i + 1] == "into": next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else "" if next_token and next_token not in ["from", "select"]: tables.add(next_token) i += 2 while i + 1 < len(tokens) and tokens[i + 1] != "from": i += 1 if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["where", "join"]: tables.add(next_token) elif token == "with": while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1 if i + 2 < len(tokens) and tokens[i + 2] == "(": bracket_count = 1 subquery_start = i + 2 i += 2 while i < len(tokens) and bracket_count > 0: if tokens[i] == "(": bracket_count += 1 elif tokens[i] == ")": bracket_count -= 1 i += 1 if bracket_count == 0 and i > subquery_start: subquery = " ".join(tokens[subquery_start:i - 1]) tables.update(t for t in extract_tables(subquery) if t not in tables) elif token == "values" and i + 1 < len(tokens) and tokens[i + 1] == "(": while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1 if i + 2 < len(tokens) and (alias := tokens[i + 2].replace(",);", "")): tables.add(alias) elif token in ["exists", "in"]: subquery_start = i + 1 while i + 1 < len(tokens) and tokens[i + 1] != ")": i += 1 if i > subquery_start: subquery = " ".join(tokens[subquery_start:i + 1]) tables.update(t for t in extract_tables(subquery) if t not in tables) return list(tables) @app.post("/api/session") async def create_session(): session_id = str(uuid.uuid4()) sessions[session_id] = {"conn": create_session_db(), "domain": None} return {"session_id": session_id} @app.get("/api/databases") async def get_databases(): questions_dir = os.path.join(BASE_DIR, "questions") logger.debug(f"Checking databases in directory: {questions_dir}") if not os.path.exists(questions_dir): logger.error(f"Questions directory not found: {questions_dir}") return {"databases": []} databases = [f.replace(".json", "") for f in os.listdir(questions_dir) if f.endswith(".json")] logger.debug(f"Found databases: {databases}") return {"databases": databases} @app.get("/api/databases") async def get_databases(): questions_dir = os.path.join(BASE_DIR, "questions") logger.debug(f"Checking databases in directory: {questions_dir}") if not os.path.exists(questions_dir): logger.error(f"Questions directory not found: {questions_dir}") return {"databases": []} databases = [f.replace(".json", "") for f in os.listdir(questions_dir) if f.endswith(".json")] logger.debug(f"Found databases: {databases}") return {"databases": databases} @app.post("/api/load-schema/{domain}") async def load_schema(domain: str, session_id: str = Header(...)): logger.debug(f"Loading schema for domain: {domain}, session_id: {session_id}") if session_id not in sessions: logger.error(f"Invalid session: {session_id}") raise HTTPException(status_code=401, detail="Invalid session") sessions[session_id] = {"conn": create_session_db(), "domain": domain} try: schema_sql = load_schema_sql(domain) logger.debug(f"Schema SQL loaded for {domain}") sessions[session_id]["conn"].executescript(schema_sql) sessions[session_id]["conn"].commit() logger.info(f"Schema loaded successfully for {domain}") except FileNotFoundError as e: logger.error(f"Schema file not found: {str(e)}") close_session_db(sessions[session_id]["conn"]) del sessions[session_id] raise HTTPException(status_code=500, detail=str(e)) except sqlite3.Error as e: logger.error(f"Database error: {str(e)}") close_session_db(sessions[session_id]["conn"]) del sessions[session_id] raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") return {"message": f"Database {domain} loaded"} @app.get("/api/schema/{domain}") async def get_schema(domain: str, session_id: str = Header(...)): if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded") conn = sessions[session_id]["conn"] cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") return {"schema": {table: [{"name": row["name"], "type": row["type"]} for row in conn.execute(f"PRAGMA table_info({table});")] for table in [row["name"] for row in cursor.fetchall()]}} @app.get("/api/sample-data/{domain}") async def get_sample_data(domain: str, session_id: str = Header(...)): if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded") conn = sessions[session_id]["conn"] cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") return {"sample_data": {table: {"columns": [desc[0] for desc in conn.execute(f"SELECT * FROM {table} LIMIT 5").description], "rows": [dict(row) for row in conn.execute(f"SELECT * FROM {table} LIMIT 5")]} for table in [row["name"] for row in cursor.fetchall()]}} @app.post("/api/run-query") async def run_query(request: RunQueryRequest, session_id: str = Header(...)): if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded") if not is_safe_query(request.query): raise HTTPException(status_code=400, detail="Only SELECT queries are allowed") conn = sessions[session_id]["conn"] cursor = conn.cursor() cursor.execute(request.query) if cursor.description: columns = [desc[0] for desc in cursor.description] return {"columns": columns, "rows": [dict(zip(columns, row)) for row in cursor.fetchall()]} return {"message": "Query executed successfully (no results)"} @app.get("/api/questions/{domain}") async def get_questions(domain: str, difficulty: str = None): questions = load_questions(domain) if difficulty: questions = [q for q in questions if q["difficulty"].lower() == difficulty.lower()] return [{"id": q["id"], "title": q["title"], "difficulty": q["difficulty"], "description": q["description"], "hint": q["hint"], "expected_sql": q["expected_sql"]} for q in questions] @app.post("/api/validate") async def validate_query(request: ValidateQueryRequest, session_id: str = Header(...)): if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded") conn = sessions[session_id]["conn"] cursor = conn.cursor() cursor.execute(request.user_query) user_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else [] cursor.execute(request.expected_query) expected_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else [] return {"valid": user_result == expected_result, "error": "Results do not match." if user_result != expected_result else ""} @app.on_event("shutdown") async def cleanup(): for session_id in list(sessions): close_session_db(sessions[session_id]["conn"]); del sessions[session_id] @app.get("/", response_class=HTMLResponse) async def serve_frontend(): file_path = os.path.join(BASE_DIR, "static", "index.html") if not os.path.exists(file_path): raise HTTPException(status_code=500, detail=f"Frontend file not found: {file_path}") with open(file_path, "r") as f: return f.read()