File size: 10,187 Bytes
63deadc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
"""Module contains code for a cache backed embedder.

The cache backed embedder is a wrapper around an embedder that caches
embeddings in a key-value store. The cache is used to avoid recomputing
embeddings for the same text.

The text is hashed and the hash is used as the key in the cache.
"""

from __future__ import annotations

import hashlib
import json
import uuid
from functools import partial
from typing import Callable, List, Optional, Sequence, Union, cast

from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore, ByteStore
from langchain_core.utils.iter import batch_iterate

from langchain.storage.encoder_backed import EncoderBackedStore

NAMESPACE_UUID = uuid.UUID(int=1985)


def _hash_string_to_uuid(input_string: str) -> uuid.UUID:
    """Hash a string and returns the corresponding UUID."""
    hash_value = hashlib.sha1(input_string.encode("utf-8")).hexdigest()
    return uuid.uuid5(NAMESPACE_UUID, hash_value)


def _key_encoder(key: str, namespace: str) -> str:
    """Encode a key."""
    return namespace + str(_hash_string_to_uuid(key))


def _create_key_encoder(namespace: str) -> Callable[[str], str]:
    """Create an encoder for a key."""
    return partial(_key_encoder, namespace=namespace)


def _value_serializer(value: Sequence[float]) -> bytes:
    """Serialize a value."""
    return json.dumps(value).encode()


def _value_deserializer(serialized_value: bytes) -> List[float]:
    """Deserialize a value."""
    return cast(List[float], json.loads(serialized_value.decode()))


