import os import requests import time import re import json from db_utils import get_schema, execute_sql def query_gemini_api(prompt, max_retries=3): """Query the Google Gemini API""" api_key = os.getenv("GOOGLE_API_KEY") if not api_key: raise ValueError("GOOGLE_API_KEY not found in environment variables") url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}" headers = { "Content-Type": "application/json" } payload = { "contents": [{ "parts": [{ "text": prompt }] }], "generationConfig": { "temperature": 0.1, "topK": 1, "topP": 0.8, "maxOutputTokens": 200, "stopSequences": ["```", "\n\n"] }, "safetySettings": [ { "category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE" }, { "category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE" }, { "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE" }, { "category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE" } ] } for attempt in range(max_retries): try: response = requests.post(url, headers=headers, json=payload, timeout=30) if response.status_code == 200: result = response.json() if "candidates" in result and len(result["candidates"]) > 0: candidate = result["candidates"][0] if candidate.get('finishReason') == 'SAFETY': return "Error: The response was blocked by safety filters." if "content" in candidate and "parts" in candidate["content"]: generated_text = candidate["content"]["parts"][0]["text"].strip() return generated_text return "No valid response generated" elif response.status_code == 429: wait_time = 60 * (attempt + 1) time.sleep(wait_time) continue else: error_msg = f"Gemini API Error {response.status_code}: {response.text}" if attempt == max_retries - 1: raise Exception(error_msg) except requests.exceptions.Timeout: if attempt == max_retries - 1: raise Exception("Request timed out after multiple attempts") time.sleep(5) except Exception as e: if attempt == max_retries - 1: raise e time.sleep(5) raise Exception("Failed to get response after all retries") def extract_user_requested_limit(nl_query): """Extract user-requested number from natural language query""" patterns = [ r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b', r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)', r'(?:names\s+of\s+)(\d+)\s+', r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)', ] for pattern in patterns: match = re.search(pattern, nl_query, re.IGNORECASE) if match: return int(match.group(1)) return None def clean_sql_output(sql_text, user_limit=None): """Clean and validate SQL output from the model""" sql_text = sql_text.strip() if sql_text.startswith("```"): lines = sql_text.split('\n') sql_lines = [] in_sql = False for line in lines: if line.strip().startswith("```"): in_sql = not in_sql continue if in_sql: sql_lines.append(line) sql_text = '\n'.join(sql_lines) lines = sql_text.split('\n') sql = "" for line in lines: line = line.strip() if line and (line.upper().startswith('SELECT') or sql): sql += line + " " if line.endswith(';'): break if not sql: for line in lines: line = line.strip() if line and any(keyword in line.upper() for keyword in ['SELECT', 'WITH', 'FROM']): sql = line break sql = sql.strip().rstrip(';') if user_limit: sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE) sql += f" LIMIT {user_limit}" return sql def text_to_sql(nl_query): """Convert natural language to SQL using Google Gemini""" try: schema = get_schema() user_limit = extract_user_requested_limit(nl_query) prompt = f"""You are an expert PostgreSQL developer. Convert this natural language question to a precise SQL query. Question: {nl_query} Database Schema: {schema[:1500]} Requirements: - Generate ONLY the SQL query, no explanation - Use PostgreSQL syntax - Be precise with table and column names from the schema - Return a single SELECT statement SQL Query:""" generated_sql = query_gemini_api(prompt) if not generated_sql or "No valid response" in generated_sql or "Error:" in generated_sql: return generated_sql, [] sql = clean_sql_output(generated_sql, user_limit) if not sql or not sql.upper().strip().startswith('SELECT'): return f"Error: Invalid SQL generated: {sql}", [] results = execute_sql(sql) return sql, results except Exception as e: return f"Error: {str(e)}", [] #--end-of-script