|
""" |
|
In-memory vector store with efficient similarity search and metadata filtering. |
|
""" |
|
|
|
import pickle |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
import numpy as np |
|
from dataclasses import dataclass, asdict |
|
import json |
|
import time |
|
|
|
from .error_handler import ResourceError |
|
from .document_processor import DocumentChunk |
|
|
|
|
|
@dataclass |
|
class VectorEntry: |
|
"""Represents a vector entry with metadata.""" |
|
id: str |
|
vector: np.ndarray |
|
metadata: Dict[str, Any] |
|
timestamp: float = None |
|
|
|
def __post_init__(self): |
|
if self.timestamp is None: |
|
self.timestamp = time.time() |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
"""Convert to dictionary (excluding vector for serialization).""" |
|
return { |
|
"id": self.id, |
|
"metadata": self.metadata, |
|
"timestamp": self.timestamp, |
|
"vector_shape": self.vector.shape, |
|
"vector_dtype": str(self.vector.dtype) |
|
} |
|
|
|
|
|
class VectorStore: |
|
"""In-memory vector store with efficient similarity search.""" |
|
|
|
def __init__(self, config: Dict[str, Any], embedding_dim: int = None): |
|
self.config = config |
|
self.embedding_dim = embedding_dim |
|
|
|
|
|
self._vectors: List[VectorEntry] = [] |
|
self._id_to_index: Dict[str, int] = {} |
|
self._vector_matrix: Optional[np.ndarray] = None |
|
self._matrix_dirty = True |
|
|
|
|
|
self.cache_dir = Path(config.get("cache", {}).get("cache_dir", "./cache")) |
|
self.auto_save = config.get("vector_store", {}).get("auto_save", True) |
|
|
|
|
|
self.stats = { |
|
"total_vectors": 0, |
|
"searches_performed": 0, |
|
"total_search_time": 0, |
|
"last_update": None, |
|
"memory_usage_mb": 0 |
|
} |
|
|
|
def add_documents(self, chunks: List[DocumentChunk], embeddings: np.ndarray) -> List[str]: |
|
""" |
|
Add document chunks with their embeddings to the vector store. |
|
|
|
Args: |
|
chunks: List of document chunks |
|
embeddings: Array of embeddings corresponding to chunks |
|
|
|
Returns: |
|
List of vector IDs that were added |
|
""" |
|
if len(chunks) != len(embeddings): |
|
raise ValueError("Number of chunks must match number of embeddings") |
|
|
|
if embeddings.size == 0: |
|
return [] |
|
|
|
|
|
if self.embedding_dim is None: |
|
self.embedding_dim = embeddings.shape[1] |
|
elif embeddings.shape[1] != self.embedding_dim: |
|
raise ValueError(f"Embedding dimension {embeddings.shape[1]} doesn't match expected {self.embedding_dim}") |
|
|
|
added_ids = [] |
|
|
|
for chunk, embedding in zip(chunks, embeddings): |
|
|
|
metadata_with_content = chunk.metadata.copy() |
|
metadata_with_content['content'] = chunk.content |
|
|
|
vector_entry = VectorEntry( |
|
id=chunk.chunk_id, |
|
vector=embedding.copy(), |
|
metadata=metadata_with_content |
|
) |
|
|
|
|
|
if vector_entry.id in self._id_to_index: |
|
|
|
index = self._id_to_index[vector_entry.id] |
|
self._vectors[index] = vector_entry |
|
else: |
|
|
|
self._id_to_index[vector_entry.id] = len(self._vectors) |
|
self._vectors.append(vector_entry) |
|
|
|
added_ids.append(vector_entry.id) |
|
|
|
|
|
self._matrix_dirty = True |
|
|
|
|
|
self._update_stats() |
|
|
|
return added_ids |
|
|
|
def search( |
|
self, |
|
query_embedding: np.ndarray, |
|
k: int = 10, |
|
metadata_filter: Optional[Dict[str, Any]] = None, |
|
similarity_threshold: float = 0.0 |
|
) -> List[Tuple[str, float, Dict[str, Any]]]: |
|
""" |
|
Search for similar vectors. |
|
|
|
Args: |
|
query_embedding: Query vector |
|
k: Number of results to return |
|
metadata_filter: Optional metadata filter |
|
similarity_threshold: Minimum similarity score |
|
|
|
Returns: |
|
List of (vector_id, similarity_score, metadata) tuples |
|
""" |
|
start_time = time.time() |
|
|
|
if not self._vectors: |
|
return [] |
|
|
|
|
|
self._build_vector_matrix() |
|
|
|
|
|
query_norm = query_embedding / np.linalg.norm(query_embedding) |
|
|
|
|
|
similarities = np.dot(self._vector_matrix, query_norm) |
|
|
|
|
|
valid_indices = np.where(similarities >= similarity_threshold)[0] |
|
|
|
if len(valid_indices) == 0: |
|
return [] |
|
|
|
|
|
candidate_k = min(len(valid_indices), k * 3) |
|
top_candidate_indices = valid_indices[np.argpartition(similarities[valid_indices], -candidate_k)[-candidate_k:]] |
|
top_candidate_indices = top_candidate_indices[np.argsort(similarities[top_candidate_indices])[::-1]] |
|
|
|
|
|
results = [] |
|
for idx in top_candidate_indices: |
|
if len(results) >= k: |
|
break |
|
|
|
vector_entry = self._vectors[idx] |
|
|
|
|
|
if metadata_filter and not self._matches_filter(vector_entry.metadata, metadata_filter): |
|
continue |
|
|
|
results.append(( |
|
vector_entry.id, |
|
float(similarities[idx]), |
|
vector_entry.metadata.copy() |
|
)) |
|
|
|
|
|
search_time = time.time() - start_time |
|
self.stats["searches_performed"] += 1 |
|
self.stats["total_search_time"] += search_time |
|
|
|
return results |
|
|
|
def _build_vector_matrix(self) -> None: |
|
"""Build or rebuild the vector matrix for efficient search.""" |
|
if not self._matrix_dirty: |
|
return |
|
|
|
if not self._vectors: |
|
self._vector_matrix = None |
|
return |
|
|
|
|
|
vectors = [entry.vector for entry in self._vectors] |
|
self._vector_matrix = np.vstack(vectors) |
|
|
|
|
|
norms = np.linalg.norm(self._vector_matrix, axis=1, keepdims=True) |
|
norms[norms == 0] = 1 |
|
self._vector_matrix = self._vector_matrix / norms |
|
|
|
self._matrix_dirty = False |
|
|
|
def _matches_filter(self, metadata: Dict[str, Any], filter_dict: Dict[str, Any]) -> bool: |
|
"""Check if metadata matches the filter.""" |
|
for key, value in filter_dict.items(): |
|
if key not in metadata: |
|
return False |
|
|
|
metadata_value = metadata[key] |
|
|
|
if isinstance(value, dict): |
|
|
|
if "$gte" in value and metadata_value < value["$gte"]: |
|
return False |
|
if "$lte" in value and metadata_value > value["$lte"]: |
|
return False |
|
if "$in" in value and metadata_value not in value["$in"]: |
|
return False |
|
elif isinstance(value, list): |
|
if metadata_value not in value: |
|
return False |
|
else: |
|
if metadata_value != value: |
|
return False |
|
|
|
return True |
|
|
|
def get_by_id(self, vector_id: str) -> Optional[Tuple[np.ndarray, Dict[str, Any]]]: |
|
"""Get vector and metadata by ID.""" |
|
if vector_id not in self._id_to_index: |
|
return None |
|
|
|
index = self._id_to_index[vector_id] |
|
entry = self._vectors[index] |
|
return entry.vector.copy(), entry.metadata.copy() |
|
|
|
def delete_by_id(self, vector_id: str) -> bool: |
|
"""Delete vector by ID.""" |
|
if vector_id not in self._id_to_index: |
|
return False |
|
|
|
index = self._id_to_index[vector_id] |
|
|
|
|
|
del self._vectors[index] |
|
|
|
|
|
del self._id_to_index[vector_id] |
|
for vid, idx in self._id_to_index.items(): |
|
if idx > index: |
|
self._id_to_index[vid] = idx - 1 |
|
|
|
|
|
self._matrix_dirty = True |
|
|
|
|
|
self._update_stats() |
|
|
|
return True |
|
|
|
def delete_by_metadata(self, metadata_filter: Dict[str, Any]) -> int: |
|
"""Delete vectors matching metadata filter.""" |
|
to_delete = [] |
|
|
|
for entry in self._vectors: |
|
if self._matches_filter(entry.metadata, metadata_filter): |
|
to_delete.append(entry.id) |
|
|
|
deleted_count = 0 |
|
for vector_id in to_delete: |
|
if self.delete_by_id(vector_id): |
|
deleted_count += 1 |
|
|
|
return deleted_count |
|
|
|
def clear(self) -> None: |
|
"""Clear all vectors from the store.""" |
|
self._vectors.clear() |
|
self._id_to_index.clear() |
|
self._vector_matrix = None |
|
self._matrix_dirty = True |
|
self._update_stats() |
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
"""Get vector store statistics.""" |
|
stats = self.stats.copy() |
|
|
|
if stats["searches_performed"] > 0: |
|
stats["avg_search_time"] = stats["total_search_time"] / stats["searches_performed"] |
|
else: |
|
stats["avg_search_time"] = 0 |
|
|
|
|
|
memory_usage = 0 |
|
if self._vector_matrix is not None: |
|
memory_usage += self._vector_matrix.nbytes |
|
|
|
for entry in self._vectors: |
|
memory_usage += entry.vector.nbytes |
|
memory_usage += len(str(entry.metadata)) * 4 |
|
|
|
stats["memory_usage_mb"] = memory_usage / (1024 * 1024) |
|
stats["embedding_dimension"] = self.embedding_dim |
|
|
|
return stats |
|
|
|
def _update_stats(self) -> None: |
|
"""Update internal statistics.""" |
|
self.stats["total_vectors"] = len(self._vectors) |
|
self.stats["last_update"] = time.time() |
|
|
|
def save_to_disk(self, filepath: Optional[str] = None) -> str: |
|
"""Save vector store to disk.""" |
|
if filepath is None: |
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
filepath = str(self.cache_dir / "vector_store.pkl") |
|
|
|
|
|
data = { |
|
"embedding_dim": self.embedding_dim, |
|
"vectors": [], |
|
"stats": self.stats |
|
} |
|
|
|
for entry in self._vectors: |
|
data["vectors"].append({ |
|
"id": entry.id, |
|
"vector": entry.vector, |
|
"metadata": entry.metadata, |
|
"timestamp": entry.timestamp |
|
}) |
|
|
|
try: |
|
with open(filepath, "wb") as f: |
|
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
print(f"Vector store saved to {filepath}") |
|
return filepath |
|
|
|
except Exception as e: |
|
raise ResourceError(f"Failed to save vector store: {str(e)}") from e |
|
|
|
def load_from_disk(self, filepath: str) -> None: |
|
"""Load vector store from disk.""" |
|
try: |
|
with open(filepath, "rb") as f: |
|
data = pickle.load(f) |
|
|
|
|
|
self.clear() |
|
|
|
|
|
self.embedding_dim = data.get("embedding_dim") |
|
self.stats = data.get("stats", {}) |
|
|
|
for vector_data in data.get("vectors", []): |
|
entry = VectorEntry( |
|
id=vector_data["id"], |
|
vector=vector_data["vector"], |
|
metadata=vector_data["metadata"], |
|
timestamp=vector_data.get("timestamp", time.time()) |
|
) |
|
|
|
self._id_to_index[entry.id] = len(self._vectors) |
|
self._vectors.append(entry) |
|
|
|
|
|
self._matrix_dirty = True |
|
|
|
print(f"Vector store loaded from {filepath} with {len(self._vectors)} vectors") |
|
|
|
except Exception as e: |
|
raise ResourceError(f"Failed to load vector store: {str(e)}") from e |
|
|
|
def get_document_chunks(self, source_filter: Optional[str] = None) -> List[Dict[str, Any]]: |
|
"""Get all document chunks, optionally filtered by source.""" |
|
chunks = [] |
|
|
|
for entry in self._vectors: |
|
if source_filter is None or entry.metadata.get("source") == source_filter: |
|
chunks.append({ |
|
"id": entry.id, |
|
"content": entry.metadata.get("content", ""), |
|
"metadata": entry.metadata |
|
}) |
|
|
|
return chunks |
|
|
|
def optimize(self) -> Dict[str, Any]: |
|
"""Optimize the vector store.""" |
|
start_time = time.time() |
|
|
|
|
|
self._build_vector_matrix() |
|
|
|
|
|
|
|
|
|
|
|
|
|
optimization_time = time.time() - start_time |
|
|
|
return { |
|
"optimization_time": optimization_time, |
|
"total_vectors": len(self._vectors), |
|
"matrix_rebuilt": True |
|
} |