File size: 11,483 Bytes
e7cf806
 
 
 
f68a38d
e7cf806
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f68a38d
e7cf806
 
 
f68a38d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7cf806
 
 
f68a38d
 
 
 
e7cf806
 
f68a38d
 
 
e7cf806
f68a38d
 
 
 
 
 
e7cf806
f68a38d
 
e7cf806
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
from fastapi import FastAPI, HTTPException, Header, Request
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from loguru import logger
import sqlite3
import sqlparse
import os
import uuid
import json  # Added for json.load()
from typing import Dict, Any
from app.database import create_session_db, close_session_db
from app.schemas import RunQueryRequest, ValidateQueryRequest

app = FastAPI()
app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static")

@app.exception_handler(Exception)
async def custom_exception_handler(request: Request, exc: Exception):
    return JSONResponse(status_code=500, content={"detail": str(exc)})

sessions: Dict[str, dict] = {}

BASE_DIR = os.path.dirname(__file__)

def load_questions(domain: str):
    file_path = os.path.join(BASE_DIR, "questions", f"{domain}.json")
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Question file not found: {file_path}")
    with open(file_path, "r") as f:
        return json.load(f)  # Replaced eval with json.load()

def load_schema_sql(domain: str):
    file_path = os.path.join(BASE_DIR, "schemas", f"{domain}.sql")
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Schema file not found: {file_path}")
    with open(file_path, "r") as f:
        return f.read()

def is_safe_query(sql: str) -> bool:
    parsed = sqlparse.parse(sql.lower())[0]
    return str(parsed).lower().strip().startswith("select") and all(kw not in str(parsed).lower() for kw in ["drop", "attach", "detach", "pragma", "insert", "update", "delete"])

def extract_tables(sql: str) -> list:
    tables = set()
    tokens = sql.replace("\n", " ").lower().split()
    in_subquery = in_openquery = in_values = False

    for i, token in enumerate(tokens):
        if token == "(" and not in_subquery and not in_values:
            in_values = i > 0 and tokens[i - 1] == "values"
            in_subquery = not in_values
        if token == ")" and (in_subquery or in_values):
            if in_values and i + 1 < len(tokens) and tokens[i + 1] == "as": in_values = False
            elif in_subquery: in_subquery = False
        if token == "openquery" and i + 1 < len(tokens) and tokens[i + 1] == "(": in_openquery = True
        if token == ")" and in_openquery: in_openquery = False
        if in_openquery: continue

        if token in ["from", "join", "update", "delete", "insert", "into", "using", "apply", "pivot", "table"]:
            next_token = tokens[i + 1].replace(",);", "") if i + 1 < len(tokens) else ""
            if next_token and next_token not in ["select", "where", "on", "order", "group", "having", "as", "("]:
                if i + 2 < len(tokens) and tokens[i + 2] == "as": next_token = next_token
                elif next_token not in ["left", "right", "inner", "outer", "cross", "full"]: tables.add(next_token)
            i += 1
        elif token == "merge" and i + 1 < len(tokens) and tokens[i + 1] == "into":
            next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else ""
            if next_token and next_token not in ["using", "select", "where"]: tables.add(next_token)
            i += 2
            while i + 1 < len(tokens) and tokens[i + 1] != "using": i += 1
            if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["select", "where"]: tables.add(next_token)
        elif token == "select" and i + 1 < len(tokens) and tokens[i + 1] == "into":
            next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else ""
            if next_token and next_token not in ["from", "select"]: tables.add(next_token)
            i += 2
            while i + 1 < len(tokens) and tokens[i + 1] != "from": i += 1
            if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["where", "join"]: tables.add(next_token)
        elif token == "with":
            while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1
            if i + 2 < len(tokens) and tokens[i + 2] == "(":
                bracket_count = 1
                subquery_start = i + 2
                i += 2
                while i < len(tokens) and bracket_count > 0:
                    if tokens[i] == "(": bracket_count += 1
                    elif tokens[i] == ")": bracket_count -= 1
                    i += 1
                if bracket_count == 0 and i > subquery_start:
                    subquery = " ".join(tokens[subquery_start:i - 1])
                    tables.update(t for t in extract_tables(subquery) if t not in tables)
        elif token == "values" and i + 1 < len(tokens) and tokens[i + 1] == "(":
            while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1
            if i + 2 < len(tokens) and (alias := tokens[i + 2].replace(",);", "")): tables.add(alias)
        elif token in ["exists", "in"]:
            subquery_start = i + 1
            while i + 1 < len(tokens) and tokens[i + 1] != ")": i += 1
            if i > subquery_start:
                subquery = " ".join(tokens[subquery_start:i + 1])
                tables.update(t for t in extract_tables(subquery) if t not in tables)
    return list(tables)

@app.post("/api/session")
async def create_session():
    session_id = str(uuid.uuid4())
    sessions[session_id] = {"conn": create_session_db(), "domain": None}
    return {"session_id": session_id}


