Spaces:
Running
Running
""" | |
Main Gradio interface for the Professional RAG Assistant. | |
""" | |
import gradio as gr | |
import asyncio | |
import threading | |
import time | |
import json | |
import sys | |
import signal | |
import logging | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
from pathlib import Path | |
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError | |
from .themes import get_theme, get_custom_css | |
from .components import ( | |
create_header, create_file_upload_section, create_search_interface, | |
create_results_display, create_document_management, create_system_status, | |
create_analytics_dashboard, format_document_list, format_search_results, | |
create_analytics_charts, format_system_overview, create_error_display, | |
create_success_display, create_loading_display | |
) | |
from .utils import ( | |
save_uploaded_files, validate_file_types, cleanup_temp_files, | |
generate_session_id, format_file_size, parse_search_filters | |
) | |
sys.path.append(str(Path(__file__).parent.parent)) | |
from src.rag_system import RAGSystem | |
from src.error_handler import RAGError | |
class RAGInterface: | |
"""Main interface class for the RAG system.""" | |
def __init__(self, config_path: str = None): | |
"""Initialize the RAG interface.""" | |
self.rag_system: Optional[RAGSystem] = None | |
self.config_path = config_path or "config.yaml" | |
self.active_sessions: Dict[str, Dict[str, Any]] = {} | |
self._initialization_lock = threading.Lock() | |
self._initialized = False | |
# Setup logging | |
self.logger = logging.getLogger(__name__) | |
# Initialize RAG system | |
self._initialize_rag_system() | |
def _initialize_rag_system(self) -> None: | |
"""Initialize the RAG system.""" | |
try: | |
with self._initialization_lock: | |
if not self._initialized: | |
print("Initializing RAG system...") | |
self.rag_system = RAGSystem(config_path=self.config_path) | |
# Warm up the system | |
warmup_result = self.rag_system.warmup() | |
if warmup_result.get("success"): | |
print("RAG system initialized and warmed up successfully") | |
self._initialized = True | |
else: | |
print(f"RAG system warmup failed: {warmup_result.get('error', {}).get('message')}") | |
except Exception as e: | |
print(f"Failed to initialize RAG system: {e}") | |
self.rag_system = None | |
def get_system_status(self) -> str: | |
"""Get current system status HTML.""" | |
if not self.rag_system or not self._initialized: | |
return create_error_display("System not initialized. Please check configuration and restart.") | |
try: | |
stats_response = self.rag_system.get_system_stats() | |
if not stats_response.get("success"): | |
return create_error_display(f"Failed to get system status: {stats_response.get('error', {}).get('message')}") | |
stats = stats_response["data"] | |
status_info = stats.get("status", {}) | |
if status_info.get("ready"): | |
status_message = f"System ready - {status_info.get('documents_indexed', 0)} documents indexed" | |
return create_success_display(status_message) | |
else: | |
return create_error_display("System not ready") | |
except Exception as e: | |
return create_error_display(f"Error getting system status: {str(e)}") | |
def process_documents( | |
self, | |
files: List[gr.File], | |
session_id: str, | |
progress=gr.Progress() | |
) -> Tuple[str, str, bool, str]: | |
"""Process uploaded documents.""" | |
if not files: | |
return ( | |
create_error_display("No files uploaded"), | |
create_error_display("Please select files to upload"), | |
False, # upload button disabled | |
"No documents uploaded yet." | |
) | |
if not self.rag_system or not self._initialized: | |
return ( | |
create_error_display("System not initialized"), | |
create_error_display("Please restart the application"), | |
False, | |
"No documents uploaded yet." | |
) | |
try: | |
self.logger.info(f"Starting document upload process - {len(files)} files received") | |
# Validate file types | |
allowed_extensions = [".pdf", ".docx", ".txt"] | |
valid_files, validation_errors = validate_file_types(files, allowed_extensions) | |
if validation_errors: | |
self.logger.warning(f"File validation errors: {validation_errors}") | |
error_html = create_error_display("\\n".join(validation_errors)) | |
return error_html, error_html, len(valid_files) > 0, self.get_document_list() | |
self.logger.info(f"File validation passed - {len(valid_files)} valid files") | |
# Save uploaded files | |
progress(0.1, desc="Saving uploaded files...") | |
self.logger.info("Saving uploaded files to temporary directory...") | |
saved_files = save_uploaded_files(valid_files) | |
for file_path, original_name in saved_files: | |
file_size = Path(file_path).stat().st_size / (1024 * 1024) # Size in MB | |
self.logger.info(f"Saved file: {original_name} ({file_size:.2f} MB) -> {file_path}") | |
if not saved_files: | |
return ( | |
create_error_display("No valid files to process"), | |
create_error_display("Please check your files and try again"), | |
False, | |
self.get_document_list() | |
) | |
# Process each file with timeout | |
processed_count = 0 | |
total_files = len(saved_files) | |
results = [] | |
timeout_seconds = 600 # 10 minutes | |
def process_single_file(file_path, original_name, session_id): | |
"""Process a single file - to be run with timeout.""" | |
self.logger.info(f"Processing file: {original_name}") | |
start_time = time.time() | |
result = self.rag_system.add_document( | |
file_path=file_path, | |
filename=original_name, | |
user_session=session_id | |
) | |
processing_time = time.time() - start_time | |
self.logger.info(f"File processing completed: {original_name} (took {processing_time:.2f}s)") | |
return result | |
self.logger.info(f"Starting processing of {total_files} files with {timeout_seconds//60}-minute timeout per file") | |
with ThreadPoolExecutor(max_workers=1) as executor: | |
for i, (file_path, original_name) in enumerate(saved_files): | |
progress((i + 1) / total_files * 0.8 + 0.1, desc=f"Processing {original_name}...") | |
self.logger.info(f"Processing file {i+1}/{total_files}: {original_name}") | |
try: | |
# Submit the task with timeout | |
future = executor.submit(process_single_file, file_path, original_name, session_id) | |
result = future.result(timeout=timeout_seconds) | |
if result.get("success"): | |
processed_count += 1 | |
chunks_created = result['data']['chunks_created'] | |
# Log detailed success info | |
self.logger.info(f"SUCCESS: {original_name} - {chunks_created} chunks created") | |
# Log sample chunk info if available | |
if 'sample_chunks' in result['data']: | |
sample_chunks = result['data']['sample_chunks'] | |
self.logger.info(f"Sample chunks from {original_name}:") | |
for idx, chunk in enumerate(sample_chunks[:3]): # Show first 3 chunks | |
chunk_preview = chunk['content'][:100] + "..." if len(chunk['content']) > 100 else chunk['content'] | |
self.logger.info(f" Chunk {idx}: {chunk_preview}") | |
results.append(f"✅ {original_name}: {chunks_created} chunks created") | |
else: | |
error_msg = result.get("error", {}).get("message", "Unknown error") | |
self.logger.error(f"FAILED: {original_name} - {error_msg}") | |
results.append(f"❌ {original_name}: {error_msg}") | |
except FutureTimeoutError: | |
self.logger.error(f"TIMEOUT: {original_name} exceeded {timeout_seconds//60} minute limit") | |
results.append(f"⏰ {original_name}: Processing timed out after {timeout_seconds//60} minutes") | |
future.cancel() # Cancel the task if possible | |
except Exception as e: | |
self.logger.error(f"EXCEPTION: {original_name} - {str(e)}") | |
results.append(f"❌ {original_name}: {str(e)}") | |
progress(1.0, desc="Cleaning up...") | |
self.logger.info("Cleaning up temporary files...") | |
# Clean up temporary files | |
cleanup_temp_files([fp for fp, _ in saved_files]) | |
# Log final summary | |
total_processing_time = time.time() - time.time() # This will be updated properly | |
self.logger.info(f"Document upload process completed:") | |
self.logger.info(f" - Total files: {total_files}") | |
self.logger.info(f" - Successfully processed: {processed_count}") | |
self.logger.info(f" - Failed: {total_files - processed_count}") | |
self.logger.info(f" - Success rate: {(processed_count/total_files*100):.1f}%") | |
# Create result message | |
if processed_count == total_files: | |
self.logger.info(f"✅ ALL DOCUMENTS PROCESSED SUCCESSFULLY ({processed_count}/{total_files})") | |
status_html = create_success_display( | |
f"Successfully processed {processed_count} documents:\\n" + "\\n".join(results) | |
) | |
upload_status = create_success_display(f"All {processed_count} documents processed successfully!") | |
elif processed_count > 0: | |
self.logger.warning(f"⚠️ PARTIAL SUCCESS ({processed_count}/{total_files} documents processed)") | |
status_html = f""" | |
<div style='background: #fef3c7; border: 1px solid #f59e0b; border-radius: 8px; padding: 1rem; margin: 1rem 0;'> | |
<div style='font-weight: 600; color: #92400e; margin-bottom: 0.5rem;'> | |
⚠️ Partially successful ({processed_count}/{total_files} files processed) | |
</div> | |
<div style='color: #78350f; font-size: 0.9rem;'>{"<br>".join(results)}</div> | |
</div> | |
""" | |
upload_status = status_html | |
else: | |
self.logger.error(f"❌ NO DOCUMENTS PROCESSED SUCCESSFULLY (0/{total_files})") | |
status_html = create_error_display( | |
f"Failed to process any documents:\\n" + "\\n".join(results) | |
) | |
upload_status = create_error_display("Document processing failed") | |
return ( | |
status_html, | |
upload_status, | |
gr.update(interactive=True), # Enable search button | |
self.get_document_list() | |
) | |
except Exception as e: | |
# Clean up on error | |
try: | |
if 'saved_files' in locals(): | |
cleanup_temp_files([fp for fp, _ in saved_files]) | |
except: | |
pass | |
error_message = f"Document processing failed: {str(e)}" | |
error_html = create_error_display(error_message) | |
return error_html, error_html, gr.update(interactive=False), self.get_document_list() | |
def perform_search( | |
self, | |
query: str, | |
search_mode: str, | |
num_results: int, | |
enable_reranking: bool, | |
metadata_filters: str, | |
session_id: str | |
) -> Tuple[str, str, str]: | |
"""Perform search and return results.""" | |
if not self.rag_system or not self._initialized: | |
error_html = create_error_display("System not initialized") | |
return error_html, "{}", "" | |
if not query or not query.strip(): | |
error_html = create_error_display("Please enter a search query") | |
return error_html, "{}", "" | |
try: | |
# Parse metadata filters | |
filters = parse_search_filters(metadata_filters) if metadata_filters else None | |
# Perform search | |
result = self.rag_system.search( | |
query=query.strip(), | |
k=num_results, | |
search_mode=search_mode, | |
enable_reranking=enable_reranking, | |
metadata_filter=filters, | |
user_session=session_id | |
) | |
if not result.get("success"): | |
error_msg = result.get("error", {}).get("message", "Search failed") | |
error_html = create_error_display(error_msg) | |
return error_html, "{}", "" | |
# Format results | |
search_data = result["data"] | |
results = search_data.get("results", []) | |
search_time = search_data.get("search_time", 0) | |
# Create HTML display | |
results_html, stats_html = format_search_results(results, search_time, query) | |
# Create JSON data for detailed view | |
json_data = { | |
"query": query, | |
"search_mode": search_mode, | |
"results_count": len(results), | |
"search_time": search_time, | |
"results": results[:5], # Limit JSON display | |
"query_suggestions": search_data.get("query_suggestions", []) | |
} | |
return results_html, json.dumps(json_data, indent=2), stats_html | |
except Exception as e: | |
error_html = create_error_display(f"Search failed: {str(e)}") | |
return error_html, "{}", "" | |
def get_document_list(self) -> str: | |
"""Get formatted document list.""" | |
if not self.rag_system or not self._initialized: | |
return "<div style='text-align: center; color: #6b7280; padding: 1rem;'>System not initialized</div>" | |
try: | |
result = self.rag_system.get_document_list() | |
if result.get("success"): | |
documents = result["data"]["documents"] | |
return format_document_list(documents) | |
else: | |
return create_error_display("Failed to load document list") | |
except Exception as e: | |
return create_error_display(f"Error loading documents: {str(e)}") | |
def clear_documents(self) -> Tuple[str, str]: | |
"""Clear all documents.""" | |
if not self.rag_system or not self._initialized: | |
error_html = create_error_display("System not initialized") | |
return error_html, error_html | |
try: | |
result = self.rag_system.clear_all_documents() | |
if result.get("success"): | |
success_msg = f"Cleared {result['data']['documents_removed']} documents" | |
success_html = create_success_display(success_msg) | |
return success_html, self.get_document_list() | |
else: | |
error_msg = result.get("error", {}).get("message", "Failed to clear documents") | |
error_html = create_error_display(error_msg) | |
return error_html, self.get_document_list() | |
except Exception as e: | |
error_html = create_error_display(f"Error clearing documents: {str(e)}") | |
return error_html, self.get_document_list() | |
def get_analytics_data(self) -> Tuple[str, gr.Plot, gr.Plot, List[List[str]]]: | |
"""Get analytics dashboard data.""" | |
if not self.rag_system or not self._initialized: | |
return ( | |
create_error_display("System not initialized"), | |
gr.Plot(), | |
gr.Plot(), | |
[] | |
) | |
try: | |
result = self.rag_system.get_analytics_dashboard() | |
if not result.get("success"): | |
error_html = create_error_display("Failed to load analytics data") | |
return error_html, gr.Plot(), gr.Plot(), [] | |
analytics_data = result["data"] | |
# Format system overview | |
overview_html = format_system_overview(analytics_data) | |
# Create charts | |
query_chart, modes_chart = create_analytics_charts(analytics_data) | |
# Create activity table data | |
activity_data = [] | |
system_data = analytics_data.get("system", {}) | |
activity_data.append([ | |
"System Started", | |
"System Initialization", | |
f"Uptime: {system_data.get('uptime_hours', 0):.1f} hours", | |
"✅ Active" | |
]) | |
if system_data.get("total_queries", 0) > 0: | |
activity_data.append([ | |
"Recent", | |
"Search Queries", | |
f"{system_data.get('total_queries')} total queries", | |
"📊 Active" | |
]) | |
if system_data.get("total_documents_processed", 0) > 0: | |
activity_data.append([ | |
"Recent", | |
"Document Processing", | |
f"{system_data.get('total_documents_processed')} documents processed", | |
"📄 Complete" | |
]) | |
return overview_html, query_chart, modes_chart, activity_data | |
except Exception as e: | |
error_html = create_error_display(f"Error loading analytics: {str(e)}") | |
return error_html, gr.Plot(), gr.Plot(), [] | |
def create_interface(self) -> gr.Blocks: | |
"""Create the main Gradio interface.""" | |
theme = get_theme() | |
css = get_custom_css() | |
with gr.Blocks( | |
theme=theme, | |
css=css, | |
title="Professional RAG Assistant", | |
analytics_enabled=False | |
) as interface: | |
# Session state | |
session_id_state = gr.State(value=generate_session_id()) | |
# Header | |
create_header() | |
# System status | |
system_status = create_system_status() | |
# Main tabs | |
with gr.Tabs() as main_tabs: | |
# Document Upload Tab | |
with gr.Tab("📁 Document Upload", id="upload"): | |
gr.Markdown("Upload your documents to build the knowledge base. Supports PDF, DOCX, and TXT files.") | |
file_upload, upload_status, upload_button = create_file_upload_section() | |
with gr.Accordion("Upload Settings", open=False): | |
gr.Markdown(""" | |
**Supported formats:** PDF, DOCX, TXT | |
**Maximum file size:** 50MB per file | |
**Processing:** Documents are split into chunks and indexed for search | |
""") | |
# Search Tab | |
with gr.Tab("🔍 Search", id="search"): | |
gr.Markdown("Search your uploaded documents using advanced AI-powered retrieval.") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
search_components = create_search_interface() | |
search_query, search_controls, search_button = search_components[:3] | |
search_mode, num_results, enable_reranking = search_components[3:] | |
with gr.Column(scale=1): | |
with gr.Accordion("Advanced Options", open=False): | |
metadata_filters = gr.Textbox( | |
label="Metadata Filters", | |
placeholder='{"source": "document.pdf"}', | |
lines=3, | |
info="JSON or key:value,key2:value2 format" | |
) | |
# Results display | |
results_html, results_json, search_stats = create_results_display() | |
with gr.Accordion("Detailed Results (JSON)", open=False): | |
results_json | |
# Document Management Tab | |
with gr.Tab("📚 Documents", id="documents"): | |
gr.Markdown("Manage your uploaded documents and view indexing status.") | |
document_list, refresh_docs_btn, clear_docs_btn = create_document_management() | |
# Analytics Tab | |
with gr.Tab("📊 Analytics", id="analytics"): | |
gr.Markdown("View system performance and usage analytics.") | |
analytics_components = create_analytics_dashboard() | |
system_overview, query_chart, search_modes_chart, activity_table = analytics_components | |
with gr.Row(): | |
with gr.Column(): | |
query_chart | |
with gr.Column(): | |
search_modes_chart | |
with gr.Accordion("Recent Activity", open=False): | |
activity_table | |
refresh_analytics_btn = gr.Button("Refresh Analytics", variant="secondary") | |
# Event handlers | |
# File upload events | |
file_upload.change( | |
fn=lambda files: ( | |
create_success_display(f"✅ {len(files)} file(s) selected! Click the green '🚀 Process Documents' button below to continue.") if files and len(files) > 0 else create_loading_display("No files selected"), | |
gr.update(interactive=files is not None and len(files) > 0) | |
), | |
inputs=[file_upload], | |
outputs=[upload_status, upload_button], | |
show_progress=False | |
) | |
upload_button.click( | |
fn=self.process_documents, | |
inputs=[file_upload, session_id_state], | |
outputs=[upload_status, system_status, search_button, document_list], | |
show_progress=True | |
) | |
# Search events | |
search_query.change( | |
fn=lambda query: gr.update(interactive=len(query.strip()) > 0 if query else False), | |
inputs=[search_query], | |
outputs=[search_button], | |
show_progress=False | |
) | |
search_button.click( | |
fn=lambda: create_loading_display("Searching..."), | |
inputs=[], | |
outputs=[results_html], | |
show_progress=False | |
).then( | |
fn=self.perform_search, | |
inputs=[ | |
search_query, search_mode, num_results, | |
enable_reranking, metadata_filters, session_id_state | |
], | |
outputs=[results_html, results_json, search_stats], | |
show_progress=True | |
) | |
# Document management events | |
refresh_docs_btn.click( | |
fn=self.get_document_list, | |
inputs=[], | |
outputs=[document_list], | |
show_progress=False | |
) | |
clear_docs_btn.click( | |
fn=self.clear_documents, | |
inputs=[], | |
outputs=[system_status, document_list], | |
show_progress=True | |
) | |
# Analytics events | |
refresh_analytics_btn.click( | |
fn=self.get_analytics_data, | |
inputs=[], | |
outputs=[system_overview, query_chart, search_modes_chart, activity_table], | |
show_progress=True | |
) | |
# Initialize interface | |
interface.load( | |
fn=lambda: ( | |
self.get_system_status(), | |
gr.update(interactive=False), # Upload button disabled initially | |
gr.update(interactive=False), # Search button disabled initially | |
self.get_document_list(), | |
*self.get_analytics_data() | |
), | |
inputs=[], | |
outputs=[ | |
system_status, upload_button, search_button, document_list, | |
system_overview, query_chart, search_modes_chart, activity_table | |
], | |
show_progress=False | |
) | |
return interface | |
def create_interface(config_path: str = None) -> gr.Blocks: | |
"""Create and return the RAG interface.""" | |
rag_interface = RAGInterface(config_path) | |
return rag_interface.create_interface() | |
if __name__ == "__main__": | |
# For testing | |
interface = create_interface() | |
interface.launch(debug=True) |