Spaces:
Running
Running
""" | |
Knowledge Database Cache System | |
============================== | |
Persistent storage for processed documents, chunks, and embeddings to avoid | |
reprocessing on system restart. | |
""" | |
import logging | |
import pickle | |
import json | |
import hashlib | |
import time | |
from pathlib import Path | |
from typing import Dict, Any, List, Optional, Tuple | |
import numpy as np | |
from dataclasses import asdict | |
logger = logging.getLogger(__name__) | |
class KnowledgeCache: | |
"""Persistent cache for processed documents and embeddings""" | |
def __init__(self, cache_dir: Path = Path("cache")): | |
""" | |
Initialize knowledge cache | |
Args: | |
cache_dir: Directory to store cache files | |
""" | |
self.cache_dir = cache_dir | |
self.cache_dir.mkdir(exist_ok=True) | |
# Cache file paths | |
self.metadata_file = self.cache_dir / "metadata.json" | |
self.documents_file = self.cache_dir / "documents.pkl" | |
self.embeddings_file = self.cache_dir / "embeddings.npy" | |
self.index_file = self.cache_dir / "faiss_index.bin" | |
# In-memory cache | |
self.metadata = self._load_metadata() | |
self.documents = None | |
self.embeddings = None | |
def _load_metadata(self) -> Dict[str, Any]: | |
"""Load cache metadata""" | |
try: | |
if self.metadata_file.exists(): | |
with open(self.metadata_file, 'r') as f: | |
return json.load(f) | |
return { | |
"version": "1.0", | |
"created": time.time(), | |
"last_updated": time.time(), | |
"document_count": 0, | |
"chunk_count": 0, | |
"file_hashes": {}, | |
"embedder_config": None | |
} | |
except Exception as e: | |
logger.error(f"Error loading metadata: {e}") | |
return self._create_empty_metadata() | |
def _create_empty_metadata(self) -> Dict[str, Any]: | |
"""Create empty metadata structure""" | |
return { | |
"version": "1.0", | |
"created": time.time(), | |
"last_updated": time.time(), | |
"document_count": 0, | |
"chunk_count": 0, | |
"file_hashes": {}, | |
"embedder_config": None | |
} | |
def _save_metadata(self): | |
"""Save metadata to file""" | |
try: | |
self.metadata["last_updated"] = time.time() | |
with open(self.metadata_file, 'w') as f: | |
json.dump(self.metadata, f, indent=2) | |
except Exception as e: | |
logger.error(f"Error saving metadata: {e}") | |
def _get_file_hash(self, file_path: Path) -> str: | |
"""Get hash of file for change detection""" | |
try: | |
with open(file_path, 'rb') as f: | |
content = f.read() | |
return hashlib.md5(content).hexdigest() | |
except Exception as e: | |
logger.error(f"Error hashing file {file_path}: {e}") | |
return "" | |
def _get_corpus_hash(self, pdf_files: List[Path]) -> str: | |
"""Get combined hash of all files in corpus""" | |
file_hashes = [] | |
for pdf_file in sorted(pdf_files): | |
file_hash = self._get_file_hash(pdf_file) | |
file_hashes.append(f"{pdf_file.name}:{file_hash}") | |
combined = "|".join(file_hashes) | |
return hashlib.md5(combined.encode()).hexdigest() | |
def is_cache_valid(self, pdf_files: List[Path], embedder_config: Dict[str, Any]) -> bool: | |
""" | |
Check if cache is valid for given files and embedder config | |
Args: | |
pdf_files: List of PDF files in corpus | |
embedder_config: Current embedder configuration | |
Returns: | |
True if cache is valid and can be used | |
""" | |
try: | |
# Check if cache files exist | |
if not all(f.exists() for f in [self.documents_file, self.embeddings_file]): | |
logger.info("Cache files missing, cache invalid") | |
return False | |
# Check if metadata exists | |
if not self.metadata or self.metadata.get("document_count", 0) == 0: | |
logger.info("No metadata or empty cache, cache invalid") | |
return False | |
# Check embedder configuration hash | |
current_config_hash = create_embedder_config_hash(embedder_config) | |
cached_config_hash = self.metadata.get("embedder_config_hash") | |
if current_config_hash != cached_config_hash: | |
logger.info("Embedder configuration changed, cache invalid") | |
return False | |
# Check file count | |
if len(pdf_files) != self.metadata.get("document_count", 0): | |
logger.info(f"Document count changed: {len(pdf_files)} vs {self.metadata.get('document_count', 0)}") | |
return False | |
# Quick check: if no files have changed timestamps, cache is likely valid | |
all_files_unchanged = True | |
for pdf_file in pdf_files: | |
if not pdf_file.exists(): | |
logger.info(f"File missing: {pdf_file.name}") | |
return False | |
# Check modification time first (faster than hashing) | |
cached_mtime = self.metadata.get("file_mtimes", {}).get(pdf_file.name) | |
current_mtime = pdf_file.stat().st_mtime | |
if cached_mtime != current_mtime: | |
all_files_unchanged = False | |
break | |
if all_files_unchanged: | |
logger.info("Cache validation successful (no timestamp changes)") | |
return True | |
# If timestamps changed, check file hashes (slower but accurate) | |
logger.info("Timestamps changed, checking file hashes...") | |
changed_files = [] | |
for pdf_file in pdf_files: | |
current_hash = self._get_file_hash(pdf_file) | |
cached_hash = self.metadata.get("file_hashes", {}).get(pdf_file.name) | |
if current_hash != cached_hash: | |
changed_files.append(pdf_file.name) | |
if changed_files: | |
logger.info(f"Files changed: {', '.join(changed_files)}") | |
return False | |
logger.info("Cache validation successful (hashes match)") | |
return True | |
except Exception as e: | |
logger.error(f"Error validating cache: {e}") | |
return False | |
def load_documents(self) -> Optional[List[Any]]: | |
"""Load processed documents from cache""" | |
try: | |
if self.documents is None and self.documents_file.exists(): | |
with open(self.documents_file, 'rb') as f: | |
self.documents = pickle.load(f) | |
logger.info(f"Loaded {len(self.documents)} documents from cache") | |
return self.documents | |
except Exception as e: | |
logger.error(f"Error loading documents: {e}") | |
return None | |
def load_embeddings(self) -> Optional[np.ndarray]: | |
"""Load embeddings from cache""" | |
try: | |
if self.embeddings is None and self.embeddings_file.exists(): | |
self.embeddings = np.load(self.embeddings_file) | |
logger.info(f"Loaded embeddings with shape {self.embeddings.shape}") | |
return self.embeddings | |
except Exception as e: | |
logger.error(f"Error loading embeddings: {e}") | |
return None | |
def load_knowledge_base(self) -> Tuple[Optional[List[Any]], Optional[np.ndarray]]: | |
"""Load both documents and embeddings from cache""" | |
try: | |
documents = self.load_documents() | |
embeddings = self.load_embeddings() | |
if documents is not None and embeddings is not None: | |
logger.info(f"Loaded knowledge base: {len(documents)} documents, embeddings shape {embeddings.shape}") | |
return documents, embeddings | |
else: | |
logger.warning("Failed to load complete knowledge base from cache") | |
return None, None | |
except Exception as e: | |
logger.error(f"Error loading knowledge base: {e}") | |
return None, None | |
def is_valid(self) -> bool: | |
"""Check if cache has valid data""" | |
try: | |
return (self.documents_file.exists() and | |
self.embeddings_file.exists() and | |
self.metadata.get("chunk_count", 0) > 0) | |
except: | |
return False | |
def save_knowledge_base(self, documents: List[Any], embeddings: np.ndarray, | |
pdf_files: List[Path], embedder_config: Dict[str, Any]): | |
""" | |
Save processed documents and embeddings to cache | |
Args: | |
documents: List of processed document objects | |
embeddings: Numpy array of embeddings | |
pdf_files: List of source PDF files | |
embedder_config: Embedder configuration used | |
""" | |
try: | |
logger.info(f"Saving knowledge base: {len(documents)} documents, {embeddings.shape} embeddings") | |
# Save documents | |
with open(self.documents_file, 'wb') as f: | |
pickle.dump(documents, f) | |
# Save embeddings | |
np.save(self.embeddings_file, embeddings) | |
# Collect file metadata | |
file_hashes = {} | |
file_mtimes = {} | |
for pdf_file in pdf_files: | |
file_hashes[pdf_file.name] = self._get_file_hash(pdf_file) | |
file_mtimes[pdf_file.name] = pdf_file.stat().st_mtime | |
# Update metadata | |
self.metadata.update({ | |
"document_count": len(pdf_files), | |
"chunk_count": len(documents), | |
"embedder_config": embedder_config, | |
"embedder_config_hash": create_embedder_config_hash(embedder_config), | |
"file_hashes": file_hashes, | |
"file_mtimes": file_mtimes | |
}) | |
self._save_metadata() | |
# Cache in memory | |
self.documents = documents | |
self.embeddings = embeddings | |
logger.info("Knowledge base saved successfully") | |
except Exception as e: | |
logger.error(f"Error saving knowledge base: {e}") | |
raise | |
def get_cache_info(self) -> Dict[str, Any]: | |
"""Get information about cached data""" | |
return { | |
"cache_valid": self.documents_file.exists() and self.embeddings_file.exists(), | |
"document_count": self.metadata.get("document_count", 0), | |
"chunk_count": self.metadata.get("chunk_count", 0), | |
"last_updated": self.metadata.get("last_updated", 0), | |
"cache_size_mb": self._get_cache_size_mb(), | |
"embedder_config": self.metadata.get("embedder_config") | |
} | |
def _get_cache_size_mb(self) -> float: | |
"""Get total cache size in MB""" | |
try: | |
total_size = 0 | |
for file_path in [self.metadata_file, self.documents_file, self.embeddings_file]: | |
if file_path.exists(): | |
total_size += file_path.stat().st_size | |
return total_size / (1024 * 1024) | |
except: | |
return 0.0 | |
def clear_cache(self): | |
"""Clear all cached data""" | |
try: | |
for file_path in [self.metadata_file, self.documents_file, self.embeddings_file, self.index_file]: | |
if file_path.exists(): | |
file_path.unlink() | |
self.metadata = self._create_empty_metadata() | |
self.documents = None | |
self.embeddings = None | |
logger.info("Cache cleared successfully") | |
except Exception as e: | |
logger.error(f"Error clearing cache: {e}") | |
raise | |
def save_faiss_index(self, index_data: bytes): | |
"""Save FAISS index to cache""" | |
try: | |
with open(self.index_file, 'wb') as f: | |
f.write(index_data) | |
logger.info("FAISS index saved to cache") | |
except Exception as e: | |
logger.error(f"Error saving FAISS index: {e}") | |
def load_faiss_index(self) -> Optional[bytes]: | |
"""Load FAISS index from cache""" | |
try: | |
if self.index_file.exists(): | |
with open(self.index_file, 'rb') as f: | |
return f.read() | |
return None | |
except Exception as e: | |
logger.error(f"Error loading FAISS index: {e}") | |
return None | |
def create_embedder_config_hash(system_or_config) -> Dict[str, Any]: | |
"""Extract embedder configuration for cache validation""" | |
try: | |
# Handle both system object and dict inputs | |
if isinstance(system_or_config, dict): | |
# Already a config dict, return as-is | |
return system_or_config | |
else: | |
# System object, extract config | |
embedder = system_or_config.get_component('embedder') | |
# Get key configuration parameters | |
config = { | |
"model_name": getattr(embedder, 'model_name', 'unknown'), | |
"model_type": type(embedder).__name__, | |
"device": getattr(embedder, 'device', 'unknown'), | |
"normalize_embeddings": getattr(embedder, 'normalize_embeddings', True) | |
} | |
# Add batch processor config if available | |
if hasattr(embedder, 'batch_processor'): | |
config["batch_size"] = getattr(embedder.batch_processor, 'batch_size', 32) | |
return config | |
except Exception as e: | |
logger.error(f"Error creating embedder config hash: {e}") | |
return {"error": str(e)} |