Spaces:
Sleeping
Sleeping
""" | |
AI Database Assistant - Streamlit Chat Interface | |
""" | |
import streamlit as st | |
import pandas as pd | |
import base64 | |
import logging | |
from typing import List, Dict, Any | |
from api_client import APIClient | |
from themes import ThemeManager | |
from config import ( | |
BASE_URL, PAGE_TITLE, PAGE_LAYOUT, AVAILABLE_MODELS, AVAILABLE_AGENTS, | |
OUTLINE_INDIGO_USER, DARK_MODE_SLATE_AI, AVAILABLE_THEMES, | |
CHAT_INPUT_PLACEHOLDER, THINKING_MESSAGE, WORKING_MESSAGE, RETRY_BUTTON_TEXT, DOWNLOAD_BUTTON_TEXT | |
) | |
# =============================== | |
# Configuration | |
# =============================== | |
# Configuration is now imported from config.py | |
# To change environments, only modify BASE_URL in config.py | |
# Setup | |
st.set_page_config(page_title=PAGE_TITLE, layout=PAGE_LAYOUT) | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
# Initialize API Client with base URL from config | |
api_client = APIClient(BASE_URL) | |
# =============================== | |
# Theme Management | |
# =============================== | |
# Initialize theme manager | |
theme_manager = ThemeManager() | |
# =============================== | |
# Database Status Functions | |
# =============================== | |
# Reduced cache time for more responsive health checks | |
def check_api_status(): | |
"""Check if the API is reachable and responsive with a lightweight health check.""" | |
try: | |
return api_client.check_health() | |
except Exception as e: | |
logging.error(f"Health check failed: {e}") | |
return "🔴 Error", f"Failed: {str(e)}", "error" | |
# Reduced cache time for detailed health | |
def get_detailed_health_status(): | |
"""Get detailed health status from the API.""" | |
try: | |
return api_client.get_detailed_health() | |
except Exception as e: | |
logging.error(f"Detailed health check failed: {e}") | |
return { | |
"status": "error", | |
"message": f"Health check failed: {str(e)}", | |
"checks": {} | |
} | |
def should_skip_api_call(force_refresh=False): | |
"""Enhanced validation to determine if API calls should be skipped - FOR HEALTH CHECKS ONLY.""" | |
try: | |
# Always allow API calls when force refresh is requested | |
if force_refresh: | |
logging.info("API call allowed: Force refresh requested") | |
return False | |
# Check for immediate health check flag (when System Status is just enabled) | |
if st.session_state.get("force_immediate_health_check", False): | |
logging.info("API call allowed: System Status just enabled") | |
# Clear the flag after use | |
st.session_state["force_immediate_health_check"] = False | |
return False | |
# Skip if currently processing a query (to prevent concurrent calls) | |
if st.session_state.get("processing_query", False): | |
logging.info("API call skipped: Query processing in progress") | |
return True | |
# Check if System Status section is enabled - only call API if it's visible | |
if not st.session_state.get("sidebar_settings", {}).get("show_system_status", False): | |
logging.info("API call skipped: System Status section is disabled") | |
return True | |
# Block API calls if ANY UI interaction happened in last 5 seconds (FOR HEALTH CHECKS) | |
if "recent_ui_action" in st.session_state: | |
current_time = pd.Timestamp.now().timestamp() | |
time_since_action = current_time - st.session_state.get("recent_ui_action", 0) | |
if time_since_action < 5: # 5 seconds protection | |
logging.info(f"API call blocked: UI interaction {time_since_action:.1f}s ago") | |
return True | |
# Check rate limiting for regular health checks | |
current_time = pd.Timestamp.now().timestamp() | |
last_update = st.session_state.get("last_status_update", 0) | |
# Allow call if we don't have cached status | |
if "last_api_status" not in st.session_state: | |
logging.info("API call allowed: No cached status available") | |
return False | |
# Otherwise, respect the rate limit | |
time_since_update = current_time - last_update | |
if time_since_update < 30: # 30 seconds rate limit | |
logging.info(f"API call skipped: Rate limit (updated {time_since_update:.1f}s ago)") | |
return True | |
logging.info("API call allowed: Rate limit passed") | |
return False | |
except Exception as e: | |
logging.error(f"Error in should_skip_api_call: {e}") | |
return False # Allow API call on error (fail safe) | |
def should_skip_query_processing(): | |
"""Determine if query processing should be skipped - FOR QUERY PROCESSING ONLY.""" | |
try: | |
# Never skip if we have a legitimate query | |
if "legitimate_query_time" in st.session_state: | |
current_time = pd.Timestamp.now().timestamp() | |
time_since_query = current_time - st.session_state.get("legitimate_query_time", 0) | |
if time_since_query < 10: # Allow legitimate queries within 10 seconds | |
logging.info(f"Query processing allowed: Legitimate query {time_since_query:.1f}s ago") | |
return False | |
# Skip if already processing a query (to prevent concurrent calls) | |
if st.session_state.get("processing_query", False): | |
logging.info("Query processing skipped: Already processing a query") | |
return True | |
# Block if recent UI action but no legitimate query flag | |
if "recent_ui_action" in st.session_state and "legitimate_query_time" not in st.session_state: | |
current_time = pd.Timestamp.now().timestamp() | |
time_since_action = current_time - st.session_state.get("recent_ui_action", 0) | |
if time_since_action < 5: # 5 seconds protection | |
logging.info(f"Query processing blocked: UI interaction {time_since_action:.1f}s ago without legitimate query") | |
return True | |
logging.info("Query processing allowed: No blocking conditions") | |
return False | |
except Exception as e: | |
logging.error(f"Error in should_skip_query_processing: {e}") | |
return False # Allow processing on error (fail safe) | |
def silent_status_update(): | |
"""Update status silently without UI disruption.""" | |
try: | |
# Use the validation function | |
if should_skip_api_call(): | |
return | |
# Clear cache and update timestamp | |
st.cache_data.clear() | |
st.session_state["last_status_update"] = pd.Timestamp.now().timestamp() | |
except: | |
pass # Silent failure | |
def get_query_stats(): | |
"""Get query statistics from session state.""" | |
if "query_stats" not in st.session_state: | |
st.session_state["query_stats"] = { | |
"total_queries": 0, | |
"successful_queries": 0, | |
"failed_queries": 0, | |
"session_start": pd.Timestamp.now() | |
} | |
stats = st.session_state["query_stats"] | |
success_rate = 0 | |
if stats["total_queries"] > 0: | |
success_rate = round((stats["successful_queries"] / stats["total_queries"]) * 100, 1) | |
return stats["total_queries"], success_rate, stats["successful_queries"], stats["failed_queries"] | |
def update_query_stats(success=True): | |
"""Update query statistics.""" | |
if "query_stats" not in st.session_state: | |
st.session_state["query_stats"] = { | |
"total_queries": 0, | |
"successful_queries": 0, | |
"failed_queries": 0, | |
"session_start": pd.Timestamp.now() | |
} | |
st.session_state["query_stats"]["total_queries"] += 1 | |
if success: | |
st.session_state["query_stats"]["successful_queries"] += 1 | |
else: | |
st.session_state["query_stats"]["failed_queries"] += 1 | |
# =============================== | |
# Session State Management | |
# =============================== | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
if "query_stats" not in st.session_state: | |
st.session_state["query_stats"] = { | |
"total_queries": 0, | |
"successful_queries": 0, | |
"failed_queries": 0, | |
"session_start": pd.Timestamp.now() | |
} | |
if "processing_query" not in st.session_state: | |
st.session_state["processing_query"] = False | |
# =============================== | |
# Main App Setup | |
# =============================== | |
# Professional header positioned at top-left of chat area | |
st.markdown(f""" | |
<div style=" display: flex; align-items: center; padding-left: 0.5rem;"> | |
<h2 style="margin: 0; font-size: 1.5rem; font-weight: 600; color: inherit;"> | |
💬 {PAGE_TITLE} | |
</h2> | |
</div> | |
""", unsafe_allow_html=True) | |
st.markdown(""" | |
<div style=" padding-left: 0.5rem;"> | |
<p style="margin: 0; font-size: 0.9rem; opacity: 0.7; color: inherit;"> | |
Ask questions in plain English to generate and run SQL queries. | |
</p> | |
</div> | |
""", unsafe_allow_html=True) | |
# Sidebar | |
with st.sidebar: | |
# Quick status indicator at the top | |
if "sidebar_settings" in st.session_state: | |
visible_count = sum([ | |
st.session_state["sidebar_settings"]["show_model_selection"], | |
st.session_state["sidebar_settings"]["show_agent_selection"], | |
st.session_state["sidebar_settings"]["show_theme_selection"], | |
st.session_state["sidebar_settings"]["show_system_status"], | |
st.session_state["sidebar_settings"]["show_tips"] | |
]) | |
if visible_count == 5: | |
st.caption("🟢 All sections visible") | |
elif visible_count > 0: | |
st.caption(f"🟡 {visible_count}/5 sections visible") | |
else: | |
st.caption("🔴 No sections visible") | |
# Sidebar Display Settings - User Configurable | |
# Initialize settings state tracking | |
if "settings_interaction_count" not in st.session_state: | |
st.session_state["settings_interaction_count"] = 0 | |
# Track if user has recently interacted with settings | |
keep_expanded = st.session_state.get("settings_interaction_count", 0) > 0 | |
# Create expander that stays open for a few interactions | |
settings_expander = st.expander( | |
"⚙️ Sidebar Settings", | |
expanded=keep_expanded | |
) | |
with settings_expander: | |
st.markdown("**Choose what to display:**") | |
# Initialize sidebar visibility settings in session state with new defaults | |
if "sidebar_settings" not in st.session_state: | |
st.session_state["sidebar_settings"] = { | |
"show_model_selection": True, # Default unchecked | |
"show_agent_selection": False, # Default unchecked | |
"show_theme_selection": True, | |
"show_system_status": False, # Default unchecked - only load health checks when enabled | |
"show_tips": True | |
} | |
# Store previous values to detect changes | |
prev_settings = st.session_state["sidebar_settings"].copy() | |
# Configurable checkboxes | |
col1, col2 = st.columns(2) | |
with col1: | |
show_model = st.checkbox( | |
"🤖 AI Model", | |
value=st.session_state["sidebar_settings"]["show_model_selection"], | |
help="Show/hide AI model selection", | |
key="settings_model" | |
) | |
show_agent = st.checkbox( | |
"🎯 Agent Type", | |
value=st.session_state["sidebar_settings"]["show_agent_selection"], | |
help="Show/hide agent selection", | |
key="settings_agent" | |
) | |
show_theme = st.checkbox( | |
"🎨 Theme", | |
value=st.session_state["sidebar_settings"]["show_theme_selection"], | |
help="Show/hide theme selection", | |
key="settings_theme" | |
) | |
with col2: | |
show_status = st.checkbox( | |
"📊 System Status", | |
value=st.session_state["sidebar_settings"]["show_system_status"], | |
help="Show/hide system status", | |
key="settings_status" | |
) | |
show_tips = st.checkbox( | |
"💡 Tips & Help", | |
value=st.session_state["sidebar_settings"]["show_tips"], | |
help="Show/hide tips section", | |
key="settings_tips" | |
) | |
# Detect if any setting changed | |
new_settings = { | |
"show_model_selection": show_model, | |
"show_agent_selection": show_agent, | |
"show_theme_selection": show_theme, | |
"show_system_status": show_status, | |
"show_tips": show_tips | |
} | |
# Check if settings changed | |
settings_changed = any( | |
prev_settings.get(key) != new_settings[key] | |
for key in new_settings.keys() | |
) | |
# Special handling for System Status being enabled | |
status_just_enabled = ( | |
not prev_settings.get("show_system_status", False) and | |
new_settings["show_system_status"] | |
) | |
# Update session state | |
st.session_state["sidebar_settings"].update(new_settings) | |
# Increment interaction count when settings change | |
if settings_changed: | |
st.session_state["settings_interaction_count"] += 1 | |
# Mark UI action to prevent unnecessary API calls for most settings | |
if not status_just_enabled: # Don't block API calls when System Status is enabled | |
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp() | |
# Reset counter after 10 interactions to prevent it growing indefinitely | |
if st.session_state["settings_interaction_count"] > 10: | |
st.session_state["settings_interaction_count"] = 5 | |
# If System Status was just enabled, clear cache and allow immediate health check | |
if status_just_enabled: | |
st.cache_data.clear() | |
# Clear any cached status to force fresh call | |
for key in ["last_api_status", "last_api_delta", "last_api_type", "last_detailed_health"]: | |
if key in st.session_state: | |
del st.session_state[key] | |
st.session_state["last_status_update"] = 0 | |
# Set flag to trigger immediate health check | |
st.session_state["force_immediate_health_check"] = True | |
# Remove any recent UI action timestamp to allow API call | |
if "recent_ui_action" in st.session_state: | |
del st.session_state["recent_ui_action"] | |
# Show current settings status | |
visible_count = sum([show_model, show_agent, show_theme, show_status, show_tips]) | |
if visible_count == 5: | |
st.success(f"✅ All {visible_count} sections visible") | |
elif visible_count > 0: | |
st.info(f"ℹ️ {visible_count}/5 sections visible") | |
else: | |
st.warning("⚠️ No sections visible") | |
# Use the current session state values for conditional rendering | |
show_model = st.session_state["sidebar_settings"]["show_model_selection"] | |
show_agent = st.session_state["sidebar_settings"]["show_agent_selection"] | |
show_theme = st.session_state["sidebar_settings"]["show_theme_selection"] | |
show_status = st.session_state["sidebar_settings"]["show_system_status"] | |
show_tips = st.session_state["sidebar_settings"]["show_tips"] | |
st.markdown("---") | |
# Model Selection (conditional display) | |
if show_model: | |
st.markdown("### 🤖 AI Model") | |
# Static list of available models - no API calls needed | |
available_models = [ | |
"gpt-4o-mini", | |
"gpt-3.5-turbo", | |
"gemini-1.5-pro", | |
"gemini-1.5-flash", | |
"claude-3-haiku", | |
"claude-3-sonnet", | |
"mistral-small", | |
"mistral-medium", | |
"mistral-large", | |
] | |
model = st.selectbox( | |
"Choose your AI model:", | |
available_models, | |
index=0, # Default to first model | |
help="Select the AI model for processing your queries", | |
key="model_selector" | |
) | |
# Show model descriptions - static information | |
model_descriptions = { | |
"gpt-4o-mini": "⚡ Fast & cost-effective OpenAI model", | |
"gpt-3.5-turbo": "🔥 Reliable & quick OpenAI model", | |
"gemini-pro": "💎 Google's powerful Gemini model", | |
"gemini-1.5-pro": "🔬 Google's latest Gemini model", | |
"gemini-1.5-flash": "⚡ Google's fast Gemini model", | |
"claude-3-haiku": "🌸 Anthropic's efficient Claude model", | |
"claude-3-sonnet": "🎵 Anthropic's balanced Claude model", | |
"mistral-small": "🎯 Mistral's efficient model", | |
"mistral-medium": "⚖️ Mistral's balanced model", | |
"mistral-large": "🦾 Mistral's most capable model", | |
} | |
description = model_descriptions.get(model, "🤖 Advanced AI model") | |
st.info(description) | |
# Show provider information - static | |
provider_info = { | |
"gpt-4o-mini": "🏢 OpenAI", | |
"gpt-3.5-turbo": "🏢 OpenAI", | |
"gemini-pro": "🔍 Google", | |
"gemini-1.5-pro": "🔍 Google", | |
"gemini-1.5-flash": "🔍 Google", | |
"claude-3-haiku": "🤖 Anthropic", | |
"claude-3-sonnet": "🤖 Anthropic", | |
"mistral-small": "⚡ Mistral AI", | |
"mistral-medium": "⚡ Mistral AI", | |
"mistral-large": "⚡ Mistral AI", | |
} | |
provider = provider_info.get(model, "🤖 AI Provider") | |
st.caption(f"Provider: {provider}") | |
# Show setup hints for different providers | |
if model.startswith("gemini"): | |
st.caption("💡 Requires GOOGLE_API_KEY in .env") | |
elif model.startswith("claude"): | |
st.caption("💡 Requires ANTHROPIC_API_KEY in .env") | |
elif model.startswith("mistral"): | |
st.caption("💡 Requires MISTRAL_API_KEY in .env") | |
elif model in ["llama3.2", "llama3.1", "codellama", "phi3"]: | |
st.caption("💡 Requires Ollama installed locally") | |
else: | |
st.caption("💡 Requires OPENAI_API_KEY in .env") | |
# Mark UI action to prevent unnecessary API calls for other interactions | |
if model: | |
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp() | |
st.markdown("---") | |
else: | |
# Use default model when hidden | |
model = AVAILABLE_MODELS[0] | |
# Agent Selection (conditional display) | |
if show_agent: | |
st.markdown("### 🎯 Agent Type") | |
agent = st.selectbox( | |
"Choose your agent:", | |
AVAILABLE_AGENTS, | |
help="Select the specialized agent for your database tasks", | |
key="agent_selector" | |
) | |
# Mark UI action to prevent unnecessary API calls | |
if agent: | |
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp() | |
# Show agent info | |
agent_info = { | |
"default": "🔧 General-purpose database assistant", | |
"sql-agent": "💾 Specialized in SQL optimization", | |
"custom-agent": "🎨 Customized for specific workflows" | |
} | |
st.info(agent_info.get(agent, "Specialized database agent")) | |
st.markdown("---") | |
else: | |
# Use default agent when hidden | |
agent = AVAILABLE_AGENTS[0] | |
# Theme Selection (conditional display) | |
if show_theme: | |
st.markdown("### 🎨 Appearance") | |
theme = st.selectbox( | |
"Choose your theme:", | |
AVAILABLE_THEMES, | |
help="Customize the visual appearance", | |
key="theme_selector" | |
) | |
# Mark UI action to prevent unnecessary API calls | |
if theme: | |
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp() | |
st.markdown("---") | |
else: | |
# Use default theme when hidden | |
theme = AVAILABLE_THEMES[0] | |
# System Status (conditional display) | |
if show_status: | |
st.markdown("### 📊 System Status") | |
# Skip status check if currently processing a query to avoid slowdown | |
if st.session_state.get("processing_query", False): | |
st.info("⏳ Status check paused during query processing") | |
# Show cached stats only | |
total_queries, success_rate, successful, failed = get_query_stats() | |
col1, col2 = st.columns(2) | |
with col1: | |
st.metric("API Status", "⏳ Processing", delta="Query in progress") | |
with col2: | |
st.metric("Total Queries", total_queries, delta=f"+{total_queries}") | |
else: | |
# Check if this is a force refresh scenario | |
is_force_refresh = st.session_state.get("force_refresh_requested", False) | |
# Use centralized validation to avoid unnecessary API calls | |
if not should_skip_api_call(force_refresh=is_force_refresh): | |
logging.info("=== HEALTH CHECK API CALL STARTING ===") | |
try: | |
# Get real-time API status only when validation passes | |
status_text, status_delta, status_type = check_api_status() | |
# Get detailed health information | |
detailed_health = get_detailed_health_status() | |
# Store the status for future use | |
st.session_state["last_api_status"] = status_text | |
st.session_state["last_api_delta"] = status_delta | |
st.session_state["last_api_type"] = status_type | |
st.session_state["last_detailed_health"] = detailed_health | |
st.session_state["last_status_update"] = pd.Timestamp.now().timestamp() | |
logging.info("=== HEALTH CHECK API CALL COMPLETED ===") | |
# Clear force refresh flag after successful update | |
if is_force_refresh: | |
st.session_state["force_refresh_requested"] = False | |
except Exception as e: | |
logging.error(f"Status check failed: {e}") | |
# Use error values if API fails | |
status_text = "🔴 Failed" | |
status_delta = f"Error: {str(e)[:50]}..." | |
status_type = "error" | |
detailed_health = { | |
"status": "error", | |
"message": f"Status check failed: {str(e)}", | |
"checks": {} | |
} | |
# Clear force refresh flag even on error | |
if is_force_refresh: | |
st.session_state["force_refresh_requested"] = False | |
else: | |
# Use cached values to avoid API calls, with better defaults | |
status_text = st.session_state.get("last_api_status", "🟡 Loading...") | |
status_delta = st.session_state.get("last_api_delta", "Initializing...") | |
status_type = st.session_state.get("last_api_type", "normal") | |
detailed_health = st.session_state.get("last_detailed_health", { | |
"status": "unknown", | |
"message": "Loading system status...", | |
"checks": {} | |
}) | |
# Get query statistics | |
total_queries, success_rate, successful, failed = get_query_stats() | |
# Display main metrics | |
col1, col2 = st.columns(2) | |
with col1: | |
st.metric("API Status", status_text, delta=status_delta) | |
with col2: | |
st.metric("Total Queries", total_queries, delta=f"+{total_queries}") | |
# Status indicator with more details | |
if status_type == "success": | |
st.success("✅ All systems operational") | |
elif status_type == "warning": | |
st.warning("⚠️ Limited functionality - some features may be slow") | |
else: | |
st.error("❌ API connection failed - please check your backend service") | |
# Show detailed health checks in an expander | |
with st.expander("🔍 Detailed System Health", expanded=False): | |
if detailed_health.get("status") in ["healthy", "unhealthy"]: | |
# Display timestamp | |
if "timestamp" in detailed_health: | |
st.caption(f"Last checked: {detailed_health['timestamp']}") | |
# Display individual checks | |
checks = detailed_health.get("checks", {}) | |
if "database" in checks: | |
db_check = checks["database"] | |
db_status = db_check.get("status", "unknown") | |
db_message = db_check.get("message", "No information") | |
if db_status == "healthy": | |
st.success(f"🗃️ Database: {db_message}") | |
elif db_status == "unhealthy": | |
st.error(f"🗃️ Database: {db_message}") | |
else: | |
st.warning(f"🗃️ Database: {db_message}") | |
if "openai_api" in checks: | |
api_check = checks["openai_api"] | |
api_status = api_check.get("status", "unknown") | |
api_message = api_check.get("message", "No information") | |
if api_status == "configured": | |
st.success(f"🤖 OpenAI API: {api_message}") | |
elif api_status == "error": | |
st.error(f"🤖 OpenAI API: {api_message}") | |
else: | |
st.warning(f"🤖 OpenAI API: {api_message}") | |
# Display version if available | |
if "version" in detailed_health: | |
st.info(f"📦 Version: {detailed_health['version']}") | |
else: | |
# Show error information | |
st.error(f"❌ Health check failed: {detailed_health.get('message', 'Unknown error')}") | |
st.caption("Unable to retrieve detailed system status") | |
# Additional detailed stats (always show regardless of processing state) | |
if total_queries > 0: | |
col3, col4 = st.columns(2) | |
with col3: | |
st.metric("Success Rate", f"{success_rate}%", delta=f"{successful} successful") | |
with col4: | |
session_duration = pd.Timestamp.now() - st.session_state["query_stats"]["session_start"] | |
hours = int(session_duration.total_seconds() // 3600) | |
minutes = int((session_duration.total_seconds() % 3600) // 60) | |
st.metric("Session Time", f"{hours}h {minutes}m", delta="Active") | |
# Show last update time and refresh button | |
col_refresh1, col_refresh2 = st.columns([2, 1]) | |
with col_refresh1: | |
st.caption(f"🔄 Last updated: {pd.Timestamp.now().strftime('%H:%M:%S')}") | |
with col_refresh2: | |
if st.button("🔄", help="Force Refresh Status", key="refresh_status"): | |
# Set force refresh flag to bypass validation | |
st.session_state["force_refresh_requested"] = True | |
# Clear all cached data and force fresh API calls | |
st.cache_data.clear() | |
# Reset validation timestamps to allow immediate API calls | |
st.session_state["last_status_update"] = 0 | |
if "recent_ui_action" in st.session_state: | |
del st.session_state["recent_ui_action"] | |
# Clear cached status values | |
for key in ["last_api_status", "last_api_delta", "last_api_type", "last_detailed_health"]: | |
if key in st.session_state: | |
del st.session_state[key] | |
st.rerun() | |
st.markdown("---") | |
# Tips and Help (conditional display) | |
if show_tips: | |
st.markdown("### 💡 Quick Tips") | |
with st.expander("📝 How to ask questions"): | |
st.markdown(""" | |
- **"Show me all customers from Chicago"** | |
- **"What are the top 5 branches by transactions?"** | |
- **"Calculate total transactions by month"** | |
- **"Find customers who haven't done any transactions recently"** | |
""") | |
with st.expander("⚡ Pro Tips"): | |
st.markdown(""" | |
- Be specific about what data you want | |
- Mention date ranges when relevant | |
- Ask for summaries or aggregations | |
- Use natural language - no SQL needed! | |
""") | |
with st.expander("🔧 Troubleshooting"): | |
st.markdown(""" | |
- **No results?** Try rephrasing your question | |
- **Error message?** Use the retry button | |
- **Slow response?** Check your connection | |
- **Wrong data?** Be more specific in your query | |
""") | |
st.markdown("---") | |
# Quick actions | |
st.markdown("### 🚀 Quick Actions") | |
# Check if query is currently being processed | |
is_processing = st.session_state.get("processing_query", False) | |
# Show processing indicator if query is running | |
if is_processing: | |
st.info("⚡ Query in progress... All action buttons are temporarily disabled.") | |
col_action1, col_action2 = st.columns(2) | |
with col_action1: | |
# Disable Clear Chat button when query is processing | |
clear_chat_disabled = is_processing | |
clear_chat_help = "Cannot clear chat while query is processing" if is_processing else "Clear all chat messages" | |
if st.button("🗑️ Clear Chat", | |
use_container_width=True, | |
disabled=clear_chat_disabled, | |
help=clear_chat_help): | |
# Simple UI action - block API calls for 5 seconds | |
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp() | |
st.session_state["messages"] = [] | |
logging.info("Clear Chat button clicked - blocking API calls for 5 seconds") | |
st.rerun() | |
with col_action2: | |
# Disable Reset Stats button when query is processing to prevent any API calls | |
reset_stats_disabled = is_processing | |
reset_stats_help = "Cannot reset stats while query is processing" if is_processing else "Reset query statistics" | |
if st.button("📊 Reset Stats", | |
use_container_width=True, | |
disabled=reset_stats_disabled, | |
help=reset_stats_help): | |
# Simple UI action - block API calls for 5 seconds | |
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp() | |
st.session_state["query_stats"] = { | |
"total_queries": 0, | |
"successful_queries": 0, | |
"failed_queries": 0, | |
"session_start": pd.Timestamp.now() | |
} | |
logging.info("Reset Stats button clicked - blocking API calls for 5 seconds") | |
st.rerun() | |
# Sample Queries button - also disabled during processing to avoid confusion | |
sample_disabled = is_processing | |
sample_help = "Cannot show samples while query is processing" if is_processing else "Show sample queries you can copy" | |
if st.button("📋 Sample Queries", | |
use_container_width=True, | |
disabled=sample_disabled, | |
help=sample_help): | |
# Simple UI action - block API calls for 5 seconds | |
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp() | |
logging.info("Sample Queries button clicked - blocking API calls for 5 seconds") | |
sample_queries = [ | |
"Show me the top 10 customers by transactions", | |
"Which branches collected more transactions last month?", | |
"Calculate average transactions value", | |
"List all active customers" | |
] | |
# Display sample queries in the sidebar instead of adding to chat | |
st.markdown("**Sample queries you can copy and paste:**") | |
for query in sample_queries: | |
st.code(query) | |
# Footer | |
st.markdown("### ℹ️ About") | |
st.markdown(""" | |
**AI Database Assistant** v2.0 | |
🚀 Powered by advanced AI | |
💬 Natural language to SQL | |
📈 Real-time analytics | |
""") | |
# Apply theme | |
theme_manager.inject_theme(theme) | |
# =============================== | |
# Chat Rendering Function | |
# =============================== | |
def render_chat(): | |
"""Render chat history with proper error handling.""" | |
st.markdown("<div class='chat-container'>", unsafe_allow_html=True) | |
messages = st.session_state["messages"] | |
for i, msg in enumerate(messages): | |
if msg["role"] == "user": | |
# User message with avatar | |
st.markdown( | |
f'''<div style="display: flex; align-items: flex-start; justify-content: flex-end; margin-bottom: 0.5em;"> | |
<div style="margin-right: 0.5em;"> | |
<img src="{OUTLINE_INDIGO_USER}" alt="User" style="width: 2.3rem; height: 2.3rem; border-radius: 50%; border: 2px solid #e3f2fd; background: #fff; object-fit: cover;" /> | |
</div> | |
<div class="user-bubble">{msg["content"]}</div> | |
</div>''', | |
unsafe_allow_html=True) | |
elif msg["role"] == "assistant": | |
# Show all assistant messages normally (including errors) | |
bubble_class = "error-bubble" if msg.get("is_error") else "ai-bubble" | |
# Skip rendering placeholder messages (we use Streamlit spinner instead) | |
if msg.get("is_placeholder"): | |
continue | |
else: | |
# For error messages, show retry button below the message | |
if msg.get("is_error"): | |
st.markdown( | |
f'''<div style="display: flex; align-items: flex-start; margin-bottom: 0.5em;"> | |
<div style="margin-right: 0.5em;"> | |
<img src="{DARK_MODE_SLATE_AI}" alt="AI" style="width: 2.3rem; height: 2.3rem; border-radius: 50%; border: 2px solid #b2dfdb; background: #fff; object-fit: cover;" /> | |
</div> | |
<div class="{bubble_class}">{msg["content"]}</div> | |
</div>''', | |
unsafe_allow_html=True) | |
# Show retry button below the error message, aligned with the error message | |
# Use same layout structure as the error message for alignment | |
cols = st.columns([0.03, 0.85]) | |
with cols[0]: | |
st.empty() # Empty space where avatar would be | |
with cols[1]: | |
if st.button(RETRY_BUTTON_TEXT, key=f"retry_error_{i}"): | |
# Use the stored user query and index for reliable retry | |
stored_user_query = msg.get("user_query") | |
stored_user_index = msg.get("user_query_index") | |
if stored_user_query and stored_user_index is not None: | |
# Remove all messages after the user query and retry | |
st.session_state["messages"] = st.session_state["messages"][:stored_user_index+1] | |
st.session_state["messages"].append({"role": "assistant", "content": "� Processing your query...", "is_placeholder": True}) | |
# Mark this as a legitimate retry, not a UI interaction | |
st.session_state["legitimate_query_time"] = pd.Timestamp.now().timestamp() | |
# Remove any recent UI action flag to allow this query to process | |
if "recent_ui_action" in st.session_state: | |
del st.session_state["recent_ui_action"] | |
st.rerun() | |
else: | |
# Regular AI message | |
st.markdown( | |
f'''<div style="display: flex; align-items: flex-start; margin-bottom: 0.5em;"> | |
<div style="margin-right: 0.5em;"> | |
<img src="{DARK_MODE_SLATE_AI}" alt="AI" style="width: 2.3rem; height: 2.3rem; border-radius: 50%; border: 2px solid #b2dfdb; background: #fff; object-fit: cover;" /> | |
</div> | |
<div class="{bubble_class}">{msg["content"]}</div> | |
</div>''', | |
unsafe_allow_html=True) | |
# Show data table and download if present | |
if msg.get("data"): | |
df = pd.DataFrame(msg["data"]) | |
st.dataframe(df, use_container_width=True) | |
csv = df.to_csv(index=False).encode("utf-8") | |
st.download_button(DOWNLOAD_BUTTON_TEXT, csv, "results.csv", "text/csv", key=f"download_csv_{id(msg)}") | |
# Show chart if present | |
if msg.get("chart"): | |
img_data = base64.b64decode(msg["chart"]) | |
st.image(img_data, use_column_width=True) | |
st.markdown("</div>", unsafe_allow_html=True) | |
# Render chat | |
render_chat() | |
# =============================== | |
# User Input Handling | |
# =============================== | |
# Check if AI is thinking | |
pending = False | |
if st.session_state["messages"]: | |
if st.session_state["messages"][-1]["role"] == "assistant": | |
pending = st.session_state["messages"][-1].get("is_placeholder", False) | |
# Get user input | |
user_query = st.chat_input(CHAT_INPUT_PLACEHOLDER, disabled=pending) | |
if user_query and not pending: | |
st.session_state["messages"].append({"role": "user", "content": user_query}) | |
st.session_state["messages"].append({"role": "assistant", "content": "� Processing ...", "is_placeholder": True}) | |
# Mark this as a legitimate user query, not a UI interaction | |
st.session_state["legitimate_query_time"] = pd.Timestamp.now().timestamp() | |
# Remove any recent UI action flag to allow this query to process | |
if "recent_ui_action" in st.session_state: | |
del st.session_state["recent_ui_action"] | |
st.rerun() | |
# =============================== | |
# API Response Handling | |
# =============================== | |
# CRITICAL: Only process API calls for legitimate user queries, not UI interactions | |
# Check if we have a placeholder message from a user query | |
has_placeholder = ( | |
st.session_state["messages"] | |
and st.session_state["messages"][-1].get("is_placeholder") | |
and len(st.session_state["messages"]) >= 2 | |
and st.session_state["messages"][-2]["role"] == "user" | |
) | |
# Check if this is a legitimate query that should be processed | |
# Block if recent UI action (sidebar interactions) triggered this rerun | |
should_process_query = ( | |
has_placeholder | |
and not should_skip_query_processing() # Use query-specific validation | |
) | |
if has_placeholder: | |
logging.info(f"=== QUERY PROCESSING CHECK ===") | |
logging.info(f"Has placeholder: {has_placeholder}") | |
logging.info(f"Should process query: {should_process_query}") | |
if not should_process_query: | |
logging.info("BLOCKED: Query processing blocked by should_skip_api_call") | |
else: | |
logging.info("ALLOWED: Query processing allowed") | |
if should_process_query: | |
user_query = st.session_state["messages"][-2]["content"] | |
user_query_index = len(st.session_state["messages"]) - 2 # Store the user query index | |
# Set processing flag to pause status checks | |
st.session_state["processing_query"] = True | |
try: | |
with st.spinner(THINKING_MESSAGE): | |
logging.info(f"Sending query to API: {user_query}") | |
try: | |
# Use the API client instead of direct requests | |
result = api_client.send_query(user_query, model, agent) | |
if result: | |
answer_text = result.get("message", "No response received") | |
rows = result.get("rows", []) | |
chart = result.get("chart", None) | |
is_error = result.get("error", False) | |
model_used = result.get("model_used", "unknown") | |
status = result.get("status", "unknown") | |
# Add model information to the response for successful queries | |
if not is_error and model_used != "unknown": | |
answer_text += f"\n\n*Powered by: {model_used}*" | |
else: | |
answer_text = "❌ Error: Unable to process your request. Please try again." | |
rows, chart, is_error = [], None, True | |
model_used, status = "unknown", "error" | |
except Exception as e: | |
logging.error(f"API connectivity error: {e}") | |
answer_text, rows, chart, is_error = ( | |
"❌ API is not available. Please check your connection or try again later.", [], None, True | |
) | |
model_used, status = "unknown", "error" | |
except Exception as e: | |
logging.error(f"Exception: {e}") | |
answer_text, rows, chart, is_error = f"⚠️ Exception: {str(e)}", [], None, True | |
model_used, status = "unknown", "error" | |
# Update query statistics | |
update_query_stats(success=not is_error) | |
# Clear processing flag | |
st.session_state["processing_query"] = False | |
# Clear legitimate query flag after processing | |
if "legitimate_query_time" in st.session_state: | |
del st.session_state["legitimate_query_time"] | |
# Replace placeholder with final response | |
st.session_state["messages"][-1] = { | |
"role": "assistant", | |
"content": answer_text, | |
"data": rows, | |
"chart": chart, | |
"is_error": is_error, | |
"model_used": model_used, | |
"status": status, | |
"user_query": user_query if is_error else None, # Store user query for retry | |
"user_query_index": user_query_index if is_error else None, # Store user query index for retry | |
} | |
st.rerun() | |