"""
AI Database Assistant - Streamlit Chat Interface
"""
import streamlit as st
import pandas as pd
import base64
import logging
from typing import List, Dict, Any
from api_client import APIClient
from themes import ThemeManager
from config import (
BASE_URL, PAGE_TITLE, PAGE_LAYOUT, AVAILABLE_MODELS, AVAILABLE_AGENTS,
OUTLINE_INDIGO_USER, DARK_MODE_SLATE_AI, AVAILABLE_THEMES,
CHAT_INPUT_PLACEHOLDER, THINKING_MESSAGE, WORKING_MESSAGE, RETRY_BUTTON_TEXT, DOWNLOAD_BUTTON_TEXT
)
# ===============================
# Configuration
# ===============================
# Configuration is now imported from config.py
# To change environments, only modify BASE_URL in config.py
# Setup
st.set_page_config(page_title=PAGE_TITLE, layout=PAGE_LAYOUT)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
# Initialize API Client with base URL from config
api_client = APIClient(BASE_URL)
# ===============================
# Theme Management
# ===============================
# Initialize theme manager
theme_manager = ThemeManager()
# ===============================
# Database Status Functions
# ===============================
@st.cache_data(ttl=30) # Reduced cache time for more responsive health checks
def check_api_status():
"""Check if the API is reachable and responsive with a lightweight health check."""
try:
return api_client.check_health()
except Exception as e:
logging.error(f"Health check failed: {e}")
return "đ´ Error", f"Failed: {str(e)}", "error"
@st.cache_data(ttl=30) # Reduced cache time for detailed health
def get_detailed_health_status():
"""Get detailed health status from the API."""
try:
return api_client.get_detailed_health()
except Exception as e:
logging.error(f"Detailed health check failed: {e}")
return {
"status": "error",
"message": f"Health check failed: {str(e)}",
"checks": {}
}
def should_skip_api_call(force_refresh=False):
"""Enhanced validation to determine if API calls should be skipped - FOR HEALTH CHECKS ONLY."""
try:
# Always allow API calls when force refresh is requested
if force_refresh:
logging.info("API call allowed: Force refresh requested")
return False
# Check for immediate health check flag (when System Status is just enabled)
if st.session_state.get("force_immediate_health_check", False):
logging.info("API call allowed: System Status just enabled")
# Clear the flag after use
st.session_state["force_immediate_health_check"] = False
return False
# Skip if currently processing a query (to prevent concurrent calls)
if st.session_state.get("processing_query", False):
logging.info("API call skipped: Query processing in progress")
return True
# Check if System Status section is enabled - only call API if it's visible
if not st.session_state.get("sidebar_settings", {}).get("show_system_status", False):
logging.info("API call skipped: System Status section is disabled")
return True
# Block API calls if ANY UI interaction happened in last 5 seconds (FOR HEALTH CHECKS)
if "recent_ui_action" in st.session_state:
current_time = pd.Timestamp.now().timestamp()
time_since_action = current_time - st.session_state.get("recent_ui_action", 0)
if time_since_action < 5: # 5 seconds protection
logging.info(f"API call blocked: UI interaction {time_since_action:.1f}s ago")
return True
# Check rate limiting for regular health checks
current_time = pd.Timestamp.now().timestamp()
last_update = st.session_state.get("last_status_update", 0)
# Allow call if we don't have cached status
if "last_api_status" not in st.session_state:
logging.info("API call allowed: No cached status available")
return False
# Otherwise, respect the rate limit
time_since_update = current_time - last_update
if time_since_update < 30: # 30 seconds rate limit
logging.info(f"API call skipped: Rate limit (updated {time_since_update:.1f}s ago)")
return True
logging.info("API call allowed: Rate limit passed")
return False
except Exception as e:
logging.error(f"Error in should_skip_api_call: {e}")
return False # Allow API call on error (fail safe)
def should_skip_query_processing():
"""Determine if query processing should be skipped - FOR QUERY PROCESSING ONLY."""
try:
# Never skip if we have a legitimate query
if "legitimate_query_time" in st.session_state:
current_time = pd.Timestamp.now().timestamp()
time_since_query = current_time - st.session_state.get("legitimate_query_time", 0)
if time_since_query < 10: # Allow legitimate queries within 10 seconds
logging.info(f"Query processing allowed: Legitimate query {time_since_query:.1f}s ago")
return False
# Skip if already processing a query (to prevent concurrent calls)
if st.session_state.get("processing_query", False):
logging.info("Query processing skipped: Already processing a query")
return True
# Block if recent UI action but no legitimate query flag
if "recent_ui_action" in st.session_state and "legitimate_query_time" not in st.session_state:
current_time = pd.Timestamp.now().timestamp()
time_since_action = current_time - st.session_state.get("recent_ui_action", 0)
if time_since_action < 5: # 5 seconds protection
logging.info(f"Query processing blocked: UI interaction {time_since_action:.1f}s ago without legitimate query")
return True
logging.info("Query processing allowed: No blocking conditions")
return False
except Exception as e:
logging.error(f"Error in should_skip_query_processing: {e}")
return False # Allow processing on error (fail safe)
def silent_status_update():
"""Update status silently without UI disruption."""
try:
# Use the validation function
if should_skip_api_call():
return
# Clear cache and update timestamp
st.cache_data.clear()
st.session_state["last_status_update"] = pd.Timestamp.now().timestamp()
except:
pass # Silent failure
def get_query_stats():
"""Get query statistics from session state."""
if "query_stats" not in st.session_state:
st.session_state["query_stats"] = {
"total_queries": 0,
"successful_queries": 0,
"failed_queries": 0,
"session_start": pd.Timestamp.now()
}
stats = st.session_state["query_stats"]
success_rate = 0
if stats["total_queries"] > 0:
success_rate = round((stats["successful_queries"] / stats["total_queries"]) * 100, 1)
return stats["total_queries"], success_rate, stats["successful_queries"], stats["failed_queries"]
def update_query_stats(success=True):
"""Update query statistics."""
if "query_stats" not in st.session_state:
st.session_state["query_stats"] = {
"total_queries": 0,
"successful_queries": 0,
"failed_queries": 0,
"session_start": pd.Timestamp.now()
}
st.session_state["query_stats"]["total_queries"] += 1
if success:
st.session_state["query_stats"]["successful_queries"] += 1
else:
st.session_state["query_stats"]["failed_queries"] += 1
# ===============================
# Session State Management
# ===============================
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "query_stats" not in st.session_state:
st.session_state["query_stats"] = {
"total_queries": 0,
"successful_queries": 0,
"failed_queries": 0,
"session_start": pd.Timestamp.now()
}
if "processing_query" not in st.session_state:
st.session_state["processing_query"] = False
# ===============================
# Main App Setup
# ===============================
# Professional header positioned at top-left of chat area
st.markdown(f"""
đŦ {PAGE_TITLE}
""", unsafe_allow_html=True)
st.markdown("""
Ask questions in plain English to generate and run SQL queries.
""", unsafe_allow_html=True)
# Sidebar
with st.sidebar:
# Quick status indicator at the top
if "sidebar_settings" in st.session_state:
visible_count = sum([
st.session_state["sidebar_settings"]["show_model_selection"],
st.session_state["sidebar_settings"]["show_agent_selection"],
st.session_state["sidebar_settings"]["show_theme_selection"],
st.session_state["sidebar_settings"]["show_system_status"],
st.session_state["sidebar_settings"]["show_tips"]
])
if visible_count == 5:
st.caption("đĸ All sections visible")
elif visible_count > 0:
st.caption(f"đĄ {visible_count}/5 sections visible")
else:
st.caption("đ´ No sections visible")
# Sidebar Display Settings - User Configurable
# Initialize settings state tracking
if "settings_interaction_count" not in st.session_state:
st.session_state["settings_interaction_count"] = 0
# Track if user has recently interacted with settings
keep_expanded = st.session_state.get("settings_interaction_count", 0) > 0
# Create expander that stays open for a few interactions
settings_expander = st.expander(
"âī¸ Sidebar Settings",
expanded=keep_expanded
)
with settings_expander:
st.markdown("**Choose what to display:**")
# Initialize sidebar visibility settings in session state with new defaults
if "sidebar_settings" not in st.session_state:
st.session_state["sidebar_settings"] = {
"show_model_selection": True, # Default unchecked
"show_agent_selection": False, # Default unchecked
"show_theme_selection": True,
"show_system_status": False, # Default unchecked - only load health checks when enabled
"show_tips": True
}
# Store previous values to detect changes
prev_settings = st.session_state["sidebar_settings"].copy()
# Configurable checkboxes
col1, col2 = st.columns(2)
with col1:
show_model = st.checkbox(
"đ¤ AI Model",
value=st.session_state["sidebar_settings"]["show_model_selection"],
help="Show/hide AI model selection",
key="settings_model"
)
show_agent = st.checkbox(
"đ¯ Agent Type",
value=st.session_state["sidebar_settings"]["show_agent_selection"],
help="Show/hide agent selection",
key="settings_agent"
)
show_theme = st.checkbox(
"đ¨ Theme",
value=st.session_state["sidebar_settings"]["show_theme_selection"],
help="Show/hide theme selection",
key="settings_theme"
)
with col2:
show_status = st.checkbox(
"đ System Status",
value=st.session_state["sidebar_settings"]["show_system_status"],
help="Show/hide system status",
key="settings_status"
)
show_tips = st.checkbox(
"đĄ Tips & Help",
value=st.session_state["sidebar_settings"]["show_tips"],
help="Show/hide tips section",
key="settings_tips"
)
# Detect if any setting changed
new_settings = {
"show_model_selection": show_model,
"show_agent_selection": show_agent,
"show_theme_selection": show_theme,
"show_system_status": show_status,
"show_tips": show_tips
}
# Check if settings changed
settings_changed = any(
prev_settings.get(key) != new_settings[key]
for key in new_settings.keys()
)
# Special handling for System Status being enabled
status_just_enabled = (
not prev_settings.get("show_system_status", False) and
new_settings["show_system_status"]
)
# Update session state
st.session_state["sidebar_settings"].update(new_settings)
# Increment interaction count when settings change
if settings_changed:
st.session_state["settings_interaction_count"] += 1
# Mark UI action to prevent unnecessary API calls for most settings
if not status_just_enabled: # Don't block API calls when System Status is enabled
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp()
# Reset counter after 10 interactions to prevent it growing indefinitely
if st.session_state["settings_interaction_count"] > 10:
st.session_state["settings_interaction_count"] = 5
# If System Status was just enabled, clear cache and allow immediate health check
if status_just_enabled:
st.cache_data.clear()
# Clear any cached status to force fresh call
for key in ["last_api_status", "last_api_delta", "last_api_type", "last_detailed_health"]:
if key in st.session_state:
del st.session_state[key]
st.session_state["last_status_update"] = 0
# Set flag to trigger immediate health check
st.session_state["force_immediate_health_check"] = True
# Remove any recent UI action timestamp to allow API call
if "recent_ui_action" in st.session_state:
del st.session_state["recent_ui_action"]
# Show current settings status
visible_count = sum([show_model, show_agent, show_theme, show_status, show_tips])
if visible_count == 5:
st.success(f"â
All {visible_count} sections visible")
elif visible_count > 0:
st.info(f"âšī¸ {visible_count}/5 sections visible")
else:
st.warning("â ī¸ No sections visible")
# Use the current session state values for conditional rendering
show_model = st.session_state["sidebar_settings"]["show_model_selection"]
show_agent = st.session_state["sidebar_settings"]["show_agent_selection"]
show_theme = st.session_state["sidebar_settings"]["show_theme_selection"]
show_status = st.session_state["sidebar_settings"]["show_system_status"]
show_tips = st.session_state["sidebar_settings"]["show_tips"]
st.markdown("---")
# Model Selection (conditional display)
if show_model:
st.markdown("### đ¤ AI Model")
# Static list of available models - no API calls needed
available_models = [
"gpt-4o-mini",
"gpt-3.5-turbo",
"gemini-1.5-pro",
"gemini-1.5-flash",
"claude-3-haiku",
"claude-3-sonnet",
"mistral-small",
"mistral-medium",
"mistral-large",
]
model = st.selectbox(
"Choose your AI model:",
available_models,
index=0, # Default to first model
help="Select the AI model for processing your queries",
key="model_selector"
)
# Show model descriptions - static information
model_descriptions = {
"gpt-4o-mini": "⥠Fast & cost-effective OpenAI model",
"gpt-3.5-turbo": "đĨ Reliable & quick OpenAI model",
"gemini-pro": "đ Google's powerful Gemini model",
"gemini-1.5-pro": "đŦ Google's latest Gemini model",
"gemini-1.5-flash": "⥠Google's fast Gemini model",
"claude-3-haiku": "đ¸ Anthropic's efficient Claude model",
"claude-3-sonnet": "đĩ Anthropic's balanced Claude model",
"mistral-small": "đ¯ Mistral's efficient model",
"mistral-medium": "âī¸ Mistral's balanced model",
"mistral-large": "đĻž Mistral's most capable model",
}
description = model_descriptions.get(model, "đ¤ Advanced AI model")
st.info(description)
# Show provider information - static
provider_info = {
"gpt-4o-mini": "đĸ OpenAI",
"gpt-3.5-turbo": "đĸ OpenAI",
"gemini-pro": "đ Google",
"gemini-1.5-pro": "đ Google",
"gemini-1.5-flash": "đ Google",
"claude-3-haiku": "đ¤ Anthropic",
"claude-3-sonnet": "đ¤ Anthropic",
"mistral-small": "⥠Mistral AI",
"mistral-medium": "⥠Mistral AI",
"mistral-large": "⥠Mistral AI",
}
provider = provider_info.get(model, "đ¤ AI Provider")
st.caption(f"Provider: {provider}")
# Show setup hints for different providers
if model.startswith("gemini"):
st.caption("đĄ Requires GOOGLE_API_KEY in .env")
elif model.startswith("claude"):
st.caption("đĄ Requires ANTHROPIC_API_KEY in .env")
elif model.startswith("mistral"):
st.caption("đĄ Requires MISTRAL_API_KEY in .env")
elif model in ["llama3.2", "llama3.1", "codellama", "phi3"]:
st.caption("đĄ Requires Ollama installed locally")
else:
st.caption("đĄ Requires OPENAI_API_KEY in .env")
# Mark UI action to prevent unnecessary API calls for other interactions
if model:
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp()
st.markdown("---")
else:
# Use default model when hidden
model = AVAILABLE_MODELS[0]
# Agent Selection (conditional display)
if show_agent:
st.markdown("### đ¯ Agent Type")
agent = st.selectbox(
"Choose your agent:",
AVAILABLE_AGENTS,
help="Select the specialized agent for your database tasks",
key="agent_selector"
)
# Mark UI action to prevent unnecessary API calls
if agent:
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp()
# Show agent info
agent_info = {
"default": "đ§ General-purpose database assistant",
"sql-agent": "đž Specialized in SQL optimization",
"custom-agent": "đ¨ Customized for specific workflows"
}
st.info(agent_info.get(agent, "Specialized database agent"))
st.markdown("---")
else:
# Use default agent when hidden
agent = AVAILABLE_AGENTS[0]
# Theme Selection (conditional display)
if show_theme:
st.markdown("### đ¨ Appearance")
theme = st.selectbox(
"Choose your theme:",
AVAILABLE_THEMES,
help="Customize the visual appearance",
key="theme_selector"
)
# Mark UI action to prevent unnecessary API calls
if theme:
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp()
st.markdown("---")
else:
# Use default theme when hidden
theme = AVAILABLE_THEMES[0]
# System Status (conditional display)
if show_status:
st.markdown("### đ System Status")
# Skip status check if currently processing a query to avoid slowdown
if st.session_state.get("processing_query", False):
st.info("âŗ Status check paused during query processing")
# Show cached stats only
total_queries, success_rate, successful, failed = get_query_stats()
col1, col2 = st.columns(2)
with col1:
st.metric("API Status", "âŗ Processing", delta="Query in progress")
with col2:
st.metric("Total Queries", total_queries, delta=f"+{total_queries}")
else:
# Check if this is a force refresh scenario
is_force_refresh = st.session_state.get("force_refresh_requested", False)
# Use centralized validation to avoid unnecessary API calls
if not should_skip_api_call(force_refresh=is_force_refresh):
logging.info("=== HEALTH CHECK API CALL STARTING ===")
try:
# Get real-time API status only when validation passes
status_text, status_delta, status_type = check_api_status()
# Get detailed health information
detailed_health = get_detailed_health_status()
# Store the status for future use
st.session_state["last_api_status"] = status_text
st.session_state["last_api_delta"] = status_delta
st.session_state["last_api_type"] = status_type
st.session_state["last_detailed_health"] = detailed_health
st.session_state["last_status_update"] = pd.Timestamp.now().timestamp()
logging.info("=== HEALTH CHECK API CALL COMPLETED ===")
# Clear force refresh flag after successful update
if is_force_refresh:
st.session_state["force_refresh_requested"] = False
except Exception as e:
logging.error(f"Status check failed: {e}")
# Use error values if API fails
status_text = "đ´ Failed"
status_delta = f"Error: {str(e)[:50]}..."
status_type = "error"
detailed_health = {
"status": "error",
"message": f"Status check failed: {str(e)}",
"checks": {}
}
# Clear force refresh flag even on error
if is_force_refresh:
st.session_state["force_refresh_requested"] = False
else:
# Use cached values to avoid API calls, with better defaults
status_text = st.session_state.get("last_api_status", "đĄ Loading...")
status_delta = st.session_state.get("last_api_delta", "Initializing...")
status_type = st.session_state.get("last_api_type", "normal")
detailed_health = st.session_state.get("last_detailed_health", {
"status": "unknown",
"message": "Loading system status...",
"checks": {}
})
# Get query statistics
total_queries, success_rate, successful, failed = get_query_stats()
# Display main metrics
col1, col2 = st.columns(2)
with col1:
st.metric("API Status", status_text, delta=status_delta)
with col2:
st.metric("Total Queries", total_queries, delta=f"+{total_queries}")
# Status indicator with more details
if status_type == "success":
st.success("â
All systems operational")
elif status_type == "warning":
st.warning("â ī¸ Limited functionality - some features may be slow")
else:
st.error("â API connection failed - please check your backend service")
# Show detailed health checks in an expander
with st.expander("đ Detailed System Health", expanded=False):
if detailed_health.get("status") in ["healthy", "unhealthy"]:
# Display timestamp
if "timestamp" in detailed_health:
st.caption(f"Last checked: {detailed_health['timestamp']}")
# Display individual checks
checks = detailed_health.get("checks", {})
if "database" in checks:
db_check = checks["database"]
db_status = db_check.get("status", "unknown")
db_message = db_check.get("message", "No information")
if db_status == "healthy":
st.success(f"đī¸ Database: {db_message}")
elif db_status == "unhealthy":
st.error(f"đī¸ Database: {db_message}")
else:
st.warning(f"đī¸ Database: {db_message}")
if "openai_api" in checks:
api_check = checks["openai_api"]
api_status = api_check.get("status", "unknown")
api_message = api_check.get("message", "No information")
if api_status == "configured":
st.success(f"đ¤ OpenAI API: {api_message}")
elif api_status == "error":
st.error(f"đ¤ OpenAI API: {api_message}")
else:
st.warning(f"đ¤ OpenAI API: {api_message}")
# Display version if available
if "version" in detailed_health:
st.info(f"đĻ Version: {detailed_health['version']}")
else:
# Show error information
st.error(f"â Health check failed: {detailed_health.get('message', 'Unknown error')}")
st.caption("Unable to retrieve detailed system status")
# Additional detailed stats (always show regardless of processing state)
if total_queries > 0:
col3, col4 = st.columns(2)
with col3:
st.metric("Success Rate", f"{success_rate}%", delta=f"{successful} successful")
with col4:
session_duration = pd.Timestamp.now() - st.session_state["query_stats"]["session_start"]
hours = int(session_duration.total_seconds() // 3600)
minutes = int((session_duration.total_seconds() % 3600) // 60)
st.metric("Session Time", f"{hours}h {minutes}m", delta="Active")
# Show last update time and refresh button
col_refresh1, col_refresh2 = st.columns([2, 1])
with col_refresh1:
st.caption(f"đ Last updated: {pd.Timestamp.now().strftime('%H:%M:%S')}")
with col_refresh2:
if st.button("đ", help="Force Refresh Status", key="refresh_status"):
# Set force refresh flag to bypass validation
st.session_state["force_refresh_requested"] = True
# Clear all cached data and force fresh API calls
st.cache_data.clear()
# Reset validation timestamps to allow immediate API calls
st.session_state["last_status_update"] = 0
if "recent_ui_action" in st.session_state:
del st.session_state["recent_ui_action"]
# Clear cached status values
for key in ["last_api_status", "last_api_delta", "last_api_type", "last_detailed_health"]:
if key in st.session_state:
del st.session_state[key]
st.rerun()
st.markdown("---")
# Tips and Help (conditional display)
if show_tips:
st.markdown("### đĄ Quick Tips")
with st.expander("đ How to ask questions"):
st.markdown("""
- **"Show me all customers from Chicago"**
- **"What are the top 5 branches by transactions?"**
- **"Calculate total transactions by month"**
- **"Find customers who haven't done any transactions recently"**
""")
with st.expander("⥠Pro Tips"):
st.markdown("""
- Be specific about what data you want
- Mention date ranges when relevant
- Ask for summaries or aggregations
- Use natural language - no SQL needed!
""")
with st.expander("đ§ Troubleshooting"):
st.markdown("""
- **No results?** Try rephrasing your question
- **Error message?** Use the retry button
- **Slow response?** Check your connection
- **Wrong data?** Be more specific in your query
""")
st.markdown("---")
# Quick actions
st.markdown("### đ Quick Actions")
# Check if query is currently being processed
is_processing = st.session_state.get("processing_query", False)
# Show processing indicator if query is running
if is_processing:
st.info("⥠Query in progress... All action buttons are temporarily disabled.")
col_action1, col_action2 = st.columns(2)
with col_action1:
# Disable Clear Chat button when query is processing
clear_chat_disabled = is_processing
clear_chat_help = "Cannot clear chat while query is processing" if is_processing else "Clear all chat messages"
if st.button("đī¸ Clear Chat",
use_container_width=True,
disabled=clear_chat_disabled,
help=clear_chat_help):
# Simple UI action - block API calls for 5 seconds
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp()
st.session_state["messages"] = []
logging.info("Clear Chat button clicked - blocking API calls for 5 seconds")
st.rerun()
with col_action2:
# Disable Reset Stats button when query is processing to prevent any API calls
reset_stats_disabled = is_processing
reset_stats_help = "Cannot reset stats while query is processing" if is_processing else "Reset query statistics"
if st.button("đ Reset Stats",
use_container_width=True,
disabled=reset_stats_disabled,
help=reset_stats_help):
# Simple UI action - block API calls for 5 seconds
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp()
st.session_state["query_stats"] = {
"total_queries": 0,
"successful_queries": 0,
"failed_queries": 0,
"session_start": pd.Timestamp.now()
}
logging.info("Reset Stats button clicked - blocking API calls for 5 seconds")
st.rerun()
# Sample Queries button - also disabled during processing to avoid confusion
sample_disabled = is_processing
sample_help = "Cannot show samples while query is processing" if is_processing else "Show sample queries you can copy"
if st.button("đ Sample Queries",
use_container_width=True,
disabled=sample_disabled,
help=sample_help):
# Simple UI action - block API calls for 5 seconds
st.session_state["recent_ui_action"] = pd.Timestamp.now().timestamp()
logging.info("Sample Queries button clicked - blocking API calls for 5 seconds")
sample_queries = [
"Show me the top 10 customers by transactions",
"Which branches collected more transactions last month?",
"Calculate average transactions value",
"List all active customers"
]
# Display sample queries in the sidebar instead of adding to chat
st.markdown("**Sample queries you can copy and paste:**")
for query in sample_queries:
st.code(query)
# Footer
st.markdown("### âšī¸ About")
st.markdown("""
**AI Database Assistant** v2.0
đ Powered by advanced AI
đŦ Natural language to SQL
đ Real-time analytics
""")
# Apply theme
theme_manager.inject_theme(theme)
# ===============================
# Chat Rendering Function
# ===============================
def render_chat():
"""Render chat history with proper error handling."""
st.markdown("", unsafe_allow_html=True)
messages = st.session_state["messages"]
for i, msg in enumerate(messages):
if msg["role"] == "user":
# User message with avatar
st.markdown(
f'''
''',
unsafe_allow_html=True)
elif msg["role"] == "assistant":
# Show all assistant messages normally (including errors)
bubble_class = "error-bubble" if msg.get("is_error") else "ai-bubble"
# Skip rendering placeholder messages (we use Streamlit spinner instead)
if msg.get("is_placeholder"):
continue
else:
# For error messages, show retry button below the message
if msg.get("is_error"):
st.markdown(
f'''
''',
unsafe_allow_html=True)
# Show retry button below the error message, aligned with the error message
# Use same layout structure as the error message for alignment
cols = st.columns([0.03, 0.85])
with cols[0]:
st.empty() # Empty space where avatar would be
with cols[1]:
if st.button(RETRY_BUTTON_TEXT, key=f"retry_error_{i}"):
# Use the stored user query and index for reliable retry
stored_user_query = msg.get("user_query")
stored_user_index = msg.get("user_query_index")
if stored_user_query and stored_user_index is not None:
# Remove all messages after the user query and retry
st.session_state["messages"] = st.session_state["messages"][:stored_user_index+1]
st.session_state["messages"].append({"role": "assistant", "content": "īŋŊ Processing your query...", "is_placeholder": True})
# Mark this as a legitimate retry, not a UI interaction
st.session_state["legitimate_query_time"] = pd.Timestamp.now().timestamp()
# Remove any recent UI action flag to allow this query to process
if "recent_ui_action" in st.session_state:
del st.session_state["recent_ui_action"]
st.rerun()
else:
# Regular AI message
st.markdown(
f'''
''',
unsafe_allow_html=True)
# Show data table and download if present
if msg.get("data"):
df = pd.DataFrame(msg["data"])
st.dataframe(df, use_container_width=True)
csv = df.to_csv(index=False).encode("utf-8")
st.download_button(DOWNLOAD_BUTTON_TEXT, csv, "results.csv", "text/csv", key=f"download_csv_{id(msg)}")
# Show chart if present
if msg.get("chart"):
img_data = base64.b64decode(msg["chart"])
st.image(img_data, use_column_width=True)
st.markdown("
", unsafe_allow_html=True)
# Render chat
render_chat()
# ===============================
# User Input Handling
# ===============================
# Check if AI is thinking
pending = False
if st.session_state["messages"]:
if st.session_state["messages"][-1]["role"] == "assistant":
pending = st.session_state["messages"][-1].get("is_placeholder", False)
# Get user input
user_query = st.chat_input(CHAT_INPUT_PLACEHOLDER, disabled=pending)
if user_query and not pending:
st.session_state["messages"].append({"role": "user", "content": user_query})
st.session_state["messages"].append({"role": "assistant", "content": "īŋŊ Processing ...", "is_placeholder": True})
# Mark this as a legitimate user query, not a UI interaction
st.session_state["legitimate_query_time"] = pd.Timestamp.now().timestamp()
# Remove any recent UI action flag to allow this query to process
if "recent_ui_action" in st.session_state:
del st.session_state["recent_ui_action"]
st.rerun()
# ===============================
# API Response Handling
# ===============================
# CRITICAL: Only process API calls for legitimate user queries, not UI interactions
# Check if we have a placeholder message from a user query
has_placeholder = (
st.session_state["messages"]
and st.session_state["messages"][-1].get("is_placeholder")
and len(st.session_state["messages"]) >= 2
and st.session_state["messages"][-2]["role"] == "user"
)
# Check if this is a legitimate query that should be processed
# Block if recent UI action (sidebar interactions) triggered this rerun
should_process_query = (
has_placeholder
and not should_skip_query_processing() # Use query-specific validation
)
if has_placeholder:
logging.info(f"=== QUERY PROCESSING CHECK ===")
logging.info(f"Has placeholder: {has_placeholder}")
logging.info(f"Should process query: {should_process_query}")
if not should_process_query:
logging.info("BLOCKED: Query processing blocked by should_skip_api_call")
else:
logging.info("ALLOWED: Query processing allowed")
if should_process_query:
user_query = st.session_state["messages"][-2]["content"]
user_query_index = len(st.session_state["messages"]) - 2 # Store the user query index
# Set processing flag to pause status checks
st.session_state["processing_query"] = True
try:
with st.spinner(THINKING_MESSAGE):
logging.info(f"Sending query to API: {user_query}")
try:
# Use the API client instead of direct requests
result = api_client.send_query(user_query, model, agent)
if result:
answer_text = result.get("message", "No response received")
rows = result.get("rows", [])
chart = result.get("chart", None)
is_error = result.get("error", False)
model_used = result.get("model_used", "unknown")
status = result.get("status", "unknown")
# Add model information to the response for successful queries
if not is_error and model_used != "unknown":
answer_text += f"\n\n*Powered by: {model_used}*"
else:
answer_text = "â Error: Unable to process your request. Please try again."
rows, chart, is_error = [], None, True
model_used, status = "unknown", "error"
except Exception as e:
logging.error(f"API connectivity error: {e}")
answer_text, rows, chart, is_error = (
"â API is not available. Please check your connection or try again later.", [], None, True
)
model_used, status = "unknown", "error"
except Exception as e:
logging.error(f"Exception: {e}")
answer_text, rows, chart, is_error = f"â ī¸ Exception: {str(e)}", [], None, True
model_used, status = "unknown", "error"
# Update query statistics
update_query_stats(success=not is_error)
# Clear processing flag
st.session_state["processing_query"] = False
# Clear legitimate query flag after processing
if "legitimate_query_time" in st.session_state:
del st.session_state["legitimate_query_time"]
# Replace placeholder with final response
st.session_state["messages"][-1] = {
"role": "assistant",
"content": answer_text,
"data": rows,
"chart": chart,
"is_error": is_error,
"model_used": model_used,
"status": status,
"user_query": user_query if is_error else None, # Store user query for retry
"user_query_index": user_query_index if is_error else None, # Store user query index for retry
}
st.rerun()