""" In-memory vector store with efficient similarity search and metadata filtering. """ import pickle from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from dataclasses import dataclass, asdict import json import time from .error_handler import ResourceError from .document_processor import DocumentChunk @dataclass class VectorEntry: """Represents a vector entry with metadata.""" id: str vector: np.ndarray metadata: Dict[str, Any] timestamp: float = None def __post_init__(self): if self.timestamp is None: self.timestamp = time.time() def to_dict(self) -> Dict[str, Any]: """Convert to dictionary (excluding vector for serialization).""" return { "id": self.id, "metadata": self.metadata, "timestamp": self.timestamp, "vector_shape": self.vector.shape, "vector_dtype": str(self.vector.dtype) } class VectorStore: """In-memory vector store with efficient similarity search.""" def __init__(self, config: Dict[str, Any], embedding_dim: int = None): self.config = config self.embedding_dim = embedding_dim # Storage self._vectors: List[VectorEntry] = [] self._id_to_index: Dict[str, int] = {} self._vector_matrix: Optional[np.ndarray] = None self._matrix_dirty = True # Configuration self.cache_dir = Path(config.get("cache", {}).get("cache_dir", "./cache")) self.auto_save = config.get("vector_store", {}).get("auto_save", True) # Statistics self.stats = { "total_vectors": 0, "searches_performed": 0, "total_search_time": 0, "last_update": None, "memory_usage_mb": 0 } def add_documents(self, chunks: List[DocumentChunk], embeddings: np.ndarray) -> List[str]: """ Add document chunks with their embeddings to the vector store. Args: chunks: List of document chunks embeddings: Array of embeddings corresponding to chunks Returns: List of vector IDs that were added """ if len(chunks) != len(embeddings): raise ValueError("Number of chunks must match number of embeddings") if embeddings.size == 0: return [] # Set embedding dimension if not set if self.embedding_dim is None: self.embedding_dim = embeddings.shape[1] elif embeddings.shape[1] != self.embedding_dim: raise ValueError(f"Embedding dimension {embeddings.shape[1]} doesn't match expected {self.embedding_dim}") added_ids = [] for chunk, embedding in zip(chunks, embeddings): # Create vector entry with content included in metadata metadata_with_content = chunk.metadata.copy() metadata_with_content['content'] = chunk.content # Add content to metadata vector_entry = VectorEntry( id=chunk.chunk_id, vector=embedding.copy(), metadata=metadata_with_content ) # Add to store if vector_entry.id in self._id_to_index: # Update existing entry index = self._id_to_index[vector_entry.id] self._vectors[index] = vector_entry else: # Add new entry self._id_to_index[vector_entry.id] = len(self._vectors) self._vectors.append(vector_entry) added_ids.append(vector_entry.id) # Mark matrix as dirty for rebuild self._matrix_dirty = True # Update statistics self._update_stats() return added_ids def search( self, query_embedding: np.ndarray, k: int = 10, metadata_filter: Optional[Dict[str, Any]] = None, similarity_threshold: float = 0.0 ) -> List[Tuple[str, float, Dict[str, Any]]]: """ Search for similar vectors. Args: query_embedding: Query vector k: Number of results to return metadata_filter: Optional metadata filter similarity_threshold: Minimum similarity score Returns: List of (vector_id, similarity_score, metadata) tuples """ start_time = time.time() if not self._vectors: return [] # Ensure vector matrix is built self._build_vector_matrix() # Normalize query vector query_norm = query_embedding / np.linalg.norm(query_embedding) # Compute similarities similarities = np.dot(self._vector_matrix, query_norm) # Apply similarity threshold valid_indices = np.where(similarities >= similarity_threshold)[0] if len(valid_indices) == 0: return [] # Get top k candidates (before metadata filtering) candidate_k = min(len(valid_indices), k * 3) # Get more candidates for filtering top_candidate_indices = valid_indices[np.argpartition(similarities[valid_indices], -candidate_k)[-candidate_k:]] top_candidate_indices = top_candidate_indices[np.argsort(similarities[top_candidate_indices])[::-1]] # Apply metadata filtering and collect results results = [] for idx in top_candidate_indices: if len(results) >= k: break vector_entry = self._vectors[idx] # Apply metadata filter if metadata_filter and not self._matches_filter(vector_entry.metadata, metadata_filter): continue results.append(( vector_entry.id, float(similarities[idx]), vector_entry.metadata.copy() )) # Update statistics search_time = time.time() - start_time self.stats["searches_performed"] += 1 self.stats["total_search_time"] += search_time return results def _build_vector_matrix(self) -> None: """Build or rebuild the vector matrix for efficient search.""" if not self._matrix_dirty: return if not self._vectors: self._vector_matrix = None return # Stack all vectors into a matrix vectors = [entry.vector for entry in self._vectors] self._vector_matrix = np.vstack(vectors) # Normalize for cosine similarity norms = np.linalg.norm(self._vector_matrix, axis=1, keepdims=True) norms[norms == 0] = 1 # Avoid division by zero self._vector_matrix = self._vector_matrix / norms self._matrix_dirty = False def _matches_filter(self, metadata: Dict[str, Any], filter_dict: Dict[str, Any]) -> bool: """Check if metadata matches the filter.""" for key, value in filter_dict.items(): if key not in metadata: return False metadata_value = metadata[key] if isinstance(value, dict): # Support for range filters, etc. if "$gte" in value and metadata_value < value["$gte"]: return False if "$lte" in value and metadata_value > value["$lte"]: return False if "$in" in value and metadata_value not in value["$in"]: return False elif isinstance(value, list): if metadata_value not in value: return False else: if metadata_value != value: return False return True def get_by_id(self, vector_id: str) -> Optional[Tuple[np.ndarray, Dict[str, Any]]]: """Get vector and metadata by ID.""" if vector_id not in self._id_to_index: return None index = self._id_to_index[vector_id] entry = self._vectors[index] return entry.vector.copy(), entry.metadata.copy() def delete_by_id(self, vector_id: str) -> bool: """Delete vector by ID.""" if vector_id not in self._id_to_index: return False index = self._id_to_index[vector_id] # Remove from vectors list del self._vectors[index] # Update index mapping del self._id_to_index[vector_id] for vid, idx in self._id_to_index.items(): if idx > index: self._id_to_index[vid] = idx - 1 # Mark matrix as dirty self._matrix_dirty = True # Update statistics self._update_stats() return True def delete_by_metadata(self, metadata_filter: Dict[str, Any]) -> int: """Delete vectors matching metadata filter.""" to_delete = [] for entry in self._vectors: if self._matches_filter(entry.metadata, metadata_filter): to_delete.append(entry.id) deleted_count = 0 for vector_id in to_delete: if self.delete_by_id(vector_id): deleted_count += 1 return deleted_count def clear(self) -> None: """Clear all vectors from the store.""" self._vectors.clear() self._id_to_index.clear() self._vector_matrix = None self._matrix_dirty = True self._update_stats() def get_stats(self) -> Dict[str, Any]: """Get vector store statistics.""" stats = self.stats.copy() if stats["searches_performed"] > 0: stats["avg_search_time"] = stats["total_search_time"] / stats["searches_performed"] else: stats["avg_search_time"] = 0 # Memory usage estimation memory_usage = 0 if self._vector_matrix is not None: memory_usage += self._vector_matrix.nbytes for entry in self._vectors: memory_usage += entry.vector.nbytes memory_usage += len(str(entry.metadata)) * 4 # Rough estimate stats["memory_usage_mb"] = memory_usage / (1024 * 1024) stats["embedding_dimension"] = self.embedding_dim return stats def _update_stats(self) -> None: """Update internal statistics.""" self.stats["total_vectors"] = len(self._vectors) self.stats["last_update"] = time.time() def save_to_disk(self, filepath: Optional[str] = None) -> str: """Save vector store to disk.""" if filepath is None: self.cache_dir.mkdir(parents=True, exist_ok=True) filepath = str(self.cache_dir / "vector_store.pkl") # Prepare data for serialization data = { "embedding_dim": self.embedding_dim, "vectors": [], "stats": self.stats } for entry in self._vectors: data["vectors"].append({ "id": entry.id, "vector": entry.vector, "metadata": entry.metadata, "timestamp": entry.timestamp }) try: with open(filepath, "wb") as f: pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) print(f"Vector store saved to {filepath}") return filepath except Exception as e: raise ResourceError(f"Failed to save vector store: {str(e)}") from e def load_from_disk(self, filepath: str) -> None: """Load vector store from disk.""" try: with open(filepath, "rb") as f: data = pickle.load(f) # Clear current data self.clear() # Restore data self.embedding_dim = data.get("embedding_dim") self.stats = data.get("stats", {}) for vector_data in data.get("vectors", []): entry = VectorEntry( id=vector_data["id"], vector=vector_data["vector"], metadata=vector_data["metadata"], timestamp=vector_data.get("timestamp", time.time()) ) self._id_to_index[entry.id] = len(self._vectors) self._vectors.append(entry) # Mark matrix as dirty for rebuild self._matrix_dirty = True print(f"Vector store loaded from {filepath} with {len(self._vectors)} vectors") except Exception as e: raise ResourceError(f"Failed to load vector store: {str(e)}") from e def get_document_chunks(self, source_filter: Optional[str] = None) -> List[Dict[str, Any]]: """Get all document chunks, optionally filtered by source.""" chunks = [] for entry in self._vectors: if source_filter is None or entry.metadata.get("source") == source_filter: chunks.append({ "id": entry.id, "content": entry.metadata.get("content", ""), "metadata": entry.metadata }) return chunks def optimize(self) -> Dict[str, Any]: """Optimize the vector store.""" start_time = time.time() # Rebuild vector matrix self._build_vector_matrix() # Could add more optimizations like: # - Removing duplicate vectors # - Compacting memory layout # - Building additional indexes optimization_time = time.time() - start_time return { "optimization_time": optimization_time, "total_vectors": len(self._vectors), "matrix_rebuilt": True } class ConversationalVectorStore(VectorStore): """Enhanced vector store with conversation context awareness.""" def __init__(self, config: Dict[str, Any], embedding_dim: int = None): """ Initialize conversational vector store. Args: config: Configuration dictionary embedding_dim: Embedding dimension """ super().__init__(config, embedding_dim) # Conversation-specific configuration self.conversation_config = config.get("conversation", {}) self.context_config = self.conversation_config.get("vector_store", {}) # Context-aware retrieval settings self.conversation_context_weight = self.context_config.get("context_weight", 0.1) self.entity_embedding_cache = {} # Cache entity embeddings self.topic_embedding_cache = {} # Cache topic embeddings # Enhanced statistics self.conversation_stats = { "contextual_searches": 0, "entity_enhanced_searches": 0, "topic_enhanced_searches": 0, "context_cache_hits": 0, "context_cache_misses": 0 } def retrieve_with_context( self, query_embedding: np.ndarray, conversation_embeddings: List[np.ndarray] = None, mentioned_entities: List[str] = None, active_topics: List[str] = None, conversation_history: List[Dict[str, Any]] = None, k: int = 10, metadata_filter: Optional[Dict[str, Any]] = None, similarity_threshold: float = 0.0 ) -> List[Tuple[str, float, Dict[str, Any]]]: """ Retrieve vectors with conversation context enhancement. Args: query_embedding: Query vector conversation_embeddings: Embeddings from conversation history mentioned_entities: Entities mentioned in conversation active_topics: Active conversation topics conversation_history: Recent conversation messages k: Number of results to return metadata_filter: Optional metadata filter similarity_threshold: Minimum similarity score Returns: List of (vector_id, similarity_score, metadata) tuples enhanced with context """ start_time = time.time() try: # Enhance query embedding with conversation context enhanced_query_embedding = self._enhance_query_embedding_with_context( query_embedding, conversation_embeddings, mentioned_entities, active_topics, conversation_history ) # Perform base retrieval with enhanced embedding base_results = super().search( query_embedding=enhanced_query_embedding, k=k * 2, # Get more results for context re-ranking metadata_filter=metadata_filter, similarity_threshold=similarity_threshold ) # Apply conversation context scoring context_enhanced_results = self._apply_conversation_context_scoring( base_results, conversation_embeddings, mentioned_entities, active_topics, conversation_history ) # Re-rank and limit results final_results = self._rerank_with_conversation_context( context_enhanced_results, k ) # Update conversation statistics search_time = time.time() - start_time self.conversation_stats["contextual_searches"] += 1 if mentioned_entities: self.conversation_stats["entity_enhanced_searches"] += 1 if active_topics: self.conversation_stats["topic_enhanced_searches"] += 1 return final_results except Exception as e: # Fallback to regular search on error return super().search(query_embedding, k, metadata_filter, similarity_threshold) def _enhance_query_embedding_with_context( self, query_embedding: np.ndarray, conversation_embeddings: List[np.ndarray] = None, mentioned_entities: List[str] = None, active_topics: List[str] = None, conversation_history: List[Dict[str, Any]] = None ) -> np.ndarray: """Enhance query embedding with conversation context.""" enhanced_embedding = query_embedding.copy() # Add conversation history context if conversation_embeddings: # Weight recent conversation embeddings context_vector = np.zeros_like(query_embedding) for i, conv_embedding in enumerate(conversation_embeddings[-3:]): # Last 3 weight = self.conversation_context_weight * (0.8 ** i) # Decay factor context_vector += weight * conv_embedding # Blend with query embedding enhanced_embedding = 0.9 * enhanced_embedding + 0.1 * context_vector # Add entity context if mentioned_entities: entity_context = self._get_entity_context_vector(mentioned_entities) if entity_context is not None: enhanced_embedding = 0.95 * enhanced_embedding + 0.05 * entity_context # Add topic context if active_topics: topic_context = self._get_topic_context_vector(active_topics) if topic_context is not None: enhanced_embedding = 0.95 * enhanced_embedding + 0.05 * topic_context # Normalize the enhanced embedding norm = np.linalg.norm(enhanced_embedding) if norm > 0: enhanced_embedding = enhanced_embedding / norm return enhanced_embedding def _get_entity_context_vector(self, entities: List[str]) -> Optional[np.ndarray]: """Get aggregated context vector for entities.""" if not entities or not self.embedding_dim: return None # Check cache first entities_key = "|".join(sorted(entities)) if entities_key in self.entity_embedding_cache: self.conversation_stats["context_cache_hits"] += 1 return self.entity_embedding_cache[entities_key] self.conversation_stats["context_cache_misses"] += 1 # Find vectors that mention these entities entity_vectors = [] for vector_entry in self._vectors: content = vector_entry.metadata.get("content", "").lower() # Check if any entity is mentioned in this content entity_mentions = sum(1 for entity in entities if entity.lower() in content) if entity_mentions > 0: # Weight by number of entity mentions weighted_vector = vector_entry.vector * entity_mentions entity_vectors.append(weighted_vector) if not entity_vectors: return None # Average the entity-related vectors context_vector = np.mean(entity_vectors, axis=0) # Cache the result self.entity_embedding_cache[entities_key] = context_vector return context_vector def _get_topic_context_vector(self, topics: List[str]) -> Optional[np.ndarray]: """Get aggregated context vector for topics.""" if not topics or not self.embedding_dim: return None # Check cache first topics_key = "|".join(sorted(topics)) if topics_key in self.topic_embedding_cache: self.conversation_stats["context_cache_hits"] += 1 return self.topic_embedding_cache[topics_key] self.conversation_stats["context_cache_misses"] += 1 # Find vectors that relate to these topics topic_vectors = [] for vector_entry in self._vectors: content = vector_entry.metadata.get("content", "").lower() # Check if any topic is mentioned in this content topic_mentions = sum(1 for topic in topics if topic.lower() in content) if topic_mentions > 0: # Weight by number of topic mentions weighted_vector = vector_entry.vector * topic_mentions topic_vectors.append(weighted_vector) if not topic_vectors: return None # Average the topic-related vectors context_vector = np.mean(topic_vectors, axis=0) # Cache the result self.topic_embedding_cache[topics_key] = context_vector return context_vector def _apply_conversation_context_scoring( self, results: List[Tuple[str, float, Dict[str, Any]]], conversation_embeddings: List[np.ndarray] = None, mentioned_entities: List[str] = None, active_topics: List[str] = None, conversation_history: List[Dict[str, Any]] = None ) -> List[Tuple[str, float, Dict[str, Any]]]: """Apply conversation context to boost relevant results.""" enhanced_results = [] for vector_id, similarity_score, metadata in results: # Get the vector entry vector_entry = None if vector_id in self._id_to_index: vector_entry = self._vectors[self._id_to_index[vector_id]] if not vector_entry: enhanced_results.append((vector_id, similarity_score, metadata)) continue # Calculate context boost context_boost = 1.0 content = metadata.get("content", "").lower() # Entity context boost if mentioned_entities: entity_matches = sum( 1 for entity in mentioned_entities if entity.lower() in content ) if entity_matches > 0: context_boost *= (1.1 ** entity_matches) # 10% boost per entity match # Topic context boost if active_topics: topic_matches = sum( 1 for topic in active_topics if topic.lower() in content ) if topic_matches > 0: context_boost *= (1.15 ** topic_matches) # 15% boost per topic match # Conversation history similarity boost if conversation_embeddings: history_boost = self._calculate_conversation_similarity_boost( vector_entry.vector, conversation_embeddings ) context_boost *= history_boost # Document continuity boost if conversation_history: continuity_boost = self._calculate_document_continuity_boost( metadata, conversation_history ) context_boost *= continuity_boost # Apply context boost to similarity score enhanced_score = similarity_score * context_boost # Add context information to metadata enhanced_metadata = metadata.copy() enhanced_metadata["conversation_context"] = { "context_boost": context_boost, "entity_matches": sum(1 for entity in (mentioned_entities or []) if entity.lower() in content), "topic_matches": sum(1 for topic in (active_topics or []) if topic.lower() in content), "original_score": similarity_score, "enhanced_score": enhanced_score } enhanced_results.append((vector_id, enhanced_score, enhanced_metadata)) return enhanced_results def _calculate_conversation_similarity_boost( self, vector: np.ndarray, conversation_embeddings: List[np.ndarray] ) -> float: """Calculate boost based on similarity to conversation history.""" if not conversation_embeddings: return 1.0 # Calculate similarity to recent conversation embeddings similarities = [] for conv_embedding in conversation_embeddings[-3:]: # Last 3 # Normalize vectors vector_norm = vector / (np.linalg.norm(vector) + 1e-8) conv_norm = conv_embedding / (np.linalg.norm(conv_embedding) + 1e-8) # Calculate cosine similarity similarity = np.dot(vector_norm, conv_norm) similarities.append(similarity) if similarities: # Use max similarity with decay for older embeddings max_similarity = max(similarities) boost = 1.0 + (0.2 * max_similarity) # Up to 20% boost return min(boost, 1.3) # Cap at 30% boost return 1.0 def _calculate_document_continuity_boost( self, metadata: Dict[str, Any], conversation_history: List[Dict[str, Any]] ) -> float: """Calculate boost for document continuity in conversation.""" current_source = metadata.get("source", "") if not current_source or not conversation_history: return 1.0 # Check if recent messages referenced the same document recent_sources = [] for message in reversed(conversation_history[-5:]): # Last 5 messages if message.get("role") == "assistant": sources = message.get("sources", []) for source in sources: if isinstance(source, dict): source_name = source.get("title", source.get("document_id", "")) if source_name: recent_sources.append(source_name) # Check for document continuity if current_source in recent_sources: return 1.1 # 10% boost for document continuity return 1.0 def _rerank_with_conversation_context( self, results: List[Tuple[str, float, Dict[str, Any]]], k: int ) -> List[Tuple[str, float, Dict[str, Any]]]: """Re-rank results based on context-enhanced scores.""" # Sort by enhanced similarity score sorted_results = sorted( results, key=lambda x: x[1], # Sort by similarity score reverse=True ) return sorted_results[:k] def search_similar_in_conversation_context( self, vector_id: str, conversation_embeddings: List[np.ndarray] = None, k: int = 5 ) -> List[Tuple[str, float, Dict[str, Any]]]: """Find similar vectors within conversation context.""" if vector_id not in self._id_to_index: return [] # Get the reference vector vector_entry = self._vectors[self._id_to_index[vector_id]] reference_embedding = vector_entry.vector # Use the reference embedding as query with conversation context return self.retrieve_with_context( query_embedding=reference_embedding, conversation_embeddings=conversation_embeddings, k=k ) def get_conversation_stats(self) -> Dict[str, Any]: """Get conversation-specific vector store statistics.""" base_stats = super().get_stats() base_stats.update(self.conversation_stats) # Add derived metrics if self.conversation_stats["contextual_searches"] > 0: base_stats["entity_enhancement_rate"] = ( self.conversation_stats["entity_enhanced_searches"] / self.conversation_stats["contextual_searches"] ) * 100 base_stats["topic_enhancement_rate"] = ( self.conversation_stats["topic_enhanced_searches"] / self.conversation_stats["contextual_searches"] ) * 100 # Cache efficiency total_cache_requests = ( self.conversation_stats["context_cache_hits"] + self.conversation_stats["context_cache_misses"] ) if total_cache_requests > 0: base_stats["context_cache_hit_rate"] = ( self.conversation_stats["context_cache_hits"] / total_cache_requests ) * 100 base_stats["entity_cache_size"] = len(self.entity_embedding_cache) base_stats["topic_cache_size"] = len(self.topic_embedding_cache) return base_stats def clear_conversation_cache(self) -> None: """Clear conversation-specific caches.""" self.entity_embedding_cache.clear() self.topic_embedding_cache.clear() def add_conversation_aware_documents( self, chunks: List[DocumentChunk], embeddings: np.ndarray, conversation_context: Dict[str, Any] = None ) -> List[str]: """ Add documents with conversation context awareness. Args: chunks: Document chunks to add embeddings: Corresponding embeddings conversation_context: Context from current conversation Returns: List of vector IDs that were added """ # Enhance metadata with conversation context if conversation_context: for chunk in chunks: chunk.metadata["conversation_context"] = conversation_context.copy() # Add conversation session info if "session_id" in conversation_context: chunk.metadata["session_id"] = conversation_context["session_id"] # Add user context if "user_id" in conversation_context: chunk.metadata["user_id"] = conversation_context["user_id"] # Use parent method to add documents return super().add_documents(chunks, embeddings)