VayuChat / app.py
AbhayVG's picture
Update app.py
8067596 verified
import streamlit as st
import os
import json
import pandas as pd
import random
from os.path import join
from datetime import datetime
from src import (
preprocess_and_load_df,
load_agent,
ask_agent,
decorate_with_code,
show_response,
get_from_user,
load_smart_df,
ask_question,
)
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from streamlit_feedback import streamlit_feedback
from huggingface_hub import HfApi
from datasets import load_dataset, get_dataset_config_info, Dataset
from PIL import Image
import time
import uuid
# Page config with beautiful theme
st.set_page_config(
page_title="VayuChat - AI Air Quality Assistant",
page_icon="🌬️",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for beautiful styling
st.markdown("""
<style>
/* Clean app background */
.stApp {
background-color: #ffffff;
color: #212529;
font-family: 'Segoe UI', sans-serif;
}
/* Sidebar */
[data-testid="stSidebar"] {
background-color: #f8f9fa;
border-right: 1px solid #dee2e6;
padding: 1rem;
}
/* Main title */
.main-title {
text-align: center;
color: #343a40;
font-size: 2.5rem;
font-weight: 700;
margin-bottom: 0.5rem;
}
/* Subtitle */
.subtitle {
text-align: center;
color: #6c757d;
font-size: 1.1rem;
margin-bottom: 1.5rem;
}
/* Instructions */
.instructions {
background-color: #f1f3f5;
border-left: 4px solid #0d6efd;
padding: 1rem;
margin-bottom: 1.5rem;
border-radius: 6px;
color: #495057;
text-align: left;
}
/* Quick prompt buttons */
.quick-prompt-container {
display: flex;
flex-wrap: wrap;
gap: 8px;
margin-bottom: 1.5rem;
padding: 1rem;
background-color: #f8f9fa;
border-radius: 10px;
border: 1px solid #dee2e6;
}
.quick-prompt-btn {
background-color: #0d6efd;
color: white;
border: none;
padding: 8px 16px;
border-radius: 20px;
font-size: 0.9rem;
cursor: pointer;
transition: all 0.2s ease;
white-space: nowrap;
}
.quick-prompt-btn:hover {
background-color: #0b5ed7;
transform: translateY(-2px);
}
/* User message styling */
.user-message {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 15px 20px;
border-radius: 20px 20px 5px 20px;
margin: 10px 0;
margin-left: auto;
margin-right: 0;
max-width: 80%;
position: relative;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.user-info {
font-size: 0.8rem;
opacity: 0.8;
margin-bottom: 5px;
text-align: right;
}
/* Assistant message styling */
.assistant-message {
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
color: white;
padding: 15px 20px;
border-radius: 20px 20px 20px 5px;
margin: 10px 0;
margin-left: 0;
margin-right: auto;
max-width: 80%;
position: relative;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.assistant-info {
font-size: 0.8rem;
opacity: 0.8;
margin-bottom: 5px;
}
/* Processing indicator */
.processing-indicator {
background: linear-gradient(135deg, #a8edea 0%, #fed6e3 100%);
color: #333;
padding: 15px 20px;
border-radius: 20px 20px 20px 5px;
margin: 10px 0;
margin-left: 0;
margin-right: auto;
max-width: 80%;
position: relative;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
animation: pulse 2s infinite;
}
@keyframes pulse {
0% { opacity: 1; }
50% { opacity: 0.7; }
100% { opacity: 1; }
}
/* Feedback box */
.feedback-section {
background-color: #f8f9fa;
border: 1px solid #dee2e6;
padding: 1rem;
border-radius: 8px;
margin: 1rem 0;
}
/* Success and error messages */
.success-message {
background-color: #d1e7dd;
color: #0f5132;
padding: 1rem;
border-radius: 6px;
border: 1px solid #badbcc;
}
.error-message {
background-color: #f8d7da;
color: #842029;
padding: 1rem;
border-radius: 6px;
border: 1px solid #f5c2c7;
}
/* Chat input */
.stChatInput {
border-radius: 6px;
border: 1px solid #ced4da;
background: #ffffff;
}
/* Button */
.stButton > button {
background-color: #0d6efd;
color: white;
border-radius: 6px;
padding: 0.5rem 1.25rem;
border: none;
font-weight: 600;
transition: background-color 0.2s ease;
}
.stButton > button:hover {
background-color: #0b5ed7;
}
/* Code details styling */
.code-details {
background-color: #f8f9fa;
border: 1px solid #dee2e6;
border-radius: 8px;
padding: 10px;
margin-top: 10px;
}
/* Hide default menu and footer */
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
/* Auto scroll */
.main-container {
height: 70vh;
overflow-y: auto;
}
</style>
""", unsafe_allow_html=True)
# Auto-scroll JavaScript
st.markdown("""
<script>
function scrollToBottom() {
setTimeout(function() {
const mainContainer = document.querySelector('.main-container');
if (mainContainer) {
mainContainer.scrollTop = mainContainer.scrollHeight;
}
window.scrollTo(0, document.body.scrollHeight);
}, 100);
}
</script>
""", unsafe_allow_html=True)
# FORCE reload environment variables
load_dotenv(override=True)
# Get API keys
Groq_Token = os.getenv("GROQ_API_KEY")
hf_token = os.getenv("HF_TOKEN")
gemini_token = os.getenv("GEMINI_TOKEN")
models = {
"gpt-oss-20b": "openai/gpt-oss-20b",
"gpt-oss-120b": "openai/gpt-oss-120b",
"llama3.1": "llama-3.1-8b-instant",
"llama3.3": "llama-3.3-70b-versatile",
"deepseek-R1": "deepseek-r1-distill-llama-70b",
"llama4 maverik":"meta-llama/llama-4-maverick-17b-128e-instruct",
"llama4 scout":"meta-llama/llama-4-scout-17b-16e-instruct",
"gemini-pro": "gemini-1.5-pro"
}
self_path = os.path.dirname(os.path.abspath(__file__))
# Initialize session ID for this session
if "session_id" not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
def upload_feedback(feedback, error, output, last_prompt, code, status):
"""Enhanced feedback upload function with better logging and error handling"""
try:
if not hf_token or hf_token.strip() == "":
st.warning("⚠️ Cannot upload feedback - HF_TOKEN not available")
return False
# Create comprehensive feedback data
feedback_data = {
"timestamp": datetime.now().isoformat(),
"session_id": st.session_state.session_id,
"feedback_score": feedback.get("score", ""),
"feedback_comment": feedback.get("text", ""),
"user_prompt": last_prompt,
"ai_output": str(output),
"generated_code": code or "",
"error_message": error or "",
"is_image_output": status.get("is_image", False),
"success": not bool(error)
}
# Create unique folder name with timestamp
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
random_id = str(uuid.uuid4())[:8]
folder_name = f"feedback_{timestamp_str}_{random_id}"
# Create markdown feedback file
markdown_content = f"""# VayuChat Feedback Report
## Session Information
- **Timestamp**: {feedback_data['timestamp']}
- **Session ID**: {feedback_data['session_id']}
## User Interaction
**Prompt**: {feedback_data['user_prompt']}
## AI Response
**Output**: {feedback_data['ai_output']}
## Generated Code
```python
{feedback_data['generated_code']}
```
## Technical Details
- **Error Message**: {feedback_data['error_message']}
- **Is Image Output**: {feedback_data['is_image_output']}
- **Success**: {feedback_data['success']}
## User Feedback
- **Score**: {feedback_data['feedback_score']}
- **Comments**: {feedback_data['feedback_comment']}
"""
# Save markdown file locally
markdown_filename = f"{folder_name}.md"
markdown_local_path = f"/tmp/{markdown_filename}"
with open(markdown_local_path, "w", encoding="utf-8") as f:
f.write(markdown_content)
# Upload to Hugging Face
api = HfApi(token=hf_token)
# Upload markdown feedback
api.upload_file(
path_or_fileobj=markdown_local_path,
path_in_repo=f"data/{markdown_filename}",
repo_id="SustainabilityLabIITGN/VayuChat_Feedback",
repo_type="dataset",
)
# Upload image if it exists and is an image output
if status.get("is_image", False) and isinstance(output, str) and os.path.exists(output):
try:
image_filename = f"{folder_name}_plot.png"
api.upload_file(
path_or_fileobj=output,
path_in_repo=f"data/{image_filename}",
repo_id="SustainabilityLabIITGN/VayuChat_Feedback",
repo_type="dataset",
)
except Exception as img_error:
print(f"Error uploading image: {img_error}")
# Clean up local files
if os.path.exists(markdown_local_path):
os.remove(markdown_local_path)
st.success("πŸŽ‰ Feedback uploaded successfully!")
return True
except Exception as e:
st.error(f"❌ Error uploading feedback: {e}")
print(f"Feedback upload error: {e}")
return False
# Beautiful header
st.markdown("<h1 class='main-title'>🌬️ VayuChat</h1>", unsafe_allow_html=True)
st.markdown("""
<div class='subtitle'>
<strong>AI-Powered Air Quality Insights</strong><br>
Simplifying pollution analysis using conversational AI.
</div>
""", unsafe_allow_html=True)
st.markdown("""
<div class='instructions'>
<strong>How to Use:</strong><br>
Select a model from the sidebar and ask questions directly in the chat. Use quick prompts below for common queries.
</div>
""", unsafe_allow_html=True)
os.environ["PANDASAI_API_KEY"] = "$2a$10$gbmqKotzJOnqa7iYOun8eO50TxMD/6Zw1pLI2JEoqncwsNx4XeBS2"
# Load data with error handling
try:
df = preprocess_and_load_df(join(self_path, "Data.csv"))
st.success("βœ… Data loaded successfully!")
except Exception as e:
st.error(f"❌ Error loading data: {e}")
st.stop()
inference_server = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
image_path = "IITGN_Logo.png"
# Beautiful sidebar
with st.sidebar:
# Logo and title
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
if os.path.exists(image_path):
st.image(image_path, use_column_width=True)
# Session info
st.markdown(f"**Session ID**: `{st.session_state.session_id[:8]}...`")
# Model selection
st.markdown("### πŸ€– AI Model Selection")
# Filter available models
available_models = []
model_names = list(models.keys())
groq_models = []
gemini_models = []
for model_name in model_names:
if "gemini" not in model_name:
groq_models.append(model_name)
else:
gemini_models.append(model_name)
if Groq_Token and Groq_Token.strip():
available_models.extend(groq_models)
if gemini_token and gemini_token.strip():
available_models.extend(gemini_models)
if not available_models:
st.error("❌ No API keys available! Please set up your API keys in the .env file")
st.stop()
model_name = st.selectbox(
"Choose your AI assistant:",
available_models,
help="Different models have different strengths. Try them all!"
)
# Model descriptions
model_descriptions = {
"llama3.1": "πŸ¦™ Fast and efficient for general queries",
"llama3.3": "πŸ¦™ Most advanced LLaMA model for complex reasoning",
"mistral": "⚑ Balanced performance and speed",
"gemma": "πŸ’Ž Google's lightweight model",
"gemini-pro": "🧠 Google's most powerful model",
"gpt-oss-20b": "πŸ“˜ OpenAI's compact open-weight GPT for everyday tasks",
"gpt-oss-120b": "πŸ“š OpenAI's massive open-weight GPT for nuanced responses",
"deepseek-R1": "πŸ” DeepSeek's distilled LLaMA model for efficient reasoning",
"llama4 maverik": "πŸš€ Meta's LLaMA 4 Maverick β€” high-performance instruction model",
"llama4 scout": "πŸ›°οΈ Meta's LLaMA 4 Scout β€” optimized for adaptive reasoning"
}
if model_name in model_descriptions:
st.info(model_descriptions[model_name])
st.markdown("---")
# Logging status
st.markdown("### πŸ“Š Logging Status")
if hf_token and hf_token.strip():
st.success("βœ… Logging enabled")
st.caption("Interactions are being logged to HuggingFace")
else:
st.warning("⚠️ Logging disabled")
st.caption("HF_TOKEN not available")
st.markdown("---")
# Clear Chat Button
if st.button("🧹 Clear Chat"):
st.session_state.responses = []
st.session_state.processing = False
# Generate new session ID for new chat
st.session_state.session_id = str(uuid.uuid4())
try:
st.rerun()
except AttributeError:
st.experimental_rerun()
st.markdown("---")
# Chat History in Sidebar
with st.expander("πŸ“œ Chat History"):
for i, response in enumerate(st.session_state.get("responses", [])):
if response.get("role") == "user":
st.markdown(f"**You:** {response.get('content', '')[:50]}...")
elif response.get("role") == "assistant":
content = response.get('content', '')
if isinstance(content, str) and len(content) > 50:
st.markdown(f"**VayuChat:** {content[:50]}...")
else:
st.markdown(f"**VayuChat:** {str(content)[:50]}...")
st.markdown("---")
# Load quick prompts
questions = []
questions_file = join(self_path, "questions.txt")
if os.path.exists(questions_file):
try:
with open(questions_file, 'r', encoding='utf-8') as f:
content = f.read()
questions = [q.strip() for q in content.split("\n") if q.strip()]
print(f"Loaded {len(questions)} quick prompts") # Debug
except Exception as e:
st.error(f"Error loading questions: {e}")
questions = []
# Add some default prompts if file doesn't exist or is empty
if not questions:
questions = [
"What is the average PM2.5 level in the dataset?",
"Show me the air quality trend over time",
"Which pollutant has the highest concentration?",
"Create a correlation plot between different pollutants",
"What are the peak pollution hours?",
"Compare weekday vs weekend pollution levels"
]
# Quick prompts section (horizontal)
st.markdown("### πŸ’­ Quick Prompts")
# Create columns for horizontal layout
cols_per_row = 2 # Reduced to 2 for better fit
rows = [questions[i:i + cols_per_row] for i in range(0, len(questions), cols_per_row)]
selected_prompt = None
for row_idx, row in enumerate(rows):
cols = st.columns(len(row))
for col_idx, question in enumerate(row):
with cols[col_idx]:
# Create unique key using row and column indices
unique_key = f"prompt_btn_{row_idx}_{col_idx}"
button_text = f"πŸ“ {question[:35]}{'...' if len(question) > 35 else ''}"
if st.button(button_text,
key=unique_key,
help=question,
use_container_width=True):
selected_prompt = question
st.markdown("---")
# Initialize chat history and processing state
if "responses" not in st.session_state:
st.session_state.responses = []
if "processing" not in st.session_state:
st.session_state.processing = False
def show_custom_response(response):
"""Custom response display function"""
role = response.get("role", "assistant")
content = response.get("content", "")
if role == "user":
st.markdown(f"""
<div class='user-message'>
<div class='user-info'>You</div>
{content}
</div>
""", unsafe_allow_html=True)
elif role == "assistant":
st.markdown(f"""
<div class='assistant-message'>
<div class='assistant-info'>πŸ€– VayuChat</div>
{content if isinstance(content, str) else str(content)}
</div>
""", unsafe_allow_html=True)
# Show generated code if available
if response.get("gen_code"):
with st.expander("πŸ“‹ View Generated Code"):
st.code(response["gen_code"], language="python")
# Try to display image if content is a file path
try:
if isinstance(content, str) and (content.endswith('.png') or content.endswith('.jpg')):
if os.path.exists(content):
st.image(content)
return {"is_image": True}
except:
pass
return {"is_image": False}
def show_processing_indicator(model_name, question):
"""Show processing indicator"""
st.markdown(f"""
<div class='processing-indicator'>
<div class='assistant-info'>πŸ€– VayuChat β€’ Processing with {model_name}</div>
<strong>Question:</strong> {question}<br>
<em>πŸ”„ Generating response...</em>
</div>
""", unsafe_allow_html=True)
# Main chat container
chat_container = st.container()
with chat_container:
# Display chat history
for response_id, response in enumerate(st.session_state.responses):
status = show_custom_response(response)
# Show feedback section for assistant responses
if response["role"] == "assistant":
feedback_key = f"feedback_{int(response_id/2)}"
error = response.get("error", "")
output = response.get("content", "")
last_prompt = response.get("last_prompt", "")
code = response.get("gen_code", "")
if "feedback" in st.session_state.responses[response_id]:
feedback_data = st.session_state.responses[response_id]["feedback"]
st.markdown(f"""
<div class='feedback-section'>
<strong>πŸ“ Your Feedback:</strong> {feedback_data.get('score', '')}
{f"- {feedback_data.get('text', '')}" if feedback_data.get('text') else ""}
</div>
""", unsafe_allow_html=True)
else:
# Beautiful feedback section
st.markdown("---")
st.markdown("**How was this response?**")
col1, col2 = st.columns(2)
with col1:
thumbs_up = st.button("πŸ‘ Helpful", key=f"{feedback_key}_up", use_container_width=True)
with col2:
thumbs_down = st.button("πŸ‘Ž Not Helpful", key=f"{feedback_key}_down", use_container_width=True)
if thumbs_up or thumbs_down:
thumbs = "πŸ‘ Helpful" if thumbs_up else "πŸ‘Ž Not Helpful"
comments = st.text_area(
"πŸ’¬ Tell us more (optional):",
key=f"{feedback_key}_comments",
placeholder="What could be improved? Any suggestions?",
max_chars=500
)
if st.button("πŸš€ Submit Feedback", key=f"{feedback_key}_submit"):
feedback = {"score": thumbs, "text": comments}
# Upload feedback with enhanced error handling
if upload_feedback(feedback, error, output, last_prompt, code, status or {}):
st.session_state.responses[response_id]["feedback"] = feedback
time.sleep(1) # Give user time to see success message
st.rerun()
else:
st.error("Failed to submit feedback. Please try again.")
# Show processing indicator if processing
if st.session_state.get("processing"):
show_processing_indicator(
st.session_state.get("current_model", "Unknown"),
st.session_state.get("current_question", "Processing...")
)
# Chat input (always visible at bottom)
prompt = st.chat_input("πŸ’¬ Ask me anything about air quality!", key="main_chat")
# Handle selected prompt from quick prompts
if selected_prompt:
prompt = selected_prompt
# Handle new queries
if prompt and not st.session_state.get("processing"):
# Prevent duplicate processing
if "last_prompt" in st.session_state:
last_prompt = st.session_state["last_prompt"]
last_model_name = st.session_state.get("last_model_name", "")
if (prompt == last_prompt) and (model_name == last_model_name):
prompt = None
if prompt:
# Add user input to chat history
user_response = get_from_user(prompt)
st.session_state.responses.append(user_response)
# Set processing state
st.session_state.processing = True
st.session_state.current_model = model_name
st.session_state.current_question = prompt
# Rerun to show processing indicator
st.rerun()
# Process the question if we're in processing state
if st.session_state.get("processing"):
prompt = st.session_state.get("current_question")
model_name = st.session_state.get("current_model")
try:
response = ask_question(model_name=model_name, question=prompt)
if not isinstance(response, dict):
response = {
"role": "assistant",
"content": "❌ Error: Invalid response format",
"gen_code": "",
"ex_code": "",
"last_prompt": prompt,
"error": "Invalid response format"
}
response.setdefault("role", "assistant")
response.setdefault("content", "No content generated")
response.setdefault("gen_code", "")
response.setdefault("ex_code", "")
response.setdefault("last_prompt", prompt)
response.setdefault("error", None)
except Exception as e:
response = {
"role": "assistant",
"content": f"Sorry, I encountered an error: {str(e)}",
"gen_code": "",
"ex_code": "",
"last_prompt": prompt,
"error": str(e)
}
st.session_state.responses.append(response)
st.session_state["last_prompt"] = prompt
st.session_state["last_model_name"] = model_name
st.session_state.processing = False
# Clear processing state
if "current_model" in st.session_state:
del st.session_state.current_model
if "current_question" in st.session_state:
del st.session_state.current_question
st.rerun()
# Auto-scroll to bottom
if st.session_state.responses:
st.markdown("<script>scrollToBottom();</script>", unsafe_allow_html=True)
# Beautiful sidebar footer
# with st.sidebar:
# st.markdown("---")
# st.markdown("""
# <div class='contact-section'>
# <h4>πŸ“„ Paper on VayuChat</h4>
# <p>Learn more about VayuChat in our <a href='https://arxiv.org/abs/2411.12760' target='_blank'>Research Paper</a>.</p>
# </div>
# """, unsafe_allow_html=True)
# Statistics (if logging is enabled)
if hf_token and hf_token.strip():
st.markdown("### πŸ“ˆ Session Stats")
total_interactions = len([r for r in st.session_state.get("responses", []) if r.get("role") == "assistant"])
st.metric("Interactions", total_interactions)
feedbacks_given = len([r for r in st.session_state.get("responses", []) if r.get("role") == "assistant" and "feedback" in r])
st.metric("Feedbacks Given", feedbacks_given)
# Footer
st.markdown("""
<div style='text-align: center; margin-top: 3rem; padding: 2rem; background: rgba(255,255,255,0.1); border-radius: 15px;'>
<h3>🌍 Together for Cleaner Air</h3>
<p>VayuChat - Empowering environmental awareness through AI</p>
<small>Β© 2024 IIT Gandhinagar Sustainability Lab</small>
</div>
""", unsafe_allow_html=True)