Spaces:
Running
Running
# core/retrieval/retriever.py | |
import os | |
from utils.logger import logger | |
from config.settings import settings | |
from typing import List, Dict, Any, Union | |
from qdrant_client import QdrantClient | |
from core.embeddings.text_embedding_model import TextEmbeddingModel | |
from core.embeddings.image_embedding_model import ImageEmbeddingModel | |
from core.embeddings.audio_embedding_model import AudioEmbeddingModel | |
from core.retrieval.vector_db_manager import VectorDBManager | |
class Retriever: | |
def __init__(self, client: QdrantClient): | |
logger.info("Initializing the Retriever...") | |
# Initialize embedding models | |
self.text_embedder = TextEmbeddingModel() | |
self.image_embedder = ImageEmbeddingModel() | |
self.audio_embedder = AudioEmbeddingModel() | |
logger.info("Embedding models initialized.") | |
qdrant_db_path = os.path.join(settings.DATA_DIR, "qdrant_data") | |
self.client = client | |
logger.info(f"Single Qdrant client initialized, connected to: {qdrant_db_path}") | |
# Initialize vector database | |
text_dim = self.text_embedder.model.get_sentence_embedding_dimension() | |
self.text_db_manager = VectorDBManager(collection_name="text_collection", embedding_dim=text_dim, client=self.client) | |
image_dim = 512 | |
self.image_db_manager = VectorDBManager(collection_name="image_collection", embedding_dim=image_dim, client=self.client) | |
audio_dim = 512 | |
self.audio_db_manager = VectorDBManager(collection_name="audio_collection", embedding_dim=audio_dim, client=self.client) | |
logger.info("VectorDB Managers connected to Qdrant collections.") | |
logger.info(f"Text collection ('{self.text_db_manager.collection_name}') contains {self.text_db_manager.get_total_vectors()} vectors.") | |
logger.info(f"Image collection ('{self.image_db_manager.collection_name}') contains {self.image_db_manager.get_total_vectors()} vectors.") | |
logger.info(f"Audio collection ('{self.audio_db_manager.collection_name}') contains {self.audio_db_manager.get_total_vectors()} vectors.") | |
def retrieve(self, query: Union[str, bytes], query_type: str, top_k: int = 5) -> List[Dict[str, Any]]: | |
logger.info(f"Received retrieval request. Query type: '{query_type}', Top K: {top_k}") | |
embedding = None | |
db_manager_to_use = None | |
# create embeddings | |
try: | |
if query_type == "text": | |
if not isinstance(query, str): | |
raise TypeError("Text query must be a string.") | |
embedding = self.text_embedder.get_embeddings(query) | |
db_manager_to_use = self.text_db_manager | |
elif query_type == "image": | |
if not isinstance(query, str) or not os.path.exists(query): | |
raise TypeError("Image query must be a valid file path.") | |
embedding = self.image_embedder.get_embeddings([query])[0] | |
db_manager_to_use = self.image_db_manager | |
elif query_type == "audio": | |
if not isinstance(query, str) or not os.path.exists(query): | |
raise TypeError("Audio query must be a valid file path.") | |
embedding = self.audio_embedder.get_embeddings([query])[0] | |
db_manager_to_use = self.audio_db_manager | |
else: | |
logger.error(f"Unsupported query type: {query_type}") | |
return [] | |
except Exception as e: | |
logger.error(f"Error generating embedding for query: {e}") | |
return [] | |
if embedding is None: | |
logger.warning("Could not generate embedding for the query.") | |
return [] | |
# searching vectors | |
try: | |
search_results = db_manager_to_use.search_vectors(embedding, k=top_k) | |
except Exception as e: | |
logger.error(f"Error searching in vector database: {e}") | |
return [] | |
formatted_results = [] | |
for score, payload in search_results: | |
formatted_results.append({ | |
"score": score, | |
"metadata": payload['metadata'], | |
"content": payload['content'] | |
}) | |
logger.info(f"Retrieval complete. Found {len(formatted_results)} results.") | |
return formatted_results | |
def is_database_empty(self) -> bool: | |
total_vectors = self.text_db_manager.get_total_vectors() \ | |
+ self.image_db_manager.get_total_vectors() \ | |
+ self.audio_db_manager.get_total_vectors() | |
return total_vectors == 0 |