File size: 4,942 Bytes
fcb8b13
 
be398ac
 
fcb8b13
 
be398ac
 
fcb8b13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be398ac
fcb8b13
 
 
 
be398ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os

from utils.logger import logger
from config.settings import settings
from uuid import uuid4

from typing import List, Tuple, Dict, Any
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, PointStruct, UpdateStatus

class VectorDBManager:
    def __init__(self, collection_name: str, embedding_dim: int, client: QdrantClient = None):
        logger.info(f"Initializing Qdrant VectorDBManager for collection: '{collection_name}'")
        
        if client:
            self.client = client
            logger.info("Using shared Qdrant client instance.")
        else:
            logger.warning("No shared Qdrant client provided. Creating a new local instance.")
            qdrant_db_path = os.path.join(settings.DATA_DIR, "qdrant_data")
            self.client = QdrantClient(path=qdrant_db_path)
        
        self.collection_name = collection_name
        self.embedding_dim = embedding_dim
        
        self.create_collection_if_not_exists()
        
    def create_collection_if_not_exists(self):
        try:
            collections = self.client.get_collections().collections
            collection_names = [collection.name for collection in collections]
            
            if self.collection_name not in collection_names:
                logger.info(f"Collection '{self.collection_name}' not found. Creating a new one...")
                
                self.client.recreate_collection(
                    collection_name=self.collection_name,
                    vectors_config=VectorParams(
                        size=self.embedding_dim,
                        distance=Distance.COSINE
                    )
                )
                logger.success(f"Collection '{self.collection_name}' created successfully.")
            else:
                logger.info(f"Collection '{self.collection_name}' already exists.")
                
        except Exception as e:
            logger.error(f"Error checking or creating collection '{self.collection_name}': {e}")
            raise
        
    def add_vectors(self, embeddings: List[List[float]], metadatas: List[Dict[str, Any]]):
        if not embeddings:
            logger.warning("No embeddings to add. Skipping.")
            return
        
        if len(embeddings) != len(metadatas):
            logger.error("Number of embeddings and metadatas must match.")
            raise ValueError("Embeddings and metadatas count mismatch.")
        
        points_to_add = []
        for i, (embedding, metadata) in enumerate(zip(embeddings, metadatas)):
            point_id = str(uuid4())
                
            points_to_add.append(
                PointStruct(
                    id=point_id,
                    vector=embedding,
                    payload=metadata
                )
            )
            
        try:
            operation_info = self.client.upsert(
                collection_name=self.collection_name,
                wait=True,
                points=points_to_add
            )
            if operation_info.status == UpdateStatus.COMPLETED:
                logger.debug(f"Successfully upserted {len(points_to_add)} points to collection '{self.collection_name}'.")
            else:
                logger.warning(f"Upsert operation finished with status: {operation_info.status}")
        except Exception as e:
            logger.error(f"Error upserting points to collection '{self.collection_name}': {e}")
            
    def search_vectors(self, query_embedding: List[float], k: int = 5, filter_payload: Dict = None) -> List[Tuple[float, Dict[str, Any]]]:
        try:
            search_results = self.client.search(
                collection_name=self.collection_name,
                query_vector=query_embedding,
                query_filter=filter_payload,
                limit=k,
                with_payload=True, # include payload in return
                with_vectors=False # exclude vectors in return
            )
            
            formatted_results = []
            for scored_point in search_results:
                score = scored_point.score
                payload = scored_point.payload
                formatted_results.append((score, payload))
            
            logger.debug(f"Searched for top {k} neighbors. Found {len(formatted_results)} results.")
            return formatted_results
        except Exception as e:
            logger.error(f"Error searching in collection '{self.collection_name}': {e}")
            return []
        
    def get_total_vectors(self) -> int:
        try:
            count_result = self.client.count(
                collection_name=self.collection_name,
                exact=True
            )
            return count_result.count
        except Exception as e:
            logger.error(f"Error counting vectors in collection '{self.collection_name}': {e}")
            return 0