Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import json | |
from typing import Any, Callable, Dict, Iterable, List, Optional, Union | |
from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
from langchain_core.documents import Document | |
from langchain_core.retrievers import BaseRetriever | |
from pydantic import Field | |
DEFAULT_PERSISTENCE_DIRECTORY = "./bm25s_index" | |
CORPUS_PERSISTENCE_FILE = "corpus.jsonl" | |
class BM25SRetriever(BaseRetriever): | |
"""`BM25` retriever with `bm25s` backend""" | |
vectorizer: Any | |
""" BM25S vectorizer.""" | |
docs: List[Document] = Field(repr=False) | |
"""List of documents to retrieve from.""" | |
k: int = 4 | |
"""Number of top results to return""" | |
activate_numba: bool = False | |
"""Accelerate backend""" | |
class Config: | |
arbitrary_types_allowed = True | |
def from_texts( | |
cls, | |
texts: Iterable[str], | |
metadatas: Optional[Iterable[dict]] = None, | |
bm25_params: Optional[Dict[str, Any]] = None, | |
stopwords: Union[str, List[str]] = "en", | |
stemmer: Optional[Callable[[List[str]], List[str]]] = None, | |
persist_directory: Optional[str] = None, | |
**kwargs: Any, | |
) -> BM25SRetriever: | |
""" | |
Create a BM25Retriever from a list of texts. | |
Args: | |
texts: | |
A list of texts to vectorize. | |
metadatas: | |
A list of metadata dicts to associate with each text. | |
bm25_params: | |
Parameters to pass to the BM25s vectorizer. | |
stopwords: | |
The list of stopwords to remove from the text. Defaults to "en". | |
stemmer: | |
The stemmer to use for stemming the tokens. It is recommended to | |
use the PyStemmer library for stemming, but you can also any | |
callable that takes a list of strings and returns a list of strings. | |
persist_directory: | |
The directory to save the BM25 index to. | |
**kwargs: Any other arguments to pass to the retriever. | |
Returns: | |
A BM25SRetriever instance. | |
""" | |
try: | |
from bm25s import BM25 | |
from bm25s import tokenize as bm25s_tokenize | |
except ImportError: | |
raise ImportError( | |
"Could not import bm25s, please install with `pip install " "bm25s`." | |
) | |
bm25_params = bm25_params or {} | |
texts_processed = bm25s_tokenize( | |
texts=texts, | |
stopwords=stopwords, | |
stemmer=stemmer, | |
return_ids=False, | |
show_progress=False, | |
) | |
vectorizer = BM25(**bm25_params) | |
vectorizer.index(texts_processed) | |
metadatas = metadatas or ({} for _ in texts) | |
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)] | |
persist_directory = persist_directory or DEFAULT_PERSISTENCE_DIRECTORY | |
# persist the vectorizer | |
vectorizer.save(persist_directory) | |
# additionally persist the corpus and the metadata | |
with open(f"{persist_directory}/{CORPUS_PERSISTENCE_FILE}", "w") as f: | |
for i, d in enumerate(docs): | |
entry = {"id": i, "text": d.page_content, "metadata": d.metadata} | |
doc_str = json.dumps(entry) | |
f.write(doc_str + "\n") | |
return cls(vectorizer=vectorizer, docs=docs, **kwargs) | |
def from_documents( | |
cls, | |
documents: Iterable[Document], | |
*, | |
bm25_params: Optional[Dict[str, Any]] = None, | |
stopwords: Union[str, List[str]] = "en", | |
stemmer: Optional[Callable[[List[str]], List[str]]] = None, | |
persist_directory: Optional[str] = None, | |
**kwargs: Any, | |
) -> BM25SRetriever: | |
""" | |
Create a BM25Retriever from a list of Documents. | |
Args: | |
documents: | |
A list of Documents to vectorize. | |
bm25_params: | |
Parameters to pass to the BM25 vectorizer. | |
stopwords: | |
The list of stopwords to remove from the text. Defaults to "en". | |
stemmer: | |
The stemmer to use for stemming the tokens. It is recommended to | |
use the PyStemmer library for stemming, but you can also any | |
callable that takes a list of strings and returns a list of strings. | |
persist_directory: | |
The directory to save the BM25 index to. | |
**kwargs: Any other arguments to pass to the retriever. | |
Returns: | |
A BM25Retriever instance. | |
""" | |
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) | |
return cls.from_texts( | |
texts=texts, | |
metadatas=metadatas, | |
bm25_params=bm25_params, | |
stopwords=stopwords, | |
stemmer=stemmer, | |
persist_directory=persist_directory, | |
**kwargs, | |
) | |
def from_persisted_directory(cls, path: str, **kwargs: Any) -> BM25SRetriever: | |
from bm25s import BM25 | |
vectorizer = BM25.load(path) | |
with open(f"{path}/{CORPUS_PERSISTENCE_FILE}", "r") as f: | |
corpus = [json.loads(line) for line in f] | |
docs = [ | |
Document(page_content=d["text"], metadata=d["metadata"]) for d in corpus | |
] | |
return cls(vectorizer=vectorizer, docs=docs, **kwargs) | |
def _get_relevant_documents( | |
self, | |
query: str, | |
*, | |
run_manager: CallbackManagerForRetrieverRun, | |
) -> List[Document]: | |
# from bm25s import tokenize as bm25s_tokenize | |
from mods.bm25s_tokenization import tokenize as bm25s_tokenize | |
processed_query = bm25s_tokenize(query, return_ids=False) | |
if self.activate_numba: | |
self.vectorizer.activate_numba_scorer() | |
return_docs = self.vectorizer.retrieve( | |
processed_query, | |
k=self.k, | |
backend_selection="numba", | |
show_progress=False, | |
) | |
return [self.docs[i] for i in return_docs.documents[0]] | |
else: | |
return_docs, scores = self.vectorizer.retrieve( | |
processed_query, self.docs, k=self.k, show_progress=False | |
) | |
return [return_docs[0, i] for i in range(return_docs.shape[1])] | |