|
""" |
|
Main Gradio interface for the Professional RAG Assistant. |
|
""" |
|
|
|
import gradio as gr |
|
import asyncio |
|
import threading |
|
import time |
|
import json |
|
import sys |
|
import signal |
|
import logging |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
from pathlib import Path |
|
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError |
|
|
|
from .themes import get_theme, get_custom_css |
|
from .components import ( |
|
create_header, create_file_upload_section, create_search_interface, |
|
create_results_display, create_document_management, create_system_status, |
|
create_analytics_dashboard, format_document_list, format_search_results, |
|
create_analytics_charts, format_system_overview, create_error_display, |
|
create_success_display, create_loading_display |
|
) |
|
from .utils import ( |
|
save_uploaded_files, validate_file_types, cleanup_temp_files, |
|
generate_session_id, format_file_size, parse_search_filters |
|
) |
|
|
|
sys.path.append(str(Path(__file__).parent.parent)) |
|
|
|
from src.rag_system import RAGSystem |
|
from src.error_handler import RAGError |
|
|
|
|
|
class RAGInterface: |
|
"""Main interface class for the RAG system.""" |
|
|
|
def __init__(self, config_path: str = None): |
|
"""Initialize the RAG interface.""" |
|
self.rag_system: Optional[RAGSystem] = None |
|
self.config_path = config_path or "config.yaml" |
|
self.active_sessions: Dict[str, Dict[str, Any]] = {} |
|
self._initialization_lock = threading.Lock() |
|
self._initialized = False |
|
|
|
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
self._initialize_rag_system() |
|
|
|
def _initialize_rag_system(self) -> None: |
|
"""Initialize the RAG system.""" |
|
try: |
|
with self._initialization_lock: |
|
if not self._initialized: |
|
print("Initializing RAG system...") |
|
self.rag_system = RAGSystem(config_path=self.config_path) |
|
|
|
|
|
warmup_result = self.rag_system.warmup() |
|
if warmup_result.get("success"): |
|
print("RAG system initialized and warmed up successfully") |
|
self._initialized = True |
|
else: |
|
print(f"RAG system warmup failed: {warmup_result.get('error', {}).get('message')}") |
|
|
|
except Exception as e: |
|
print(f"Failed to initialize RAG system: {e}") |
|
self.rag_system = None |
|
|
|
def get_system_status(self) -> str: |
|
"""Get current system status HTML.""" |
|
if not self.rag_system or not self._initialized: |
|
return create_error_display("System not initialized. Please check configuration and restart.") |
|
|
|
try: |
|
stats_response = self.rag_system.get_system_stats() |
|
if not stats_response.get("success"): |
|
return create_error_display(f"Failed to get system status: {stats_response.get('error', {}).get('message')}") |
|
|
|
stats = stats_response["data"] |
|
status_info = stats.get("status", {}) |
|
|
|
if status_info.get("ready"): |
|
status_message = f"System ready - {status_info.get('documents_indexed', 0)} documents indexed" |
|
return create_success_display(status_message) |
|
else: |
|
return create_error_display("System not ready") |
|
|
|
except Exception as e: |
|
return create_error_display(f"Error getting system status: {str(e)}") |
|
|
|
def process_documents( |
|
self, |
|
files: List[gr.File], |
|
session_id: str, |
|
progress=gr.Progress() |
|
) -> Tuple[str, str, bool, str]: |
|
"""Process uploaded documents.""" |
|
if not files: |
|
return ( |
|
create_error_display("No files uploaded"), |
|
create_error_display("Please select files to upload"), |
|
False, |
|
"No documents uploaded yet." |
|
) |
|
|
|
if not self.rag_system or not self._initialized: |
|
return ( |
|
create_error_display("System not initialized"), |
|
create_error_display("Please restart the application"), |
|
False, |
|
"No documents uploaded yet." |
|
) |
|
|
|
try: |
|
self.logger.info(f"Starting document upload process - {len(files)} files received") |
|
|
|
|
|
allowed_extensions = [".pdf", ".docx", ".txt"] |
|
valid_files, validation_errors = validate_file_types(files, allowed_extensions) |
|
|
|
if validation_errors: |
|
self.logger.warning(f"File validation errors: {validation_errors}") |
|
error_html = create_error_display("\\n".join(validation_errors)) |
|
return error_html, error_html, len(valid_files) > 0, self.get_document_list() |
|
|
|
self.logger.info(f"File validation passed - {len(valid_files)} valid files") |
|
|
|
|
|
progress(0.1, desc="Saving uploaded files...") |
|
self.logger.info("Saving uploaded files to temporary directory...") |
|
saved_files = save_uploaded_files(valid_files) |
|
|
|
for file_path, original_name in saved_files: |
|
file_size = Path(file_path).stat().st_size / (1024 * 1024) |
|
self.logger.info(f"Saved file: {original_name} ({file_size:.2f} MB) -> {file_path}") |
|
|
|
if not saved_files: |
|
return ( |
|
create_error_display("No valid files to process"), |
|
create_error_display("Please check your files and try again"), |
|
False, |
|
self.get_document_list() |
|
) |
|
|
|
|
|
processed_count = 0 |
|
total_files = len(saved_files) |
|
results = [] |
|
timeout_seconds = 600 |
|
|
|
def process_single_file(file_path, original_name, session_id): |
|
"""Process a single file - to be run with timeout.""" |
|
self.logger.info(f"Processing file: {original_name}") |
|
start_time = time.time() |
|
|
|
result = self.rag_system.add_document( |
|
file_path=file_path, |
|
filename=original_name, |
|
user_session=session_id |
|
) |
|
|
|
processing_time = time.time() - start_time |
|
self.logger.info(f"File processing completed: {original_name} (took {processing_time:.2f}s)") |
|
|
|
return result |
|
|
|
self.logger.info(f"Starting processing of {total_files} files with {timeout_seconds//60}-minute timeout per file") |
|
|
|
with ThreadPoolExecutor(max_workers=1) as executor: |
|
for i, (file_path, original_name) in enumerate(saved_files): |
|
progress((i + 1) / total_files * 0.8 + 0.1, desc=f"Processing {original_name}...") |
|
self.logger.info(f"Processing file {i+1}/{total_files}: {original_name}") |
|
|
|
try: |
|
|
|
future = executor.submit(process_single_file, file_path, original_name, session_id) |
|
result = future.result(timeout=timeout_seconds) |
|
|
|
if result.get("success"): |
|
processed_count += 1 |
|
chunks_created = result['data']['chunks_created'] |
|
|
|
|
|
self.logger.info(f"SUCCESS: {original_name} - {chunks_created} chunks created") |
|
|
|
|
|
if 'sample_chunks' in result['data']: |
|
sample_chunks = result['data']['sample_chunks'] |
|
self.logger.info(f"Sample chunks from {original_name}:") |
|
for idx, chunk in enumerate(sample_chunks[:3]): |
|
chunk_preview = chunk['content'][:100] + "..." if len(chunk['content']) > 100 else chunk['content'] |
|
self.logger.info(f" Chunk {idx}: {chunk_preview}") |
|
|
|
results.append(f"β
{original_name}: {chunks_created} chunks created") |
|
else: |
|
error_msg = result.get("error", {}).get("message", "Unknown error") |
|
self.logger.error(f"FAILED: {original_name} - {error_msg}") |
|
results.append(f"β {original_name}: {error_msg}") |
|
|
|
except FutureTimeoutError: |
|
self.logger.error(f"TIMEOUT: {original_name} exceeded {timeout_seconds//60} minute limit") |
|
results.append(f"β° {original_name}: Processing timed out after {timeout_seconds//60} minutes") |
|
future.cancel() |
|
except Exception as e: |
|
self.logger.error(f"EXCEPTION: {original_name} - {str(e)}") |
|
results.append(f"β {original_name}: {str(e)}") |
|
|
|
progress(1.0, desc="Cleaning up...") |
|
self.logger.info("Cleaning up temporary files...") |
|
|
|
|
|
cleanup_temp_files([fp for fp, _ in saved_files]) |
|
|
|
|
|
total_processing_time = time.time() - time.time() |
|
self.logger.info(f"Document upload process completed:") |
|
self.logger.info(f" - Total files: {total_files}") |
|
self.logger.info(f" - Successfully processed: {processed_count}") |
|
self.logger.info(f" - Failed: {total_files - processed_count}") |
|
self.logger.info(f" - Success rate: {(processed_count/total_files*100):.1f}%") |
|
|
|
|
|
if processed_count == total_files: |
|
self.logger.info(f"β
ALL DOCUMENTS PROCESSED SUCCESSFULLY ({processed_count}/{total_files})") |
|
status_html = create_success_display( |
|
f"Successfully processed {processed_count} documents:\\n" + "\\n".join(results) |
|
) |
|
upload_status = create_success_display(f"All {processed_count} documents processed successfully!") |
|
elif processed_count > 0: |
|
self.logger.warning(f"β οΈ PARTIAL SUCCESS ({processed_count}/{total_files} documents processed)") |
|
status_html = f""" |
|
<div style='background: #fef3c7; border: 1px solid #f59e0b; border-radius: 8px; padding: 1rem; margin: 1rem 0;'> |
|
<div style='font-weight: 600; color: #92400e; margin-bottom: 0.5rem;'> |
|
β οΈ Partially successful ({processed_count}/{total_files} files processed) |
|
</div> |
|
<div style='color: #78350f; font-size: 0.9rem;'>{"<br>".join(results)}</div> |
|
</div> |
|
""" |
|
upload_status = status_html |
|
else: |
|
self.logger.error(f"β NO DOCUMENTS PROCESSED SUCCESSFULLY (0/{total_files})") |
|
status_html = create_error_display( |
|
f"Failed to process any documents:\\n" + "\\n".join(results) |
|
) |
|
upload_status = create_error_display("Document processing failed") |
|
|
|
return ( |
|
status_html, |
|
upload_status, |
|
gr.update(interactive=True), |
|
self.get_document_list() |
|
) |
|
|
|
except Exception as e: |
|
|
|
try: |
|
if 'saved_files' in locals(): |
|
cleanup_temp_files([fp for fp, _ in saved_files]) |
|
except: |
|
pass |
|
|
|
error_message = f"Document processing failed: {str(e)}" |
|
error_html = create_error_display(error_message) |
|
return error_html, error_html, gr.update(interactive=False), self.get_document_list() |
|
|
|
def perform_search( |
|
self, |
|
query: str, |
|
search_mode: str, |
|
num_results: int, |
|
enable_reranking: bool, |
|
metadata_filters: str, |
|
session_id: str |
|
) -> Tuple[str, str, str]: |
|
"""Perform search and return results.""" |
|
if not self.rag_system or not self._initialized: |
|
error_html = create_error_display("System not initialized") |
|
return error_html, "{}", "" |
|
|
|
if not query or not query.strip(): |
|
error_html = create_error_display("Please enter a search query") |
|
return error_html, "{}", "" |
|
|
|
try: |
|
|
|
filters = parse_search_filters(metadata_filters) if metadata_filters else None |
|
|
|
|
|
result = self.rag_system.search( |
|
query=query.strip(), |
|
k=num_results, |
|
search_mode=search_mode, |
|
enable_reranking=enable_reranking, |
|
metadata_filter=filters, |
|
user_session=session_id |
|
) |
|
|
|
if not result.get("success"): |
|
error_msg = result.get("error", {}).get("message", "Search failed") |
|
error_html = create_error_display(error_msg) |
|
return error_html, "{}", "" |
|
|
|
|
|
search_data = result["data"] |
|
results = search_data.get("results", []) |
|
search_time = search_data.get("search_time", 0) |
|
|
|
|
|
results_html, stats_html = format_search_results(results, search_time, query) |
|
|
|
|
|
json_data = { |
|
"query": query, |
|
"search_mode": search_mode, |
|
"results_count": len(results), |
|
"search_time": search_time, |
|
"results": results[:5], |
|
"query_suggestions": search_data.get("query_suggestions", []) |
|
} |
|
|
|
return results_html, json.dumps(json_data, indent=2), stats_html |
|
|
|
except Exception as e: |
|
error_html = create_error_display(f"Search failed: {str(e)}") |
|
return error_html, "{}", "" |
|
|
|
def get_document_list(self) -> str: |
|
"""Get formatted document list.""" |
|
if not self.rag_system or not self._initialized: |
|
return "<div style='text-align: center; color: #6b7280; padding: 1rem;'>System not initialized</div>" |
|
|
|
try: |
|
result = self.rag_system.get_document_list() |
|
if result.get("success"): |
|
documents = result["data"]["documents"] |
|
return format_document_list(documents) |
|
else: |
|
return create_error_display("Failed to load document list") |
|
except Exception as e: |
|
return create_error_display(f"Error loading documents: {str(e)}") |
|
|
|
def clear_documents(self) -> Tuple[str, str]: |
|
"""Clear all documents.""" |
|
if not self.rag_system or not self._initialized: |
|
error_html = create_error_display("System not initialized") |
|
return error_html, error_html |
|
|
|
try: |
|
result = self.rag_system.clear_all_documents() |
|
if result.get("success"): |
|
success_msg = f"Cleared {result['data']['documents_removed']} documents" |
|
success_html = create_success_display(success_msg) |
|
return success_html, self.get_document_list() |
|
else: |
|
error_msg = result.get("error", {}).get("message", "Failed to clear documents") |
|
error_html = create_error_display(error_msg) |
|
return error_html, self.get_document_list() |
|
except Exception as e: |
|
error_html = create_error_display(f"Error clearing documents: {str(e)}") |
|
return error_html, self.get_document_list() |
|
|
|
def get_analytics_data(self) -> Tuple[str, gr.Plot, gr.Plot, List[List[str]]]: |
|
"""Get analytics dashboard data.""" |
|
if not self.rag_system or not self._initialized: |
|
return ( |
|
create_error_display("System not initialized"), |
|
gr.Plot(), |
|
gr.Plot(), |
|
[] |
|
) |
|
|
|
try: |
|
result = self.rag_system.get_analytics_dashboard() |
|
if not result.get("success"): |
|
error_html = create_error_display("Failed to load analytics data") |
|
return error_html, gr.Plot(), gr.Plot(), [] |
|
|
|
analytics_data = result["data"] |
|
|
|
|
|
overview_html = format_system_overview(analytics_data) |
|
|
|
|
|
query_chart, modes_chart = create_analytics_charts(analytics_data) |
|
|
|
|
|
activity_data = [] |
|
system_data = analytics_data.get("system", {}) |
|
|
|
activity_data.append([ |
|
"System Started", |
|
"System Initialization", |
|
f"Uptime: {system_data.get('uptime_hours', 0):.1f} hours", |
|
"β
Active" |
|
]) |
|
|
|
if system_data.get("total_queries", 0) > 0: |
|
activity_data.append([ |
|
"Recent", |
|
"Search Queries", |
|
f"{system_data.get('total_queries')} total queries", |
|
"π Active" |
|
]) |
|
|
|
if system_data.get("total_documents_processed", 0) > 0: |
|
activity_data.append([ |
|
"Recent", |
|
"Document Processing", |
|
f"{system_data.get('total_documents_processed')} documents processed", |
|
"π Complete" |
|
]) |
|
|
|
return overview_html, query_chart, modes_chart, activity_data |
|
|
|
except Exception as e: |
|
error_html = create_error_display(f"Error loading analytics: {str(e)}") |
|
return error_html, gr.Plot(), gr.Plot(), [] |
|
|
|
def create_interface(self) -> gr.Blocks: |
|
"""Create the main Gradio interface.""" |
|
theme = get_theme() |
|
css = get_custom_css() |
|
|
|
with gr.Blocks( |
|
theme=theme, |
|
css=css, |
|
title="Professional RAG Assistant", |
|
analytics_enabled=False |
|
) as interface: |
|
|
|
session_id_state = gr.State(value=generate_session_id()) |
|
|
|
|
|
create_header() |
|
|
|
|
|
system_status = create_system_status() |
|
|
|
|
|
with gr.Tabs() as main_tabs: |
|
|
|
with gr.Tab("π Document Upload", id="upload"): |
|
gr.Markdown("Upload your documents to build the knowledge base. Supports PDF, DOCX, and TXT files.") |
|
|
|
file_upload, upload_status, upload_button = create_file_upload_section() |
|
|
|
with gr.Accordion("Upload Settings", open=False): |
|
gr.Markdown(""" |
|
**Supported formats:** PDF, DOCX, TXT |
|
**Maximum file size:** 50MB per file |
|
**Processing:** Documents are split into chunks and indexed for search |
|
""") |
|
|
|
|
|
with gr.Tab("π Search", id="search"): |
|
gr.Markdown("Search your uploaded documents using advanced AI-powered retrieval.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
search_components = create_search_interface() |
|
search_query, search_controls, search_button = search_components[:3] |
|
search_mode, num_results, enable_reranking = search_components[3:] |
|
|
|
with gr.Column(scale=1): |
|
with gr.Accordion("Advanced Options", open=False): |
|
metadata_filters = gr.Textbox( |
|
label="Metadata Filters", |
|
placeholder='{"source": "document.pdf"}', |
|
lines=3, |
|
info="JSON or key:value,key2:value2 format" |
|
) |
|
|
|
|
|
results_html, results_json, search_stats = create_results_display() |
|
|
|
with gr.Accordion("Detailed Results (JSON)", open=False): |
|
results_json |
|
|
|
|
|
with gr.Tab("π Documents", id="documents"): |
|
gr.Markdown("Manage your uploaded documents and view indexing status.") |
|
|
|
document_list, refresh_docs_btn, clear_docs_btn = create_document_management() |
|
|
|
|
|
with gr.Tab("π Analytics", id="analytics"): |
|
gr.Markdown("View system performance and usage analytics.") |
|
|
|
analytics_components = create_analytics_dashboard() |
|
system_overview, query_chart, search_modes_chart, activity_table = analytics_components |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
query_chart |
|
with gr.Column(): |
|
search_modes_chart |
|
|
|
with gr.Accordion("Recent Activity", open=False): |
|
activity_table |
|
|
|
refresh_analytics_btn = gr.Button("Refresh Analytics", variant="secondary") |
|
|
|
|
|
|
|
|
|
file_upload.change( |
|
fn=lambda files: ( |
|
create_success_display(f"β
{len(files)} file(s) selected! Click the green 'π Process Documents' button below to continue.") if files and len(files) > 0 else create_loading_display("No files selected"), |
|
gr.update(interactive=files is not None and len(files) > 0) |
|
), |
|
inputs=[file_upload], |
|
outputs=[upload_status, upload_button], |
|
show_progress=False |
|
) |
|
|
|
upload_button.click( |
|
fn=self.process_documents, |
|
inputs=[file_upload, session_id_state], |
|
outputs=[upload_status, system_status, search_button, document_list], |
|
show_progress=True |
|
) |
|
|
|
|
|
search_query.change( |
|
fn=lambda query: gr.update(interactive=len(query.strip()) > 0 if query else False), |
|
inputs=[search_query], |
|
outputs=[search_button], |
|
show_progress=False |
|
) |
|
|
|
search_button.click( |
|
fn=lambda: create_loading_display("Searching..."), |
|
inputs=[], |
|
outputs=[results_html], |
|
show_progress=False |
|
).then( |
|
fn=self.perform_search, |
|
inputs=[ |
|
search_query, search_mode, num_results, |
|
enable_reranking, metadata_filters, session_id_state |
|
], |
|
outputs=[results_html, results_json, search_stats], |
|
show_progress=True |
|
) |
|
|
|
|
|
refresh_docs_btn.click( |
|
fn=self.get_document_list, |
|
inputs=[], |
|
outputs=[document_list], |
|
show_progress=False |
|
) |
|
|
|
clear_docs_btn.click( |
|
fn=self.clear_documents, |
|
inputs=[], |
|
outputs=[system_status, document_list], |
|
show_progress=True |
|
) |
|
|
|
|
|
refresh_analytics_btn.click( |
|
fn=self.get_analytics_data, |
|
inputs=[], |
|
outputs=[system_overview, query_chart, search_modes_chart, activity_table], |
|
show_progress=True |
|
) |
|
|
|
|
|
interface.load( |
|
fn=lambda: ( |
|
self.get_system_status(), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
self.get_document_list(), |
|
*self.get_analytics_data() |
|
), |
|
inputs=[], |
|
outputs=[ |
|
system_status, upload_button, search_button, document_list, |
|
system_overview, query_chart, search_modes_chart, activity_table |
|
], |
|
show_progress=False |
|
) |
|
|
|
return interface |
|
|
|
|
|
def create_interface(config_path: str = None) -> gr.Blocks: |
|
"""Create and return the RAG interface.""" |
|
rag_interface = RAGInterface(config_path) |
|
return rag_interface.create_interface() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
interface = create_interface() |
|
interface.launch(debug=True) |