File size: 6,390 Bytes
08fac87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09d7140
142bd00
08fac87
 
 
 
 
142bd00
 
 
 
08fac87
 
 
 
f027363
08fac87
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from __future__ import annotations

import json
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import Field

DEFAULT_PERSISTENCE_DIRECTORY = "./bm25s_index"
CORPUS_PERSISTENCE_FILE = "corpus.jsonl"


class BM25SRetriever(BaseRetriever):
    """`BM25` retriever with `bm25s` backend"""

    vectorizer: Any
    """ BM25S vectorizer."""
    docs: List[Document] = Field(repr=False)
    """List of documents to retrieve from."""
    k: int = 4
    """Number of top results to return"""
    activate_numba: bool = False
    """Accelerate backend"""

    class Config:
        arbitrary_types_allowed = True

    @classmethod
    def from_texts(
        cls,
        texts: Iterable[str],
        metadatas: Optional[Iterable[dict]] = None,
        bm25_params: Optional[Dict[str, Any]] = None,
        stopwords: Union[str, List[str]] = "en",
        stemmer: Optional[Callable[[List[str]], List[str]]] = None,
        persist_directory: Optional[str] = None,
        **kwargs: Any,
    ) -> BM25SRetriever:
        """
        Create a BM25Retriever from a list of texts.
        Args:
            texts:
                A list of texts to vectorize.
            metadatas:
                A list of metadata dicts to associate with each text.
            bm25_params:
                Parameters to pass to the BM25s vectorizer.
            stopwords:
                The list of stopwords to remove from the text. Defaults to "en".
            stemmer:
                The stemmer to use for stemming the tokens. It is recommended to
                use the PyStemmer library for stemming, but you can also any
                callable that takes a list of strings and returns a list of strings.
            persist_directory:
                The directory to save the BM25 index to.
            **kwargs: Any other arguments to pass to the retriever.

        Returns:
            A BM25SRetriever instance.
        """
        try:
            from bm25s import BM25
            from bm25s import tokenize as bm25s_tokenize
        except ImportError:
            raise ImportError(
                "Could not import bm25s, please install with `pip install " "bm25s`."
            )

        bm25_params = bm25_params or {}
        texts_processed = bm25s_tokenize(
            texts=texts,
            stopwords=stopwords,
            stemmer=stemmer,
            return_ids=False,
            show_progress=False,
        )
        vectorizer = BM25(**bm25_params)
        vectorizer.index(texts_processed)

        metadatas = metadatas or ({} for _ in texts)
        docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]

        persist_directory = persist_directory or DEFAULT_PERSISTENCE_DIRECTORY
        # persist the vectorizer
        vectorizer.save(persist_directory)
        # additionally persist the corpus and the metadata
        with open(f"{persist_directory}/{CORPUS_PERSISTENCE_FILE}", "w") as f:
            for i, d in enumerate(docs):
                entry = {"id": i, "text": d.page_content, "metadata": d.metadata}
                doc_str = json.dumps(entry)
                f.write(doc_str + "\n")

        return cls(vectorizer=vectorizer, docs=docs, **kwargs)

    @classmethod
    def from_documents(
        cls,
        documents: Iterable[Document],
        *,
        bm25_params: Optional[Dict[str, Any]] = None,
        stopwords: Union[str, List[str]] = "en",
        stemmer: Optional[Callable[[List[str]], List[str]]] = None,
        persist_directory: Optional[str] = None,
        **kwargs: Any,
    ) -> BM25SRetriever:
        """
        Create a BM25Retriever from a list of Documents.
        Args:
            documents:
                A list of Documents to vectorize.
            bm25_params:
                Parameters to pass to the BM25 vectorizer.
            stopwords:
                The list of stopwords to remove from the text. Defaults to "en".
            stemmer:
                The stemmer to use for stemming the tokens. It is recommended to
                use the PyStemmer library for stemming, but you can also any
                callable that takes a list of strings and returns a list of strings.
            persist_directory:
                The directory to save the BM25 index to.
            **kwargs: Any other arguments to pass to the retriever.

        Returns:
            A BM25Retriever instance.
        """
        texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
        return cls.from_texts(
            texts=texts,
            metadatas=metadatas,
            bm25_params=bm25_params,
            stopwords=stopwords,
            stemmer=stemmer,
            persist_directory=persist_directory,
            **kwargs,
        )

    @classmethod
    def from_persisted_directory(cls, path: str, **kwargs: Any) -> BM25SRetriever:
        from bm25s import BM25

        vectorizer = BM25.load(path)
        with open(f"{path}/{CORPUS_PERSISTENCE_FILE}", "r") as f:
            corpus = [json.loads(line) for line in f]

        docs = [
            Document(page_content=d["text"], metadata=d["metadata"]) for d in corpus
        ]
        return cls(vectorizer=vectorizer, docs=docs, **kwargs)

    def _get_relevant_documents(
        self,
        query: str,
        *,
        run_manager: CallbackManagerForRetrieverRun,
    ) -> List[Document]:
        # from bm25s import tokenize as bm25s_tokenize
        from mods.bm25s_tokenization import tokenize as bm25s_tokenize

        processed_query = bm25s_tokenize(query, return_ids=False)
        if self.activate_numba:
            self.vectorizer.activate_numba_scorer()
            return_docs = self.vectorizer.retrieve(
                processed_query,
                k=self.k,
                backend_selection="numba",
                show_progress=False,
            )
            return [self.docs[i] for i in return_docs.documents[0]]
        else:
            return_docs, scores = self.vectorizer.retrieve(
                processed_query, self.docs, k=self.k, show_progress=False
            )
            return [return_docs[0, i] for i in range(return_docs.shape[1])]