Spaces:
Runtime error
Runtime error
from enum import Enum | |
from typing import Dict, List, Optional | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForRetrieverRun, | |
CallbackManagerForRetrieverRun, | |
) | |
from langchain_core.documents import Document | |
from langchain_core.pydantic_v1 import Field, root_validator | |
from langchain_core.retrievers import BaseRetriever | |
from langchain_core.stores import BaseStore, ByteStore | |
from langchain_core.vectorstores import VectorStore | |
from langchain.storage._lc_store import create_kv_docstore | |
class SearchType(str, Enum): | |
"""Enumerator of the types of search to perform.""" | |
similarity = "similarity" | |
"""Similarity search.""" | |
mmr = "mmr" | |
"""Maximal Marginal Relevance reranking of similarity search.""" | |
class MultiVectorRetriever(BaseRetriever): | |
"""Retrieve from a set of multiple embeddings for the same document.""" | |
vectorstore: VectorStore | |
"""The underlying vectorstore to use to store small chunks | |
and their embedding vectors""" | |
byte_store: Optional[ByteStore] = None | |
"""The lower-level backing storage layer for the parent documents""" | |
docstore: BaseStore[str, Document] | |
"""The storage interface for the parent documents""" | |
id_key: str = "doc_id" | |
search_kwargs: dict = Field(default_factory=dict) | |
"""Keyword arguments to pass to the search function.""" | |
search_type: SearchType = SearchType.similarity | |
"""Type of search to perform (similarity / mmr)""" | |
def shim_docstore(cls, values: Dict) -> Dict: | |
byte_store = values.get("byte_store") | |
docstore = values.get("docstore") | |
if byte_store is not None: | |
docstore = create_kv_docstore(byte_store) | |
elif docstore is None: | |
raise Exception("You must pass a `byte_store` parameter.") | |
values["docstore"] = docstore | |
return values | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
"""Get documents relevant to a query. | |
Args: | |
query: String to find relevant documents for | |
run_manager: The callbacks handler to use | |
Returns: | |
List of relevant documents | |
""" | |
if self.search_type == SearchType.mmr: | |
sub_docs = self.vectorstore.max_marginal_relevance_search( | |
query, **self.search_kwargs | |
) | |
else: | |
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs) | |
# We do this to maintain the order of the ids that are returned | |
ids = [] | |
for d in sub_docs: | |
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids: | |
ids.append(d.metadata[self.id_key]) | |
docs = self.docstore.mget(ids) | |
return [d for d in docs if d is not None] | |
async def _aget_relevant_documents( | |
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun | |
) -> List[Document]: | |
"""Asynchronously get documents relevant to a query. | |
Args: | |
query: String to find relevant documents for | |
run_manager: The callbacks handler to use | |
Returns: | |
List of relevant documents | |
""" | |
if self.search_type == SearchType.mmr: | |
sub_docs = await self.vectorstore.amax_marginal_relevance_search( | |
query, **self.search_kwargs | |
) | |
else: | |
sub_docs = await self.vectorstore.asimilarity_search( | |
query, **self.search_kwargs | |
) | |
# We do this to maintain the order of the ids that are returned | |
ids = [] | |
for d in sub_docs: | |
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids: | |
ids.append(d.metadata[self.id_key]) | |
docs = await self.docstore.amget(ids) | |
return [d for d in docs if d is not None] | |