class CacheBackedEmbeddings(Embeddings):
    """Interface for caching results from embedding models.

    The interface allows works with any store that implements
    the abstract store interface accepting keys of type str and values of list of
    floats.

    If need be, the interface can be extended to accept other implementations
    of the value serializer and deserializer, as well as the key encoder.

    Note that by default only document embeddings are cached. To cache query
    embeddings too, pass in a query_embedding_store to constructor.

    Examples:

        .. code-block: python

            from langchain.embeddings import CacheBackedEmbeddings
            from langchain.storage import LocalFileStore
            from langchain_community.embeddings import OpenAIEmbeddings

            store = LocalFileStore('./my_cache')

            underlying_embedder = OpenAIEmbeddings()
            embedder = CacheBackedEmbeddings.from_bytes_store(
                underlying_embedder, store, namespace=underlying_embedder.model
            )

            # Embedding is computed and cached
            embeddings = embedder.embed_documents(["hello", "goodbye"])

            # Embeddings are retrieved from the cache, no computation is done
            embeddings = embedder.embed_documents(["hello", "goodbye"])
    """

    def __init__(
        self,
        underlying_embeddings: Embeddings,
        document_embedding_store: BaseStore[str, List[float]],
        *,
        batch_size: Optional[int] = None,
        query_embedding_store: Optional[BaseStore[str, List[float]]] = None,
    ) -> None:
        """Initialize the embedder.

        Args:
            underlying_embeddings: the embedder to use for computing embeddings.
            document_embedding_store: The store to use for caching document embeddings.
            batch_size: The number of documents to embed between store updates.
            query_embedding_store: The store to use for caching query embeddings.
                If None, query embeddings are not cached.
        """
        super().__init__()
        self.document_embedding_store = document_embedding_store
        self.query_embedding_store = query_embedding_store
        self.underlying_embeddings = underlying_embeddings
        self.batch_size = batch_size

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of texts.

        The method first checks the cache for the embeddings.
        If the embeddings are not found, the method uses the underlying embedder
        to embed the documents and stores the results in the cache.

        Args:
            texts: A list of texts to embed.

        Returns:
            A list of embeddings for the given texts.
        """
        vectors: List[Union[List[float], None]] = self.document_embedding_store.mget(
            texts
        )
        all_missing_indices: List[int] = [
            i for i, vector in enumerate(vectors) if vector is None
        ]

        for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
            missing_texts = [texts[i] for i in missing_indices]
            missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
            self.document_embedding_store.mset(
                list(zip(missing_texts, missing_vectors))
            )
            for index, updated_vector in zip(missing_indices, missing_vectors):
                vectors[index] = updated_vector

        return cast(
            List[List[float]], vectors
        )  # Nones should have been resolved by now

    async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of texts.

        The method first checks the cache for the embeddings.
        If the embeddings are not found, the method uses the underlying embedder
        to embed the documents and stores the results in the cache.

        Args:
            texts: A list of texts to embed.

        Returns:
            A list of embeddings for the given texts.
        """
        vectors: List[
            Union[List[float], None]
        ] = await self.document_embedding_store.amget(texts)
        all_missing_indices: List[int] = [
            i for i, vector in enumerate(vectors) if vector is None
        ]

        # batch_iterate supports None batch_size which returns all elements at once
        # as a single batch.
        for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
            missing_texts = [texts[i] for i in missing_indices]
            missing_vectors = await self.underlying_embeddings.aembed_documents(
                missing_texts
            )
            await self.document_embedding_store.amset(
                list(zip(missing_texts, missing_vectors))
            )
            for index, updated_vector in zip(missing_indices, missing_vectors):
                vectors[index] = updated_vector

        return cast(
            List[List[float]], vectors
        )  # Nones should have been resolved by now

    def embed_query(self, text: str) -> List[float]:
        """Embed query text.

        By default, this method does not cache queries. To enable caching, set the
        `cache_query` parameter to `True` when initializing the embedder.

        Args:
            text: The text to embed.

        Returns:
            The embedding for the given text.
        """
        if not self.query_embedding_store:
            return self.underlying_embeddings.embed_query(text)

        (cached,) = self.query_embedding_store.mget([text])
        if cached is not None:
            return cached

        vector = self.underlying_embeddings.embed_query(text)
        self.query_embedding_store.mset([(text, vector)])
        return vector

    async def aembed_query(self, text: str) -> List[float]:
        """Embed query text.

        By default, this method does not cache queries. To enable caching, set the
        `cache_query` parameter to `True` when initializing the embedder.

        Args:
            text: The text to embed.

        Returns:
            The embedding for the given text.
        """
        if not self.query_embedding_store:
            return await self.underlying_embeddings.aembed_query(text)

        (cached,) = await self.query_embedding_store.amget([text])
        if cached is not None:
            return cached

        vector = await self.underlying_embeddings.aembed_query(text)
        await self.query_embedding_store.amset([(text, vector)])
        return vector

    @classmethod
    def from_bytes_store(
        cls,
        underlying_embeddings: Embeddings,
        document_embedding_cache: ByteStore,
        *,
        namespace: str = "",
        batch_size: Optional[int] = None,
        query_embedding_cache: Union[bool, ByteStore] = False,
    ) -> CacheBackedEmbeddings:
        """On-ramp that adds the necessary serialization and encoding to the store.

        Args:
            underlying_embeddings: The embedder to use for embedding.
            document_embedding_cache: The cache to use for storing document embeddings.
            *,
            namespace: The namespace to use for document cache.
                       This namespace is used to avoid collisions with other caches.
                       For example, set it to the name of the embedding model used.
            batch_size: The number of documents to embed between store updates.
            query_embedding_cache: The cache to use for storing query embeddings.
                True to use the same cache as document embeddings.
                False to not cache query embeddings.
        """
        namespace = namespace
        key_encoder = _create_key_encoder(namespace)
        document_embedding_store = EncoderBackedStore[str, List[float]](
            document_embedding_cache,
            key_encoder,
            _value_serializer,
            _value_deserializer,
        )
        if query_embedding_cache is True:
            query_embedding_store = document_embedding_store
        elif query_embedding_cache is False:
            query_embedding_store = None
        else:
            query_embedding_store = EncoderBackedStore[str, List[float]](
                query_embedding_cache,
                key_encoder,
                _value_serializer,
                _value_deserializer,
            )

        return cls(
            underlying_embeddings,
            document_embedding_store,
            batch_size=batch_size,
            query_embedding_store=query_embedding_store,
        )