RAG / src /vector_store.py
Jialun He
fix search
f7c2b86
"""
In-memory vector store with efficient similarity search and metadata filtering.
"""
import pickle
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from dataclasses import dataclass, asdict
import json
import time
from .error_handler import ResourceError
from .document_processor import DocumentChunk
@dataclass
class VectorEntry:
"""Represents a vector entry with metadata."""
id: str
vector: np.ndarray
metadata: Dict[str, Any]
timestamp: float = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = time.time()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary (excluding vector for serialization)."""
return {
"id": self.id,
"metadata": self.metadata,
"timestamp": self.timestamp,
"vector_shape": self.vector.shape,
"vector_dtype": str(self.vector.dtype)
}
class VectorStore:
"""In-memory vector store with efficient similarity search."""
def __init__(self, config: Dict[str, Any], embedding_dim: int = None):
self.config = config
self.embedding_dim = embedding_dim
# Storage
self._vectors: List[VectorEntry] = []
self._id_to_index: Dict[str, int] = {}
self._vector_matrix: Optional[np.ndarray] = None
self._matrix_dirty = True
# Configuration
self.cache_dir = Path(config.get("cache", {}).get("cache_dir", "./cache"))
self.auto_save = config.get("vector_store", {}).get("auto_save", True)
# Statistics
self.stats = {
"total_vectors": 0,
"searches_performed": 0,
"total_search_time": 0,
"last_update": None,
"memory_usage_mb": 0
}
def add_documents(self, chunks: List[DocumentChunk], embeddings: np.ndarray) -> List[str]:
"""
Add document chunks with their embeddings to the vector store.
Args:
chunks: List of document chunks
embeddings: Array of embeddings corresponding to chunks
Returns:
List of vector IDs that were added
"""
if len(chunks) != len(embeddings):
raise ValueError("Number of chunks must match number of embeddings")
if embeddings.size == 0:
return []
# Set embedding dimension if not set
if self.embedding_dim is None:
self.embedding_dim = embeddings.shape[1]
elif embeddings.shape[1] != self.embedding_dim:
raise ValueError(f"Embedding dimension {embeddings.shape[1]} doesn't match expected {self.embedding_dim}")
added_ids = []
for chunk, embedding in zip(chunks, embeddings):
# Create vector entry with content included in metadata
metadata_with_content = chunk.metadata.copy()
metadata_with_content['content'] = chunk.content # Add content to metadata
vector_entry = VectorEntry(
id=chunk.chunk_id,
vector=embedding.copy(),
metadata=metadata_with_content
)
# Add to store
if vector_entry.id in self._id_to_index:
# Update existing entry
index = self._id_to_index[vector_entry.id]
self._vectors[index] = vector_entry
else:
# Add new entry
self._id_to_index[vector_entry.id] = len(self._vectors)
self._vectors.append(vector_entry)
added_ids.append(vector_entry.id)
# Mark matrix as dirty for rebuild
self._matrix_dirty = True
# Update statistics
self._update_stats()
return added_ids
def search(
self,
query_embedding: np.ndarray,
k: int = 10,
metadata_filter: Optional[Dict[str, Any]] = None,
similarity_threshold: float = 0.0
) -> List[Tuple[str, float, Dict[str, Any]]]:
"""
Search for similar vectors.
Args:
query_embedding: Query vector
k: Number of results to return
metadata_filter: Optional metadata filter
similarity_threshold: Minimum similarity score
Returns:
List of (vector_id, similarity_score, metadata) tuples
"""
start_time = time.time()
if not self._vectors:
return []
# Ensure vector matrix is built
self._build_vector_matrix()
# Normalize query vector
query_norm = query_embedding / np.linalg.norm(query_embedding)
# Compute similarities
similarities = np.dot(self._vector_matrix, query_norm)
# Apply similarity threshold
valid_indices = np.where(similarities >= similarity_threshold)[0]
if len(valid_indices) == 0:
return []
# Get top k candidates (before metadata filtering)
candidate_k = min(len(valid_indices), k * 3) # Get more candidates for filtering
top_candidate_indices = valid_indices[np.argpartition(similarities[valid_indices], -candidate_k)[-candidate_k:]]
top_candidate_indices = top_candidate_indices[np.argsort(similarities[top_candidate_indices])[::-1]]
# Apply metadata filtering and collect results
results = []
for idx in top_candidate_indices:
if len(results) >= k:
break
vector_entry = self._vectors[idx]
# Apply metadata filter
if metadata_filter and not self._matches_filter(vector_entry.metadata, metadata_filter):
continue
results.append((
vector_entry.id,
float(similarities[idx]),
vector_entry.metadata.copy()
))
# Update statistics
search_time = time.time() - start_time
self.stats["searches_performed"] += 1
self.stats["total_search_time"] += search_time
return results
def _build_vector_matrix(self) -> None:
"""Build or rebuild the vector matrix for efficient search."""
if not self._matrix_dirty:
return
if not self._vectors:
self._vector_matrix = None
return
# Stack all vectors into a matrix
vectors = [entry.vector for entry in self._vectors]
self._vector_matrix = np.vstack(vectors)
# Normalize for cosine similarity
norms = np.linalg.norm(self._vector_matrix, axis=1, keepdims=True)
norms[norms == 0] = 1 # Avoid division by zero
self._vector_matrix = self._vector_matrix / norms
self._matrix_dirty = False
def _matches_filter(self, metadata: Dict[str, Any], filter_dict: Dict[str, Any]) -> bool:
"""Check if metadata matches the filter."""
for key, value in filter_dict.items():
if key not in metadata:
return False
metadata_value = metadata[key]
if isinstance(value, dict):
# Support for range filters, etc.
if "$gte" in value and metadata_value < value["$gte"]:
return False
if "$lte" in value and metadata_value > value["$lte"]:
return False
if "$in" in value and metadata_value not in value["$in"]:
return False
elif isinstance(value, list):
if metadata_value not in value:
return False
else:
if metadata_value != value:
return False
return True
def get_by_id(self, vector_id: str) -> Optional[Tuple[np.ndarray, Dict[str, Any]]]:
"""Get vector and metadata by ID."""
if vector_id not in self._id_to_index:
return None
index = self._id_to_index[vector_id]
entry = self._vectors[index]
return entry.vector.copy(), entry.metadata.copy()
def delete_by_id(self, vector_id: str) -> bool:
"""Delete vector by ID."""
if vector_id not in self._id_to_index:
return False
index = self._id_to_index[vector_id]
# Remove from vectors list
del self._vectors[index]
# Update index mapping
del self._id_to_index[vector_id]
for vid, idx in self._id_to_index.items():
if idx > index:
self._id_to_index[vid] = idx - 1
# Mark matrix as dirty
self._matrix_dirty = True
# Update statistics
self._update_stats()
return True
def delete_by_metadata(self, metadata_filter: Dict[str, Any]) -> int:
"""Delete vectors matching metadata filter."""
to_delete = []
for entry in self._vectors:
if self._matches_filter(entry.metadata, metadata_filter):
to_delete.append(entry.id)
deleted_count = 0
for vector_id in to_delete:
if self.delete_by_id(vector_id):
deleted_count += 1
return deleted_count
def clear(self) -> None:
"""Clear all vectors from the store."""
self._vectors.clear()
self._id_to_index.clear()
self._vector_matrix = None
self._matrix_dirty = True
self._update_stats()
def get_stats(self) -> Dict[str, Any]:
"""Get vector store statistics."""
stats = self.stats.copy()
if stats["searches_performed"] > 0:
stats["avg_search_time"] = stats["total_search_time"] / stats["searches_performed"]
else:
stats["avg_search_time"] = 0
# Memory usage estimation
memory_usage = 0
if self._vector_matrix is not None:
memory_usage += self._vector_matrix.nbytes
for entry in self._vectors:
memory_usage += entry.vector.nbytes
memory_usage += len(str(entry.metadata)) * 4 # Rough estimate
stats["memory_usage_mb"] = memory_usage / (1024 * 1024)
stats["embedding_dimension"] = self.embedding_dim
return stats
def _update_stats(self) -> None:
"""Update internal statistics."""
self.stats["total_vectors"] = len(self._vectors)
self.stats["last_update"] = time.time()
def save_to_disk(self, filepath: Optional[str] = None) -> str:
"""Save vector store to disk."""
if filepath is None:
self.cache_dir.mkdir(parents=True, exist_ok=True)
filepath = str(self.cache_dir / "vector_store.pkl")
# Prepare data for serialization
data = {
"embedding_dim": self.embedding_dim,
"vectors": [],
"stats": self.stats
}
for entry in self._vectors:
data["vectors"].append({
"id": entry.id,
"vector": entry.vector,
"metadata": entry.metadata,
"timestamp": entry.timestamp
})
try:
with open(filepath, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
print(f"Vector store saved to {filepath}")
return filepath
except Exception as e:
raise ResourceError(f"Failed to save vector store: {str(e)}") from e
def load_from_disk(self, filepath: str) -> None:
"""Load vector store from disk."""
try:
with open(filepath, "rb") as f:
data = pickle.load(f)
# Clear current data
self.clear()
# Restore data
self.embedding_dim = data.get("embedding_dim")
self.stats = data.get("stats", {})
for vector_data in data.get("vectors", []):
entry = VectorEntry(
id=vector_data["id"],
vector=vector_data["vector"],
metadata=vector_data["metadata"],
timestamp=vector_data.get("timestamp", time.time())
)
self._id_to_index[entry.id] = len(self._vectors)
self._vectors.append(entry)
# Mark matrix as dirty for rebuild
self._matrix_dirty = True
print(f"Vector store loaded from {filepath} with {len(self._vectors)} vectors")
except Exception as e:
raise ResourceError(f"Failed to load vector store: {str(e)}") from e
def get_document_chunks(self, source_filter: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get all document chunks, optionally filtered by source."""
chunks = []
for entry in self._vectors:
if source_filter is None or entry.metadata.get("source") == source_filter:
chunks.append({
"id": entry.id,
"content": entry.metadata.get("content", ""),
"metadata": entry.metadata
})
return chunks
def optimize(self) -> Dict[str, Any]:
"""Optimize the vector store."""
start_time = time.time()
# Rebuild vector matrix
self._build_vector_matrix()
# Could add more optimizations like:
# - Removing duplicate vectors
# - Compacting memory layout
# - Building additional indexes
optimization_time = time.time() - start_time
return {
"optimization_time": optimization_time,
"total_vectors": len(self._vectors),
"matrix_rebuilt": True
}