Spaces:
Running
Running
File size: 5,888 Bytes
e6cde35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import os
import requests
import time
import re
import json
from db_utils import get_schema, execute_sql
def query_gemini_api(prompt, max_retries=3):
"""Query the Google Gemini API"""
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
raise ValueError("GOOGLE_API_KEY not found in environment variables")
url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
headers = {
"Content-Type": "application/json"
}
payload = {
"contents": [{
"parts": [{
"text": prompt
}]
}],
"generationConfig": {
"temperature": 0.1,
"topK": 1,
"topP": 0.8,
"maxOutputTokens": 200,
"stopSequences": ["```", "\n\n"]
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
}
for attempt in range(max_retries):
try:
response = requests.post(url, headers=headers, json=payload, timeout=30)
if response.status_code == 200:
result = response.json()
if "candidates" in result and len(result["candidates"]) > 0:
candidate = result["candidates"][0]
if candidate.get('finishReason') == 'SAFETY':
return "Error: The response was blocked by safety filters."
if "content" in candidate and "parts" in candidate["content"]:
generated_text = candidate["content"]["parts"][0]["text"].strip()
return generated_text
return "No valid response generated"
elif response.status_code == 429:
wait_time = 60 * (attempt + 1)
time.sleep(wait_time)
continue
else:
error_msg = f"Gemini API Error {response.status_code}: {response.text}"
if attempt == max_retries - 1:
raise Exception(error_msg)
except requests.exceptions.Timeout:
if attempt == max_retries - 1:
raise Exception("Request timed out after multiple attempts")
time.sleep(5)
except Exception as e:
if attempt == max_retries - 1:
raise e
time.sleep(5)
raise Exception("Failed to get response after all retries")
def extract_user_requested_limit(nl_query):
"""Extract user-requested number from natural language query"""
patterns = [
r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
r'(?:names\s+of\s+)(\d+)\s+',
r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
]
for pattern in patterns:
match = re.search(pattern, nl_query, re.IGNORECASE)
if match:
return int(match.group(1))
return None
def clean_sql_output(sql_text, user_limit=None):
"""Clean and validate SQL output from the model"""
sql_text = sql_text.strip()
if sql_text.startswith("```"):
lines = sql_text.split('\n')
sql_lines = []
in_sql = False
for line in lines:
if line.strip().startswith("```"):
in_sql = not in_sql
continue
if in_sql:
sql_lines.append(line)
sql_text = '\n'.join(sql_lines)
lines = sql_text.split('\n')
sql = ""
for line in lines:
line = line.strip()
if line and (line.upper().startswith('SELECT') or sql):
sql += line + " "
if line.endswith(';'):
break
if not sql:
for line in lines:
line = line.strip()
if line and any(keyword in line.upper() for keyword in ['SELECT', 'WITH', 'FROM']):
sql = line
break
sql = sql.strip().rstrip(';')
if user_limit:
sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
sql += f" LIMIT {user_limit}"
return sql
def text_to_sql(nl_query):
"""Convert natural language to SQL using Google Gemini"""
try:
schema = get_schema()
user_limit = extract_user_requested_limit(nl_query)
prompt = f"""You are an expert PostgreSQL developer. Convert this natural language question to a precise SQL query.
Question: {nl_query}
Database Schema:
{schema[:1500]}
Requirements:
- Generate ONLY the SQL query, no explanation
- Use PostgreSQL syntax
- Be precise with table and column names from the schema
- Return a single SELECT statement
SQL Query:"""
generated_sql = query_gemini_api(prompt)
if not generated_sql or "No valid response" in generated_sql or "Error:" in generated_sql:
return generated_sql, []
sql = clean_sql_output(generated_sql, user_limit)
if not sql or not sql.upper().strip().startswith('SELECT'):
return f"Error: Invalid SQL generated: {sql}", []
results = execute_sql(sql)
return sql, results
except Exception as e:
return f"Error: {str(e)}", []
#--end-of-script
|