RAG_ChatBot / ui /components.py
Jialun He
UI
efe948f
"""
Reusable UI components for the Gradio interface.
"""
import gradio as gr
import json
import time
from typing import Any, Dict, List, Optional, Tuple, Callable
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from .themes import create_info_card, create_status_indicator, format_search_result, create_progress_bar
def create_header() -> gr.HTML:
"""Create the header section."""
header_html = """
<div class="header-container">
<h1 class="header-title">Professional RAG Assistant</h1>
<p class="header-description">
Upload documents and ask questions with AI-powered retrieval and generation.
Supports PDF, DOCX, and TXT files with advanced search capabilities.
</p>
</div>
"""
return gr.HTML(header_html)
def create_file_upload_section() -> Tuple[gr.File, gr.HTML, gr.Button]:
"""Create file upload section."""
file_upload = gr.File(
label="Upload Documents",
file_types=[".pdf", ".docx", ".txt"],
file_count="multiple",
interactive=True,
height=150
)
upload_status = gr.HTML(
create_status_indicator("ready", "Ready to upload documents"),
visible=True
)
# Make the button more prominent by putting it in a separate row
with gr.Row():
with gr.Column(scale=1):
gr.HTML("") # Empty space
with gr.Column(scale=2):
upload_button = gr.Button(
"🚀 Process Documents",
variant="primary",
size="lg",
interactive=False,
elem_classes=["process-button"]
)
with gr.Column(scale=1):
gr.HTML("") # Empty space
return file_upload, upload_status, upload_button
def create_search_interface() -> Tuple[gr.Textbox, gr.Row, gr.Button]:
"""Create search interface components."""
search_query = gr.Textbox(
label="Search Query",
placeholder="Ask a question about your documents...",
lines=2,
max_lines=4,
interactive=True,
scale=4
)
with gr.Row() as search_controls:
with gr.Column(scale=1):
search_mode = gr.Dropdown(
choices=["hybrid", "vector", "bm25"],
value="hybrid",
label="Search Mode",
interactive=True
)
with gr.Column(scale=1):
num_results = gr.Slider(
minimum=1,
maximum=20,
value=10,
step=1,
label="Number of Results",
interactive=True
)
with gr.Column(scale=1):
enable_reranking = gr.Checkbox(
label="Enable Re-ranking",
value=True,
interactive=True
)
search_button = gr.Button(
"Search",
variant="primary",
size="lg",
interactive=False
)
return search_query, search_controls, search_button, search_mode, num_results, enable_reranking
def create_results_display() -> Tuple[gr.HTML, gr.JSON, gr.HTML]:
"""Create results display components."""
results_html = gr.HTML(
"<div style='text-align: center; color: #6b7280; padding: 2rem;'>No search results yet. Upload documents and try searching!</div>",
visible=True
)
results_json = gr.JSON(
label="Detailed Results (JSON)",
visible=False
)
search_stats = gr.HTML(visible=False)
return results_html, results_json, search_stats
def create_document_management() -> Tuple[gr.HTML, gr.Button, gr.Button]:
"""Create document management interface."""
document_list = gr.HTML(
"<div style='text-align: center; color: #6b7280; padding: 1rem;'>No documents uploaded yet.</div>"
)
with gr.Row():
refresh_docs_btn = gr.Button("Refresh List", variant="secondary")
clear_docs_btn = gr.Button("Clear All Documents", variant="stop")
return document_list, refresh_docs_btn, clear_docs_btn
def create_system_status() -> gr.HTML:
"""Create system status display."""
return gr.HTML(
create_status_indicator("loading", "Initializing system..."),
visible=True
)
def create_analytics_dashboard() -> Tuple[gr.HTML, gr.Plot, gr.Plot, gr.Dataframe]:
"""Create analytics dashboard components."""
# System overview cards
system_overview = gr.HTML(
"<div style='display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 1rem; margin-bottom: 2rem;'></div>"
)
# Query analytics chart
query_chart = gr.Plot(
label="Queries Over Time",
visible=False
)
# Search modes chart
search_modes_chart = gr.Plot(
label="Search Modes Distribution",
visible=False
)
# Recent activity table
activity_table = gr.Dataframe(
headers=["Timestamp", "Activity", "Details", "Status"],
label="Recent Activity",
visible=False
)
return system_overview, query_chart, search_modes_chart, activity_table
def format_document_list(documents: List[Dict[str, Any]]) -> str:
"""Format document list as HTML."""
if not documents:
return "<div style='text-align: center; color: #6b7280; padding: 1rem;'>No documents uploaded yet.</div>"
html_parts = ["<div style='space-y: 1rem;'>"]
for doc in documents:
filename = doc.get("filename", "Unknown")
chunk_count = doc.get("chunk_count", 0)
file_type = doc.get("file_type", "unknown").upper()
file_size = doc.get("file_size", 0)
# Format file size
if file_size > 1024 * 1024:
size_str = f"{file_size / (1024 * 1024):.1f} MB"
elif file_size > 1024:
size_str = f"{file_size / 1024:.1f} KB"
else:
size_str = f"{file_size} bytes"
doc_html = f"""
<div class="result-card" style="margin-bottom: 1rem;">
<div class="result-title">📄 {filename}</div>
<div class="result-metadata" style="margin-top: 0.5rem;">
<span class="metadata-tag">Type: {file_type}</span>
<span class="metadata-tag">Size: {size_str}</span>
<span class="metadata-tag">Chunks: {chunk_count}</span>
</div>
</div>
"""
html_parts.append(doc_html)
html_parts.append("</div>")
return "".join(html_parts)
def format_search_results(results: List[Dict[str, Any]], search_time: float, query: str) -> Tuple[str, str]:
"""Format search results as HTML and create search statistics."""
if not results:
results_html = """
<div style='text-align: center; color: #6b7280; padding: 2rem;'>
<div style='font-size: 1.5rem; margin-bottom: 1rem;'>🔍</div>
<div>No results found for your query.</div>
<div style='font-size: 0.875rem; margin-top: 0.5rem;'>Try different keywords or check your search settings.</div>
</div>
"""
stats_html = f"""
<div style='background: #fef3c7; border: 1px solid #f59e0b; border-radius: 8px; padding: 1rem; margin: 1rem 0;'>
<strong>Search completed</strong> in {search_time:.2f}s - No results found
</div>
"""
return results_html, stats_html
# Format results
results_parts = [f"<div style='margin-bottom: 1rem;'><h3 style='color: #374151;'>Search Results for: \"{query}\"</h3></div>"]
for i, result in enumerate(results, 1):
result_html = format_search_result(result, i)
results_parts.append(result_html)
results_html = "".join(results_parts)
# Create search statistics
avg_score = sum(r.get("scores", {}).get("final_score", 0) for r in results) / len(results)
stats_html = f"""
<div style='background: #d1fae5; border: 1px solid #10b981; border-radius: 8px; padding: 1rem; margin: 1rem 0;'>
<div style='display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 1rem; text-align: center;'>
<div>
<div style='font-weight: 600; color: #065f46;'>{len(results)}</div>
<div style='font-size: 0.875rem; color: #047857;'>Results Found</div>
</div>
<div>
<div style='font-weight: 600; color: #065f46;'>{search_time:.2f}s</div>
<div style='font-size: 0.875rem; color: #047857;'>Search Time</div>
</div>
<div>
<div style='font-weight: 600; color: #065f46;'>{avg_score:.3f}</div>
<div style='font-size: 0.875rem; color: #047857;'>Avg Score</div>
</div>
</div>
</div>
"""
return results_html, stats_html
def create_analytics_charts(analytics_data: Dict[str, Any]) -> Tuple[go.Figure, go.Figure]:
"""Create analytics charts."""
system_data = analytics_data.get("system", {})
queries_data = analytics_data.get("queries_24h", {})
# Queries over time chart
queries_per_hour = queries_data.get("queries_per_hour", [])
hours = list(range(len(queries_per_hour)))
query_fig = go.Figure()
query_fig.add_trace(go.Scatter(
x=hours,
y=queries_per_hour,
mode='lines+markers',
name='Queries per Hour',
line=dict(color='#667eea', width=3),
marker=dict(size=8, color='#667eea')
))
query_fig.update_layout(
title="Queries Over Time (24 Hours)",
xaxis_title="Hours Ago",
yaxis_title="Number of Queries",
template="plotly_white",
height=300
)
# Search modes distribution
search_modes = queries_data.get("search_modes", {})
if search_modes:
modes = list(search_modes.keys())
counts = list(search_modes.values())
modes_fig = go.Figure(data=[
go.Pie(
labels=modes,
values=counts,
hole=0.3,
marker=dict(colors=['#667eea', '#8b5cf6', '#06b6d4'])
)
])
modes_fig.update_layout(
title="Search Modes Distribution",
template="plotly_white",
height=300
)
else:
modes_fig = go.Figure()
modes_fig.update_layout(
title="Search Modes Distribution",
template="plotly_white",
height=300,
annotations=[dict(text="No data available", showarrow=False)]
)
return query_fig, modes_fig
def format_system_overview(analytics_data: Dict[str, Any]) -> str:
"""Format system overview cards."""
system_data = analytics_data.get("system", {})
cards_html = """
<div style='display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 1rem; margin-bottom: 2rem;'>
"""
# Total queries
total_queries = system_data.get("total_queries", 0)
cards_html += create_info_card("Total Queries", str(total_queries), "All-time search queries")
# Documents processed
total_docs = system_data.get("total_documents_processed", 0)
cards_html += create_info_card("Documents", str(total_docs), "Successfully processed")
# Uptime
uptime_hours = system_data.get("uptime_hours", 0)
uptime_str = f"{uptime_hours:.1f}h" if uptime_hours < 24 else f"{uptime_hours/24:.1f}d"
cards_html += create_info_card("Uptime", uptime_str, "System running time")
# Active sessions
active_sessions = system_data.get("active_sessions", 0)
cards_html += create_info_card("Active Users", str(active_sessions), "Current sessions")
cards_html += "</div>"
return cards_html
def create_progress_callback() -> Callable:
"""Create a progress callback function for document processing."""
def progress_callback(message: str, progress: float) -> str:
return create_progress_bar(progress, message)
return progress_callback
def create_error_display(error_message: str) -> str:
"""Create error display HTML."""
return f"""
<div style='background: #ef4444; border: 1px solid #dc2626; border-radius: 8px; padding: 1rem; margin: 1rem 0; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);'>
<div style='display: flex; align-items: center; gap: 0.5rem; color: #ffffff; font-weight: 700; margin-bottom: 0.5rem; text-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);'>
<span>⚠️</span>
<span>Error</span>
</div>
<div style='color: #ffffff; font-weight: 500; opacity: 0.95;'>{error_message}</div>
</div>
"""
def create_success_display(message: str) -> str:
"""Create success display HTML."""
return f"""
<div style='background: #10b981; border: 1px solid #059669; border-radius: 8px; padding: 1rem; margin: 1rem 0; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);'>
<div style='display: flex; align-items: center; gap: 0.5rem; color: #ffffff; font-weight: 700; margin-bottom: 0.5rem; text-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);'>
<span>✅</span>
<span>Success</span>
</div>
<div style='color: #ffffff; font-weight: 500; opacity: 0.95;'>{message}</div>
</div>
"""
def create_loading_display(message: str = "Processing...") -> str:
"""Create loading display HTML."""
return f"""
<div style='text-align: center; padding: 2rem;'>
<div class='loading-spinner' style='margin-bottom: 1rem;'></div>
<div style='color: #6b7280; font-weight: 500;'>{message}</div>
</div>
"""