Spaces:
Runtime error
Runtime error
File size: 10,187 Bytes
63deadc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 |
"""Module contains code for a cache backed embedder.
The cache backed embedder is a wrapper around an embedder that caches
embeddings in a key-value store. The cache is used to avoid recomputing
embeddings for the same text.
The text is hashed and the hash is used as the key in the cache.
"""
from __future__ import annotations
import hashlib
import json
import uuid
from functools import partial
from typing import Callable, List, Optional, Sequence, Union, cast
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore, ByteStore
from langchain_core.utils.iter import batch_iterate
from langchain.storage.encoder_backed import EncoderBackedStore
NAMESPACE_UUID = uuid.UUID(int=1985)
def _hash_string_to_uuid(input_string: str) -> uuid.UUID:
"""Hash a string and returns the corresponding UUID."""
hash_value = hashlib.sha1(input_string.encode("utf-8")).hexdigest()
return uuid.uuid5(NAMESPACE_UUID, hash_value)
def _key_encoder(key: str, namespace: str) -> str:
"""Encode a key."""
return namespace + str(_hash_string_to_uuid(key))
def _create_key_encoder(namespace: str) -> Callable[[str], str]:
"""Create an encoder for a key."""
return partial(_key_encoder, namespace=namespace)
def _value_serializer(value: Sequence[float]) -> bytes:
"""Serialize a value."""
return json.dumps(value).encode()
def _value_deserializer(serialized_value: bytes) -> List[float]:
"""Deserialize a value."""
return cast(List[float], json.loads(serialized_value.decode()))
class CacheBackedEmbeddings(Embeddings):
"""Interface for caching results from embedding models.
The interface allows works with any store that implements
the abstract store interface accepting keys of type str and values of list of
floats.
If need be, the interface can be extended to accept other implementations
of the value serializer and deserializer, as well as the key encoder.
Note that by default only document embeddings are cached. To cache query
embeddings too, pass in a query_embedding_store to constructor.
Examples:
.. code-block: python
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.embeddings import OpenAIEmbeddings
store = LocalFileStore('./my_cache')
underlying_embedder = OpenAIEmbeddings()
embedder = CacheBackedEmbeddings.from_bytes_store(
underlying_embedder, store, namespace=underlying_embedder.model
)
# Embedding is computed and cached
embeddings = embedder.embed_documents(["hello", "goodbye"])
# Embeddings are retrieved from the cache, no computation is done
embeddings = embedder.embed_documents(["hello", "goodbye"])
"""
def __init__(
self,
underlying_embeddings: Embeddings,
document_embedding_store: BaseStore[str, List[float]],
*,
batch_size: Optional[int] = None,
query_embedding_store: Optional[BaseStore[str, List[float]]] = None,
) -> None:
"""Initialize the embedder.
Args:
underlying_embeddings: the embedder to use for computing embeddings.
document_embedding_store: The store to use for caching document embeddings.
batch_size: The number of documents to embed between store updates.
query_embedding_store: The store to use for caching query embeddings.
If None, query embeddings are not cached.
"""
super().__init__()
self.document_embedding_store = document_embedding_store
self.query_embedding_store = query_embedding_store
self.underlying_embeddings = underlying_embeddings
self.batch_size = batch_size
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of texts.
The method first checks the cache for the embeddings.
If the embeddings are not found, the method uses the underlying embedder
to embed the documents and stores the results in the cache.
Args:
texts: A list of texts to embed.
Returns:
A list of embeddings for the given texts.
"""
vectors: List[Union[List[float], None]] = self.document_embedding_store.mget(
texts
)
all_missing_indices: List[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
missing_texts = [texts[i] for i in missing_indices]
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
self.document_embedding_store.mset(
list(zip(missing_texts, missing_vectors))
)
for index, updated_vector in zip(missing_indices, missing_vectors):
vectors[index] = updated_vector
return cast(
List[List[float]], vectors
) # Nones should have been resolved by now
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of texts.
The method first checks the cache for the embeddings.
If the embeddings are not found, the method uses the underlying embedder
to embed the documents and stores the results in the cache.
Args:
texts: A list of texts to embed.
Returns:
A list of embeddings for the given texts.
"""
vectors: List[
Union[List[float], None]
] = await self.document_embedding_store.amget(texts)
all_missing_indices: List[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
# batch_iterate supports None batch_size which returns all elements at once
# as a single batch.
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
missing_texts = [texts[i] for i in missing_indices]
missing_vectors = await self.underlying_embeddings.aembed_documents(
missing_texts
)
await self.document_embedding_store.amset(
list(zip(missing_texts, missing_vectors))
)
for index, updated_vector in zip(missing_indices, missing_vectors):
vectors[index] = updated_vector
return cast(
List[List[float]], vectors
) # Nones should have been resolved by now
def embed_query(self, text: str) -> List[float]:
"""Embed query text.
By default, this method does not cache queries. To enable caching, set the
`cache_query` parameter to `True` when initializing the embedder.
Args:
text: The text to embed.
Returns:
The embedding for the given text.
"""
if not self.query_embedding_store:
return self.underlying_embeddings.embed_query(text)
(cached,) = self.query_embedding_store.mget([text])
if cached is not None:
return cached
vector = self.underlying_embeddings.embed_query(text)
self.query_embedding_store.mset([(text, vector)])
return vector
async def aembed_query(self, text: str) -> List[float]:
"""Embed query text.
By default, this method does not cache queries. To enable caching, set the
`cache_query` parameter to `True` when initializing the embedder.
Args:
text: The text to embed.
Returns:
The embedding for the given text.
"""
if not self.query_embedding_store:
return await self.underlying_embeddings.aembed_query(text)
(cached,) = await self.query_embedding_store.amget([text])
if cached is not None:
return cached
vector = await self.underlying_embeddings.aembed_query(text)
await self.query_embedding_store.amset([(text, vector)])
return vector
@classmethod
def from_bytes_store(
cls,
underlying_embeddings: Embeddings,
document_embedding_cache: ByteStore,
*,
namespace: str = "",
batch_size: Optional[int] = None,
query_embedding_cache: Union[bool, ByteStore] = False,
) -> CacheBackedEmbeddings:
"""On-ramp that adds the necessary serialization and encoding to the store.
Args:
underlying_embeddings: The embedder to use for embedding.
document_embedding_cache: The cache to use for storing document embeddings.
*,
namespace: The namespace to use for document cache.
This namespace is used to avoid collisions with other caches.
For example, set it to the name of the embedding model used.
batch_size: The number of documents to embed between store updates.
query_embedding_cache: The cache to use for storing query embeddings.
True to use the same cache as document embeddings.
False to not cache query embeddings.
"""
namespace = namespace
key_encoder = _create_key_encoder(namespace)
document_embedding_store = EncoderBackedStore[str, List[float]](
document_embedding_cache,
key_encoder,
_value_serializer,
_value_deserializer,
)
if query_embedding_cache is True:
query_embedding_store = document_embedding_store
elif query_embedding_cache is False:
query_embedding_store = None
else:
query_embedding_store = EncoderBackedStore[str, List[float]](
query_embedding_cache,
key_encoder,
_value_serializer,
_value_deserializer,
)
return cls(
underlying_embeddings,
document_embedding_store,
batch_size=batch_size,
query_embedding_store=query_embedding_store,
)
|