Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,390 Bytes
08fac87 09d7140 142bd00 08fac87 142bd00 08fac87 f027363 08fac87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from __future__ import annotations
import json
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import Field
DEFAULT_PERSISTENCE_DIRECTORY = "./bm25s_index"
CORPUS_PERSISTENCE_FILE = "corpus.jsonl"
class BM25SRetriever(BaseRetriever):
"""`BM25` retriever with `bm25s` backend"""
vectorizer: Any
""" BM25S vectorizer."""
docs: List[Document] = Field(repr=False)
"""List of documents to retrieve from."""
k: int = 4
"""Number of top results to return"""
activate_numba: bool = False
"""Accelerate backend"""
class Config:
arbitrary_types_allowed = True
@classmethod
def from_texts(
cls,
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
bm25_params: Optional[Dict[str, Any]] = None,
stopwords: Union[str, List[str]] = "en",
stemmer: Optional[Callable[[List[str]], List[str]]] = None,
persist_directory: Optional[str] = None,
**kwargs: Any,
) -> BM25SRetriever:
"""
Create a BM25Retriever from a list of texts.
Args:
texts:
A list of texts to vectorize.
metadatas:
A list of metadata dicts to associate with each text.
bm25_params:
Parameters to pass to the BM25s vectorizer.
stopwords:
The list of stopwords to remove from the text. Defaults to "en".
stemmer:
The stemmer to use for stemming the tokens. It is recommended to
use the PyStemmer library for stemming, but you can also any
callable that takes a list of strings and returns a list of strings.
persist_directory:
The directory to save the BM25 index to.
**kwargs: Any other arguments to pass to the retriever.
Returns:
A BM25SRetriever instance.
"""
try:
from bm25s import BM25
from bm25s import tokenize as bm25s_tokenize
except ImportError:
raise ImportError(
"Could not import bm25s, please install with `pip install " "bm25s`."
)
bm25_params = bm25_params or {}
texts_processed = bm25s_tokenize(
texts=texts,
stopwords=stopwords,
stemmer=stemmer,
return_ids=False,
show_progress=False,
)
vectorizer = BM25(**bm25_params)
vectorizer.index(texts_processed)
metadatas = metadatas or ({} for _ in texts)
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
persist_directory = persist_directory or DEFAULT_PERSISTENCE_DIRECTORY
# persist the vectorizer
vectorizer.save(persist_directory)
# additionally persist the corpus and the metadata
with open(f"{persist_directory}/{CORPUS_PERSISTENCE_FILE}", "w") as f:
for i, d in enumerate(docs):
entry = {"id": i, "text": d.page_content, "metadata": d.metadata}
doc_str = json.dumps(entry)
f.write(doc_str + "\n")
return cls(vectorizer=vectorizer, docs=docs, **kwargs)
@classmethod
def from_documents(
cls,
documents: Iterable[Document],
*,
bm25_params: Optional[Dict[str, Any]] = None,
stopwords: Union[str, List[str]] = "en",
stemmer: Optional[Callable[[List[str]], List[str]]] = None,
persist_directory: Optional[str] = None,
**kwargs: Any,
) -> BM25SRetriever:
"""
Create a BM25Retriever from a list of Documents.
Args:
documents:
A list of Documents to vectorize.
bm25_params:
Parameters to pass to the BM25 vectorizer.
stopwords:
The list of stopwords to remove from the text. Defaults to "en".
stemmer:
The stemmer to use for stemming the tokens. It is recommended to
use the PyStemmer library for stemming, but you can also any
callable that takes a list of strings and returns a list of strings.
persist_directory:
The directory to save the BM25 index to.
**kwargs: Any other arguments to pass to the retriever.
Returns:
A BM25Retriever instance.
"""
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(
texts=texts,
metadatas=metadatas,
bm25_params=bm25_params,
stopwords=stopwords,
stemmer=stemmer,
persist_directory=persist_directory,
**kwargs,
)
@classmethod
def from_persisted_directory(cls, path: str, **kwargs: Any) -> BM25SRetriever:
from bm25s import BM25
vectorizer = BM25.load(path)
with open(f"{path}/{CORPUS_PERSISTENCE_FILE}", "r") as f:
corpus = [json.loads(line) for line in f]
docs = [
Document(page_content=d["text"], metadata=d["metadata"]) for d in corpus
]
return cls(vectorizer=vectorizer, docs=docs, **kwargs)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
# from bm25s import tokenize as bm25s_tokenize
from mods.bm25s_tokenization import tokenize as bm25s_tokenize
processed_query = bm25s_tokenize(query, return_ids=False)
if self.activate_numba:
self.vectorizer.activate_numba_scorer()
return_docs = self.vectorizer.retrieve(
processed_query,
k=self.k,
backend_selection="numba",
show_progress=False,
)
return [self.docs[i] for i in return_docs.documents[0]]
else:
return_docs, scores = self.vectorizer.retrieve(
processed_query, self.docs, k=self.k, show_progress=False
)
return [return_docs[0, i] for i in range(return_docs.shape[1])]
|