Spaces:
Paused
Paused
""" | |
Multi-Method RAG System - SIGHT | |
Enhanced Streamlit application with method comparison and analytics. | |
Directory structure: | |
/data/ # Original PDFs, HTML | |
/embeddings/ # FAISS, Chroma, DPR vector stores | |
/graph/ # Graph database files | |
/metadata/ # Image metadata (SQLite or MongoDB) | |
""" | |
import streamlit as st | |
import os | |
import logging | |
import tempfile | |
import time | |
import uuid | |
from typing import Tuple, List, Dict, Any, Optional | |
from pathlib import Path | |
# NEW: same-origin base path for the backend on Hugging Face Spaces | |
# The Docker/Nginx setup routes /api/* to your FastAPI. | |
API_BASE = os.getenv("BACKEND_BASE", "/api") # e.g., "/api" | |
# Import all query modules | |
from query_graph import query as graph_query, query_graph | |
from query_vanilla import query as vanilla_query | |
from query_dpr import query as dpr_query | |
from query_bm25 import query as bm25_query | |
from query_context import query as context_query | |
from query_vision import query as vision_query, query_image_only | |
from config import * | |
from analytics_db import log_query, get_analytics_stats, get_method_performance, analytics_db | |
import streamlit.components.v1 as components | |
import requests | |
logger = logging.getLogger(__name__) | |
# Check realtime server health | |
# Cache for 30 seconds | |
def check_realtime_server_health(): | |
"""Check if the realtime server is running.""" | |
try: | |
# CHANGED: same-origin health check behind /api | |
response = requests.get(f"{API_BASE}/health", timeout=2) | |
return response.status_code == 200 | |
except: | |
return False | |
# Query method dispatch | |
QUERY_DISPATCH = { | |
'graph': graph_query, | |
'vanilla': vanilla_query, | |
'dpr': dpr_query, | |
'bm25': bm25_query, | |
'context': context_query, | |
'vision': vision_query | |
} | |
# Method options for speech interface | |
METHOD_OPTIONS = ['graph', 'vanilla', 'dpr', 'bm25', 'context', 'vision'] | |
def format_citations_html(chunks): | |
"""Format citations for display (backward compatibility).""" | |
html = [] | |
for idx, (hdr, sc, txt, citation) in enumerate(chunks, start=1): | |
body = txt.replace("\n", "<br>") | |
html.append( | |
f"<details>" | |
f"<summary>{hdr} (relevance score: {sc:.3f})</summary>" | |
f"<div style='font-size:0.9em; margin-top:0.5em;'>" | |
f"<strong>Source:</strong> {citation} " | |
f"</div>" | |
f"<div style='font-size:0.8em; margin-left:1em; margin-top:0.5em;'>{body}</div>" | |
f"</details><br><br>" | |
) | |
return "<br>".join(html) | |
def format_citations_html(citations: List[dict], method: str) -> str: | |
"""Format citations as HTML based on method and citation type.""" | |
if not citations: | |
return "<p><em>No citations available</em></p>" | |
html_parts = ["<div style='margin-top: 1em;'><strong>Sources:</strong><ul>"] | |
for citation in citations: | |
# Skip citations without source | |
if 'source' not in citation: | |
continue | |
source = citation['source'] | |
cite_type = citation.get('type', 'unknown') | |
# Build citation text based on type | |
if cite_type == 'pdf': | |
cite_text = f"π {source} (PDF)" | |
elif cite_type == 'html': | |
url = citation.get('url', '') | |
if url: | |
cite_text = f"π <a href='{url}' target='_blank'>{source}</a> (Web)" | |
else: | |
cite_text = f"π {source} (Web)" | |
elif cite_type == 'image': | |
page = citation.get('page', 'N/A') | |
cite_text = f"πΌοΈ {source} (Image, page {page})" | |
elif cite_type == 'image_analysis': | |
classification = citation.get('classification', 'N/A') | |
cite_text = f"π {source} - {classification}" | |
else: | |
cite_text = f"π {source}" | |
# Add scores if available | |
scores = [] | |
if 'relevance_score' in citation: | |
scores.append(f"relevance: {citation['relevance_score']}") | |
if 'bm25_score' in citation: | |
scores.append(f"BM25: {citation['bm25_score']}") | |
if 'rerank_score' in citation: | |
scores.append(f"rerank: {citation['rerank_score']}") | |
if 'similarity' in citation: | |
scores.append(f"similarity: {citation['similarity']}") | |
if 'score' in citation: | |
scores.append(f"score: {citation['score']:.3f}") | |
if scores: | |
cite_text += f" <small>({', '.join(scores)})</small>" | |
html_parts.append(f"<li>{cite_text}</li>") | |
html_parts.append("</ul></div>") | |
return "".join(html_parts) | |
def save_uploaded_file(uploaded_file) -> str: | |
"""Save uploaded file to temporary location.""" | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp_file: | |
tmp_file.write(uploaded_file.getvalue()) | |
return tmp_file.name | |
except Exception as e: | |
st.error(f"Error saving file: {e}") | |
return None | |
# Page configuration | |
st.set_page_config( | |
page_title="Multi-Method RAG System - SIGHT", | |
page_icon="π", | |
layout="wide" | |
) | |
# Sidebar configuration | |
st.sidebar.title("Configuration") | |
# Method selector | |
st.sidebar.markdown("### Retrieval Method") | |
selected_method = st.sidebar.radio( | |
"Choose retrieval method:", | |
options=['graph', 'vanilla', 'dpr', 'bm25', 'context', 'vision'], | |
format_func=lambda x: x.capitalize(), | |
help="Select different RAG methods to compare results" | |
) | |
# Display method description | |
st.sidebar.info(METHOD_DESCRIPTIONS[selected_method]) | |
# Advanced settings | |
with st.sidebar.expander("Advanced Settings"): | |
top_k = st.slider("Number of chunks to retrieve", min_value=1, max_value=10, value=DEFAULT_TOP_K) | |
if selected_method == 'bm25': | |
use_hybrid = st.checkbox("Use hybrid search (BM25 + semantic)", value=False) | |
if use_hybrid: | |
alpha = st.slider("BM25 weight (alpha)", min_value=0.0, max_value=1.0, value=0.5) | |
# Sidebar info | |
st.sidebar.markdown("---") | |
st.sidebar.markdown("### About") | |
st.sidebar.markdown("**Authors:** [The SIGHT Project Team](https://sites.miamioh.edu/sight/)") | |
st.sidebar.markdown(f"**Version:** V. {VERSION}") | |
st.sidebar.markdown(f"**Date:** {DATE}") | |
st.sidebar.markdown(f"**Model:** {OPENAI_CHAT_MODEL}") | |
st.sidebar.markdown("---") | |
st.sidebar.markdown( | |
"**Funding:** SIGHT is funded by [OHBWC WSIC](https://info.bwc.ohio.gov/for-employers/safety-services/workplace-safety-innovation-center/wsic-overview)" | |
) | |
# Main interface with dynamic status | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
st.title("π Multi-Method RAG System - SIGHT") | |
st.markdown("### Compare different retrieval methods for machine safety Q&A") | |
with col2: | |
# Quick stats in the header | |
if 'chat_history' in st.session_state: | |
total_queries = len(st.session_state.chat_history) | |
st.metric("Session Queries", total_queries, delta=None if total_queries == 0 else "+1" if total_queries == 1 else f"+{total_queries}") | |
# Voice chat status indicator | |
if st.session_state.get('voice_session_active', False): | |
st.success("π΄ Voice LIVE") | |
# Create tabs for different interfaces | |
tab1, tab2, tab3, tab4 = st.tabs(["π¬ Chat", "π Method Comparison", "π Voice Chat", "π Analytics"]) | |
with tab1: | |
# Example questions | |
with st.expander("π Example Questions", expanded=False): | |
example_cols = st.columns(2) | |
with example_cols[0]: | |
st.markdown( | |
"**General Safety:**\n" | |
"- What are general machine guarding requirements?\n" | |
"- How do I perform lockout/tagout?\n" | |
"- What is required for emergency stops?" | |
) | |
with example_cols[1]: | |
st.markdown( | |
"**Specific Topics:**\n" | |
"- Summarize robot safety requirements from OSHA\n" | |
"- Compare guard types: fixed vs interlocked\n" | |
"- What are the ANSI standards for machine safety?" | |
) | |
# File uploader for vision method | |
uploaded_file = None | |
if selected_method == 'vision': | |
st.markdown("#### πΌοΈ Upload an image for analysis") | |
uploaded_file = st.file_uploader( | |
"Choose an image file", | |
type=['png', 'jpg', 'jpeg', 'bmp', 'gif'], | |
help="Upload an image of safety equipment, signs, or machinery" | |
) | |
if uploaded_file: | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True) | |
# Initialize session state | |
if 'chat_history' not in st.session_state: | |
st.session_state.chat_history = [] | |
if 'session_id' not in st.session_state: | |
st.session_state.session_id = str(uuid.uuid4())[:8] | |
# Chat input | |
query = st.text_input( | |
"Ask a question:", | |
placeholder="E.g., What are the safety requirements for collaborative robots?", | |
key="chat_input" | |
) | |
col1, col2, col3 = st.columns([1, 1, 8]) | |
with col1: | |
send_button = st.button("π Send", type="primary", use_container_width=True) | |
with col2: | |
clear_button = st.button("ποΈ Clear", use_container_width=True) | |
if clear_button: | |
st.session_state.chat_history = [] | |
st.rerun() | |
if send_button and query: | |
# Save uploaded file if present | |
image_path = None | |
if uploaded_file and selected_method == 'vision': | |
image_path = save_uploaded_file(uploaded_file) | |
# Show spinner while processing | |
with st.spinner(f"Searching using {selected_method.upper()} method..."): | |
start_time = time.time() | |
error_message = None | |
answer = "" | |
citations = [] | |
try: | |
# Get the appropriate query function | |
query_func = QUERY_DISPATCH[selected_method] | |
# Call the query function | |
if selected_method == 'vision' and not image_path: | |
error_message = "Please upload an image for vision-based search" | |
st.error(error_message) | |
else: | |
answer, citations = query_func(query, image_path=image_path, top_k=top_k) | |
# Add to history | |
st.session_state.chat_history.append({ | |
'query': query, | |
'answer': answer, | |
'citations': citations, | |
'method': selected_method, | |
'image_path': image_path | |
}) | |
except Exception as e: | |
error_message = str(e) | |
answer = f"Error: {error_message}" | |
st.error(f"Error processing query: {error_message}") | |
st.info("Make sure you've run preprocess.py to generate the required indices.") | |
finally: | |
# Log query to analytics database (always, even on error) | |
response_time = (time.time() - start_time) * 1000 # Convert to ms | |
try: | |
log_query( | |
user_query=query, | |
method=selected_method, | |
answer=answer, | |
citations=citations, | |
response_time=response_time, | |
image_path=image_path, | |
error_message=error_message, | |
top_k=top_k, | |
session_id=st.session_state.session_id | |
) | |
except Exception as log_error: | |
logger.error(f"Failed to log query: {log_error}") | |
# Clean up temp file | |
if image_path and os.path.exists(image_path): | |
os.unlink(image_path) | |
# Display chat history | |
if st.session_state.chat_history: | |
st.markdown("---") | |
st.markdown("### Chat History") | |
for i, entry in enumerate(reversed(st.session_state.chat_history)): | |
with st.container(): | |
# User message | |
st.markdown(f"**π§ You** ({entry['method'].upper()}):") | |
st.markdown(entry['query']) | |
# Assistant response | |
st.markdown("**π€ Assistant:**") | |
st.markdown(entry['answer']) | |
# Citations | |
st.markdown(format_citations_html(entry['citations'], entry['method']), unsafe_allow_html=True) | |
if i < len(st.session_state.chat_history) - 1: | |
st.markdown("---") | |
with tab2: | |
st.markdown("### Method Comparison") | |
st.markdown("Compare results from different retrieval methods for the same query.") | |
comparison_query = st.text_input( | |
"Enter a query to compare across methods:", | |
placeholder="E.g., What are the requirements for machine guards?", | |
key="comparison_input" | |
) | |
methods_to_compare = st.multiselect( | |
"Select methods to compare:", | |
options=['graph', 'vanilla', 'dpr', 'bm25', 'context'], | |
default=['vanilla', 'bm25'], | |
help="Vision method requires an image and is not included in comparison" | |
) | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
compare_button = st.button("π Compare Methods", type="primary") | |
with col2: | |
if 'comparison_results' in st.session_state and st.session_state.comparison_results: | |
if st.button("πͺ Full Screen View", help="View results in a dedicated comparison window"): | |
st.session_state.show_comparison_window = True | |
st.rerun() | |
if compare_button: | |
if comparison_query and methods_to_compare: | |
results = {} | |
progress_bar = st.progress(0) | |
for idx, method in enumerate(methods_to_compare): | |
with st.spinner(f"Running {method.upper()}..."): | |
start_time = time.time() | |
error_message = None | |
try: | |
query_func = QUERY_DISPATCH[method] | |
answer, citations = query_func(comparison_query, top_k=top_k) | |
results[method] = { | |
'answer': answer, | |
'citations': citations | |
} | |
except Exception as e: | |
error_message = str(e) | |
answer = f"Error: {error_message}" | |
citations = [] | |
results[method] = { | |
'answer': answer, | |
'citations': citations | |
} | |
finally: | |
# Log comparison queries too | |
response_time = (time.time() - start_time) * 1000 | |
try: | |
log_query( | |
user_query=comparison_query, | |
method=method, | |
answer=results[method]['answer'], | |
citations=results[method]['citations'], | |
response_time=response_time, | |
error_message=error_message, | |
top_k=top_k, | |
session_id=st.session_state.session_id, | |
additional_settings={'comparison_mode': True} | |
) | |
except Exception as log_error: | |
logger.error(f"Failed to log comparison query: {log_error}") | |
progress_bar.progress((idx + 1) / len(methods_to_compare)) | |
# Store results in session state for full screen view | |
st.session_state.comparison_results = { | |
'query': comparison_query, | |
'methods': methods_to_compare, | |
'results': results, | |
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S") | |
} | |
# Display results in compact columns | |
cols = st.columns(len(methods_to_compare)) | |
for idx, (method, col) in enumerate(zip(methods_to_compare, cols)): | |
with col: | |
st.markdown(f"#### {method.upper()}") | |
# Use expandable container for full text without truncation | |
answer = results[method]['answer'] | |
if len(answer) > 800: | |
# Show first 300 chars, then expandable for full text | |
st.markdown(answer[:300] + "...") | |
with st.expander("π Show full answer"): | |
st.markdown(answer) | |
else: | |
# Short answers display fully | |
st.markdown(answer) | |
st.markdown(format_citations_html(results[method]['citations'], method), unsafe_allow_html=True) | |
else: | |
st.warning("Please enter a query and select at least one method to compare.") | |
with tab3: | |
st.markdown("### π Voice Chat - Hands-free AI Assistant") | |
# Server status check | |
server_healthy = check_realtime_server_health() | |
if server_healthy: | |
st.success("β **Voice Server Online** - Ready for voice interactions") | |
else: | |
st.error("β **Voice Server Offline** - Please start the realtime server: `python realtime_server.py`") | |
st.code("python realtime_server.py", language="bash") | |
st.stop() | |
st.info( | |
"π€ **Real-time Voice Interaction**: Speak naturally and get instant responses from your chosen RAG method. " | |
"The AI will automatically transcribe your speech, search the knowledge base, and respond with synthesized voice." | |
) | |
# Voice Chat Status and Configuration | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
# Use the same method from sidebar | |
st.info(f"π **Voice using {selected_method.upper()} method** (change in sidebar)") | |
with col2: | |
# Voice settings (simplified) | |
voice_choice = st.selectbox( | |
"ποΈ AI Voice:", | |
["alloy", "echo", "fable", "onyx", "nova", "shimmer"], | |
index=0, | |
help="Select the AI voice for responses" | |
) | |
response_speed = st.slider( | |
"β±οΈ Response Speed (seconds):", | |
min_value=1, max_value=5, value=2, | |
help="How quickly the AI should respond after you stop speaking" | |
) | |
# CHANGED: same-origin base for the JS voice client (used as `serverBase` in the HTML below) | |
server_url = API_BASE # e.g., "/api" | |
# Voice Chat Interface | |
st.markdown("---") | |
# Initialize voice chat session state | |
if 'voice_chat_history' not in st.session_state: | |
st.session_state.voice_chat_history = [] | |
if 'voice_session_active' not in st.session_state: | |
st.session_state.voice_session_active = False | |
# Simple Status Display | |
if st.session_state.voice_session_active: | |
st.success("π΄ **LIVE** - Voice chat active using " + selected_method.upper()) | |
# Enhanced Voice Interface with better UX | |
components.html(f""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset="utf-8" /> | |
<style> | |
body {{ | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
padding: 20px; | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
color: white; | |
border-radius: 10px; | |
}} | |
.container {{ | |
max-width: 800px; | |
margin: 0 auto; | |
background: rgba(255,255,255,0.1); | |
padding: 30px; | |
border-radius: 15px; | |
backdrop-filter: blur(10px); | |
}} | |
.controls {{ | |
display: flex; | |
gap: 20px; | |
align-items: center; | |
justify-content: center; | |
margin-bottom: 30px; | |
}} | |
.status-display {{ | |
text-align: center; | |
margin: 20px 0; | |
padding: 15px; | |
border-radius: 10px; | |
background: rgba(255,255,255,0.2); | |
}} | |
.status-idle {{ background: rgba(108, 117, 125, 0.3); }} | |
.status-connecting {{ background: rgba(255, 193, 7, 0.3); }} | |
.status-active {{ background: rgba(40, 167, 69, 0.3); }} | |
.status-error {{ background: rgba(220, 53, 69, 0.3); }} | |
button {{ | |
padding: 12px 24px; | |
font-size: 16px; | |
border: none; | |
border-radius: 25px; | |
cursor: pointer; | |
transition: all 0.3s ease; | |
font-weight: bold; | |
}} | |
.start-btn {{ | |
background: linear-gradient(45deg, #28a745, #20c997); | |
color: white; | |
}} | |
.start-btn:hover {{ transform: translateY(-2px); box-shadow: 0 4px 12px rgba(40,167,69,0.4); }} | |
.start-btn:disabled {{ | |
background: #6c757d; | |
cursor: not-allowed; | |
transform: none; | |
box-shadow: none; | |
}} | |
.stop-btn {{ | |
background: linear-gradient(45deg, #dc3545, #fd7e14); | |
color: white; | |
}} | |
.stop-btn:hover {{ transform: translateY(-2px); box-shadow: 0 4px 12px rgba(220,53,69,0.4); }} | |
.stop-btn:disabled {{ | |
background: #6c757d; | |
cursor: not-allowed; | |
transform: none; | |
box-shadow: none; | |
}} | |
.log {{ | |
height: 200px; | |
overflow-y: auto; | |
border: 1px solid rgba(255,255,255,0.3); | |
padding: 15px; | |
background: rgba(0,0,0,0.2); | |
border-radius: 10px; | |
font-family: 'Monaco', 'Menlo', monospace; | |
font-size: 13px; | |
line-height: 1.4; | |
}} | |
.audio-controls {{ | |
text-align: center; | |
margin: 20px 0; | |
}} | |
.pulse {{ | |
animation: pulse 2s infinite; | |
}} | |
@keyframes pulse {{ | |
0% {{ transform: scale(1); }} | |
50% {{ transform: scale(1.05); }} | |
100% {{ transform: scale(1); }} | |
}} | |
.visualizer {{ | |
width: 100%; | |
height: 60px; | |
background: rgba(0,0,0,0.2); | |
border-radius: 10px; | |
margin: 10px 0; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
font-size: 14px; | |
}} | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<div class="status-display status-idle" id="statusDisplay"> | |
<h3 id="statusTitle">π€ Voice Chat</h3> | |
<p id="statusText">Click "Start Listening" to begin</p> | |
</div> | |
<div class="controls"> | |
<button id="startBtn" class="start-btn">π€ Start Listening</button> | |
<button id="stopBtn" class="stop-btn" disabled>βΉοΈ Stop</button> | |
</div> | |
<div class="audio-controls"> | |
<audio id="remoteAudio" autoplay style="width: 100%; max-width: 400px;"></audio> | |
</div> | |
<div class="visualizer" id="visualizer"> | |
π Audio will appear here when active | |
</div> | |
<div class="log" id="log"></div> | |
</div> | |
<script> | |
(async () => {{ | |
// CHANGED: use same-origin base (e.g., "/api") | |
const serverBase = {server_url!r}; | |
const chosenMethod = {selected_method!r}; | |
const voiceChoice = {voice_choice!r}; | |
const responseSpeed = {response_speed!r}; | |
const logEl = document.getElementById('log'); | |
const statusDisplay = document.getElementById('statusDisplay'); | |
const statusTitle = document.getElementById('statusTitle'); | |
const statusText = document.getElementById('statusText'); | |
const startBtn = document.getElementById('startBtn'); | |
const stopBtn = document.getElementById('stopBtn'); | |
const visualizer = document.getElementById('visualizer'); | |
let pc, dc, micStream; | |
let isConnected = false; | |
let questionStartTime = null; | |
function updateStatus(status, title, text, className) {{ | |
statusDisplay.className = `status-display ${{className}}`; | |
statusTitle.textContent = title; | |
statusText.textContent = text; | |
}} | |
function log(msg, type = 'info') {{ | |
const timestamp = new Date().toLocaleTimeString(); | |
const icon = type === 'error' ? 'β' : type === 'success' ? 'β ' : type === 'warning' ? 'β οΈ' : 'βΉοΈ'; | |
logEl.innerHTML += `<div>${{timestamp}} ${{icon}} ${{msg}}</div>`; | |
logEl.scrollTop = logEl.scrollHeight; | |
}} | |
async function start() {{ | |
startBtn.disabled = true; | |
stopBtn.disabled = false; | |
updateStatus('connecting', 'π Connecting...', 'Establishing secure connection to voice services', 'status-connecting'); | |
try {{ | |
log('Initializing voice session...', 'info'); | |
// 1) Fetch ephemeral session token | |
const sessResp = await fetch(serverBase + "/session", {{ | |
method: "POST", | |
headers: {{ "Content-Type": "application/json" }}, | |
body: JSON.stringify({{ voice: voiceChoice }}) | |
}}); | |
if (!sessResp.ok) {{ | |
throw new Error(`Server error: ${{sessResp.status}} ${{sessResp.statusText}}`); | |
}} | |
const sess = await sessResp.json(); | |
if (sess.error) throw new Error(sess.error); | |
const EPHEMERAL_KEY = sess.client_secret; | |
if (!EPHEMERAL_KEY) throw new Error("No ephemeral token from server"); | |
log('β Session token obtained', 'success'); | |
// 2) Setup WebRTC | |
pc = new RTCPeerConnection(); | |
const remoteAudio = document.getElementById('remoteAudio'); | |
pc.ontrack = (event) => {{ | |
log('π Audio track received from OpenAI', 'success'); | |
const stream = event.streams[0]; | |
if (stream && stream.getAudioTracks().length > 0) {{ | |
remoteAudio.srcObject = stream; | |
visualizer.textContent = 'π Audio stream connected - AI can speak'; | |
log(`π΅ Audio tracks: ${{stream.getAudioTracks().length}}`, 'success'); | |
}} else {{ | |
log('β οΈ No audio tracks in stream', 'warning'); | |
visualizer.textContent = 'β οΈ No audio stream received'; | |
}} | |
}}; | |
// 3) Create data channel | |
dc = pc.createDataChannel("oai-data"); | |
dc.onopen = () => {{ | |
log('π Data channel established', 'success'); | |
}}; | |
dc.onerror = (error) => {{ | |
log('β Data channel error: ' + error, 'error'); | |
}}; | |
dc.onmessage = (e) => handleDataMessage(e); | |
// 4) Get microphone | |
log('π€ Requesting microphone access...', 'info'); | |
micStream = await navigator.mediaDevices.getUserMedia({{ audio: true }}); | |
log('β Microphone access granted', 'success'); | |
visualizer.textContent = 'π€ Microphone active - speak naturally'; | |
for (const track of micStream.getTracks()) {{ | |
pc.addTrack(track, micStream); | |
}} | |
// 5) Setup audio receiving | |
pc.addTransceiver("audio", {{ direction: "recvonly" }}); | |
log('π Audio receiver configured', 'success'); | |
// 6) Create and set local description | |
const offer = await pc.createOffer(); | |
await pc.setLocalDescription(offer); | |
log('π‘ WebRTC offer created', 'success'); | |
// 7) Exchange SDP with OpenAI Realtime | |
const baseUrl = "https://api.openai.com/v1/realtime"; | |
const model = sess.model || "gpt-4o-realtime-preview"; | |
const sdpResp = await fetch(`${{baseUrl}}?model=${{encodeURIComponent(model)}}`, {{ | |
method: "POST", | |
body: offer.sdp, | |
headers: {{ | |
Authorization: `Bearer ${{EPHEMERAL_KEY}}`, | |
"Content-Type": "application/sdp" | |
}} | |
}}); | |
if (!sdpResp.ok) throw new Error(`WebRTC setup failed: ${{sdpResp.status}}`); | |
const answer = {{ type: "answer", sdp: await sdpResp.text() }}; | |
await pc.setRemoteDescription(answer); | |
// 8) Configure the session with tools and faster response | |
setTimeout(() => {{ | |
if (dc.readyState === 'open') {{ | |
const toolDecl = {{ | |
type: "session.update", | |
session: {{ | |
tools: [{{ | |
"type": "function", | |
"name": "ask_rag", | |
"description": "Search the safety knowledge base for accurate, authoritative information. Call this immediately when users ask safety questions to get current, reliable information with proper citations.", | |
"parameters": {{ | |
"type": "object", | |
"properties": {{ | |
"query": {{ "type": "string", "description": "User's safety question" }}, | |
"top_k": {{ "type": "integer", "minimum": 1, "maximum": 20, "default": 5 }} | |
}}, | |
"required": ["query"] | |
}} | |
}}], | |
turn_detection: {{ | |
type: "server_vad", | |
threshold: 0.5, | |
prefix_padding_ms: 300, | |
silence_duration_ms: {response_speed * 1000} | |
}}, | |
input_audio_transcription: {{ | |
model: "whisper-1" | |
}}, | |
voice: voiceChoice, | |
temperature: 0.7, | |
max_response_output_tokens: 1000, | |
modalities: ["audio", "text"], | |
response_format: "audio" | |
}} | |
}}; | |
dc.send(JSON.stringify(toolDecl)); | |
log('π οΈ RAG tools configured', 'success'); | |
const initialMessage = {{ | |
type: "conversation.item.create", | |
item: {{ | |
type: "message", | |
role: "user", | |
content: [{{ | |
type: "input_text", | |
text: "Hello! I'm ready to ask you questions about machine safety. Please speak naturally like a safety expert - no need to mention specific documents or sources, just give me the information as your expertise." | |
}}] | |
}} | |
}}; | |
dc.send(JSON.stringify(initialMessage)); | |
const responseRequest = {{ | |
type: "response.create", | |
response: {{ | |
modalities: ["audio"], | |
instructions: "Acknowledge briefly that you're ready to help with safety questions. Speak naturally and confidently as a safety expert - no citations or document references needed." | |
}} | |
}}; | |
dc.send(JSON.stringify(responseRequest)); | |
}} else {{ | |
log('β οΈ Data channel not ready, retrying...', 'warning'); | |
}} | |
}}, 500); | |
isConnected = true; | |
updateStatus('active', 'π€ Live - Speak Now!', `Using ${{chosenMethod.toUpperCase()}} method β’ Voice: ${{voiceChoice}} β’ Response: ${{responseSpeed}}s`, 'status-active'); | |
startBtn.classList.add('pulse'); | |
}} catch (error) {{ | |
log(`β Connection failed: ${{ | |
(error && (error.message || error.toString())) || 'Unknown error' | |
}}`, 'error'); | |
updateStatus('error', 'β Connection Failed', (error && (error.message || error.toString())) || 'Unknown error', 'status-error'); | |
startBtn.disabled = false; | |
stopBtn.disabled = true; | |
cleanup(); | |
}} | |
}} | |
function cleanup() {{ | |
try {{ | |
if (dc && dc.readyState === 'open') dc.close(); | |
if (pc) pc.close(); | |
if (micStream) micStream.getTracks().forEach(t => t.stop()); | |
}} catch (e) {{ /* ignore cleanup errors */ }} | |
startBtn.classList.remove('pulse'); | |
visualizer.textContent = 'π Audio inactive'; | |
}} | |
async function stop() {{ | |
startBtn.disabled = false; | |
stopBtn.disabled = true; | |
isConnected = false; | |
updateStatus('idle', 'βͺ Session Ended', 'Click "Start Listening" to begin a new voice session', 'status-idle'); | |
log('π Voice session terminated', 'info'); | |
cleanup(); | |
}} | |
// Handle realtime events | |
async function handleDataMessage(e) {{ | |
if (!isConnected) return; | |
try {{ | |
const msg = JSON.parse(e.data); | |
if (msg.type === "response.function_call") {{ | |
const {{ name, call_id, arguments: args }} = msg; | |
if (name === "ask_rag") {{ | |
visualizer.textContent = 'β Question received - searching...'; | |
const query = JSON.parse(args || "{{}}").query; | |
log(`β AI heard: "${{query}}"`, 'success'); | |
log('π Searching knowledge base...', 'info'); | |
const payload = JSON.parse(args || "{{}}"); | |
const ragResp = await fetch(serverBase + "/rag", {{ | |
method: "POST", | |
headers: {{ "Content-Type": "application/json" }}, | |
body: JSON.stringify({{ | |
query: payload.query, | |
top_k: payload.top_k ?? 5, | |
method: chosenMethod | |
}}) | |
}}); | |
const rag = await ragResp.json(); | |
if (dc && dc.readyState === 'open') {{ | |
dc.send(JSON.stringify({{ | |
type: "response.function_call_result", | |
call_id, | |
output: JSON.stringify({{ | |
answer: rag.answer, | |
instruction: "Speak this information naturally as your expertise. Do not mention sources or documents." | |
}}) | |
}})); | |
}} else {{ | |
log('β οΈ Data channel closed, cannot send result', 'warning'); | |
}} | |
const searchTime = ((Date.now() - questionStartTime) / 1000).toFixed(1); | |
log(`β Found ${{rag.citations?.length || 0}} citations in ${{searchTime}}s`, 'success'); | |
visualizer.textContent = 'ποΈ AI is speaking your answer...'; | |
}} | |
}} | |
if (msg.type === "input_audio_buffer.speech_started") {{ | |
questionStartTime = Date.now(); | |
visualizer.textContent = 'ποΈ Listening to you...'; | |
log('π€ Speech detected', 'info'); | |
}} | |
if (msg.type === "input_audio_buffer.speech_stopped") {{ | |
visualizer.textContent = 'π€ Processing your question...'; | |
log('βΈοΈ Processing speech...', 'info'); | |
}} | |
if (msg.type === "response.audio.delta") {{ | |
visualizer.textContent = 'π AI speaking...'; | |
}} | |
if (msg.type === "response.done") {{ | |
if (questionStartTime) {{ | |
const totalTime = ((Date.now() - questionStartTime) / 1000).toFixed(1); | |
visualizer.textContent = 'π€ Your turn - speak now'; | |
log(`β Response complete in ${{totalTime}}s`, 'success'); | |
questionStartTime = null; | |
}} else {{ | |
visualizer.textContent = 'π€ Your turn - speak now'; | |
log('β Response complete', 'success'); | |
}} | |
}} | |
}} catch (err) {{ | |
// Ignore non-JSON messages | |
}} | |
}} | |
startBtn.onclick = start; | |
stopBtn.onclick = stop; | |
// Initialize | |
log('π Voice chat interface loaded', 'success'); | |
}})(); | |
</script> | |
</body> | |
</html> | |
""", height=600, scrolling=True) | |
# Voice Chat History | |
if st.session_state.voice_chat_history: | |
st.markdown("### π£οΈ Recent Voice Conversations") | |
for i, entry in enumerate(reversed(st.session_state.voice_chat_history[-5:])): | |
with st.expander(f"π€ Conversation {len(st.session_state.voice_chat_history)-i} - {entry.get('method', 'unknown').upper()}"): | |
st.write(f"**Query**: {entry.get('query', 'N/A')}") | |
st.write(f"**Response**: {entry.get('answer', 'N/A')[:200]}...") | |
st.write(f"**Citations**: {len(entry.get('citations', []))}") | |
st.write(f"**Timestamp**: {entry.get('timestamp', 'N/A')}") | |
with tab4: | |
st.markdown("### π Analytics Dashboard") | |
st.markdown("*Persistent analytics from all user interactions*") | |
# Time period selector | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
st.markdown("") | |
with col2: | |
days_filter = st.selectbox("Time Period", [7, 30, 90, 365], index=1, format_func=lambda x: f"Last {x} days") | |
# Get analytics data | |
try: | |
stats = get_analytics_stats(days=days_filter) | |
performance = get_method_performance() | |
recent_queries = analytics_db.get_recent_queries(limit=10) | |
# Overview Metrics | |
st.markdown("#### π Overview") | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.metric( | |
"Total Queries", | |
stats.get('total_queries', 0), | |
help="All queries processed in the selected time period" | |
) | |
with col2: | |
avg_citations = stats.get('avg_citations', 0) | |
st.metric( | |
"Avg Citations", | |
f"{avg_citations:.1f}", | |
help="Average number of citations per query" | |
) | |
with col3: | |
error_rate = stats.get('error_rate', 0) | |
st.metric( | |
"Success Rate", | |
f"{100 - error_rate:.1f}%", | |
delta=f"-{error_rate:.1f}% errors" if error_rate > 0 else None, | |
help="Percentage of successful queries" | |
) | |
with col4: | |
total_citations = stats.get('total_citations', 0) | |
st.metric( | |
"Total Citations", | |
total_citations, | |
help="Total citations generated across all queries" | |
) | |
# Method Performance Comparison | |
if performance: | |
st.markdown("#### β‘ Method Performance") | |
perf_data = [] | |
for method, metrics in performance.items(): | |
perf_data.append({ | |
'Method': method.upper(), | |
'Avg Response Time (ms)': f"{metrics['avg_response_time']:.0f}", | |
'Avg Citations': f"{metrics['avg_citations']:.1f}", | |
'Avg Answer Length': f"{metrics['avg_answer_length']:.0f}", | |
'Query Count': int(metrics['query_count']) | |
}) | |
if perf_data: | |
st.dataframe(perf_data, use_container_width=True, hide_index=True) | |
# Method Usage with Voice Interaction Indicator | |
method_usage = stats.get('method_usage', {}) | |
if method_usage: | |
st.markdown("#### π― Method Usage Distribution") | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.bar_chart(method_usage) | |
with col2: | |
st.markdown("**Most Popular Methods:**") | |
sorted_methods = sorted(method_usage.items(), key=lambda x: x[1], reverse=True) | |
for i, (method, count) in enumerate(sorted_methods[:3], 1): | |
percentage = (count / sum(method_usage.values())) * 100 | |
st.markdown(f"{i}. **{method.upper()}** - {count} queries ({percentage:.1f}%)") | |
# Voice interaction stats | |
try: | |
voice_queries = analytics_db.get_voice_interaction_stats() | |
if voice_queries and voice_queries.get('total_voice_queries', 0) > 0: | |
st.markdown("---") | |
st.markdown("**π€ Voice Interactions:**") | |
st.markdown(f"π Voice queries: {voice_queries['total_voice_queries']}") | |
if voice_queries.get('avg_voice_response_time', 0) > 0: | |
st.markdown(f"β±οΈ Avg response time: {voice_queries['avg_voice_response_time']:.1f}ms") | |
if sum(method_usage.values()) > 0: | |
voice_percentage = (voice_queries['total_voice_queries'] / sum(method_usage.values())) * 100 | |
st.markdown(f"π Voice usage: {voice_percentage:.1f}%") | |
except Exception as e: | |
logger.error(f"Voice stats error: {e}") | |
pass | |
# Voice Analytics Section (if voice interactions exist) | |
try: | |
voice_queries = analytics_db.get_voice_interaction_stats() | |
if voice_queries and voice_queries.get('total_voice_queries', 0) > 0: | |
st.markdown("#### π€ Voice Interaction Analytics") | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
voice_by_method = voice_queries.get('voice_by_method', {}) | |
if voice_by_method: | |
st.bar_chart(voice_by_method) | |
else: | |
st.info("No voice method breakdown available yet") | |
with col2: | |
st.markdown("**Voice Stats:**") | |
total_voice = voice_queries['total_voice_queries'] | |
st.markdown(f"π Total voice queries: {total_voice}") | |
avg_response = voice_queries.get('avg_voice_response_time', 0) | |
if avg_response > 0: | |
st.markdown(f"β±οΈ Avg response: {avg_response:.1f}ms") | |
# Most used voice method | |
if voice_by_method: | |
most_used_voice = max(voice_by_method.items(), key=lambda x: x[1]) | |
st.markdown(f"π― Top voice method: {most_used_voice[0].upper()}") | |
except Exception as e: | |
logger.error(f"Voice analytics error: {e}") | |
# Citation Analysis | |
citation_types = stats.get('citation_types', {}) | |
if citation_types: | |
st.markdown("#### π Citation Sources") | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
# Filter out empty/null citation types | |
filtered_citations = {k: v for k, v in citation_types.items() if k and k.strip()} | |
if filtered_citations: | |
st.bar_chart(filtered_citations) | |
with col2: | |
st.markdown("**Source Breakdown:**") | |
total_citations = sum(citation_types.values()) | |
for cite_type, count in sorted(citation_types.items(), key=lambda x: x[1], reverse=True): | |
if cite_type and cite_type.strip(): | |
percentage = (count / total_citations) * 100 | |
icon = "π" if cite_type == "pdf" else "π" if cite_type == "html" else "πΌοΈ" if cite_type == "image" else "π" | |
st.markdown(f"{icon} **{cite_type.title()}**: {count} ({percentage:.1f}%)") | |
# Popular Keywords | |
keywords = stats.get('top_keywords', {}) | |
if keywords: | |
st.markdown("#### π Popular Query Topics") | |
col1, col2, col3 = st.columns(3) | |
keyword_items = list(keywords.items()) | |
for i, (word, count) in enumerate(keyword_items[:9]): # Top 9 keywords | |
col = [col1, col2, col3][i % 3] | |
with col: | |
st.metric(word.title(), count) | |
# Recent Queries with Responses | |
if recent_queries: | |
st.markdown("#### π Recent Queries & Responses") | |
for query in recent_queries[:5]: # Show last 5 | |
# Create expander title with query preview | |
query_preview = query['query'][:60] + "..." if len(query['query']) > 60 else query['query'] | |
expander_title = f"π§ **{query['method'].upper()}**: {query_preview}" | |
with st.expander(expander_title): | |
# Query details | |
st.markdown(f"**π Full Query:** {query['query']}") | |
# Metrics row | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.metric("Answer Length", f"{query['answer_length']} chars") | |
with col2: | |
st.metric("Citations", query['citations']) | |
with col3: | |
if query['response_time']: | |
st.metric("Response Time", f"{query['response_time']:.0f}ms") | |
else: | |
st.metric("Response Time", "N/A") | |
with col4: | |
status = "β Error" if query.get('error_message') else "β Success" | |
st.markdown(f"**Status:** {status}") | |
# Show error message if exists | |
if query.get('error_message'): | |
st.error(f"**Error:** {query['error_message']}") | |
else: | |
# Show answer in a styled container | |
st.markdown("**π€ Response:**") | |
answer = query.get('answer', 'No answer available') | |
# Truncate very long answers for better UX | |
if len(answer) > 1000: | |
st.markdown( | |
f'<div style="background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #28a745;">' | |
f'{answer[:800].replace(chr(10), "<br>")}<br><br>' | |
f'<i>... (truncated, showing first 800 chars of {len(answer)} total)</i>' | |
f'</div>', | |
unsafe_allow_html=True | |
) | |
# Option to view full answer | |
if st.button(f"π View Full Answer", key=f"full_answer_{query['query_id']}"): | |
st.markdown("**Full Answer:**") | |
st.markdown( | |
f'<div style="background-color: #f8f9fa; padding: 15px; border-radius: 8px; max-height: 400px; overflow-y: auto;">' | |
f'{answer.replace(chr(10), "<br>")}' | |
f'</div>', | |
unsafe_allow_html=True | |
) | |
else: | |
# Short answers display fully | |
st.markdown( | |
f'<div style="background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #28a745;">' | |
f'{answer.replace(chr(10), "<br>")}' | |
f'</div>', | |
unsafe_allow_html=True | |
) | |
# Show detailed citation info | |
if query['citations'] > 0: | |
if st.button(f"π View Citations", key=f"citations_{query['query_id']}"): | |
detailed_query = analytics_db.get_query_with_citations(query['query_id']) | |
if detailed_query and 'citations' in detailed_query: | |
st.markdown("**Citations:**") | |
for i, citation in enumerate(detailed_query['citations'], 1): | |
scores = [] | |
if citation.get('relevance_score'): | |
scores.append(f"relevance: {citation['relevance_score']:.3f}") | |
if citation.get('bm25_score'): | |
scores.append(f"BM25: {citation['bm25_score']:.3f}") | |
if citation.get('rerank_score'): | |
scores.append(f"rerank: {citation['rerank_score']:.3f}") | |
score_text = f" ({', '.join(scores)})" if scores else "" | |
st.markdown(f"{i}. **{citation['source']}** {score_text}") | |
st.markdown(f"**π Timestamp:** {query['timestamp']}") | |
st.markdown("---") | |
# Session Info | |
st.markdown("---") | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
st.markdown("*Analytics are updated in real-time and persist across sessions*") | |
with col2: | |
st.markdown(f"**Session ID:** `{st.session_state.session_id}`") | |
except Exception as e: | |
st.error(f"Error loading analytics: {e}") | |
st.info("Analytics data will appear after your first query. The database is created automatically.") | |
# Fallback to session analytics | |
if st.session_state.chat_history: | |
st.markdown("#### π Current Session") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.metric("Session Queries", len(st.session_state.chat_history)) | |
with col2: | |
methods_used = [entry['method'] for entry in st.session_state.chat_history] | |
most_used = max(set(methods_used), key=methods_used.count) if methods_used else "N/A" | |
st.metric("Most Used Method", most_used.upper() if most_used != "N/A" else most_used) | |
# Full Screen Comparison Window (Modal-like) | |
if st.session_state.get('show_comparison_window', False): | |
st.markdown("---") | |
# Header with close button | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
comparison_data = st.session_state.comparison_results | |
st.markdown(f"## πͺ Full Screen Comparison") | |
st.markdown(f"**Query:** {comparison_data['query']}") | |
st.markdown(f"**Generated:** {comparison_data['timestamp']} | **Methods:** {', '.join([m.upper() for m in comparison_data['methods']])}") | |
with col2: | |
if st.button("βοΈ Close", help="Close full screen view"): | |
st.session_state.show_comparison_window = False | |
st.rerun() | |
st.markdown("---") | |
# Full-width comparison display | |
results = comparison_data['results'] | |
methods = comparison_data['methods'] | |
for method in methods: | |
st.markdown(f"### πΈ {method.upper()} Method") | |
# Answer | |
answer = results[method]['answer'] | |
st.markdown("**Answer:**") | |
# Use a container with custom styling for better readability | |
with st.container(): | |
st.markdown( | |
f'<div style="background-color: #f0f2f6; padding: 20px; border-radius: 10px; margin: 10px 0; border-left: 5px solid #1f77b4;">' | |
f'{answer.replace(chr(10), "<br>")}' | |
f'</div>', | |
unsafe_allow_html=True | |
) | |
# Citations | |
st.markdown("**Citations:**") | |
st.markdown(format_citations_html(results[method]['citations'], method), unsafe_allow_html=True) | |
# Statistics | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Answer Length", f"{len(answer)} chars") | |
with col2: | |
st.metric("Citations", len(results[method]['citations'])) | |
with col3: | |
word_count = len(answer.split()) | |
st.metric("Word Count", word_count) | |
if method != methods[-1]: # Not the last method | |
st.markdown("---") | |
# Summary comparison table | |
st.markdown("### π Method Comparison Summary") | |
summary_data = [] | |
for method in methods: | |
summary_data.append({ | |
'Method': method.upper(), | |
'Answer Length (chars)': len(results[method]['answer']), | |
'Word Count': len(results[method]['answer'].split()), | |
'Citations': len(results[method]['citations']), | |
'Avg Citation Score': round( | |
sum(float(c.get('relevance_score', 0) or c.get('score', 0) or 0) | |
for c in results[method]['citations']) / len(results[method]['citations']) | |
if results[method]['citations'] else 0, 3 | |
) | |
}) | |
st.dataframe(summary_data, use_container_width=True, hide_index=True) | |
st.markdown("---") | |
# Return to normal view button | |
col1, col2, col3 = st.columns([2, 1, 2]) | |
with col2: | |
if st.button("β¬ οΈ Back to Comparison Tab", type="primary", use_container_width=True): | |
st.session_state.show_comparison_window = False | |
st.rerun() | |
st.stop() # Stop rendering the rest of the app when in full screen mode | |
# Footer | |
st.markdown("---") | |
st.markdown( | |
"**β οΈ Disclaimer:** *This system uses AI to retrieve and generate responses. " | |
"While we strive for accuracy, please verify critical safety information with official sources.*" | |
) | |
st.markdown( | |
"**π Acknowledgment:** *We thank [Ohio BWC/WSIC](https://info.bwc.ohio.gov/) " | |
"for funding that made this multi-method RAG system possible.*" | |
) | |