|
import os |
|
from typing import List, Dict, Any |
|
import numpy as np |
|
|
|
from src import RetrieverConfig, logger, get_chroma_client, get_embedder |
|
|
|
class Retriever: |
|
""" |
|
Retrieves documents from a ChromaDB collection. |
|
""" |
|
def __init__(self, collection_name: str, config: RetrieverConfig): |
|
self.collection_name = collection_name |
|
self.config = config |
|
self.client = get_chroma_client() |
|
self.embedder = get_embedder() |
|
self.collection = self.client.get_or_create_collection(name=self.collection_name) |
|
|
|
def retrieve(self, query: str, top_k: int = None) -> List[Dict[str, Any]]: |
|
""" |
|
Embeds a query and retrieves the top_k most similar documents from ChromaDB. |
|
""" |
|
if top_k is None: |
|
top_k = self.config.TOP_K |
|
|
|
if self.collection.count() == 0: |
|
logger.warning(f"Chroma collection '{self.collection_name}' is empty. Cannot retrieve.") |
|
return [] |
|
|
|
try: |
|
|
|
query_embedding = self.embedder.embed([query])[0] |
|
|
|
|
|
results = self.collection.query( |
|
query_embeddings=[query_embedding], |
|
n_results=top_k, |
|
include=["metadatas", "documents"] |
|
) |
|
|
|
|
|
|
|
if not results or not results.get('ids', [[]])[0]: |
|
return [] |
|
|
|
ids = results['ids'][0] |
|
documents = results['documents'][0] |
|
metadatas = results['metadatas'][0] |
|
|
|
retrieved_chunks = [] |
|
for i, doc_id in enumerate(ids): |
|
chunk = { |
|
'id': doc_id, |
|
'narration': documents[i], |
|
**metadatas[i] |
|
} |
|
retrieved_chunks.append(chunk) |
|
|
|
return retrieved_chunks |
|
|
|
except Exception as e: |
|
logger.error(f"ChromaDB retrieval failed for collection '{self.collection_name}': {e}") |
|
return [] |