from qdrant_client import QdrantClient # main component to provide the access to db from qdrant_client.http.models import ( ScoredPoint, Filter, FieldCondition, MatchText ) from qdrant_client.models import ( VectorParams, Distance, PointStruct, TextIndexParams, TokenizerType ) # VectorParams -> config of vectors that will be used as primary keys from app.core.models import Embedder # Distance -> defines the metric from app.core.chunks import Chunk # PointStruct -> instance that will be stored in db import numpy as np from uuid import UUID from app.settings import settings import time from fastapi import HTTPException import re class VectorDatabase: def __init__(self, embedder: Embedder, host: str = "qdrant", port: int = 6333): self.host: str = host self.client: QdrantClient = self._initialize_qdrant_client() self.embedder: Embedder = embedder # embedder is used to convert a user's query self.already_stored: np.array[np.array] = np.array([]).reshape( 0, embedder.get_vector_dimensionality() ) def store( self, collection_name: str, chunks: list[Chunk], batch_size: int = 1000 ) -> None: points: list[PointStruct] = [] vectors = self.embedder.encode([chunk.get_raw_text() for chunk in chunks]) for vector, chunk in zip(vectors, chunks): if self.accept_vector(collection_name, vector): points.append( PointStruct( id=str(chunk.id), vector=vector, payload={ "metadata": chunk.get_metadata(), "text": chunk.get_raw_text(), }, ) ) if len(points): for group in range(0, len(points), batch_size): self.client.upsert( collection_name=collection_name, points=points[group : group + batch_size], wait=False, ) """ Measures a cosine of angle between tow vectors """ def cosine_similarity(self, vec1, vec2): vec1_np = np.array(vec1) vec2_np = np.array(vec2) return vec1_np @ vec2_np / (np.linalg.norm(vec1_np) * np.linalg.norm(vec2_np)) """ Defines weather the vector should be stored in the db by searching for the most similar one """ def accept_vector(self, collection_name: str, vector: np.array) -> bool: most_similar = self.client.query_points( collection_name=collection_name, query=vector, limit=1, with_vectors=True ).points if not len(most_similar): return True else: most_similar = most_similar[0] if 1 - self.cosine_similarity(vector, most_similar.vector) < settings.max_delta: return False return True def construct_keywords_list(self, query: str) -> list[FieldCondition]: keywords = re.findall(r'\b[A-Z]{2,}\b', query) filters = [] print(keywords) for word in keywords: if len(word) > 30 or len(word) < 2: continue filters.append(FieldCondition(key="text", match=MatchText(text=word))) return filters """ According to tests, re-ranker needs ~7-10 chunks to generate the most accurate hit TODO: implement hybrid search """ def search(self, collection_name: str, query: str, top_k: int = 5) -> list[Chunk]: query_embedded: np.ndarray = self.embedder.encode(query) if isinstance(query_embedded, list): query_embedded = query_embedded[0] keywords = self.construct_keywords_list(query) dense_result: list[ScoredPoint] = self.client.query_points( collection_name=collection_name, query=query_embedded, limit=int(top_k * 0.7) ).points sparse_result: list[ScoredPoint] = self.client.query_points( collection_name=collection_name, query=query_embedded, limit=int(top_k * 0.3), query_filter=Filter(should=keywords) ).points combined = [*dense_result, *sparse_result] print(len(combined)) return [ Chunk( id=UUID(point.payload.get("metadata", {}).get("id", "")), filename=point.payload.get("metadata", {}).get("filename", ""), page_number=point.payload.get("metadata", {}).get("page_number", 0), start_index=point.payload.get("metadata", {}).get("start_index", 0), start_line=point.payload.get("metadata", {}).get("start_line", 0), end_line=point.payload.get("metadata", {}).get("end_line", 0), text=point.payload.get("text", ""), ) for point in combined ] def _initialize_qdrant_client(self, max_retries=5, delay=2) -> QdrantClient: for attempt in range(max_retries): try: client = QdrantClient(**settings.qdrant.model_dump()) client.get_collections() return client except Exception as e: if attempt == max_retries - 1: raise HTTPException( 500, f"Failed to connect to Qdrant server after {max_retries} attempts. " f"Last error: {str(e)}", ) print( f"Connection attempt {attempt + 1} out of {max_retries} failed. " f"Retrying in {delay} seconds..." ) time.sleep(delay) delay *= 2 def _check_collection_exists(self, collection_name: str) -> bool: try: return self.client.collection_exists(collection_name) except Exception as e: raise HTTPException( 500, f"Failed to check collection {collection_name} exists. Last error: {str(e)}", ) def _create_collection(self, collection_name: str) -> None: try: self.client.create_collection( collection_name=collection_name, vectors_config=VectorParams( size=self.embedder.get_vector_dimensionality(), distance=Distance.COSINE, ), ) self.client.create_payload_index( collection_name=collection_name, field_name="text", field_schema=TextIndexParams( type="text", tokenizer=TokenizerType.WORD, min_token_len=2, max_token_len=30, lowercase=True ) ) except Exception as e: raise HTTPException( 500, f"Failed to create collection {self.collection_name}: {str(e)}" ) def create_collection(self, collection_name: str) -> None: try: if self._check_collection_exists(collection_name): return self._create_collection(collection_name) except Exception as e: print(e) raise HTTPException(500, e) def __del__(self): if hasattr(self, "client"): self.client.close() def get_collections(self) -> list[str]: try: return self.client.get_collections() except Exception as e: print(e) raise HTTPException(500, "Failed to get collection names")