Spaces:
Running
Running
import streamlit as st | |
import requests | |
import os | |
import json | |
import uuid | |
from datetime import datetime, timedelta | |
from sentence_transformers import SentenceTransformer | |
import chromadb | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
import re | |
import shutil | |
from git import Repo | |
# Page configuration | |
st.set_page_config( | |
page_title="RAG Chat Flow β", | |
page_icon="β", | |
initial_sidebar_state="expanded" | |
) | |
# Initialize dark mode state | |
if 'dark_mode' not in st.session_state: | |
st.session_state.dark_mode = False | |
# Define personality questions - reduced to general ones | |
PERSONALITY_QUESTIONS = [ | |
"What is [name]'s personality like?", | |
"What does [name] do for work?", | |
"What are [name]'s hobbies?", | |
"What makes [name] special?", | |
"Tell me about [name]" | |
] | |
# Enhanced CSS styling with dark mode support | |
def get_css_styles(): | |
if st.session_state.dark_mode: | |
return """ | |
<style> | |
/* Dark Mode Styles */ | |
.stApp { | |
background: #0e1117; | |
color: #fafafa; | |
} | |
.main .block-container { | |
max-width: 900px; | |
} | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
header {visibility: hidden;} | |
.stDeployButton {display: none;} | |
/* Sidebar dark mode */ | |
.css-1d391kg { | |
background-color: #1e1e1e !important; | |
} | |
.css-1cypcdb { | |
background-color: #1e1e1e !important; | |
} | |
/* Chat messages dark mode */ | |
.stChatMessage { | |
background-color: #262730 !important; | |
border: 1px solid #404040 !important; | |
} | |
/* Input fields dark mode */ | |
.stTextInput > div > div > input { | |
background-color: #262730 !important; | |
color: #fafafa !important; | |
border-color: #404040 !important; | |
} | |
.stTextArea > div > div > textarea { | |
background-color: #262730 !important; | |
color: #fafafa !important; | |
border-color: #404040 !important; | |
} | |
.model-id { | |
color: #4ade80; | |
font-family: monospace; | |
} | |
.model-attribution { | |
color: #4ade80; | |
font-size: 0.8em; | |
font-style: italic; | |
} | |
.rag-attribution { | |
color: #a78bfa; | |
font-size: 0.8em; | |
font-style: italic; | |
background: #1f2937; | |
padding: 8px; | |
border-radius: 4px; | |
border-left: 3px solid #a78bfa; | |
margin-top: 8px; | |
} | |
/* Dark mode toggle button */ | |
.dark-mode-toggle { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
color: white; | |
padding: 8px 16px; | |
margin: 4px 0; | |
border-radius: 8px; | |
border: none; | |
cursor: pointer; | |
transition: all 0.3s ease; | |
font-size: 0.9em; | |
width: 100%; | |
text-align: center; | |
} | |
.dark-mode-toggle:hover { | |
transform: translateY(-1px); | |
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4); | |
} | |
/* NEW CHAT BUTTON - Black background for dark mode */ | |
.stButton > button[kind="primary"] { | |
background-color: #1f2937 !important; | |
border-color: #374151 !important; | |
color: #fafafa !important; | |
} | |
.stButton > button[kind="primary"]:hover { | |
background-color: #374151 !important; | |
border-color: #4b5563 !important; | |
color: #fafafa !important; | |
} | |
/* Regular buttons dark mode */ | |
.stButton > button { | |
background-color: #374151 !important; | |
border-color: #4b5563 !important; | |
color: #fafafa !important; | |
} | |
.stButton > button:hover { | |
background-color: #4b5563 !important; | |
border-color: #6b7280 !important; | |
color: #fafafa !important; | |
} | |
/* Personality Questions Styling Dark Mode */ | |
.personality-question { | |
background: linear-gradient(135deg, #4f46e5 0%, #7c3aed 100%); | |
color: white; | |
padding: 8px 12px; | |
margin: 4px 0; | |
border-radius: 8px; | |
border: none; | |
cursor: pointer; | |
transition: all 0.3s ease; | |
font-size: 0.85em; | |
width: 100%; | |
text-align: left; | |
} | |
.personality-question:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 4px 12px rgba(79, 70, 229, 0.4); | |
} | |
.personality-section { | |
background: #1f2937; | |
color: #e5e7eb; | |
padding: 12px; | |
border-radius: 8px; | |
border-left: 4px solid #4f46e5; | |
margin: 10px 0; | |
} | |
/* Chat history styling dark mode */ | |
.chat-history-item { | |
padding: 8px 12px; | |
margin: 4px 0; | |
border-radius: 8px; | |
border: 1px solid #374151; | |
background: #1f2937; | |
color: #e5e7eb; | |
cursor: pointer; | |
transition: all 0.2s; | |
} | |
.chat-history-item:hover { | |
background: #374151; | |
border-color: #4ade80; | |
} | |
.document-status { | |
background: #1e3a8a; | |
color: #bfdbfe; | |
padding: 10px; | |
border-radius: 8px; | |
border-left: 4px solid #3b82f6; | |
margin: 10px 0; | |
} | |
.github-status { | |
background: #581c87; | |
color: #e9d5ff; | |
padding: 10px; | |
border-radius: 8px; | |
border-left: 4px solid #a78bfa; | |
margin: 10px 0; | |
} | |
.rag-stats { | |
background: #581c87; | |
color: #e9d5ff; | |
padding: 8px; | |
border-radius: 6px; | |
font-size: 0.85em; | |
} | |
/* Expander dark mode */ | |
.streamlit-expanderHeader { | |
background-color: #1f2937 !important; | |
color: #fafafa !important; | |
} | |
.streamlit-expanderContent { | |
background-color: #111827 !important; | |
color: #fafafa !important; | |
} | |
/* Checkbox dark mode */ | |
.stCheckbox { | |
color: #fafafa !important; | |
} | |
/* Select box dark mode */ | |
.stSelectbox > div > div { | |
background-color: #262730 !important; | |
color: #fafafa !important; | |
} | |
/* File uploader dark mode */ | |
.stFileUploader { | |
background-color: #1f2937 !important; | |
border-color: #374151 !important; | |
} | |
/* Progress bar dark mode */ | |
.stProgress .st-bo { | |
background-color: #374151 !important; | |
} | |
/* Success/Error/Warning messages dark mode */ | |
.stSuccess { | |
background-color: #064e3b !important; | |
color: #6ee7b7 !important; | |
} | |
.stError { | |
background-color: #7f1d1d !important; | |
color: #fca5a5 !important; | |
} | |
.stWarning { | |
background-color: #78350f !important; | |
color: #fcd34d !important; | |
} | |
.stInfo { | |
background-color: #1e3a8a !important; | |
color: #93c5fd !important; | |
} | |
/* Caption text dark mode */ | |
.caption { | |
color: #9ca3af !important; | |
} | |
/* Divider dark mode */ | |
hr { | |
border-color: #374151 !important; | |
} | |
</style> | |
""" | |
else: | |
return """ | |
<style> | |
/* Light Mode Styles */ | |
.stApp { | |
background: white; | |
color: #000000; | |
} | |
.main .block-container { | |
max-width: 900px; | |
} | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
header {visibility: hidden;} | |
.stDeployButton {display: none;} | |
.model-id { | |
color: #28a745; | |
font-family: monospace; | |
} | |
.model-attribution { | |
color: #28a745; | |
font-size: 0.8em; | |
font-style: italic; | |
} | |
.rag-attribution { | |
color: #6f42c1; | |
font-size: 0.8em; | |
font-style: italic; | |
background: #f8f9fa; | |
padding: 8px; | |
border-radius: 4px; | |
border-left: 3px solid #6f42c1; | |
margin-top: 8px; | |
} | |
/* Light mode toggle button */ | |
.dark-mode-toggle { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
color: white; | |
padding: 8px 16px; | |
margin: 4px 0; | |
border-radius: 8px; | |
border: none; | |
cursor: pointer; | |
transition: all 0.3s ease; | |
font-size: 0.9em; | |
width: 100%; | |
text-align: center; | |
} | |
.dark-mode-toggle:hover { | |
transform: translateY(-1px); | |
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3); | |
} | |
/* NEW CHAT BUTTON - Black background */ | |
.stButton > button[kind="primary"] { | |
background-color: #000000 !important; | |
border-color: #000000 !important; | |
color: #ffffff !important; | |
} | |
.stButton > button[kind="primary"]:hover { | |
background-color: #333333 !important; | |
border-color: #333333 !important; | |
color: #ffffff !important; | |
} | |
/* Personality Questions Styling */ | |
.personality-question { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
color: white; | |
padding: 8px 12px; | |
margin: 4px 0; | |
border-radius: 8px; | |
border: none; | |
cursor: pointer; | |
transition: all 0.3s ease; | |
font-size: 0.85em; | |
width: 100%; | |
text-align: left; | |
} | |
.personality-question:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3); | |
} | |
.personality-section { | |
background: #f8f9ff; | |
padding: 12px; | |
border-radius: 8px; | |
border-left: 4px solid #667eea; | |
margin: 10px 0; | |
} | |
/* Chat history styling */ | |
.chat-history-item { | |
padding: 8px 12px; | |
margin: 4px 0; | |
border-radius: 8px; | |
border: 1px solid #e0e0e0; | |
background: #f8f9fa; | |
cursor: pointer; | |
transition: all 0.2s; | |
} | |
.chat-history-item:hover { | |
background: #e9ecef; | |
border-color: #28a745; | |
} | |
.document-status { | |
background: #e3f2fd; | |
padding: 10px; | |
border-radius: 8px; | |
border-left: 4px solid #2196f3; | |
margin: 10px 0; | |
} | |
.github-status { | |
background: #f3e5f5; | |
padding: 10px; | |
border-radius: 8px; | |
border-left: 4px solid #6f42c1; | |
margin: 10px 0; | |
} | |
.rag-stats { | |
background: #f3e5f5; | |
padding: 8px; | |
border-radius: 6px; | |
font-size: 0.85em; | |
color: #4a148c; | |
} | |
</style> | |
""" | |
# Apply CSS styles | |
st.markdown(get_css_styles(), unsafe_allow_html=True) | |
# File paths | |
HISTORY_FILE = "rag_chat_history.json" | |
SESSIONS_FILE = "rag_chat_sessions.json" | |
USERS_FILE = "online_users.json" | |
# ================= GITHUB INTEGRATION ================= | |
def clone_github_repo(): | |
"""Clone or update GitHub repository with documents""" | |
github_token = os.getenv("GITHUB_TOKEN") | |
if not github_token: | |
st.error("π GITHUB_TOKEN not found in environment variables") | |
return False | |
try: | |
repo_url = f"https://{github_token}@github.com/Umer-K/family-profiles.git" | |
repo_dir = "family_profiles" | |
# Clean up existing directory if it exists | |
if os.path.exists(repo_dir): | |
shutil.rmtree(repo_dir) | |
# Clone the repository | |
with st.spinner("π Cloning private repository..."): | |
Repo.clone_from(repo_url, repo_dir) | |
# Copy txt files to documents folder | |
documents_dir = "documents" | |
os.makedirs(documents_dir, exist_ok=True) | |
# Clear existing documents | |
for file in os.listdir(documents_dir): | |
if file.endswith('.txt'): | |
os.remove(os.path.join(documents_dir, file)) | |
# Copy new txt files from repo | |
txt_files_found = 0 | |
for root, dirs, files in os.walk(repo_dir): | |
for file in files: | |
if file.endswith('.txt'): | |
src_path = os.path.join(root, file) | |
dst_path = os.path.join(documents_dir, file) | |
shutil.copy2(src_path, dst_path) | |
txt_files_found += 1 | |
# Clean up repo directory | |
shutil.rmtree(repo_dir) | |
st.success(f"β Successfully synced {txt_files_found} documents from GitHub!") | |
return True | |
except Exception as e: | |
st.error(f"β GitHub sync failed: {str(e)}") | |
return False | |
def check_github_status(): | |
"""Check GitHub token availability and repo access""" | |
github_token = os.getenv("GITHUB_TOKEN") | |
if not github_token: | |
return { | |
"status": "missing", | |
"message": "No GitHub token found", | |
"color": "red" | |
} | |
try: | |
# Test token by making a simple API call | |
headers = { | |
"Authorization": f"token {github_token}", | |
"Accept": "application/vnd.github.v3+json" | |
} | |
response = requests.get( | |
"https://api.github.com/repos/Umer-K/family-profiles", | |
headers=headers, | |
timeout=10 | |
) | |
if response.status_code == 200: | |
return { | |
"status": "connected", | |
"message": "GitHub access verified", | |
"color": "green" | |
} | |
elif response.status_code == 404: | |
return { | |
"status": "not_found", | |
"message": "Repository not found or no access", | |
"color": "orange" | |
} | |
elif response.status_code == 401: | |
return { | |
"status": "unauthorized", | |
"message": "Invalid GitHub token", | |
"color": "red" | |
} | |
else: | |
return { | |
"status": "error", | |
"message": f"GitHub API error: {response.status_code}", | |
"color": "orange" | |
} | |
except Exception as e: | |
return { | |
"status": "error", | |
"message": f"Connection error: {str(e)}", | |
"color": "orange" | |
} | |
# ================= RAG SYSTEM CLASS ================= | |
def initialize_rag_system(): | |
"""Initialize RAG system with caching""" | |
return ProductionRAGSystem() | |
class ProductionRAGSystem: | |
def __init__(self, collection_name="streamlit_rag_docs"): | |
self.collection_name = collection_name | |
# Initialize embedding model | |
try: | |
self.model = SentenceTransformer('all-mpnet-base-v2') | |
except Exception as e: | |
st.error(f"Error loading embedding model: {e}") | |
self.model = None | |
return | |
# Initialize ChromaDB | |
try: | |
self.client = chromadb.PersistentClient(path="./chroma_db") | |
try: | |
self.collection = self.client.get_collection(collection_name) | |
except: | |
self.collection = self.client.create_collection(collection_name) | |
except Exception as e: | |
st.error(f"Error initializing ChromaDB: {e}") | |
self.client = None | |
return | |
# Initialize text splitter | |
self.text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=800, | |
chunk_overlap=100, | |
length_function=len, | |
separators=["\n\n", "\n", ". ", " ", ""] | |
) | |
def get_collection_count(self): | |
"""Get number of documents in collection""" | |
try: | |
return self.collection.count() if self.collection else 0 | |
except: | |
return 0 | |
def load_documents_from_folder(self, folder_path="documents"): | |
"""Load documents from folder""" | |
if not os.path.exists(folder_path): | |
return [] | |
txt_files = [f for f in os.listdir(folder_path) if f.endswith('.txt')] | |
if not txt_files: | |
return [] | |
all_chunks = [] | |
for filename in txt_files: | |
filepath = os.path.join(folder_path, filename) | |
try: | |
with open(filepath, 'r', encoding='utf-8') as f: | |
content = f.read().strip() | |
if content: | |
chunks = self.text_splitter.split_text(content) | |
for i, chunk in enumerate(chunks): | |
all_chunks.append({ | |
'content': chunk, | |
'source_file': filename, | |
'chunk_index': i, | |
'char_count': len(chunk) | |
}) | |
except Exception as e: | |
st.error(f"Error reading {filename}: {e}") | |
return all_chunks | |
def index_documents(self, document_folder="documents"): | |
"""Index documents with progress bar""" | |
if not self.model or not self.client: | |
return False | |
chunks = self.load_documents_from_folder(document_folder) | |
if not chunks: | |
return False | |
# Clear existing collection | |
try: | |
self.client.delete_collection(self.collection_name) | |
self.collection = self.client.create_collection(self.collection_name) | |
except: | |
pass | |
# Create embeddings with progress bar | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
chunk_texts = [chunk['content'] for chunk in chunks] | |
try: | |
status_text.text("Creating embeddings...") | |
embeddings = self.model.encode(chunk_texts, show_progress_bar=False) | |
status_text.text("Storing in database...") | |
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): | |
chunk_id = f"{chunk['source_file']}_{chunk['chunk_index']}" | |
metadata = { | |
"source_file": chunk['source_file'], | |
"chunk_index": chunk['chunk_index'], | |
"char_count": chunk['char_count'] | |
} | |
self.collection.add( | |
documents=[chunk['content']], | |
ids=[chunk_id], | |
embeddings=[embedding.tolist()], | |
metadatas=[metadata] | |
) | |
progress_bar.progress((i + 1) / len(chunks)) | |
progress_bar.empty() | |
status_text.empty() | |
return True | |
except Exception as e: | |
st.error(f"Error during indexing: {e}") | |
progress_bar.empty() | |
status_text.empty() | |
return False | |
def expand_query_with_family_terms(self, query): | |
"""Expand query to include family relationship synonyms""" | |
family_mappings = { | |
'mother': ['mama', 'mom', 'ammi'], | |
'mama': ['mother', 'mom', 'ammi'], | |
'father': ['papa', 'dad', 'abbu'], | |
'papa': ['father', 'dad', 'abbu'], | |
'brother': ['bhai', 'bro'], | |
'bhai': ['brother', 'bro'], | |
'sister': ['behn', 'sis'], | |
'behn': ['sister', 'sis'] | |
} | |
expanded_terms = [query] | |
query_lower = query.lower() | |
for key, synonyms in family_mappings.items(): | |
if key in query_lower: | |
for synonym in synonyms: | |
expanded_terms.append(query_lower.replace(key, synonym)) | |
return expanded_terms | |
def search(self, query, n_results=5): | |
"""Search for relevant chunks with family relationship mapping""" | |
if not self.model or not self.collection: | |
return None | |
try: | |
# Expand query with family terms | |
expanded_queries = self.expand_query_with_family_terms(query) | |
all_results = [] | |
# Search with all expanded terms | |
for search_query in expanded_queries: | |
query_embedding = self.model.encode([search_query])[0].tolist() | |
results = self.collection.query( | |
query_embeddings=[query_embedding], | |
n_results=n_results | |
) | |
if results['documents'][0]: | |
for chunk, distance, metadata in zip( | |
results['documents'][0], | |
results['distances'][0], | |
results['metadatas'][0] | |
): | |
similarity = max(0, 1 - distance) | |
all_results.append({ | |
'content': chunk, | |
'metadata': metadata, | |
'similarity': similarity, | |
'query_used': search_query | |
}) | |
if not all_results: | |
return None | |
# Remove duplicates and sort by similarity | |
seen_chunks = set() | |
unique_results = [] | |
for result in all_results: | |
chunk_id = f"{result['metadata']['source_file']}_{result['content'][:50]}" | |
if chunk_id not in seen_chunks: | |
seen_chunks.add(chunk_id) | |
unique_results.append(result) | |
# Sort by similarity and take top results | |
unique_results.sort(key=lambda x: x['similarity'], reverse=True) | |
search_results = unique_results[:n_results] | |
# Debug: Show search results for troubleshooting | |
print(f"Search for '{query}' (expanded to {len(expanded_queries)} terms) found {len(search_results)} results") | |
for i, result in enumerate(search_results[:3]): | |
print(f" {i+1}. Similarity: {result['similarity']:.3f} | Source: {result['metadata']['source_file']} | Query: {result['query_used']}") | |
print(f" Content preview: {result['content'][:100]}...") | |
return search_results | |
except Exception as e: | |
st.error(f"Search error: {e}") | |
return None | |
def extract_direct_answer(self, query, content): | |
"""Extract direct answer from content""" | |
query_lower = query.lower() | |
sentences = re.split(r'[.!?]+', content) | |
sentences = [s.strip() for s in sentences if len(s.strip()) > 20] | |
query_words = set(query_lower.split()) | |
scored_sentences = [] | |
for sentence in sentences: | |
sentence_words = set(sentence.lower().split()) | |
exact_matches = len(query_words.intersection(sentence_words)) | |
# Bonus scoring for key terms | |
bonus_score = 0 | |
if '401k' in query_lower and ('401' in sentence.lower() or 'retirement' in sentence.lower()): | |
bonus_score += 3 | |
if 'sick' in query_lower and 'sick' in sentence.lower(): | |
bonus_score += 3 | |
if 'vacation' in query_lower and 'vacation' in sentence.lower(): | |
bonus_score += 3 | |
total_score = exact_matches * 2 + bonus_score | |
if total_score > 0: | |
scored_sentences.append((sentence, total_score)) | |
if scored_sentences: | |
scored_sentences.sort(key=lambda x: x[1], reverse=True) | |
best_sentence = scored_sentences[0][0] | |
if not best_sentence.endswith('.'): | |
best_sentence += '.' | |
return best_sentence | |
# Fallback | |
for sentence in sentences: | |
if len(sentence) > 30: | |
return sentence + ('.' if not sentence.endswith('.') else '') | |
return content[:200] + "..." | |
def generate_answer(self, query, search_results, use_ai_enhancement=True, unlimited_tokens=False): | |
"""Generate both AI and extracted answers with proper token handling""" | |
if not search_results: | |
return { | |
'ai_answer': "No information found in documents.", | |
'extracted_answer': "No information found in documents.", | |
'sources': [], | |
'confidence': 0, | |
'has_both': False | |
} | |
best_result = search_results[0] | |
sources = list(set([r['metadata']['source_file'] for r in search_results[:2]])) | |
avg_confidence = sum(r['similarity'] for r in search_results[:2]) / len(search_results[:2]) | |
# Always generate extracted answer | |
extracted_answer = self.extract_direct_answer(query, best_result['content']) | |
# Try AI answer if requested and API key available | |
ai_answer = None | |
openrouter_key = os.environ.get("OPENROUTER_API_KEY") | |
if use_ai_enhancement and openrouter_key: | |
# Build context from search results | |
context = "\n\n".join([f"Source: {r['metadata']['source_file']}\nContent: {r['content']}" | |
for r in search_results[:3]]) | |
# Create focused prompt for rich, engaging family responses | |
if unlimited_tokens: | |
prompt = f"""You are a warm, caring family assistant who knows everyone well. Based on the family information below, provide a rich, detailed, and engaging response. | |
Family Document Context: | |
{context} | |
Question: {query} | |
Instructions: | |
- Use the document information as your foundation | |
- Expand with logical personality traits and qualities someone like this would have | |
- Add 3-4 additional lines of thoughtful insights about their character | |
- Use 5-6 relevant emojis throughout the response to make it warm and engaging | |
- Write in a caring, family-friend tone | |
- If someone asks about relationships (like "mother" = "mama"), make those connections | |
- Make the response feel personal and detailed, not just a basic fact | |
- Include both strengths and endearing qualities | |
- Keep it warm but informative (4-6 sentences total) | |
- Sprinkle emojis naturally throughout, not just at the end | |
Remember: You're helping someone learn about their family members in a meaningful way! π""" | |
max_tokens = 400 # Increased for richer responses | |
temperature = 0.3 # Slightly more creative | |
else: | |
# Shorter but still enhanced prompt for conservative mode | |
prompt = f"""Based on this family info: {extracted_answer} | |
Question: {query} | |
Give a warm, detailed answer with 3-4 emojis spread throughout. Add 2-3 more qualities this person likely has. Make it caring and personal! π""" | |
max_tokens = 150 # Better than 50 for family context | |
temperature = 0.2 | |
try: | |
response = requests.post( | |
"https://openrouter.ai/api/v1/chat/completions", | |
headers={ | |
"Authorization": f"Bearer {openrouter_key}", | |
"Content-Type": "application/json", | |
"HTTP-Referer": "https://huggingface.co/spaces", | |
"X-Title": "RAG Chatbot" | |
}, | |
json={ | |
"model": "openai/gpt-3.5-turbo", | |
"messages": [{"role": "user", "content": prompt}], | |
"max_tokens": max_tokens, | |
"temperature": temperature | |
}, | |
timeout=15 | |
) | |
if response.status_code == 200: | |
ai_response = response.json()['choices'][0]['message']['content'].strip() | |
ai_answer = ai_response if len(ai_response) > 10 else extracted_answer | |
else: | |
# Log the actual error for debugging | |
error_detail = "" | |
try: | |
error_detail = response.json().get('error', {}).get('message', '') | |
except: | |
pass | |
if response.status_code == 402: | |
st.warning("π³ OpenRouter credits exhausted. Using extracted answers only.") | |
elif response.status_code == 429: | |
st.warning("β±οΈ Rate limit reached. Using extracted answers only.") | |
elif response.status_code == 401: | |
st.error("π Invalid API key. Check your OpenRouter key.") | |
elif response.status_code == 400: | |
st.error(f"β Bad request: {error_detail}") | |
else: | |
st.warning(f"API Error {response.status_code}: {error_detail}. Using extracted answers only.") | |
except requests.exceptions.Timeout: | |
st.warning("β±οΈ API timeout. Using extracted answers only.") | |
except Exception as e: | |
st.warning(f"API Exception: {str(e)}. Using extracted answers only.") | |
return { | |
'ai_answer': ai_answer, | |
'extracted_answer': extracted_answer, | |
'sources': sources, | |
'confidence': avg_confidence, | |
'has_both': ai_answer is not None | |
} | |
def get_general_ai_response(query, unlimited_tokens=False): | |
"""Get AI response for general questions with family-friendly enhancement""" | |
openrouter_key = os.environ.get("OPENROUTER_API_KEY") | |
if not openrouter_key: | |
return "I can only answer questions about your family members from the uploaded documents. Please add an OpenRouter API key for general conversations. π" | |
try: | |
# Adjust parameters based on token availability | |
if unlimited_tokens: | |
max_tokens = 350 # Good limit for detailed family responses | |
temperature = 0.5 | |
prompt = f"""You are a caring family assistant. Someone is asking about their family but I couldn't find specific information in their family documents. | |
Question: {query} | |
Please provide a warm, helpful response that: | |
- Acknowledges I don't have specific information about their family member | |
- Suggests they might want to add more details to their family profiles | |
- Offers to help in other ways | |
- Uses a caring, family-friendly tone with appropriate emojis | |
- Keep it supportive and understanding π""" | |
else: | |
max_tokens = 100 # Reasonable for conservative mode | |
temperature = 0.4 | |
prompt = f"Family question: {query[:100]} - I don't have info about this family member. Give a caring, helpful response with emojis π" | |
response = requests.post( | |
"https://openrouter.ai/api/v1/chat/completions", | |
headers={ | |
"Authorization": f"Bearer {openrouter_key}", | |
"Content-Type": "application/json", | |
"HTTP-Referer": "https://huggingface.co/spaces", | |
"X-Title": "RAG Chatbot" | |
}, | |
json={ | |
"model": "openai/gpt-3.5-turbo", | |
"messages": [{"role": "user", "content": prompt}], | |
"max_tokens": max_tokens, | |
"temperature": temperature | |
}, | |
timeout=15 | |
) | |
if response.status_code == 200: | |
return response.json()['choices'][0]['message']['content'].strip() | |
else: | |
# Get detailed error information | |
error_detail = "" | |
try: | |
error_detail = response.json().get('error', {}).get('message', '') | |
except: | |
pass | |
if response.status_code == 402: | |
return "Sorry, OpenRouter credits exhausted. Please add more credits or top up your account." | |
elif response.status_code == 429: | |
return "Rate limit reached. Please try again in a moment." | |
elif response.status_code == 401: | |
return "Invalid API key. Please check your OpenRouter API key configuration." | |
elif response.status_code == 400: | |
return f"Bad request: {error_detail}. Please try rephrasing your question." | |
else: | |
return f"API error (Status: {response.status_code}): {error_detail}. Please try again." | |
except requests.exceptions.Timeout: | |
return "Request timeout. Please try again." | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def get_user_id(): | |
"""Get unique ID for this user session""" | |
if 'user_id' not in st.session_state: | |
st.session_state.user_id = str(uuid.uuid4())[:8] | |
return st.session_state.user_id | |
def update_online_users(): | |
"""Update user status""" | |
try: | |
users = {} | |
if os.path.exists(USERS_FILE): | |
with open(USERS_FILE, 'r') as f: | |
users = json.load(f) | |
user_id = get_user_id() | |
users[user_id] = { | |
'last_seen': datetime.now().isoformat(), | |
'name': f'User-{user_id}', | |
'session_start': users.get(user_id, {}).get('session_start', datetime.now().isoformat()) | |
} | |
# Clean up old users | |
current_time = datetime.now() | |
active_users = {} | |
for uid, data in users.items(): | |
try: | |
last_seen = datetime.fromisoformat(data['last_seen']) | |
if current_time - last_seen < timedelta(minutes=5): | |
active_users[uid] = data | |
except: | |
continue | |
with open(USERS_FILE, 'w') as f: | |
json.dump(active_users, f, indent=2) | |
return len(active_users) | |
except: | |
return 1 | |
def load_chat_history(): | |
"""Load chat history""" | |
try: | |
if os.path.exists(HISTORY_FILE): | |
with open(HISTORY_FILE, 'r', encoding='utf-8') as f: | |
return json.load(f) | |
except: | |
pass | |
return [] | |
def save_chat_history(messages): | |
"""Save chat history""" | |
try: | |
with open(HISTORY_FILE, 'w', encoding='utf-8') as f: | |
json.dump(messages, f, ensure_ascii=False, indent=2) | |
except Exception as e: | |
st.error(f"Error saving history: {e}") | |
def start_new_chat(): | |
"""Start new chat session""" | |
st.session_state.messages = [] | |
st.session_state.session_id = str(uuid.uuid4()) | |
# ================= MAIN APP ================= | |
# Initialize session state | |
if "messages" not in st.session_state: | |
st.session_state.messages = load_chat_history() | |
if "session_id" not in st.session_state: | |
st.session_state.session_id = str(uuid.uuid4()) | |
# Initialize RAG system | |
rag_system = initialize_rag_system() | |
# Header with dark mode toggle | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
st.title("RAG Chat Flow β") | |
st.caption("Ask questions about your documents with AI-powered retrieval") | |
with col2: | |
# Dark mode toggle button | |
mode_text = "π Light" if st.session_state.dark_mode else "π Dark" | |
if st.button(mode_text, use_container_width=True): | |
st.session_state.dark_mode = not st.session_state.dark_mode | |
st.rerun() | |
# Sidebar | |
with st.sidebar: | |
# New Chat Button | |
if st.button("β New Chat", use_container_width=True, type="primary"): | |
start_new_chat() | |
st.rerun() | |
st.divider() | |
# Dark Mode Toggle in Sidebar too | |
st.header("π¨ Theme") | |
theme_status = "Dark Mode β¨" if st.session_state.dark_mode else "Light Mode βοΈ" | |
if st.button(f"π Switch to {'Light' if st.session_state.dark_mode else 'Dark'} Mode", use_container_width=True): | |
st.session_state.dark_mode = not st.session_state.dark_mode | |
st.rerun() | |
st.info(f"Current: {theme_status}") | |
st.divider() | |
# Personality Questions Section | |
st.header("π Personality Questions") | |
# Name input for personalizing questions | |
name_input = st.text_input("Enter name for personalized questions:", placeholder="First name only", help="Replace [name] in questions with this name") | |
if name_input.strip(): | |
name = name_input.strip() | |
st.markdown(f""" | |
<div class="personality-section"> | |
<strong>π« Quick Questions for {name}:</strong><br> | |
<small>Click any question to ask about {name}</small> | |
</div> | |
""", unsafe_allow_html=True) | |
# Display personality questions as clickable buttons | |
for i, question in enumerate(PERSONALITY_QUESTIONS): | |
formatted_question = question.replace("[name]", name) | |
if st.button(formatted_question, key=f"pq_{i}", use_container_width=True): | |
# Add the question to chat and set flag to process it | |
user_message = {"role": "user", "content": formatted_question} | |
st.session_state.messages.append(user_message) | |
st.session_state.process_personality_question = formatted_question | |
st.rerun() | |
else: | |
st.markdown(""" | |
<div class="personality-section"> | |
<strong>π« Sample Questions:</strong><br> | |
<small>Enter a name above to personalize these questions</small> | |
</div> | |
""", unsafe_allow_html=True) | |
# Show sample questions without names | |
for question in PERSONALITY_QUESTIONS[:5]: # Show first 5 as examples | |
st.markdown(f"β’ {question}") | |
st.divider() | |
# GitHub Integration | |
st.header("π GitHub Integration") | |
github_status = check_github_status() | |
if github_status["status"] == "connected": | |
st.markdown(f""" | |
<div class="github-status"> | |
<strong>π’ GitHub:</strong> {github_status['message']}<br> | |
<strong>π Repo:</strong> family-profiles (private) | |
</div> | |
""", unsafe_allow_html=True) | |
# Sync from GitHub button | |
if st.button("π Sync from GitHub", use_container_width=True): | |
if clone_github_repo(): | |
# Auto-index after successful sync | |
if rag_system and rag_system.model: | |
with st.spinner("Auto-indexing synced documents..."): | |
if rag_system.index_documents("documents"): | |
st.success("β Documents synced and indexed!") | |
st.rerun() | |
else: | |
st.warning("β οΈ Sync successful but indexing failed") | |
else: | |
color_map = {"red": "π΄", "orange": "π ", "green": "π’"} | |
color_icon = color_map.get(github_status["color"], "π΄") | |
st.markdown(f""" | |
<div class="github-status"> | |
<strong>{color_icon} GitHub:</strong> {github_status['message']}<br> | |
<strong>π Setup:</strong> Add GITHUB_TOKEN to Hugging Face secrets | |
</div> | |
""", unsafe_allow_html=True) | |
st.divider() | |
# Document Management | |
st.header("π Document Management") | |
if rag_system and rag_system.model: | |
doc_count = rag_system.get_collection_count() | |
if doc_count > 0: | |
st.markdown(f""" | |
<div class="document-status"> | |
<strong>π Documents Indexed:</strong> {doc_count} chunks<br> | |
<strong>π Status:</strong> Ready for queries | |
</div> | |
""", unsafe_allow_html=True) | |
else: | |
st.warning("No documents indexed. Sync from GitHub or upload documents to get started.") | |
# Document indexing | |
if st.button("π Re-index Documents", use_container_width=True): | |
with st.spinner("Indexing documents..."): | |
if rag_system.index_documents("documents"): | |
st.success("Documents indexed successfully!") | |
st.rerun() | |
else: | |
st.error("Failed to index documents. Check your documents folder.") | |
# Show document count only (hidden) | |
if os.path.exists("documents"): | |
txt_files = [f for f in os.listdir("documents") if f.endswith('.txt')] | |
if txt_files: | |
st.info(f"π {len(txt_files)} documents loaded (hidden)") | |
# Manual upload interface (fallback) | |
st.subheader("π€ Manual Upload") | |
uploaded_files = st.file_uploader( | |
"Upload text files (fallback)", | |
type=['txt'], | |
accept_multiple_files=True, | |
help="Upload .txt files if GitHub sync is not available" | |
) | |
if uploaded_files: | |
if st.button("πΎ Save & Index Files"): | |
os.makedirs("documents", exist_ok=True) | |
saved_files = [] | |
for uploaded_file in uploaded_files: | |
file_path = os.path.join("documents", uploaded_file.name) | |
with open(file_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
saved_files.append(uploaded_file.name) | |
st.success(f"Saved {len(saved_files)} files!") | |
# Auto-index | |
with st.spinner("Auto-indexing new documents..."): | |
if rag_system.index_documents("documents"): | |
st.success("Documents indexed successfully!") | |
st.rerun() | |
else: | |
st.error("RAG system initialization failed. Check your setup.") | |
st.divider() | |
# Online Users | |
st.header("π₯ Online Users") | |
online_count = update_online_users() | |
if online_count == 1: | |
st.success("π’ Just you online") | |
else: | |
st.success(f"π’ {online_count} people online") | |
st.divider() | |
# Settings | |
st.header("βοΈ Settings") | |
# API Status with better checking | |
openrouter_key = os.environ.get("OPENROUTER_API_KEY") | |
if openrouter_key: | |
st.success(" β API Connected") | |
# Quick API test | |
if st.button("Test API Connection", use_container_width=True): | |
try: | |
test_response = requests.post( | |
"https://openrouter.ai/api/v1/chat/completions", | |
headers={ | |
"Authorization": f"Bearer {openrouter_key}", | |
"Content-Type": "application/json" | |
}, | |
json={ | |
"model": "openai/gpt-3.5-turbo", | |
"messages": [{"role": "user", "content": "test"}], | |
"max_tokens": 5 | |
}, | |
timeout=5 | |
) | |
if test_response.status_code == 200: | |
st.success("β API working correctly!") | |
elif test_response.status_code == 402: | |
st.error("β Credits exhausted") | |
elif test_response.status_code == 429: | |
st.warning("β±οΈ Rate limited") | |
else: | |
st.error(f"β API Error: {test_response.status_code}") | |
except Exception as e: | |
st.error(f"β API Test Failed: {str(e)}") | |
else: | |
st.error("β No OpenRouter API Key") | |
st.info("Add OPENROUTER_API_KEY in Hugging Face Space settings β Variables and secrets") | |
# Enhanced Settings | |
st.subheader("π Token Settings") | |
unlimited_tokens = st.checkbox("π₯ Unlimited Tokens Mode", value=True, help="Use higher token limits for detailed responses") | |
use_ai_enhancement = st.checkbox("πΎ AI Enhancement", value=bool(openrouter_key), help="Enhance answers with AI when documents are found") | |
st.subheader("ποΈ Display Settings") | |
show_sources = st.checkbox("π Show Sources", value=True) | |
show_confidence = st.checkbox("π― Show Confidence Scores", value=True) | |
# Token mode indicator | |
if unlimited_tokens: | |
st.success("π₯ Unlimited mode: Detailed responses enabled") | |
else: | |
st.info("π° Conservative mode: Limited tokens to save credits") | |
st.divider() | |
# Chat History Controls | |
st.header("πΎ Chat History") | |
if st.session_state.messages: | |
st.info(f"Messages: {len(st.session_state.messages)}") | |
col1, col2 = st.columns(2) | |
with col1: | |
if st.button("πΎ Save", use_container_width=True): | |
save_chat_history(st.session_state.messages) | |
st.success("Saved!") | |
with col2: | |
if st.button("ποΈ Clear", use_container_width=True): | |
start_new_chat() | |
st.success("Cleared!") | |
st.rerun() | |
# ================= MAIN CHAT AREA ================= | |
# Display chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
if message["role"] == "assistant" and "rag_info" in message: | |
# Display AI answer | |
st.markdown(message["content"]) | |
# Display RAG information | |
rag_info = message["rag_info"] | |
if show_sources and rag_info.get("sources"): | |
confidence_text = f"{rag_info['confidence']*100:.1f}%" if show_confidence else "" | |
st.markdown(f""" | |
<div class="rag-attribution"> | |
<strong>π Sources:</strong> {', '.join(rag_info['sources'])}<br> | |
<strong>π― Confidence:</strong> {confidence_text} | |
</div> | |
""", unsafe_allow_html=True) | |
# Show extracted answer if different | |
if rag_info.get("extracted_answer") and rag_info["extracted_answer"] != message["content"]: | |
st.markdown("**π Extracted Answer:**") | |
st.markdown(f"_{rag_info['extracted_answer']}_") | |
else: | |
st.markdown(message["content"]) | |
# Check if we need to process a personality question | |
if hasattr(st.session_state, 'process_personality_question'): | |
prompt = st.session_state.process_personality_question | |
del st.session_state.process_personality_question # Clear the flag | |
# Display user message | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Process the question using the same logic as chat input | |
# Update user tracking | |
update_online_users() | |
# Get RAG response | |
with st.chat_message("assistant"): | |
if rag_system and rag_system.model and rag_system.get_collection_count() > 0: | |
# Search documents first | |
search_results = rag_system.search(prompt, n_results=5) | |
# Debug output for troubleshooting | |
if search_results: | |
st.info(f"π Found {len(search_results)} potential matches. Best similarity: {search_results[0]['similarity']:.3f}") | |
else: | |
st.warning("π No search results returned from vector database") | |
# Check if we found relevant documents (very low threshold) | |
if search_results and search_results[0]['similarity'] > 0.001: # Ultra-low threshold | |
# Generate document-based answer | |
result = rag_system.generate_answer( | |
prompt, | |
search_results, | |
use_ai_enhancement=use_ai_enhancement, | |
unlimited_tokens=unlimited_tokens | |
) | |
# Display AI answer or extracted answer | |
if use_ai_enhancement and result['has_both']: | |
answer_text = result['ai_answer'] | |
st.markdown(f"πΎ **AI Enhanced Answer:** {answer_text}") | |
# Also show extracted answer for comparison if different | |
if result['extracted_answer'] != answer_text: | |
with st.expander("π View Extracted Answer"): | |
st.markdown(result['extracted_answer']) | |
else: | |
answer_text = result['extracted_answer'] | |
st.markdown(f"π **Document Answer:** {answer_text}") | |
# Show why AI enhancement wasn't used | |
if use_ai_enhancement and not result['has_both']: | |
st.info("π‘ AI enhancement failed - showing extracted answer from documents") | |
# Show RAG info with more details | |
if show_sources and result['sources']: | |
confidence_text = f"{result['confidence']*100:.1f}%" if show_confidence else "" | |
st.markdown(f""" | |
<div class="rag-attribution"> | |
<strong>π Sources:</strong> {', '.join(result['sources'])}<br> | |
<strong>π― Confidence:</strong> {confidence_text}<br> | |
<strong>π Found:</strong> {len(search_results)} relevant sections<br> | |
<strong>π Best Match:</strong> {search_results[0]['similarity']:.3f} similarity | |
</div> | |
""", unsafe_allow_html=True) | |
# Add to messages with RAG info | |
assistant_message = { | |
"role": "assistant", | |
"content": answer_text, | |
"rag_info": { | |
"sources": result['sources'], | |
"confidence": result['confidence'], | |
"extracted_answer": result['extracted_answer'], | |
"has_ai": result['has_both'] | |
} | |
} | |
else: | |
# No relevant documents found - show debug info | |
if search_results: | |
st.warning(f"π Found documents but similarity too low (best: {search_results[0]['similarity']:.3f}). Using general AI...") | |
else: | |
st.warning("π No documents found in search. Using general AI...") | |
general_response = get_general_ai_response(prompt, unlimited_tokens=unlimited_tokens) | |
st.markdown(f"π¬ **General AI:** {general_response}") | |
assistant_message = { | |
"role": "assistant", | |
"content": general_response, | |
"rag_info": {"sources": [], "confidence": 0, "mode": "general"} | |
} | |
else: | |
# RAG system not ready - use general AI | |
if rag_system and rag_system.get_collection_count() == 0: | |
st.warning("No documents indexed. Sync from GitHub or upload documents first...") | |
else: | |
st.error("RAG system not ready. Using general AI mode...") | |
general_response = get_general_ai_response(prompt, unlimited_tokens=unlimited_tokens) | |
st.markdown(f"π¬ **General AI:** {general_response}") | |
assistant_message = { | |
"role": "assistant", | |
"content": general_response, | |
"rag_info": {"sources": [], "confidence": 0, "mode": "general"} | |
} | |
# Add assistant message to history | |
st.session_state.messages.append(assistant_message) | |
# Auto-save | |
save_chat_history(st.session_state.messages) | |
# Chat input | |
if prompt := st.chat_input("Ask questions about your documents..."): | |
# Update user tracking | |
update_online_users() | |
# Add user message | |
user_message = {"role": "user", "content": prompt} | |
st.session_state.messages.append(user_message) | |
# Display user message | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Get RAG response | |
with st.chat_message("assistant"): | |
if rag_system and rag_system.model and rag_system.get_collection_count() > 0: | |
# Search documents first | |
search_results = rag_system.search(prompt, n_results=5) | |
# Debug output for troubleshooting | |
if search_results: | |
st.info(f"π Found {len(search_results)} potential matches. Best similarity: {search_results[0]['similarity']:.3f}") | |
else: | |
st.warning("π No search results returned from vector database") | |
# Check if we found relevant documents (very low threshold) | |
if search_results and search_results[0]['similarity'] > 0.001: # Ultra-low threshold | |
# Generate document-based answer | |
result = rag_system.generate_answer( | |
prompt, | |
search_results, | |
use_ai_enhancement=use_ai_enhancement, | |
unlimited_tokens=unlimited_tokens | |
) | |
# Display AI answer or extracted answer | |
if use_ai_enhancement and result['has_both']: | |
answer_text = result['ai_answer'] | |
st.markdown(f"πΎ **AI Enhanced Answer:** {answer_text}") | |
# Also show extracted answer for comparison if different | |
if result['extracted_answer'] != answer_text: | |
with st.expander("π View Extracted Answer"): | |
st.markdown(result['extracted_answer']) | |
else: | |
answer_text = result['extracted_answer'] | |
st.markdown(f"π **Document Answer:** {answer_text}") | |
# Show why AI enhancement wasn't used | |
if use_ai_enhancement and not result['has_both']: | |
st.info("π‘ AI enhancement failed - showing extracted answer from documents") | |
# Show RAG info with more details | |
if show_sources and result['sources']: | |
confidence_text = f"{result['confidence']*100:.1f}%" if show_confidence else "" | |
st.markdown(f""" | |
<div class="rag-attribution"> | |
<strong>π Sources:</strong> {', '.join(result['sources'])}<br> | |
<strong>π― Confidence:</strong> {confidence_text}<br> | |
<strong>π Found:</strong> {len(search_results)} relevant sections<br> | |
<strong>π Best Match:</strong> {search_results[0]['similarity']:.3f} similarity | |
</div> | |
""", unsafe_allow_html=True) | |
# Add to messages with RAG info | |
assistant_message = { | |
"role": "assistant", | |
"content": answer_text, | |
"rag_info": { | |
"sources": result['sources'], | |
"confidence": result['confidence'], | |
"extracted_answer": result['extracted_answer'], | |
"has_ai": result['has_both'] | |
} | |
} | |
else: | |
# No relevant documents found - show debug info | |
if search_results: | |
st.warning(f"π Found documents but similarity too low (best: {search_results[0]['similarity']:.3f}). Using general AI...") | |
else: | |
st.warning("π No documents found in search. Using general AI...") | |
general_response = get_general_ai_response(prompt, unlimited_tokens=unlimited_tokens) | |
st.markdown(f"π¬ **General AI:** {general_response}") | |
assistant_message = { | |
"role": "assistant", | |
"content": general_response, | |
"rag_info": {"sources": [], "confidence": 0, "mode": "general"} | |
} | |
else: | |
# RAG system not ready - use general AI | |
if rag_system and rag_system.get_collection_count() == 0: | |
st.warning("No documents indexed. Sync from GitHub or upload documents first...") | |
else: | |
st.error("RAG system not ready. Using general AI mode...") | |
general_response = get_general_ai_response(prompt, unlimited_tokens=unlimited_tokens) | |
st.markdown(f"π¬ **General AI:** {general_response}") | |
assistant_message = { | |
"role": "assistant", | |
"content": general_response, | |
"rag_info": {"sources": [], "confidence": 0, "mode": "general"} | |
} | |
# Add assistant message to history | |
st.session_state.messages.append(assistant_message) | |
# Auto-save | |
save_chat_history(st.session_state.messages) | |
# Footer info | |
if rag_system and rag_system.model: | |
doc_count = rag_system.get_collection_count() | |
token_mode = "π₯ Unlimited" if unlimited_tokens else "π° Conservative" | |
github_status = check_github_status() | |
github_icon = "π’" if github_status["status"] == "connected" else "π΄" | |
theme_icon = "π" if st.session_state.dark_mode else "βοΈ" | |
st.caption(f"π Knowledge Base: {doc_count} indexed chunks | π RAG System Active | {token_mode} Token Mode | {github_icon} GitHub {github_status['status'].title()} | {theme_icon} {theme_status}") |