Spaces:
Running
Running
""" | |
Main RAG system orchestrator that coordinates all components. | |
""" | |
import os | |
import time | |
import yaml | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import threading | |
from dataclasses import dataclass | |
from .error_handler import ( | |
ErrorHandler, RAGError, DocumentProcessingError, | |
SearchError, ConfigurationError, validate_config, | |
create_success_response, create_error_response | |
) | |
from .document_processor import DocumentProcessor, DocumentChunk | |
from .embedding_manager import EmbeddingManager | |
from .vector_store import VectorStore | |
from .search_engine import HybridSearchEngine, SearchResult | |
from .reranker import RerankingPipeline | |
from .cache_manager import CacheManager | |
from .analytics import AnalyticsManager | |
class RAGSystemStatus: | |
"""Represents the current status of the RAG system.""" | |
initialized: bool = False | |
ready: bool = False | |
models_loaded: bool = False | |
documents_indexed: int = 0 | |
total_chunks: int = 0 | |
last_updated: Optional[float] = None | |
error_message: Optional[str] = None | |
class RAGSystem: | |
"""Main RAG system that orchestrates all components.""" | |
def __init__(self, config_path: Optional[str] = None, config_dict: Optional[Dict[str, Any]] = None): | |
""" | |
Initialize the RAG system. | |
Args: | |
config_path: Path to YAML configuration file | |
config_dict: Dictionary configuration (overrides config_path) | |
""" | |
# Initialize basic logging first | |
self.logger = None | |
try: | |
# Load configuration | |
if config_dict: | |
self.config = config_dict | |
elif config_path: | |
self.config = self._load_config(config_path) | |
else: | |
# Try default config paths | |
for default_path in ["config.yaml", "config-local.yaml"]: | |
if Path(default_path).exists(): | |
self.config = self._load_config(default_path) | |
break | |
else: | |
# Use default configuration if no config file found | |
self.config = self._get_default_config() | |
# Validate configuration | |
validate_config(self.config) | |
# Initialize error handling | |
self.error_handler = ErrorHandler(self.config) | |
self.logger = self.error_handler.logger | |
except Exception as e: | |
# If config loading fails, use basic logging | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
self.logger = logging.getLogger(__name__) | |
self.logger.error(f"Failed to load configuration: {e}") | |
# Use default config | |
self.config = self._get_default_config() | |
self.error_handler = ErrorHandler(self.config) | |
self.logger = self.error_handler.logger | |
# Initialize components | |
self.cache_manager = CacheManager(self.config) | |
self.document_processor = DocumentProcessor(self.config) | |
self.embedding_manager = EmbeddingManager(self.config, self.cache_manager) | |
self.vector_store = VectorStore(self.config) | |
self.search_engine = HybridSearchEngine(self.config, self.vector_store) | |
self.reranking_pipeline = RerankingPipeline(self.config) | |
self.analytics_manager = AnalyticsManager(self.config) | |
# System state | |
self.status = RAGSystemStatus() | |
self._lock = threading.RLock() | |
self._document_index: Dict[str, List[str]] = {} # filename -> chunk_ids | |
# Connect components | |
self.search_engine.set_embedding_manager(self.embedding_manager) | |
self.logger.info("RAG system initialized successfully") | |
self.status.initialized = True | |
def _load_config(self, config_path: str) -> Dict[str, Any]: | |
"""Load configuration from YAML file.""" | |
config_path = Path(config_path) | |
if not config_path.exists(): | |
raise ConfigurationError(f"Configuration file not found: {config_path}") | |
try: | |
with open(config_path, 'r') as f: | |
config = yaml.safe_load(f) | |
self.logger.info(f"Configuration loaded from {config_path}") | |
return config | |
except yaml.YAMLError as e: | |
raise ConfigurationError(f"Failed to parse YAML configuration: {str(e)}") from e | |
except Exception as e: | |
raise ConfigurationError(f"Failed to load configuration: {str(e)}") from e | |
def _get_default_config(self) -> Dict[str, Any]: | |
"""Get default configuration when no config file is found.""" | |
return { | |
"app": { | |
"name": "Professional RAG Document Assistant", | |
"version": "1.0.0", | |
"debug": False, | |
"max_upload_size": 50, | |
"max_concurrent_uploads": 3 | |
}, | |
"models": { | |
"embedding": { | |
"name": "sentence-transformers/all-MiniLM-L6-v2", | |
"max_seq_length": 256, | |
"batch_size": 32, | |
"device": "auto" | |
}, | |
"reranker": { | |
"name": "cross-encoder/ms-marco-MiniLM-L-6-v2", | |
"max_seq_length": 512, | |
"batch_size": 16, | |
"enabled": True | |
} | |
}, | |
"processing": { | |
"chunk_size": 512, | |
"chunk_overlap": 50, | |
"min_chunk_size": 100, | |
"max_chunks_per_doc": 1000, | |
"supported_formats": ["pdf", "docx", "txt"] | |
}, | |
"search": { | |
"default_k": 10, | |
"max_k": 20, | |
"vector_weight": 0.7, | |
"bm25_weight": 0.3, | |
"rerank_top_k": 50, | |
"final_top_k": 10 | |
}, | |
"cache": { | |
"embedding_cache_size": 10000, | |
"query_cache_size": 1000, | |
"cache_ttl": 3600, | |
"enable_disk_cache": True, | |
"cache_dir": "./cache" | |
}, | |
"ui": { | |
"theme": "soft", | |
"title": "Professional RAG Assistant", | |
"description": "Upload documents and ask questions with AI-powered retrieval", | |
"max_file_size": "50MB", | |
"allowed_extensions": [".pdf", ".docx", ".txt"], | |
"show_progress": True, | |
"show_analytics": True | |
}, | |
"logging": { | |
"level": "INFO", | |
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
"file": "logs/rag_system.log", | |
"max_size": "10MB", | |
"backup_count": 5 | |
} | |
} | |
def is_ready(self) -> bool: | |
"""Check if system is ready for operations.""" | |
return self.status.ready and self.status.initialized | |
def warmup(self) -> Dict[str, Any]: | |
"""Warm up the system by loading models and initializing components.""" | |
try: | |
self.logger.info("Starting system warmup...") | |
start_time = time.time() | |
# Warm up embedding manager | |
self.embedding_manager.warmup() | |
# Warm up re-ranker if enabled | |
self.reranking_pipeline.warmup() | |
# Update status | |
self.status.models_loaded = True | |
self.status.ready = True | |
self.status.last_updated = time.time() | |
warmup_time = time.time() - start_time | |
self.logger.info(f"System warmup completed in {warmup_time:.2f}s") | |
return create_success_response({ | |
"warmup_time": warmup_time, | |
"models_loaded": True, | |
"system_ready": True | |
}) | |
except Exception as e: | |
error_msg = self.error_handler.log_error(e, {"operation": "warmup"}) | |
self.status.error_message = error_msg | |
return create_error_response(RAGError(error_msg)) | |
def add_document( | |
self, | |
file_path: str, | |
filename: Optional[str] = None, | |
user_session: str = None, | |
progress_callback: Optional[callable] = None | |
) -> Dict[str, Any]: | |
""" | |
Add a document to the RAG system. | |
Args: | |
file_path: Path to the document file | |
filename: Optional original filename | |
user_session: Optional user session ID | |
progress_callback: Optional callback for progress updates | |
Returns: | |
Response dictionary with operation results | |
""" | |
start_time = time.time() | |
processing_successful = False | |
chunk_count = 0 | |
error_message = None | |
try: | |
with self._lock: | |
filename = filename or Path(file_path).name | |
if progress_callback: | |
progress_callback("Processing document...", 0.1) | |
# Process document | |
self.logger.info(f"Processing document: {filename}") | |
chunks = self.document_processor.process_document(file_path, filename) | |
chunk_count = len(chunks) | |
if progress_callback: | |
progress_callback("Generating embeddings...", 0.3) | |
# Generate embeddings | |
texts = [chunk.content for chunk in chunks] | |
embeddings = self.embedding_manager.generate_embeddings(texts) | |
if progress_callback: | |
progress_callback("Adding to vector store...", 0.7) | |
# Add to vector store | |
chunk_ids = self.vector_store.add_documents(chunks, embeddings) | |
if progress_callback: | |
progress_callback("Building search index...", 0.9) | |
# Update search index | |
all_chunks = [] | |
for chunk_id in chunk_ids: | |
chunk_data = self.vector_store.get_by_id(chunk_id) | |
if chunk_data: | |
_, metadata = chunk_data | |
chunk = DocumentChunk( | |
content=metadata.get("content", ""), | |
metadata=metadata, | |
chunk_id=chunk_id | |
) | |
all_chunks.append(chunk) | |
# Rebuild BM25 index with all documents | |
all_stored_chunks = [] | |
for stored_chunk_id in self.vector_store._id_to_index.keys(): | |
stored_data = self.vector_store.get_by_id(stored_chunk_id) | |
if stored_data: | |
_, stored_metadata = stored_data | |
stored_chunk = DocumentChunk( | |
content=stored_metadata.get("content", ""), | |
metadata=stored_metadata, | |
chunk_id=stored_chunk_id | |
) | |
all_stored_chunks.append(stored_chunk) | |
self.search_engine.build_bm25_index(all_stored_chunks) | |
# Update document index | |
self._document_index[filename] = chunk_ids | |
# Update system status | |
self.status.documents_indexed = len(self._document_index) | |
self.status.total_chunks = len(self.vector_store._vectors) | |
self.status.last_updated = time.time() | |
processing_time = time.time() - start_time | |
processing_successful = True | |
if progress_callback: | |
progress_callback("Document processing completed!", 1.0) | |
# Get document stats | |
doc_stats = self.document_processor.get_document_stats(chunks) | |
# Create sample chunk data for logging | |
sample_chunks = [] | |
for i, chunk in enumerate(chunks[:5]): # First 5 chunks as samples | |
sample_chunks.append({ | |
"chunk_index": i, | |
"chunk_id": chunk.chunk_id, | |
"content": chunk.content, | |
"metadata": chunk.metadata, | |
"content_length": len(chunk.content) | |
}) | |
self.logger.info( | |
f"Document processed successfully: {filename} " | |
f"({chunk_count} chunks, {processing_time:.2f}s)" | |
) | |
# Log sample chunks | |
self.logger.info(f"Sample chunks from {filename}:") | |
for i, chunk in enumerate(chunks[:3]): # Log first 3 chunks | |
chunk_preview = chunk.content[:150] + "..." if len(chunk.content) > 150 else chunk.content | |
self.logger.info(f" Chunk {i} (ID: {chunk.chunk_id}): {chunk_preview}") | |
if chunk.metadata.get('page'): | |
self.logger.info(f" - From page {chunk.metadata['page']}") | |
# Track analytics | |
file_stats = Path(file_path).stat() | |
self.analytics_manager.track_document_processing( | |
filename=filename, | |
file_size=file_stats.st_size, | |
file_type=Path(filename).suffix.lower(), | |
processing_time=processing_time, | |
chunk_count=chunk_count, | |
success=True, | |
user_session=user_session | |
) | |
return create_success_response({ | |
"filename": filename, | |
"chunks_created": chunk_count, | |
"processing_time": processing_time, | |
"document_stats": doc_stats, | |
"total_documents": self.status.documents_indexed, | |
"total_chunks": self.status.total_chunks, | |
"sample_chunks": sample_chunks # Include sample chunks for detailed logging | |
}) | |
except Exception as e: | |
error_message = self.error_handler.log_error(e, { | |
"operation": "add_document", | |
"filename": filename, | |
"file_path": file_path | |
}) | |
processing_time = time.time() - start_time | |
# Track failed processing | |
try: | |
file_stats = Path(file_path).stat() | |
self.analytics_manager.track_document_processing( | |
filename=filename or "unknown", | |
file_size=file_stats.st_size, | |
file_type=Path(filename or file_path).suffix.lower(), | |
processing_time=processing_time, | |
chunk_count=0, | |
success=False, | |
error_message=str(e), | |
user_session=user_session | |
) | |
except Exception: | |
pass # Don't fail on analytics tracking | |
return create_error_response(RAGError(error_message)) | |
def search( | |
self, | |
query: str, | |
k: int = None, | |
search_mode: str = "hybrid", | |
enable_reranking: bool = True, | |
metadata_filter: Optional[Dict[str, Any]] = None, | |
user_session: str = None | |
) -> Dict[str, Any]: | |
""" | |
Search the document collection. | |
Args: | |
query: Search query | |
k: Number of results to return | |
search_mode: Search mode ("vector", "bm25", "hybrid") | |
enable_reranking: Whether to apply re-ranking | |
metadata_filter: Optional metadata filter | |
user_session: Optional user session ID | |
Returns: | |
Response dictionary with search results | |
""" | |
start_time = time.time() | |
try: | |
if not self.is_ready: | |
raise SearchError("System not ready. Please run warmup first.") | |
if not query or not query.strip(): | |
raise SearchError("Query cannot be empty") | |
query = query.strip() | |
k = k or self.config.get("search", {}).get("default_k", 10) | |
self.logger.info(f"Searching: '{query}' (mode: {search_mode}, k: {k})") | |
# Perform search | |
search_results = self.search_engine.search( | |
query=query, | |
k=k * 2, # Get more results for re-ranking | |
search_mode=search_mode, | |
metadata_filter=metadata_filter | |
) | |
# Apply re-ranking | |
final_results = self.reranking_pipeline.process( | |
query=query, | |
results=search_results, | |
apply_reranking=enable_reranking | |
) | |
# Limit to requested number of results | |
final_results = final_results[:k] | |
search_time = time.time() - start_time | |
# Convert results to serializable format | |
results_data = [result.to_dict() for result in final_results] | |
# Get query suggestions if results are available | |
suggestions = [] | |
if final_results: | |
suggestions = self.search_engine.suggest_query_expansion(query, final_results[:3]) | |
self.logger.info(f"Search completed: {len(final_results)} results in {search_time:.2f}s") | |
# Track analytics | |
self.analytics_manager.track_query( | |
query=query, | |
search_mode=search_mode, | |
results_count=len(final_results), | |
search_time=search_time, | |
user_session=user_session, | |
metadata_filters=metadata_filter | |
) | |
return create_success_response({ | |
"query": query, | |
"results": results_data, | |
"total_results": len(final_results), | |
"search_time": search_time, | |
"search_mode": search_mode, | |
"reranking_applied": enable_reranking, | |
"query_suggestions": suggestions | |
}) | |
except Exception as e: | |
error_message = self.error_handler.log_error(e, { | |
"operation": "search", | |
"query": query, | |
"search_mode": search_mode, | |
"k": k | |
}) | |
return create_error_response(RAGError(error_message)) | |
def get_document_list(self) -> Dict[str, Any]: | |
"""Get list of indexed documents.""" | |
try: | |
with self._lock: | |
documents = [] | |
for filename, chunk_ids in self._document_index.items(): | |
if chunk_ids: | |
# Get metadata from first chunk | |
first_chunk_data = self.vector_store.get_by_id(chunk_ids[0]) | |
if first_chunk_data: | |
_, metadata = first_chunk_data | |
documents.append({ | |
"filename": filename, | |
"chunk_count": len(chunk_ids), | |
"file_type": metadata.get("file_type", "unknown"), | |
"file_size": metadata.get("file_size", 0), | |
"source": metadata.get("source", ""), | |
"indexed_at": metadata.get("timestamp") | |
}) | |
return create_success_response({ | |
"documents": documents, | |
"total_documents": len(documents), | |
"total_chunks": self.status.total_chunks | |
}) | |
except Exception as e: | |
error_message = self.error_handler.log_error(e, {"operation": "get_document_list"}) | |
return create_error_response(RAGError(error_message)) | |
def remove_document(self, filename: str) -> Dict[str, Any]: | |
"""Remove a document from the index.""" | |
try: | |
with self._lock: | |
if filename not in self._document_index: | |
raise DocumentProcessingError(f"Document not found: {filename}") | |
chunk_ids = self._document_index[filename] | |
# Remove chunks from vector store | |
removed_count = 0 | |
for chunk_id in chunk_ids: | |
if self.vector_store.delete_by_id(chunk_id): | |
removed_count += 1 | |
# Remove from document index | |
del self._document_index[filename] | |
# Rebuild BM25 index | |
all_chunks = [] | |
for remaining_chunk_id in self.vector_store._id_to_index.keys(): | |
chunk_data = self.vector_store.get_by_id(remaining_chunk_id) | |
if chunk_data: | |
_, metadata = chunk_data | |
chunk = DocumentChunk( | |
content=metadata.get("content", ""), | |
metadata=metadata, | |
chunk_id=remaining_chunk_id | |
) | |
all_chunks.append(chunk) | |
self.search_engine.build_bm25_index(all_chunks) | |
# Update status | |
self.status.documents_indexed = len(self._document_index) | |
self.status.total_chunks = len(self.vector_store._vectors) | |
self.status.last_updated = time.time() | |
self.logger.info(f"Document removed: {filename} ({removed_count} chunks)") | |
return create_success_response({ | |
"filename": filename, | |
"chunks_removed": removed_count, | |
"total_documents": self.status.documents_indexed, | |
"total_chunks": self.status.total_chunks | |
}) | |
except Exception as e: | |
error_message = self.error_handler.log_error(e, { | |
"operation": "remove_document", | |
"filename": filename | |
}) | |
return create_error_response(RAGError(error_message)) | |
def clear_all_documents(self) -> Dict[str, Any]: | |
"""Clear all documents from the index.""" | |
try: | |
with self._lock: | |
# Clear vector store | |
self.vector_store.clear() | |
# Clear search index | |
self.search_engine.build_bm25_index([]) | |
# Clear document index | |
total_docs = len(self._document_index) | |
self._document_index.clear() | |
# Update status | |
self.status.documents_indexed = 0 | |
self.status.total_chunks = 0 | |
self.status.last_updated = time.time() | |
self.logger.info(f"All documents cleared ({total_docs} documents)") | |
return create_success_response({ | |
"documents_removed": total_docs, | |
"total_documents": 0, | |
"total_chunks": 0 | |
}) | |
except Exception as e: | |
error_message = self.error_handler.log_error(e, {"operation": "clear_all_documents"}) | |
return create_error_response(RAGError(error_message)) | |
def get_system_stats(self) -> Dict[str, Any]: | |
"""Get comprehensive system statistics.""" | |
try: | |
stats = { | |
"status": { | |
"initialized": self.status.initialized, | |
"ready": self.status.ready, | |
"models_loaded": self.status.models_loaded, | |
"documents_indexed": self.status.documents_indexed, | |
"total_chunks": self.status.total_chunks, | |
"last_updated": self.status.last_updated, | |
"error_message": self.status.error_message | |
}, | |
"embedding_manager": self.embedding_manager.get_stats(), | |
"vector_store": self.vector_store.get_stats(), | |
"search_engine": self.search_engine.get_stats(), | |
"reranking_pipeline": self.reranking_pipeline.get_stats(), | |
"cache_manager": self.cache_manager.get_stats(), | |
"analytics": self.analytics_manager.get_system_analytics() | |
} | |
return create_success_response(stats) | |
except Exception as e: | |
error_message = self.error_handler.log_error(e, {"operation": "get_system_stats"}) | |
return create_error_response(RAGError(error_message)) | |
def get_analytics_dashboard(self) -> Dict[str, Any]: | |
"""Get analytics dashboard data.""" | |
try: | |
dashboard_data = self.analytics_manager.get_dashboard_data() | |
return create_success_response(dashboard_data) | |
except Exception as e: | |
error_message = self.error_handler.log_error(e, {"operation": "get_analytics_dashboard"}) | |
return create_error_response(RAGError(error_message)) | |
def optimize_system(self) -> Dict[str, Any]: | |
"""Optimize system performance.""" | |
try: | |
self.logger.info("Starting system optimization...") | |
start_time = time.time() | |
optimization_results = {} | |
# Optimize cache | |
cache_optimization = self.cache_manager.optimize() | |
optimization_results["cache"] = cache_optimization | |
# Optimize vector store | |
vector_optimization = self.vector_store.optimize() | |
optimization_results["vector_store"] = vector_optimization | |
# Optimize search engine | |
search_optimization = self.search_engine.optimize_index() | |
optimization_results["search_engine"] = search_optimization | |
optimization_time = time.time() - start_time | |
self.logger.info(f"System optimization completed in {optimization_time:.2f}s") | |
return create_success_response({ | |
"optimization_time": optimization_time, | |
"components_optimized": optimization_results | |
}) | |
except Exception as e: | |
error_message = self.error_handler.log_error(e, {"operation": "optimize_system"}) | |
return create_error_response(RAGError(error_message)) | |
def save_state(self, filepath: Optional[str] = None) -> Dict[str, Any]: | |
"""Save system state to disk.""" | |
try: | |
saved_files = [] | |
# Save vector store | |
vector_store_path = self.vector_store.save_to_disk(filepath) | |
saved_files.append(vector_store_path) | |
# Export analytics | |
analytics_path = self.analytics_manager.export_data() | |
saved_files.append(analytics_path) | |
self.logger.info(f"System state saved to {len(saved_files)} files") | |
return create_success_response({ | |
"saved_files": saved_files, | |
"total_files": len(saved_files) | |
}) | |
except Exception as e: | |
error_message = self.error_handler.log_error(e, {"operation": "save_state"}) | |
return create_error_response(RAGError(error_message)) | |
def shutdown(self) -> None: | |
"""Shutdown the RAG system gracefully.""" | |
try: | |
self.logger.info("Shutting down RAG system...") | |
# Save analytics data | |
self.analytics_manager.shutdown() | |
# Unload models to free memory | |
self.embedding_manager.unload_model() | |
self.reranking_pipeline.unload_models() | |
# Clear status | |
self.status.ready = False | |
self.status.models_loaded = False | |
self.logger.info("RAG system shutdown completed") | |
except Exception as e: | |
self.logger.error(f"Error during shutdown: {e}") | |
def __enter__(self): | |
"""Context manager entry.""" | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
"""Context manager exit.""" | |
self.shutdown() |