File size: 2,201 Bytes
3301b3c
 
6c61722
3301b3c
6c61722
33f4e34
3301b3c
 
6c61722
3301b3c
6c61722
 
 
 
 
 
c613bb1
6c61722
33f4e34
6c61722
33f4e34
6c61722
 
 
 
 
33f4e34
3301b3c
33f4e34
6c61722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3301b3c
6c61722
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
from typing import List, Dict, Any
import numpy as np

from src import RetrieverConfig, logger, get_chroma_client, get_embedder

class Retriever:
    """
    Retrieves documents from a ChromaDB collection.
    """
    def __init__(self, collection_name: str, config: RetrieverConfig):
        self.collection_name = collection_name
        self.config = config
        self.client = get_chroma_client()
        self.embedder = get_embedder()
        self.collection = self.client.get_or_create_collection(name=self.collection_name)

    def retrieve(self, query: str, top_k: int = None) -> List[Dict[str, Any]]:
        """
        Embeds a query and retrieves the top_k most similar documents from ChromaDB.
        """
        if top_k is None:
            top_k = self.config.TOP_K
        
        if self.collection.count() == 0:
            logger.warning(f"Chroma collection '{self.collection_name}' is empty. Cannot retrieve.")
            return []

        try:
            # 1. Embed the query
            query_embedding = self.embedder.embed([query])[0]
            
            # 2. Query ChromaDB
            results = self.collection.query(
                query_embeddings=[query_embedding],
                n_results=top_k,
                include=["metadatas", "documents"] 
            )
            
            # 3. Format results into chunks
            # Chroma returns lists of lists, so we access the first element.
            if not results or not results.get('ids', [[]])[0]:
                return []

            ids = results['ids'][0]
            documents = results['documents'][0]
            metadatas = results['metadatas'][0]
            
            retrieved_chunks = []
            for i, doc_id in enumerate(ids):
                chunk = {
                    'id': doc_id,
                    'narration': documents[i],
                    **metadatas[i]  # Add all other metadata from Chroma
                }
                retrieved_chunks.append(chunk)

            return retrieved_chunks

        except Exception as e:
            logger.error(f"ChromaDB retrieval failed for collection '{self.collection_name}': {e}")
            return []