Spaces:
Running
Running
""" | |
In-memory vector store with efficient similarity search and metadata filtering. | |
""" | |
import pickle | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import numpy as np | |
from dataclasses import dataclass, asdict | |
import json | |
import time | |
from .error_handler import ResourceError | |
from .document_processor import DocumentChunk | |
class VectorEntry: | |
"""Represents a vector entry with metadata.""" | |
id: str | |
vector: np.ndarray | |
metadata: Dict[str, Any] | |
timestamp: float = None | |
def __post_init__(self): | |
if self.timestamp is None: | |
self.timestamp = time.time() | |
def to_dict(self) -> Dict[str, Any]: | |
"""Convert to dictionary (excluding vector for serialization).""" | |
return { | |
"id": self.id, | |
"metadata": self.metadata, | |
"timestamp": self.timestamp, | |
"vector_shape": self.vector.shape, | |
"vector_dtype": str(self.vector.dtype) | |
} | |
class VectorStore: | |
"""In-memory vector store with efficient similarity search.""" | |
def __init__(self, config: Dict[str, Any], embedding_dim: int = None): | |
self.config = config | |
self.embedding_dim = embedding_dim | |
# Storage | |
self._vectors: List[VectorEntry] = [] | |
self._id_to_index: Dict[str, int] = {} | |
self._vector_matrix: Optional[np.ndarray] = None | |
self._matrix_dirty = True | |
# Configuration | |
self.cache_dir = Path(config.get("cache", {}).get("cache_dir", "./cache")) | |
self.auto_save = config.get("vector_store", {}).get("auto_save", True) | |
# Statistics | |
self.stats = { | |
"total_vectors": 0, | |
"searches_performed": 0, | |
"total_search_time": 0, | |
"last_update": None, | |
"memory_usage_mb": 0 | |
} | |
def add_documents(self, chunks: List[DocumentChunk], embeddings: np.ndarray) -> List[str]: | |
""" | |
Add document chunks with their embeddings to the vector store. | |
Args: | |
chunks: List of document chunks | |
embeddings: Array of embeddings corresponding to chunks | |
Returns: | |
List of vector IDs that were added | |
""" | |
if len(chunks) != len(embeddings): | |
raise ValueError("Number of chunks must match number of embeddings") | |
if embeddings.size == 0: | |
return [] | |
# Set embedding dimension if not set | |
if self.embedding_dim is None: | |
self.embedding_dim = embeddings.shape[1] | |
elif embeddings.shape[1] != self.embedding_dim: | |
raise ValueError(f"Embedding dimension {embeddings.shape[1]} doesn't match expected {self.embedding_dim}") | |
added_ids = [] | |
for chunk, embedding in zip(chunks, embeddings): | |
# Create vector entry with content included in metadata | |
metadata_with_content = chunk.metadata.copy() | |
metadata_with_content['content'] = chunk.content # Add content to metadata | |
vector_entry = VectorEntry( | |
id=chunk.chunk_id, | |
vector=embedding.copy(), | |
metadata=metadata_with_content | |
) | |
# Add to store | |
if vector_entry.id in self._id_to_index: | |
# Update existing entry | |
index = self._id_to_index[vector_entry.id] | |
self._vectors[index] = vector_entry | |
else: | |
# Add new entry | |
self._id_to_index[vector_entry.id] = len(self._vectors) | |
self._vectors.append(vector_entry) | |
added_ids.append(vector_entry.id) | |
# Mark matrix as dirty for rebuild | |
self._matrix_dirty = True | |
# Update statistics | |
self._update_stats() | |
return added_ids | |
def search( | |
self, | |
query_embedding: np.ndarray, | |
k: int = 10, | |
metadata_filter: Optional[Dict[str, Any]] = None, | |
similarity_threshold: float = 0.0 | |
) -> List[Tuple[str, float, Dict[str, Any]]]: | |
""" | |
Search for similar vectors. | |
Args: | |
query_embedding: Query vector | |
k: Number of results to return | |
metadata_filter: Optional metadata filter | |
similarity_threshold: Minimum similarity score | |
Returns: | |
List of (vector_id, similarity_score, metadata) tuples | |
""" | |
start_time = time.time() | |
if not self._vectors: | |
return [] | |
# Ensure vector matrix is built | |
self._build_vector_matrix() | |
# Normalize query vector | |
query_norm = query_embedding / np.linalg.norm(query_embedding) | |
# Compute similarities | |
similarities = np.dot(self._vector_matrix, query_norm) | |
# Apply similarity threshold | |
valid_indices = np.where(similarities >= similarity_threshold)[0] | |
if len(valid_indices) == 0: | |
return [] | |
# Get top k candidates (before metadata filtering) | |
candidate_k = min(len(valid_indices), k * 3) # Get more candidates for filtering | |
top_candidate_indices = valid_indices[np.argpartition(similarities[valid_indices], -candidate_k)[-candidate_k:]] | |
top_candidate_indices = top_candidate_indices[np.argsort(similarities[top_candidate_indices])[::-1]] | |
# Apply metadata filtering and collect results | |
results = [] | |
for idx in top_candidate_indices: | |
if len(results) >= k: | |
break | |
vector_entry = self._vectors[idx] | |
# Apply metadata filter | |
if metadata_filter and not self._matches_filter(vector_entry.metadata, metadata_filter): | |
continue | |
results.append(( | |
vector_entry.id, | |
float(similarities[idx]), | |
vector_entry.metadata.copy() | |
)) | |
# Update statistics | |
search_time = time.time() - start_time | |
self.stats["searches_performed"] += 1 | |
self.stats["total_search_time"] += search_time | |
return results | |
def _build_vector_matrix(self) -> None: | |
"""Build or rebuild the vector matrix for efficient search.""" | |
if not self._matrix_dirty: | |
return | |
if not self._vectors: | |
self._vector_matrix = None | |
return | |
# Stack all vectors into a matrix | |
vectors = [entry.vector for entry in self._vectors] | |
self._vector_matrix = np.vstack(vectors) | |
# Normalize for cosine similarity | |
norms = np.linalg.norm(self._vector_matrix, axis=1, keepdims=True) | |
norms[norms == 0] = 1 # Avoid division by zero | |
self._vector_matrix = self._vector_matrix / norms | |
self._matrix_dirty = False | |
def _matches_filter(self, metadata: Dict[str, Any], filter_dict: Dict[str, Any]) -> bool: | |
"""Check if metadata matches the filter.""" | |
for key, value in filter_dict.items(): | |
if key not in metadata: | |
return False | |
metadata_value = metadata[key] | |
if isinstance(value, dict): | |
# Support for range filters, etc. | |
if "$gte" in value and metadata_value < value["$gte"]: | |
return False | |
if "$lte" in value and metadata_value > value["$lte"]: | |
return False | |
if "$in" in value and metadata_value not in value["$in"]: | |
return False | |
elif isinstance(value, list): | |
if metadata_value not in value: | |
return False | |
else: | |
if metadata_value != value: | |
return False | |
return True | |
def get_by_id(self, vector_id: str) -> Optional[Tuple[np.ndarray, Dict[str, Any]]]: | |
"""Get vector and metadata by ID.""" | |
if vector_id not in self._id_to_index: | |
return None | |
index = self._id_to_index[vector_id] | |
entry = self._vectors[index] | |
return entry.vector.copy(), entry.metadata.copy() | |
def delete_by_id(self, vector_id: str) -> bool: | |
"""Delete vector by ID.""" | |
if vector_id not in self._id_to_index: | |
return False | |
index = self._id_to_index[vector_id] | |
# Remove from vectors list | |
del self._vectors[index] | |
# Update index mapping | |
del self._id_to_index[vector_id] | |
for vid, idx in self._id_to_index.items(): | |
if idx > index: | |
self._id_to_index[vid] = idx - 1 | |
# Mark matrix as dirty | |
self._matrix_dirty = True | |
# Update statistics | |
self._update_stats() | |
return True | |
def delete_by_metadata(self, metadata_filter: Dict[str, Any]) -> int: | |
"""Delete vectors matching metadata filter.""" | |
to_delete = [] | |
for entry in self._vectors: | |
if self._matches_filter(entry.metadata, metadata_filter): | |
to_delete.append(entry.id) | |
deleted_count = 0 | |
for vector_id in to_delete: | |
if self.delete_by_id(vector_id): | |
deleted_count += 1 | |
return deleted_count | |
def clear(self) -> None: | |
"""Clear all vectors from the store.""" | |
self._vectors.clear() | |
self._id_to_index.clear() | |
self._vector_matrix = None | |
self._matrix_dirty = True | |
self._update_stats() | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get vector store statistics.""" | |
stats = self.stats.copy() | |
if stats["searches_performed"] > 0: | |
stats["avg_search_time"] = stats["total_search_time"] / stats["searches_performed"] | |
else: | |
stats["avg_search_time"] = 0 | |
# Memory usage estimation | |
memory_usage = 0 | |
if self._vector_matrix is not None: | |
memory_usage += self._vector_matrix.nbytes | |
for entry in self._vectors: | |
memory_usage += entry.vector.nbytes | |
memory_usage += len(str(entry.metadata)) * 4 # Rough estimate | |
stats["memory_usage_mb"] = memory_usage / (1024 * 1024) | |
stats["embedding_dimension"] = self.embedding_dim | |
return stats | |
def _update_stats(self) -> None: | |
"""Update internal statistics.""" | |
self.stats["total_vectors"] = len(self._vectors) | |
self.stats["last_update"] = time.time() | |
def save_to_disk(self, filepath: Optional[str] = None) -> str: | |
"""Save vector store to disk.""" | |
if filepath is None: | |
self.cache_dir.mkdir(parents=True, exist_ok=True) | |
filepath = str(self.cache_dir / "vector_store.pkl") | |
# Prepare data for serialization | |
data = { | |
"embedding_dim": self.embedding_dim, | |
"vectors": [], | |
"stats": self.stats | |
} | |
for entry in self._vectors: | |
data["vectors"].append({ | |
"id": entry.id, | |
"vector": entry.vector, | |
"metadata": entry.metadata, | |
"timestamp": entry.timestamp | |
}) | |
try: | |
with open(filepath, "wb") as f: | |
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) | |
print(f"Vector store saved to {filepath}") | |
return filepath | |
except Exception as e: | |
raise ResourceError(f"Failed to save vector store: {str(e)}") from e | |
def load_from_disk(self, filepath: str) -> None: | |
"""Load vector store from disk.""" | |
try: | |
with open(filepath, "rb") as f: | |
data = pickle.load(f) | |
# Clear current data | |
self.clear() | |
# Restore data | |
self.embedding_dim = data.get("embedding_dim") | |
self.stats = data.get("stats", {}) | |
for vector_data in data.get("vectors", []): | |
entry = VectorEntry( | |
id=vector_data["id"], | |
vector=vector_data["vector"], | |
metadata=vector_data["metadata"], | |
timestamp=vector_data.get("timestamp", time.time()) | |
) | |
self._id_to_index[entry.id] = len(self._vectors) | |
self._vectors.append(entry) | |
# Mark matrix as dirty for rebuild | |
self._matrix_dirty = True | |
print(f"Vector store loaded from {filepath} with {len(self._vectors)} vectors") | |
except Exception as e: | |
raise ResourceError(f"Failed to load vector store: {str(e)}") from e | |
def get_document_chunks(self, source_filter: Optional[str] = None) -> List[Dict[str, Any]]: | |
"""Get all document chunks, optionally filtered by source.""" | |
chunks = [] | |
for entry in self._vectors: | |
if source_filter is None or entry.metadata.get("source") == source_filter: | |
chunks.append({ | |
"id": entry.id, | |
"content": entry.metadata.get("content", ""), | |
"metadata": entry.metadata | |
}) | |
return chunks | |
def optimize(self) -> Dict[str, Any]: | |
"""Optimize the vector store.""" | |
start_time = time.time() | |
# Rebuild vector matrix | |
self._build_vector_matrix() | |
# Could add more optimizations like: | |
# - Removing duplicate vectors | |
# - Compacting memory layout | |
# - Building additional indexes | |
optimization_time = time.time() - start_time | |
return { | |
"optimization_time": optimization_time, | |
"total_vectors": len(self._vectors), | |
"matrix_rebuilt": True | |
} |