Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
from typing import Any, Dict, Union | |
from langchain_core.retrievers import ( | |
BaseRetriever, | |
RetrieverOutput, | |
) | |
from langchain_core.runnables import Runnable, RunnablePassthrough | |
def create_retrieval_chain( | |
retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]], | |
combine_docs_chain: Runnable[Dict[str, Any], str], | |
) -> Runnable: | |
"""Create retrieval chain that retrieves documents and then passes them on. | |
Args: | |
retriever: Retriever-like object that returns list of documents. Should | |
either be a subclass of BaseRetriever or a Runnable that returns | |
a list of documents. If a subclass of BaseRetriever, then it | |
is expected that an `input` key be passed in - this is what | |
is will be used to pass into the retriever. If this is NOT a | |
subclass of BaseRetriever, then all the inputs will be passed | |
into this runnable, meaning that runnable should take a dictionary | |
as input. | |
combine_docs_chain: Runnable that takes inputs and produces a string output. | |
The inputs to this will be any original inputs to this chain, a new | |
context key with the retrieved documents, and chat_history (if not present | |
in the inputs) with a value of `[]` (to easily enable conversational | |
retrieval. | |
Returns: | |
An LCEL Runnable. The Runnable return is a dictionary containing at the very | |
least a `context` and `answer` key. | |
Example: | |
.. code-block:: python | |
# pip install -U langchain langchain-community | |
from langchain_community.chat_models import ChatOpenAI | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain.chains import create_retrieval_chain | |
from langchain import hub | |
retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat") | |
llm = ChatOpenAI() | |
retriever = ... | |
combine_docs_chain = create_stuff_documents_chain( | |
llm, retrieval_qa_chat_prompt | |
) | |
retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain) | |
chain.invoke({"input": "..."}) | |
""" | |
if not isinstance(retriever, BaseRetriever): | |
retrieval_docs: Runnable[dict, RetrieverOutput] = retriever | |
else: | |
retrieval_docs = (lambda x: x["input"]) | retriever | |
retrieval_chain = ( | |
RunnablePassthrough.assign( | |
context=retrieval_docs.with_config(run_name="retrieve_documents"), | |
).assign(answer=combine_docs_chain) | |
).with_config(run_name="retrieval_chain") | |
return retrieval_chain | |