Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| import pandas as pd | |
| import base64 | |
| from io import BytesIO | |
| import json | |
| import time | |
| import traceback | |
| from PIL import Image | |
| # π§ CONFIGURE: Your Flask API URL (Public URL) | |
| FLASK_API_URL = "http://3.16.57.66:5000" | |
| def query_database(question, dashboard_mode=False, chart_type=None): | |
| if not question.strip(): | |
| return "", None, None, "β οΈ Please enter a question.", "", "", "", "", "", "" | |
| try: | |
| start_time = time.time() | |
| print("\n" + "="*50) | |
| print("=== NEW QUERY REQUEST ===") | |
| print(f"Time: {time.strftime('%Y-%m-%d %H:%M:%S')}") | |
| print(f"Question: {question}") | |
| print(f"Chart type: {chart_type}") | |
| print("="*50) | |
| # Prepare the request | |
| payload = { | |
| "question": question, | |
| "visualize": True | |
| } | |
| # Add chart_type if specified | |
| if chart_type and chart_type != "auto": | |
| payload["viz_type"] = chart_type | |
| print(f"Request URL: {FLASK_API_URL}/ask") | |
| print(f"Request payload: {payload}") | |
| # Make the request | |
| print("Sending request...") | |
| response = requests.post( | |
| f"{FLASK_API_URL}/ask", | |
| json=payload, | |
| timeout=300 | |
| ) | |
| elapsed_time = time.time() - start_time | |
| print(f"\n=== RESPONSE RECEIVED ===") | |
| print(f"Response time: {elapsed_time:.2f} seconds") | |
| print(f"Response status: {response.status_code}") | |
| print(f"Response headers: {dict(response.headers)}") | |
| # Check if response is empty | |
| if not response.text: | |
| error_msg = "Empty response from server" | |
| print(f"ERROR: {error_msg}") | |
| return "", None, None, "β οΈ Please enter a question.", "", "", "", "", "", "" | |
| # Try to parse JSON response | |
| try: | |
| result = response.json() | |
| print(f"Parsed JSON successfully: {type(result)}") | |
| except json.JSONDecodeError as e: | |
| error_msg = f"Invalid JSON response: {str(e)}" | |
| print(f"ERROR: {error_msg}") | |
| print(f"Response text (first 1000 chars): {response.text[:1000]}") | |
| return "", None, None, "β οΈ Please enter a question.", "", "", "", "", "", "" | |
| # Check HTTP status | |
| if response.status_code != 200: | |
| error_msg = result.get("error", f"HTTP {response.status_code}") | |
| print(f"ERROR: HTTP status {response.status_code}: {error_msg}") | |
| # For 500 errors, show more details | |
| if response.status_code == 500: | |
| error_details = f"Server Error (500): {error_msg}\n" | |
| error_details += f"Response: {json.dumps(result, indent=2)}" | |
| return "", None, None, "β οΈ Please enter a question.", "", "", "", error_details, "", "", "" | |
| return "", None, None, "β οΈ Please enter a question.", "", "", "", f"HTTP error: {error_msg}", "", "", "" | |
| # Extract data - Updated to match new response structure | |
| sql = result.get("sql", "") | |
| chart = result.get("chart", {}) # Changed from "visualization" to "chart" | |
| chart_generated = result.get("chart_generated", False) # New field | |
| row_count = result.get("row_count", 0) | |
| # Extract result data if available (for backward compatibility) | |
| rows = result.get("result", []) | |
| print(f"\n=== EXTRACTED DATA ===") | |
| print(f"SQL: {sql[:100] if sql else 'None'}...") | |
| print(f"Chart generated: {chart_generated}") | |
| print(f"Row count: {row_count}") | |
| print(f"Chart data: {chart}") | |
| # Create DataFrame | |
| df = pd.DataFrame(rows) if rows else pd.DataFrame() | |
| print(f"DataFrame shape: {df.shape}") | |
| # Process visualization | |
| chart_image = None | |
| chart_html = None | |
| chart_title = "" | |
| chart_type_result = "" | |
| chart_error = None | |
| chart_format = "" | |
| show_image = True | |
| show_html = False | |
| # Only process chart if chart_generated is True | |
| if chart_generated and chart: | |
| try: | |
| if isinstance(chart, dict): | |
| # Check if it's an error response | |
| if "error" in chart: | |
| chart_error = chart["error"] | |
| print(f"Chart error: {chart_error}") | |
| else: | |
| # Extract the base64 image string | |
| chart_image_b64 = chart.get("image") | |
| if chart_image_b64: | |
| try: | |
| # Handle base64 prefix | |
| if chart_image_b64.startswith("data:image/"): | |
| chart_image_b64 = chart_image_b64.split(",")[1] | |
| image_bytes = base64.b64decode(chart_image_b64) | |
| chart_image = Image.open(BytesIO(image_bytes)) | |
| print("Chart decoded successfully") | |
| chart_format = "png" | |
| show_image = True | |
| show_html = False | |
| except Exception as e: | |
| print(f"Error decoding chart: {e}") | |
| chart_error = f"Chart decoding error: {str(e)}" | |
| # Extract chart metadata | |
| chart_title = chart.get("title", "") | |
| chart_type_result = chart.get("type", "") | |
| chart_format = chart.get("format", chart_format) | |
| print(f"Chart type: {chart_type_result}") | |
| print(f"Chart title: {chart_title}") | |
| print(f"Chart format: {chart_format}") | |
| elif isinstance(chart, str): | |
| # Fallback for backward compatibility | |
| try: | |
| # Handle base64 prefix | |
| if chart.startswith("data:image/"): | |
| chart = chart.split(",")[1] | |
| image_bytes = base64.b64decode(chart) | |
| chart_image = Image.open(BytesIO(image_bytes)) | |
| print("Chart decoded successfully (fallback)") | |
| chart_format = "png" | |
| show_image = True | |
| show_html = False | |
| except Exception as e: | |
| print(f"Error decoding chart (fallback): {e}") | |
| chart_error = f"Chart decoding error: {str(e)}" | |
| except Exception as e: | |
| print(f"Error processing chart: {e}") | |
| chart_error = f"Chart processing error: {str(e)}" | |
| elif not chart_generated: | |
| print("Chart was not generated") | |
| chart_error = "Chart generation was not successful" | |
| else: | |
| print("No chart data available") | |
| chart_error = "No chart data available" | |
| # Prepare details | |
| details = f"Request time: {elapsed_time:.2f}s\n" | |
| details += f"Status code: {response.status_code}\n" | |
| details += f"Rows returned: {row_count}\n" | |
| details += f"Chart generated: {chart_generated}\n" | |
| if chart_type_result: | |
| details += f"Chart type: {chart_type_result}\n" | |
| if chart_title: | |
| details += f"Chart title: {chart_title}\n" | |
| if chart_format: | |
| details += f"Chart format: {chart_format}\n" | |
| details += f"Response size: {len(response.text)} bytes" | |
| print(f"=== REQUEST COMPLETED SUCCESSFULLY ===") | |
| return sql, df, chart_image, chart_html, show_image, show_html, f"β Query completed successfully", details, "Success", chart_type_result, chart_title, chart_format, chart_error | |
| except requests.exceptions.ConnectionError as e: | |
| error_msg = f"Connection failed: {str(e)}" | |
| print(f"CONNECTION ERROR: {error_msg}") | |
| print(f"Traceback: {traceback.format_exc()}") | |
| return "", None, None, "β οΈ Please enter a question.", "", True, False, f"β {error_msg}", "Connection error", "Connection error", "", "", "", "" | |
| except requests.exceptions.Timeout: | |
| error_msg = "Request timed out after 300 seconds" | |
| print(f"TIMEOUT ERROR: {error_msg}") | |
| return "", None, None, "β οΈ Please enter a question.", "", True, False, f"β±οΈ {error_msg}", "Timeout error", "Timeout error", "", "", "", "" | |
| except requests.exceptions.RequestException as e: | |
| error_msg = f"Request exception: {str(e)}" | |
| print(f"REQUEST ERROR: {error_msg}") | |
| print(f"Traceback: {traceback.format_exc()}") | |
| return "", None, None, "β οΈ Please enter a question.", "", True, False, f"β {error_msg}", "Request error", "Request error", "", "", "", "" | |
| except Exception as e: | |
| error_msg = f"Unexpected error: {str(e)}" | |
| print(f"UNEXPECTED ERROR: {error_msg}") | |
| print(f"Traceback: {traceback.format_exc()}") | |
| return "", None, None, "β οΈ Please enter a question.", "", True, False, f"π¨ {error_msg}", f"Error: {str(e)}", "Unexpected error", "", "", "", "" | |
| def check_health(): | |
| try: | |
| print("Checking API health...") | |
| print(f"API URL: {FLASK_API_URL}/health") | |
| response = requests.get(f"{FLASK_API_URL}/health", timeout=10) | |
| print(f"Health check response status: {response.status_code}") | |
| print(f"Health check response headers: {dict(response.headers)}") | |
| print(f"Health check response text: {response.text}") | |
| if response.status_code == 200: | |
| try: | |
| health_data = response.json() | |
| print(f"Parsed health data: {health_data}") | |
| status = health_data.get('status', 'unknown') | |
| tables = health_data.get('tables', []) | |
| model = health_data.get('model', 'unknown') | |
| data_rows = health_data.get('data_rows', 0) | |
| # Ensure tables is a list before joining | |
| if isinstance(tables, list): | |
| tables_str = ', '.join(tables) if tables else 'None' | |
| else: | |
| tables_str = str(tables) | |
| health_msg = f"β API Status: {status.upper()}\n" | |
| health_msg += f"π€ Model: {model}\n" | |
| health_msg += f"π Tables: {tables_str}\n" | |
| health_msg += f"π Data Rows: {data_rows:,}" | |
| return health_msg, "success" | |
| except json.JSONDecodeError as e: | |
| error_msg = f"Failed to parse health check response: {str(e)}" | |
| print(f"JSON PARSE ERROR: {error_msg}") | |
| return f"β {error_msg}", "error" | |
| else: | |
| return f"β API returned status {response.status_code}\nResponse: {response.text}", "error" | |
| except requests.exceptions.ConnectionError as e: | |
| error_msg = f"Connection to API failed: {str(e)}" | |
| print(f"CONNECTION ERROR: {error_msg}") | |
| print(f"Traceback: {traceback.format_exc()}") | |
| return f"β {error_msg}", "error" | |
| except requests.exceptions.Timeout: | |
| error_msg = "Health check request timed out" | |
| print(f"TIMEOUT ERROR: {error_msg}") | |
| return f"β {error_msg}", "error" | |
| except Exception as e: | |
| error_msg = f"Health check failed: {str(e)}" | |
| print(f"HEALTH CHECK ERROR: {error_msg}") | |
| print(f"Traceback: {traceback.format_exc()}") | |
| return f"β {error_msg}", "error" | |
| def get_schema(): | |
| try: | |
| print("Fetching database schema...") | |
| print(f"API URL: {FLASK_API_URL}/tables") | |
| response = requests.get(f"{FLASK_API_URL}/tables", timeout=10) | |
| print(f"Schema response status: {response.status_code}") | |
| print(f"Schema response text: {response.text}") | |
| if response.status_code == 200: | |
| try: | |
| tables_data = response.json() | |
| print(f"Parsed tables data: {tables_data}") | |
| tables = tables_data.get("tables", []) | |
| schema_text = "## Database Schema\n\n" | |
| for table in tables: | |
| schema_text += f"### {table.get('name', 'Unknown')}\n" | |
| schema_text += "| Column |\n|--------|\n" | |
| for col in table.get('columns', []): | |
| schema_text += f"| {col} |\n" | |
| schema_text += "\n" | |
| return schema_text, "success" | |
| except json.JSONDecodeError as e: | |
| error_msg = f"Failed to parse schema response: {str(e)}" | |
| print(f"JSON PARSE ERROR: {error_msg}") | |
| return f"β {error_msg}", "error" | |
| else: | |
| return f"β API returned status {response.status_code}\nResponse: {response.text}", "error" | |
| except requests.exceptions.ConnectionError as e: | |
| error_msg = f"Connection to API failed: {str(e)}" | |
| print(f"CONNECTION ERROR: {error_msg}") | |
| print(f"Traceback: {traceback.format_exc()}") | |
| return f"β {error_msg}", "error" | |
| except requests.exceptions.Timeout: | |
| error_msg = "Schema request timed out" | |
| print(f"TIMEOUT ERROR: {error_msg}") | |
| return f"β {error_msg}", "error" | |
| except Exception as e: | |
| error_msg = f"Failed to fetch schema: {str(e)}" | |
| print(f"SCHEMA FETCH ERROR: {error_msg}") | |
| print(f"Traceback: {traceback.format_exc()}") | |
| return f"β {error_msg}", "error" | |
| # π¨ Theme: Enterprise Dark Blue | |
| theme = gr.themes.Default( | |
| primary_hue="blue", | |
| secondary_hue="gray", | |
| neutral_hue="slate", | |
| font=["Inter", "sans-serif"] | |
| ).set( | |
| body_background_fill="*neutral_950", | |
| background_fill_secondary="*neutral_900", | |
| button_primary_background_fill="*primary_600", | |
| button_primary_background_fill_hover="*primary_700", | |
| button_secondary_background_fill="*neutral_800", | |
| button_secondary_background_fill_hover="*neutral_700", | |
| block_title_text_color="*primary_400", | |
| block_label_text_color="*neutral_300", | |
| input_background_fill="*neutral_800", | |
| input_border_color="*neutral_700" | |
| ) | |
| # π UI Layout | |
| with gr.Blocks(theme=theme, title="Enterprise SQL Assistant", css=""" | |
| .example-btn { | |
| background-color: #1e40af; | |
| color: white; | |
| border-radius: 8px; | |
| padding: 8px 12px; | |
| margin: 4px; | |
| font-size: 0.9em; | |
| border: none; | |
| cursor: pointer; | |
| transition: all 0.2s; | |
| } | |
| .example-btn:hover { | |
| background-color: #1d4ed8; | |
| transform: translateY(-1px); | |
| } | |
| .chatbot-container { | |
| border: 1px solid #334155; | |
| border-radius: 8px; | |
| padding: 15px; | |
| background-color: #1e293b; | |
| min-height: 200px; | |
| max-height: 400px; | |
| overflow-y: auto; | |
| } | |
| .status-success { | |
| color: #10b981; | |
| font-weight: bold; | |
| } | |
| .status-error { | |
| color: #ef4444; | |
| font-weight: bold; | |
| } | |
| .status-warning { | |
| color: #f59e0b; | |
| font-weight: bold; | |
| } | |
| .chart-container { | |
| border: 1px solid #334155; | |
| border-radius: 8px; | |
| padding: 10px; | |
| background-color: #1e293b; | |
| } | |
| .chart-metadata { | |
| background-color: #1e293b; | |
| border: 1px solid #334155; | |
| border-radius: 8px; | |
| padding: 10px; | |
| margin-bottom: 10px; | |
| } | |
| .schema-container { | |
| background-color: #1e293b; | |
| border: 1px solid #334155; | |
| border-radius: 8px; | |
| padding: 15px; | |
| max-height: 400px; | |
| overflow-y: auto; | |
| } | |
| /* Fix input text color */ | |
| .gradio-container input[type="text"], | |
| .gradio-container textarea { | |
| color: #f3f4f6 !important; | |
| } | |
| /* Fix placeholder color */ | |
| .gradio-container input::placeholder, | |
| .gradio-container textarea::placeholder { | |
| color: #9ca3af !important; | |
| } | |
| """) as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 20px;"> | |
| <h1 style="color: #3B82F6; margin-bottom: 5px;">π Enterprise SQL Assistant</h1> | |
| <p style="color: #9CA3AF; font-size: 1.1em; margin-top: 0;"> | |
| Ask questions about your data. Get SQL, results, and insights. | |
| </p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Ask a Question") | |
| question_input = gr.Textbox( | |
| placeholder="E.g., How many members are there?", | |
| label="Natural Language Query", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| # Removed dashboard_mode as backend no longer supports it | |
| chart_type_dropdown = gr.Dropdown( | |
| label="Chart Type (Optional)", | |
| choices=[ | |
| "auto", "bar", "line", "scatter", "pie", "histogram", | |
| "time_series", "correlation" | |
| ], | |
| value="auto", | |
| info="Force a specific chart type" | |
| ) | |
| submit_btn = gr.Button("π Generate SQL & Results", variant="primary", size="lg") | |
| gr.Markdown("### π‘ Example Queries") | |
| with gr.Row(): | |
| with gr.Column(): | |
| example1 = gr.Button("How many members are there?", elem_classes=["example-btn"]) | |
| example2 = gr.Button("What is the total transaction amount?", elem_classes=["example-btn"]) | |
| example3 = gr.Button("Show members with their account balances", elem_classes=["example-btn"]) | |
| example4 = gr.Button("Which member has the highest balance?", elem_classes=["example-btn"]) | |
| with gr.Column(): | |
| example5 = gr.Button("Show transaction trends over time", elem_classes=["example-btn"]) | |
| example6 = gr.Button("Count of members by status", elem_classes=["example-btn"]) | |
| example7 = gr.Button("Show distribution of transaction amounts", elem_classes=["example-btn"]) | |
| example8 = gr.Button("Show correlations between numeric fields", elem_classes=["example-btn"]) | |
| with gr.Accordion("Advanced Options", open=False): | |
| gr.Markdown("### π©Ί System Information") | |
| with gr.Row(): | |
| health_btn = gr.Button("Check API Health", variant="secondary", size="sm") | |
| schema_btn = gr.Button("Get Database Schema", variant="secondary", size="sm") | |
| health_output = gr.Markdown(label="API Status") | |
| schema_output = gr.Markdown(label="Database Schema", elem_classes="schema-container") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π€ AI Assistant Response") | |
| chatbot_output = gr.Markdown( | |
| label="AI Response", | |
| elem_classes=["chatbot-container"] | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("SQL Query"): | |
| sql_output = gr.Code(label="", language="sql") | |
| with gr.Tab("Data Results"): | |
| results_output = gr.Dataframe( | |
| label="Query Results", | |
| interactive=False, | |
| wrap=True | |
| ) | |
| with gr.Tab("Visual Insights"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chart_output = gr.Image( | |
| label="Chart", | |
| type="pil", | |
| height=400, | |
| elem_classes="chart-container" | |
| ) | |
| # Add HTML component for fallback | |
| html_output = gr.HTML( | |
| label="Interactive Chart", | |
| visible=False | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Chart Information") | |
| chart_type_output = gr.Markdown(label="Chart Type", elem_classes="chart-metadata") | |
| chart_title_output = gr.Markdown(label="Chart Title", elem_classes="chart-metadata") | |
| chart_format_output = gr.Markdown(label="Chart Format", elem_classes="chart-metadata") | |
| chart_error_output = gr.Markdown(label="Chart Error", elem_classes="chart-metadata") | |
| gr.Markdown("### π Request Details") | |
| request_details = gr.Textbox(label="Request Details", interactive=False, lines=6) | |
| gr.Markdown("### π Error Details") | |
| error_details = gr.Textbox(label="Error Details", interactive=False, lines=4) | |
| # Function to handle example query clicks | |
| def set_example_query(example_text): | |
| return example_text | |
| # Events | |
| # Health check button | |
| health_btn.click( | |
| fn=check_health, | |
| inputs=[], | |
| outputs=[health_output, error_details] | |
| ) | |
| # Schema button | |
| schema_btn.click( | |
| fn=get_schema, | |
| inputs=[], | |
| outputs=[schema_output, error_details] | |
| ) | |
| # Example query buttons | |
| for example_btn in [example1, example2, example3, example4, example5, example6, example7, example8]: | |
| example_btn.click( | |
| fn=set_example_query, | |
| inputs=[example_btn], | |
| outputs=[question_input] | |
| ) | |
| # Submit button - Updated to remove dashboard_mode parameter | |
| submit_btn.click( | |
| fn=query_database, | |
| inputs=[question_input, chart_type_dropdown], # Removed dashboard_mode | |
| outputs=[sql_output, results_output, chart_output, html_output, gr.Number(visible=False, value=1), gr.Number(visible=False, value=0), chatbot_output, request_details, error_details, chart_type_output, chart_title_output, chart_format_output, chart_error_output] | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch() |