# core/retrieval/retriever.py import os from utils.logger import logger from config.settings import settings from typing import List, Dict, Any, Union from qdrant_client import QdrantClient from core.embeddings.text_embedding_model import TextEmbeddingModel from core.embeddings.image_embedding_model import ImageEmbeddingModel from core.embeddings.audio_embedding_model import AudioEmbeddingModel from core.retrieval.vector_db_manager import VectorDBManager class Retriever: def __init__(self, client: QdrantClient): logger.info("Initializing the Retriever...") # Initialize embedding models self.text_embedder = TextEmbeddingModel() self.image_embedder = ImageEmbeddingModel() self.audio_embedder = AudioEmbeddingModel() logger.info("Embedding models initialized.") qdrant_db_path = os.path.join(settings.DATA_DIR, "qdrant_data") self.client = client logger.info(f"Single Qdrant client initialized, connected to: {qdrant_db_path}") # Initialize vector database text_dim = self.text_embedder.model.get_sentence_embedding_dimension() self.text_db_manager = VectorDBManager(collection_name="text_collection", embedding_dim=text_dim, client=self.client) image_dim = 512 self.image_db_manager = VectorDBManager(collection_name="image_collection", embedding_dim=image_dim, client=self.client) audio_dim = 512 self.audio_db_manager = VectorDBManager(collection_name="audio_collection", embedding_dim=audio_dim, client=self.client) logger.info("VectorDB Managers connected to Qdrant collections.") logger.info(f"Text collection ('{self.text_db_manager.collection_name}') contains {self.text_db_manager.get_total_vectors()} vectors.") logger.info(f"Image collection ('{self.image_db_manager.collection_name}') contains {self.image_db_manager.get_total_vectors()} vectors.") logger.info(f"Audio collection ('{self.audio_db_manager.collection_name}') contains {self.audio_db_manager.get_total_vectors()} vectors.") def retrieve(self, query: Union[str, bytes], query_type: str, top_k: int = 5) -> List[Dict[str, Any]]: logger.info(f"Received retrieval request. Query type: '{query_type}', Top K: {top_k}") embedding = None db_manager_to_use = None # create embeddings try: if query_type == "text": if not isinstance(query, str): raise TypeError("Text query must be a string.") embedding = self.text_embedder.get_embeddings(query) db_manager_to_use = self.text_db_manager elif query_type == "image": if not isinstance(query, str) or not os.path.exists(query): raise TypeError("Image query must be a valid file path.") embedding = self.image_embedder.get_embeddings([query])[0] db_manager_to_use = self.image_db_manager elif query_type == "audio": if not isinstance(query, str) or not os.path.exists(query): raise TypeError("Audio query must be a valid file path.") embedding = self.audio_embedder.get_embeddings([query])[0] db_manager_to_use = self.audio_db_manager else: logger.error(f"Unsupported query type: {query_type}") return [] except Exception as e: logger.error(f"Error generating embedding for query: {e}") return [] if embedding is None: logger.warning("Could not generate embedding for the query.") return [] # searching vectors try: search_results = db_manager_to_use.search_vectors(embedding, k=top_k) except Exception as e: logger.error(f"Error searching in vector database: {e}") return [] formatted_results = [] for score, payload in search_results: formatted_results.append({ "score": score, "metadata": payload['metadata'], "content": payload['content'] }) logger.info(f"Retrieval complete. Found {len(formatted_results)} results.") return formatted_results def is_database_empty(self) -> bool: total_vectors = self.text_db_manager.get_total_vectors() \ + self.image_db_manager.get_total_vectors() \ + self.audio_db_manager.get_total_vectors() return total_vectors == 0