File size: 5,888 Bytes
e6cde35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import requests
import time
import re
import json
from db_utils import get_schema, execute_sql

def query_gemini_api(prompt, max_retries=3):
    """Query the Google Gemini API"""
    api_key = os.getenv("GOOGLE_API_KEY")
    if not api_key:
        raise ValueError("GOOGLE_API_KEY not found in environment variables")
    
    url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
    
    headers = {
        "Content-Type": "application/json"
    }
    
    payload = {
        "contents": [{
            "parts": [{
                "text": prompt
            }]
        }],
        "generationConfig": {
            "temperature": 0.1,
            "topK": 1,
            "topP": 0.8,
            "maxOutputTokens": 200,
            "stopSequences": ["```", "\n\n"]
        },
        "safetySettings": [
            {
                "category": "HARM_CATEGORY_HARASSMENT",
                "threshold": "BLOCK_NONE"
            },
            {
                "category": "HARM_CATEGORY_HATE_SPEECH",
                "threshold": "BLOCK_NONE"
            },
            {
                "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
                "threshold": "BLOCK_NONE"
            },
            {
                "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
                "threshold": "BLOCK_NONE"
            }
        ]
    }
    
    for attempt in range(max_retries):
        try:
            response = requests.post(url, headers=headers, json=payload, timeout=30)
            
            if response.status_code == 200:
                result = response.json()
                if "candidates" in result and len(result["candidates"]) > 0:
                    candidate = result["candidates"][0]
                    if candidate.get('finishReason') == 'SAFETY':
                        return "Error: The response was blocked by safety filters."
                    if "content" in candidate and "parts" in candidate["content"]:
                        generated_text = candidate["content"]["parts"][0]["text"].strip()
                        return generated_text
                
                return "No valid response generated"
            
            elif response.status_code == 429:
                wait_time = 60 * (attempt + 1)
                time.sleep(wait_time)
                continue
            else:
                error_msg = f"Gemini API Error {response.status_code}: {response.text}"
                if attempt == max_retries - 1:
                    raise Exception(error_msg)
                    
        except requests.exceptions.Timeout:
            if attempt == max_retries - 1:
                raise Exception("Request timed out after multiple attempts")
            time.sleep(5)
            
        except Exception as e:
            if attempt == max_retries - 1:
                raise e
            time.sleep(5)
    
    raise Exception("Failed to get response after all retries")

def extract_user_requested_limit(nl_query):
    """Extract user-requested number from natural language query"""
    patterns = [
        r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
        r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
        r'(?:names\s+of\s+)(\d+)\s+',
        r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
    ]
    
    for pattern in patterns:
        match = re.search(pattern, nl_query, re.IGNORECASE)
        if match:
            return int(match.group(1))
    return None

def clean_sql_output(sql_text, user_limit=None):
    """Clean and validate SQL output from the model"""
    sql_text = sql_text.strip()
    
    if sql_text.startswith("```"):
        lines = sql_text.split('\n')
        sql_lines = []
        in_sql = False
        for line in lines:
            if line.strip().startswith("```"):
                in_sql = not in_sql
                continue
            if in_sql:
                sql_lines.append(line)
        sql_text = '\n'.join(sql_lines)
    
    lines = sql_text.split('\n')
    sql = ""
    for line in lines:
        line = line.strip()
        if line and (line.upper().startswith('SELECT') or sql):
            sql += line + " "
            if line.endswith(';'):
                break
    
    if not sql:
        for line in lines:
            line = line.strip()
            if line and any(keyword in line.upper() for keyword in ['SELECT', 'WITH', 'FROM']):
                sql = line
                break
    
    sql = sql.strip().rstrip(';')
    
    if user_limit:
        sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
        sql += f" LIMIT {user_limit}"
    
    return sql

def text_to_sql(nl_query):
    """Convert natural language to SQL using Google Gemini"""
    try:
        schema = get_schema()
        user_limit = extract_user_requested_limit(nl_query)
        
        prompt = f"""You are an expert PostgreSQL developer. Convert this natural language question to a precise SQL query.

Question: {nl_query}

Database Schema:
{schema[:1500]}

Requirements:
- Generate ONLY the SQL query, no explanation
- Use PostgreSQL syntax
- Be precise with table and column names from the schema
- Return a single SELECT statement

SQL Query:"""

        generated_sql = query_gemini_api(prompt)
        
        if not generated_sql or "No valid response" in generated_sql or "Error:" in generated_sql:
            return generated_sql, []
        
        sql = clean_sql_output(generated_sql, user_limit)
        
        if not sql or not sql.upper().strip().startswith('SELECT'):
            return f"Error: Invalid SQL generated: {sql}", []
        
        results = execute_sql(sql)
        return sql, results
        
    except Exception as e:
        return f"Error: {str(e)}", []

#--end-of-script