Spaces:
Sleeping
Sleeping
from qdrant_client import QdrantClient # main component to provide the access to db | |
from qdrant_client.http.models import ScoredPoint | |
from qdrant_client.models import VectorParams, Distance, \ | |
PointStruct # VectorParams -> config of vectors that will be used as primary keys | |
from app.models import Embedder # Distance -> defines the metric | |
from app.chunks import Chunk # PointStruct -> instance that will be stored in db | |
import numpy as np | |
from uuid import UUID | |
from app.settings import qdrant_client_config, max_delta | |
import time | |
# TODO: for now all documents are saved to one db, but what if user wants to get references from his own documents, so temp storage is needed | |
class VectorDatabase: | |
def __init__(self, embedder: Embedder, host: str = "qdrant", port: int = 6333): | |
self.host: str = host | |
self.client: QdrantClient = self._initialize_qdrant_client() | |
self.collection_name: str = "document_chunks" | |
self.embedder: Embedder = embedder # embedder is used to convert a user's query | |
self.already_stored: np.array[np.array] = np.array([]).reshape(0, embedder.get_vector_dimensionality()) # should be already normalized | |
if not self._check_collection_exists(): | |
self._create_collection() | |
def store(self, chunks: list[Chunk], batch_size: int = 1000) -> None: | |
points: list[PointStruct] = [] | |
vectors = self.embedder.encode([chunk.get_raw_text() for chunk in chunks]) | |
for vector, chunk in zip(vectors, chunks): | |
if self.accept_vector(vector): | |
points.append(PointStruct( | |
id=str(chunk.id), | |
vector=vector, | |
payload={"metadata": chunk.get_metadata(), "text": chunk.get_raw_text()} | |
)) | |
if len(points): | |
for group in range(0, len(points), batch_size): | |
self.client.upsert( | |
collection_name=self.collection_name, | |
points=points[group : group + batch_size], | |
wait=False, | |
) | |
''' | |
Measures a cosine of angle between tow vectors | |
''' | |
def cosine_similarity(self, vec1, vec2): | |
return vec1 @ vec2 / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) | |
''' | |
Defines weather the vector should be stored in the db by searching for the most | |
similar one | |
''' | |
def accept_vector(self, vector: np.array) -> bool: | |
most_similar = self.client.query_points( | |
collection_name=self.collection_name, | |
query=vector, | |
limit=1, | |
with_vectors=True | |
).points | |
if not len(most_similar): | |
return True | |
else: | |
most_similar = most_similar[0] | |
if 1 - self.cosine_similarity(vector, most_similar.vector) < max_delta: | |
return False | |
return True | |
''' | |
According to tests, re-ranker needs ~7-10 chunks to generate the most accurate hit | |
TODO: implement hybrid search | |
''' | |
def search(self, query: str, top_k: int = 5) -> list[Chunk]: | |
query_embedded: np.ndarray = self.embedder.encode(query) | |
points: list[ScoredPoint] = self.client.query_points( | |
collection_name=self.collection_name, | |
query=query_embedded, | |
limit=top_k | |
).points | |
return [ | |
Chunk( | |
id=UUID(point.payload.get("metadata", {}).get("id", "")), | |
filename=point.payload.get("metadata", {}).get("filename", ""), | |
page_number=point.payload.get("metadata", {}).get("page_number", 0), | |
start_index=point.payload.get("metadata", {}).get("start_index", 0), | |
start_line=point.payload.get("metadata", {}).get("start_line", 0), | |
end_line=point.payload.get("metadata", {}).get("end_line", 0), | |
text=point.payload.get("text", "") | |
) for point in points | |
] | |
def _initialize_qdrant_client(self, max_retries=5, delay=2) -> QdrantClient: | |
for attempt in range(max_retries): | |
try: | |
client = QdrantClient(**qdrant_client_config) | |
client.get_collections() | |
return client | |
except Exception as e: | |
if attempt == max_retries - 1: | |
raise ConnectionError( | |
f"Failed to connect to Qdrant server after {max_retries} attempts. " | |
f"Last error: {str(e)}" | |
) | |
print(f"Connection attempt {attempt + 1} out of {max_retries} failed. " | |
f"Retrying in {delay} seconds...") | |
time.sleep(delay) | |
delay *= 2 | |
def _check_collection_exists(self) -> bool: | |
try: | |
return self.client.collection_exists(self.collection_name) | |
except Exception as e: | |
raise ConnectionError( | |
f"Failed to check collection {self.collection_name} exists. Last error: {str(e)}" | |
) | |
def _create_collection(self) -> None: | |
try: | |
self.client.create_collection( | |
collection_name=self.collection_name, | |
vectors_config=VectorParams( | |
size=self.embedder.get_vector_dimensionality(), | |
distance=Distance.COSINE | |
) | |
) | |
except Exception as e: | |
raise RuntimeError(f"Failed to create collection {self.collection_name}: {str(e)}") | |
def __del__(self): | |
if hasattr(self, "client"): | |
self.client.close() | |