RAG / src /rag_system.py
Jialun He
add log
9fb62ac
"""
Main RAG system orchestrator that coordinates all components.
"""
import os
import time
import yaml
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import threading
from dataclasses import dataclass
from .error_handler import (
ErrorHandler, RAGError, DocumentProcessingError,
SearchError, ConfigurationError, validate_config,
create_success_response, create_error_response
)
from .document_processor import DocumentProcessor, DocumentChunk
from .embedding_manager import EmbeddingManager
from .vector_store import VectorStore
from .search_engine import HybridSearchEngine, SearchResult
from .reranker import RerankingPipeline
from .cache_manager import CacheManager
from .analytics import AnalyticsManager
@dataclass
class RAGSystemStatus:
"""Represents the current status of the RAG system."""
initialized: bool = False
ready: bool = False
models_loaded: bool = False
documents_indexed: int = 0
total_chunks: int = 0
last_updated: Optional[float] = None
error_message: Optional[str] = None
class RAGSystem:
"""Main RAG system that orchestrates all components."""
def __init__(self, config_path: Optional[str] = None, config_dict: Optional[Dict[str, Any]] = None):
"""
Initialize the RAG system.
Args:
config_path: Path to YAML configuration file
config_dict: Dictionary configuration (overrides config_path)
"""
# Initialize basic logging first
self.logger = None
try:
# Load configuration
if config_dict:
self.config = config_dict
elif config_path:
self.config = self._load_config(config_path)
else:
# Try default config paths
for default_path in ["config.yaml", "config-local.yaml"]:
if Path(default_path).exists():
self.config = self._load_config(default_path)
break
else:
# Use default configuration if no config file found
self.config = self._get_default_config()
# Validate configuration
validate_config(self.config)
# Initialize error handling
self.error_handler = ErrorHandler(self.config)
self.logger = self.error_handler.logger
except Exception as e:
# If config loading fails, use basic logging
import logging
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
self.logger.error(f"Failed to load configuration: {e}")
# Use default config
self.config = self._get_default_config()
self.error_handler = ErrorHandler(self.config)
self.logger = self.error_handler.logger
# Initialize components
self.cache_manager = CacheManager(self.config)
self.document_processor = DocumentProcessor(self.config)
self.embedding_manager = EmbeddingManager(self.config, self.cache_manager)
self.vector_store = VectorStore(self.config)
self.search_engine = HybridSearchEngine(self.config, self.vector_store)
self.reranking_pipeline = RerankingPipeline(self.config)
self.analytics_manager = AnalyticsManager(self.config)
# System state
self.status = RAGSystemStatus()
self._lock = threading.RLock()
self._document_index: Dict[str, List[str]] = {} # filename -> chunk_ids
# Connect components
self.search_engine.set_embedding_manager(self.embedding_manager)
self.logger.info("RAG system initialized successfully")
self.status.initialized = True
def _load_config(self, config_path: str) -> Dict[str, Any]:
"""Load configuration from YAML file."""
config_path = Path(config_path)
if not config_path.exists():
raise ConfigurationError(f"Configuration file not found: {config_path}")
try:
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
self.logger.info(f"Configuration loaded from {config_path}")
return config
except yaml.YAMLError as e:
raise ConfigurationError(f"Failed to parse YAML configuration: {str(e)}") from e
except Exception as e:
raise ConfigurationError(f"Failed to load configuration: {str(e)}") from e
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration when no config file is found."""
return {
"app": {
"name": "Professional RAG Document Assistant",
"version": "1.0.0",
"debug": False,
"max_upload_size": 50,
"max_concurrent_uploads": 3
},
"models": {
"embedding": {
"name": "sentence-transformers/all-MiniLM-L6-v2",
"max_seq_length": 256,
"batch_size": 32,
"device": "auto"
},
"reranker": {
"name": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"max_seq_length": 512,
"batch_size": 16,
"enabled": True
}
},
"processing": {
"chunk_size": 512,
"chunk_overlap": 50,
"min_chunk_size": 100,
"max_chunks_per_doc": 1000,
"supported_formats": ["pdf", "docx", "txt"]
},
"search": {
"default_k": 10,
"max_k": 20,
"vector_weight": 0.7,
"bm25_weight": 0.3,
"rerank_top_k": 50,
"final_top_k": 10
},
"cache": {
"embedding_cache_size": 10000,
"query_cache_size": 1000,
"cache_ttl": 3600,
"enable_disk_cache": True,
"cache_dir": "./cache"
},
"ui": {
"theme": "soft",
"title": "Professional RAG Assistant",
"description": "Upload documents and ask questions with AI-powered retrieval",
"max_file_size": "50MB",
"allowed_extensions": [".pdf", ".docx", ".txt"],
"show_progress": True,
"show_analytics": True
},
"logging": {
"level": "INFO",
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
"file": "logs/rag_system.log",
"max_size": "10MB",
"backup_count": 5
}
}
@property
def is_ready(self) -> bool:
"""Check if system is ready for operations."""
return self.status.ready and self.status.initialized
def warmup(self) -> Dict[str, Any]:
"""Warm up the system by loading models and initializing components."""
try:
self.logger.info("Starting system warmup...")
start_time = time.time()
# Warm up embedding manager
self.embedding_manager.warmup()
# Warm up re-ranker if enabled
self.reranking_pipeline.warmup()
# Update status
self.status.models_loaded = True
self.status.ready = True
self.status.last_updated = time.time()
warmup_time = time.time() - start_time
self.logger.info(f"System warmup completed in {warmup_time:.2f}s")
return create_success_response({
"warmup_time": warmup_time,
"models_loaded": True,
"system_ready": True
})
except Exception as e:
error_msg = self.error_handler.log_error(e, {"operation": "warmup"})
self.status.error_message = error_msg
return create_error_response(RAGError(error_msg))
def add_document(
self,
file_path: str,
filename: Optional[str] = None,
user_session: str = None,
progress_callback: Optional[callable] = None
) -> Dict[str, Any]:
"""
Add a document to the RAG system.
Args:
file_path: Path to the document file
filename: Optional original filename
user_session: Optional user session ID
progress_callback: Optional callback for progress updates
Returns:
Response dictionary with operation results
"""
start_time = time.time()
processing_successful = False
chunk_count = 0
error_message = None
try:
with self._lock:
filename = filename or Path(file_path).name
if progress_callback:
progress_callback("Processing document...", 0.1)
# Process document
self.logger.info(f"Processing document: {filename}")
chunks = self.document_processor.process_document(file_path, filename)
chunk_count = len(chunks)
if progress_callback:
progress_callback("Generating embeddings...", 0.3)
# Generate embeddings
texts = [chunk.content for chunk in chunks]
embeddings = self.embedding_manager.generate_embeddings(texts)
if progress_callback:
progress_callback("Adding to vector store...", 0.7)
# Add to vector store
chunk_ids = self.vector_store.add_documents(chunks, embeddings)
if progress_callback:
progress_callback("Building search index...", 0.9)
# Update search index
all_chunks = []
for chunk_id in chunk_ids:
chunk_data = self.vector_store.get_by_id(chunk_id)
if chunk_data:
_, metadata = chunk_data
chunk = DocumentChunk(
content=metadata.get("content", ""),
metadata=metadata,
chunk_id=chunk_id
)
all_chunks.append(chunk)
# Rebuild BM25 index with all documents
all_stored_chunks = []
for stored_chunk_id in self.vector_store._id_to_index.keys():
stored_data = self.vector_store.get_by_id(stored_chunk_id)
if stored_data:
_, stored_metadata = stored_data
stored_chunk = DocumentChunk(
content=stored_metadata.get("content", ""),
metadata=stored_metadata,
chunk_id=stored_chunk_id
)
all_stored_chunks.append(stored_chunk)
self.search_engine.build_bm25_index(all_stored_chunks)
# Update document index
self._document_index[filename] = chunk_ids
# Update system status
self.status.documents_indexed = len(self._document_index)
self.status.total_chunks = len(self.vector_store._vectors)
self.status.last_updated = time.time()
processing_time = time.time() - start_time
processing_successful = True
if progress_callback:
progress_callback("Document processing completed!", 1.0)
# Get document stats
doc_stats = self.document_processor.get_document_stats(chunks)
# Create sample chunk data for logging
sample_chunks = []
for i, chunk in enumerate(chunks[:5]): # First 5 chunks as samples
sample_chunks.append({
"chunk_index": i,
"chunk_id": chunk.chunk_id,
"content": chunk.content,
"metadata": chunk.metadata,
"content_length": len(chunk.content)
})
self.logger.info(
f"Document processed successfully: {filename} "
f"({chunk_count} chunks, {processing_time:.2f}s)"
)
# Log sample chunks
self.logger.info(f"Sample chunks from {filename}:")
for i, chunk in enumerate(chunks[:3]): # Log first 3 chunks
chunk_preview = chunk.content[:150] + "..." if len(chunk.content) > 150 else chunk.content
self.logger.info(f" Chunk {i} (ID: {chunk.chunk_id}): {chunk_preview}")
if chunk.metadata.get('page'):
self.logger.info(f" - From page {chunk.metadata['page']}")
# Track analytics
file_stats = Path(file_path).stat()
self.analytics_manager.track_document_processing(
filename=filename,
file_size=file_stats.st_size,
file_type=Path(filename).suffix.lower(),
processing_time=processing_time,
chunk_count=chunk_count,
success=True,
user_session=user_session
)
return create_success_response({
"filename": filename,
"chunks_created": chunk_count,
"processing_time": processing_time,
"document_stats": doc_stats,
"total_documents": self.status.documents_indexed,
"total_chunks": self.status.total_chunks,
"sample_chunks": sample_chunks # Include sample chunks for detailed logging
})
except Exception as e:
error_message = self.error_handler.log_error(e, {
"operation": "add_document",
"filename": filename,
"file_path": file_path
})
processing_time = time.time() - start_time
# Track failed processing
try:
file_stats = Path(file_path).stat()
self.analytics_manager.track_document_processing(
filename=filename or "unknown",
file_size=file_stats.st_size,
file_type=Path(filename or file_path).suffix.lower(),
processing_time=processing_time,
chunk_count=0,
success=False,
error_message=str(e),
user_session=user_session
)
except Exception:
pass # Don't fail on analytics tracking
return create_error_response(RAGError(error_message))
def search(
self,
query: str,
k: int = None,
search_mode: str = "hybrid",
enable_reranking: bool = True,
metadata_filter: Optional[Dict[str, Any]] = None,
user_session: str = None
) -> Dict[str, Any]:
"""
Search the document collection.
Args:
query: Search query
k: Number of results to return
search_mode: Search mode ("vector", "bm25", "hybrid")
enable_reranking: Whether to apply re-ranking
metadata_filter: Optional metadata filter
user_session: Optional user session ID
Returns:
Response dictionary with search results
"""
start_time = time.time()
try:
if not self.is_ready:
raise SearchError("System not ready. Please run warmup first.")
if not query or not query.strip():
raise SearchError("Query cannot be empty")
query = query.strip()
k = k or self.config.get("search", {}).get("default_k", 10)
self.logger.info(f"Searching: '{query}' (mode: {search_mode}, k: {k})")
# Perform search
search_results = self.search_engine.search(
query=query,
k=k * 2, # Get more results for re-ranking
search_mode=search_mode,
metadata_filter=metadata_filter
)
# Apply re-ranking
final_results = self.reranking_pipeline.process(
query=query,
results=search_results,
apply_reranking=enable_reranking
)
# Limit to requested number of results
final_results = final_results[:k]
search_time = time.time() - start_time
# Convert results to serializable format
results_data = [result.to_dict() for result in final_results]
# Get query suggestions if results are available
suggestions = []
if final_results:
suggestions = self.search_engine.suggest_query_expansion(query, final_results[:3])
self.logger.info(f"Search completed: {len(final_results)} results in {search_time:.2f}s")
# Track analytics
self.analytics_manager.track_query(
query=query,
search_mode=search_mode,
results_count=len(final_results),
search_time=search_time,
user_session=user_session,
metadata_filters=metadata_filter
)
return create_success_response({
"query": query,
"results": results_data,
"total_results": len(final_results),
"search_time": search_time,
"search_mode": search_mode,
"reranking_applied": enable_reranking,
"query_suggestions": suggestions
})
except Exception as e:
error_message = self.error_handler.log_error(e, {
"operation": "search",
"query": query,
"search_mode": search_mode,
"k": k
})
return create_error_response(RAGError(error_message))
def get_document_list(self) -> Dict[str, Any]:
"""Get list of indexed documents."""
try:
with self._lock:
documents = []
for filename, chunk_ids in self._document_index.items():
if chunk_ids:
# Get metadata from first chunk
first_chunk_data = self.vector_store.get_by_id(chunk_ids[0])
if first_chunk_data:
_, metadata = first_chunk_data
documents.append({
"filename": filename,
"chunk_count": len(chunk_ids),
"file_type": metadata.get("file_type", "unknown"),
"file_size": metadata.get("file_size", 0),
"source": metadata.get("source", ""),
"indexed_at": metadata.get("timestamp")
})
return create_success_response({
"documents": documents,
"total_documents": len(documents),
"total_chunks": self.status.total_chunks
})
except Exception as e:
error_message = self.error_handler.log_error(e, {"operation": "get_document_list"})
return create_error_response(RAGError(error_message))
def remove_document(self, filename: str) -> Dict[str, Any]:
"""Remove a document from the index."""
try:
with self._lock:
if filename not in self._document_index:
raise DocumentProcessingError(f"Document not found: {filename}")
chunk_ids = self._document_index[filename]
# Remove chunks from vector store
removed_count = 0
for chunk_id in chunk_ids:
if self.vector_store.delete_by_id(chunk_id):
removed_count += 1
# Remove from document index
del self._document_index[filename]
# Rebuild BM25 index
all_chunks = []
for remaining_chunk_id in self.vector_store._id_to_index.keys():
chunk_data = self.vector_store.get_by_id(remaining_chunk_id)
if chunk_data:
_, metadata = chunk_data
chunk = DocumentChunk(
content=metadata.get("content", ""),
metadata=metadata,
chunk_id=remaining_chunk_id
)
all_chunks.append(chunk)
self.search_engine.build_bm25_index(all_chunks)
# Update status
self.status.documents_indexed = len(self._document_index)
self.status.total_chunks = len(self.vector_store._vectors)
self.status.last_updated = time.time()
self.logger.info(f"Document removed: {filename} ({removed_count} chunks)")
return create_success_response({
"filename": filename,
"chunks_removed": removed_count,
"total_documents": self.status.documents_indexed,
"total_chunks": self.status.total_chunks
})
except Exception as e:
error_message = self.error_handler.log_error(e, {
"operation": "remove_document",
"filename": filename
})
return create_error_response(RAGError(error_message))
def clear_all_documents(self) -> Dict[str, Any]:
"""Clear all documents from the index."""
try:
with self._lock:
# Clear vector store
self.vector_store.clear()
# Clear search index
self.search_engine.build_bm25_index([])
# Clear document index
total_docs = len(self._document_index)
self._document_index.clear()
# Update status
self.status.documents_indexed = 0
self.status.total_chunks = 0
self.status.last_updated = time.time()
self.logger.info(f"All documents cleared ({total_docs} documents)")
return create_success_response({
"documents_removed": total_docs,
"total_documents": 0,
"total_chunks": 0
})
except Exception as e:
error_message = self.error_handler.log_error(e, {"operation": "clear_all_documents"})
return create_error_response(RAGError(error_message))
def get_system_stats(self) -> Dict[str, Any]:
"""Get comprehensive system statistics."""
try:
stats = {
"status": {
"initialized": self.status.initialized,
"ready": self.status.ready,
"models_loaded": self.status.models_loaded,
"documents_indexed": self.status.documents_indexed,
"total_chunks": self.status.total_chunks,
"last_updated": self.status.last_updated,
"error_message": self.status.error_message
},
"embedding_manager": self.embedding_manager.get_stats(),
"vector_store": self.vector_store.get_stats(),
"search_engine": self.search_engine.get_stats(),
"reranking_pipeline": self.reranking_pipeline.get_stats(),
"cache_manager": self.cache_manager.get_stats(),
"analytics": self.analytics_manager.get_system_analytics()
}
return create_success_response(stats)
except Exception as e:
error_message = self.error_handler.log_error(e, {"operation": "get_system_stats"})
return create_error_response(RAGError(error_message))
def get_analytics_dashboard(self) -> Dict[str, Any]:
"""Get analytics dashboard data."""
try:
dashboard_data = self.analytics_manager.get_dashboard_data()
return create_success_response(dashboard_data)
except Exception as e:
error_message = self.error_handler.log_error(e, {"operation": "get_analytics_dashboard"})
return create_error_response(RAGError(error_message))
def optimize_system(self) -> Dict[str, Any]:
"""Optimize system performance."""
try:
self.logger.info("Starting system optimization...")
start_time = time.time()
optimization_results = {}
# Optimize cache
cache_optimization = self.cache_manager.optimize()
optimization_results["cache"] = cache_optimization
# Optimize vector store
vector_optimization = self.vector_store.optimize()
optimization_results["vector_store"] = vector_optimization
# Optimize search engine
search_optimization = self.search_engine.optimize_index()
optimization_results["search_engine"] = search_optimization
optimization_time = time.time() - start_time
self.logger.info(f"System optimization completed in {optimization_time:.2f}s")
return create_success_response({
"optimization_time": optimization_time,
"components_optimized": optimization_results
})
except Exception as e:
error_message = self.error_handler.log_error(e, {"operation": "optimize_system"})
return create_error_response(RAGError(error_message))
def save_state(self, filepath: Optional[str] = None) -> Dict[str, Any]:
"""Save system state to disk."""
try:
saved_files = []
# Save vector store
vector_store_path = self.vector_store.save_to_disk(filepath)
saved_files.append(vector_store_path)
# Export analytics
analytics_path = self.analytics_manager.export_data()
saved_files.append(analytics_path)
self.logger.info(f"System state saved to {len(saved_files)} files")
return create_success_response({
"saved_files": saved_files,
"total_files": len(saved_files)
})
except Exception as e:
error_message = self.error_handler.log_error(e, {"operation": "save_state"})
return create_error_response(RAGError(error_message))
def shutdown(self) -> None:
"""Shutdown the RAG system gracefully."""
try:
self.logger.info("Shutting down RAG system...")
# Save analytics data
self.analytics_manager.shutdown()
# Unload models to free memory
self.embedding_manager.unload_model()
self.reranking_pipeline.unload_models()
# Clear status
self.status.ready = False
self.status.models_loaded = False
self.logger.info("RAG system shutdown completed")
except Exception as e:
self.logger.error(f"Error during shutdown: {e}")
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.shutdown()