Spaces:
Running
Running
import os | |
from utils.logger import logger | |
from config.settings import settings | |
from uuid import uuid4 | |
from typing import List, Tuple, Dict, Any | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import Distance, VectorParams, PointStruct, UpdateStatus | |
class VectorDBManager: | |
def __init__(self, collection_name: str, embedding_dim: int, client: QdrantClient = None): | |
logger.info(f"Initializing Qdrant VectorDBManager for collection: '{collection_name}'") | |
if client: | |
self.client = client | |
logger.info("Using shared Qdrant client instance.") | |
else: | |
logger.warning("No shared Qdrant client provided. Creating a new local instance.") | |
qdrant_db_path = os.path.join(settings.DATA_DIR, "qdrant_data") | |
self.client = QdrantClient(path=qdrant_db_path) | |
self.collection_name = collection_name | |
self.embedding_dim = embedding_dim | |
self.create_collection_if_not_exists() | |
def create_collection_if_not_exists(self): | |
try: | |
collections = self.client.get_collections().collections | |
collection_names = [collection.name for collection in collections] | |
if self.collection_name not in collection_names: | |
logger.info(f"Collection '{self.collection_name}' not found. Creating a new one...") | |
self.client.recreate_collection( | |
collection_name=self.collection_name, | |
vectors_config=VectorParams( | |
size=self.embedding_dim, | |
distance=Distance.COSINE | |
) | |
) | |
logger.success(f"Collection '{self.collection_name}' created successfully.") | |
else: | |
logger.info(f"Collection '{self.collection_name}' already exists.") | |
except Exception as e: | |
logger.error(f"Error checking or creating collection '{self.collection_name}': {e}") | |
raise | |
def add_vectors(self, embeddings: List[List[float]], metadatas: List[Dict[str, Any]]): | |
if not embeddings: | |
logger.warning("No embeddings to add. Skipping.") | |
return | |
if len(embeddings) != len(metadatas): | |
logger.error("Number of embeddings and metadatas must match.") | |
raise ValueError("Embeddings and metadatas count mismatch.") | |
points_to_add = [] | |
for i, (embedding, metadata) in enumerate(zip(embeddings, metadatas)): | |
point_id = str(uuid4()) | |
points_to_add.append( | |
PointStruct( | |
id=point_id, | |
vector=embedding, | |
payload=metadata | |
) | |
) | |
try: | |
operation_info = self.client.upsert( | |
collection_name=self.collection_name, | |
wait=True, | |
points=points_to_add | |
) | |
if operation_info.status == UpdateStatus.COMPLETED: | |
logger.debug(f"Successfully upserted {len(points_to_add)} points to collection '{self.collection_name}'.") | |
else: | |
logger.warning(f"Upsert operation finished with status: {operation_info.status}") | |
except Exception as e: | |
logger.error(f"Error upserting points to collection '{self.collection_name}': {e}") | |
def search_vectors(self, query_embedding: List[float], k: int = 5, filter_payload: Dict = None) -> List[Tuple[float, Dict[str, Any]]]: | |
try: | |
search_results = self.client.search( | |
collection_name=self.collection_name, | |
query_vector=query_embedding, | |
query_filter=filter_payload, | |
limit=k, | |
with_payload=True, # include payload in return | |
with_vectors=False # exclude vectors in return | |
) | |
formatted_results = [] | |
for scored_point in search_results: | |
score = scored_point.score | |
payload = scored_point.payload | |
formatted_results.append((score, payload)) | |
logger.debug(f"Searched for top {k} neighbors. Found {len(formatted_results)} results.") | |
return formatted_results | |
except Exception as e: | |
logger.error(f"Error searching in collection '{self.collection_name}': {e}") | |
return [] | |
def get_total_vectors(self) -> int: | |
try: | |
count_result = self.client.count( | |
collection_name=self.collection_name, | |
exact=True | |
) | |
return count_result.count | |
except Exception as e: | |
logger.error(f"Error counting vectors in collection '{self.collection_name}': {e}") | |
return 0 |