@app.get("/api/databases")
async def get_databases():
    questions_dir = os.path.join(BASE_DIR, "questions")
    logger.debug(f"Checking databases in directory: {questions_dir}")
    if not os.path.exists(questions_dir):
        logger.error(f"Questions directory not found: {questions_dir}")
        return {"databases": []}
    databases = [f.replace(".json", "") for f in os.listdir(questions_dir) if f.endswith(".json")]
    logger.debug(f"Found databases: {databases}")
    return {"databases": databases}

@app.get("/api/databases")
async def get_databases():
    questions_dir = os.path.join(BASE_DIR, "questions")
    logger.debug(f"Checking databases in directory: {questions_dir}")
    if not os.path.exists(questions_dir):
        logger.error(f"Questions directory not found: {questions_dir}")
        return {"databases": []}
    databases = [f.replace(".json", "") for f in os.listdir(questions_dir) if f.endswith(".json")]
    logger.debug(f"Found databases: {databases}")
    return {"databases": databases}

@app.post("/api/load-schema/{domain}")
async def load_schema(domain: str, session_id: str = Header(...)):
    logger.debug(f"Loading schema for domain: {domain}, session_id: {session_id}")
    if session_id not in sessions:
        logger.error(f"Invalid session: {session_id}")
        raise HTTPException(status_code=401, detail="Invalid session")
    sessions[session_id] = {"conn": create_session_db(), "domain": domain}
    try:
        schema_sql = load_schema_sql(domain)
        logger.debug(f"Schema SQL loaded for {domain}")
        sessions[session_id]["conn"].executescript(schema_sql)
        sessions[session_id]["conn"].commit()
        logger.info(f"Schema loaded successfully for {domain}")
    except FileNotFoundError as e:
        logger.error(f"Schema file not found: {str(e)}")
        close_session_db(sessions[session_id]["conn"])
        del sessions[session_id]
        raise HTTPException(status_code=500, detail=str(e))
    except sqlite3.Error as e:
        logger.error(f"Database error: {str(e)}")
        close_session_db(sessions[session_id]["conn"])
        del sessions[session_id]
        raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
    return {"message": f"Database {domain} loaded"}

@app.get("/api/schema/{domain}")
async def get_schema(domain: str, session_id: str = Header(...)):
    if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded")
    conn = sessions[session_id]["conn"]
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    return {"schema": {table: [{"name": row["name"], "type": row["type"]} for row in conn.execute(f"PRAGMA table_info({table});")] for table in [row["name"] for row in cursor.fetchall()]}}

@app.get("/api/sample-data/{domain}")
async def get_sample_data(domain: str, session_id: str = Header(...)):
    if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded")
    conn = sessions[session_id]["conn"]
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    return {"sample_data": {table: {"columns": [desc[0] for desc in conn.execute(f"SELECT * FROM {table} LIMIT 5").description], "rows": [dict(row) for row in conn.execute(f"SELECT * FROM {table} LIMIT 5")]} for table in [row["name"] for row in cursor.fetchall()]}}

@app.post("/api/run-query")
async def run_query(request: RunQueryRequest, session_id: str = Header(...)):
    if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded")
    if not is_safe_query(request.query): raise HTTPException(status_code=400, detail="Only SELECT queries are allowed")
    conn = sessions[session_id]["conn"]
    cursor = conn.cursor()
    cursor.execute(request.query)
    if cursor.description:
        columns = [desc[0] for desc in cursor.description]
        return {"columns": columns, "rows": [dict(zip(columns, row)) for row in cursor.fetchall()]}
    return {"message": "Query executed successfully (no results)"}

@app.get("/api/questions/{domain}")
async def get_questions(domain: str, difficulty: str = None):
    questions = load_questions(domain)
    if difficulty: questions = [q for q in questions if q["difficulty"].lower() == difficulty.lower()]
    return [{"id": q["id"], "title": q["title"], "difficulty": q["difficulty"], "description": q["description"], "hint": q["hint"], "expected_sql": q["expected_sql"]} for q in questions]

@app.post("/api/validate")
async def validate_query(request: ValidateQueryRequest, session_id: str = Header(...)):
    if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded")
    conn = sessions[session_id]["conn"]
    cursor = conn.cursor()
    cursor.execute(request.user_query)
    user_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else []
    cursor.execute(request.expected_query)
    expected_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else []
    return {"valid": user_result == expected_result, "error": "Results do not match." if user_result != expected_result else ""}

@app.on_event("shutdown")
async def cleanup():
    for session_id in list(sessions): close_session_db(sessions[session_id]["conn"]); del sessions[session_id]

@app.get("/", response_class=HTMLResponse)
async def serve_frontend():
    file_path = os.path.join(BASE_DIR, "static", "index.html")
    if not os.path.exists(file_path): raise HTTPException(status_code=500, detail=f"Frontend file not found: {file_path}")
    with open(file_path, "r") as f: return f.read()