Spaces:
Runtime error
Runtime error
"""Chain for chatting with a vector database.""" | |
from __future__ import annotations | |
import inspect | |
import warnings | |
from abc import abstractmethod | |
from pathlib import Path | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union | |
from langchain_core._api import deprecated | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForChainRun, | |
CallbackManagerForChainRun, | |
Callbacks, | |
) | |
from langchain_core.documents import Document | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.messages import BaseMessage | |
from langchain_core.prompts import BasePromptTemplate | |
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator | |
from langchain_core.retrievers import BaseRetriever | |
from langchain_core.runnables import RunnableConfig | |
from langchain_core.vectorstores import VectorStore | |
from langchain.chains.base import Chain | |
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain | |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain | |
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT | |
from langchain.chains.llm import LLMChain | |
from langchain.chains.question_answering import load_qa_chain | |
# Depending on the memory type and configuration, the chat history format may differ. | |
# This needs to be consolidated. | |
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage] | |
_ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "} | |
def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str: | |
buffer = "" | |
for dialogue_turn in chat_history: | |
if isinstance(dialogue_turn, BaseMessage): | |
role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ") | |
buffer += f"\n{role_prefix}{dialogue_turn.content}" | |
elif isinstance(dialogue_turn, tuple): | |
human = "Human: " + dialogue_turn[0] | |
ai = "Assistant: " + dialogue_turn[1] | |
buffer += "\n" + "\n".join([human, ai]) | |
else: | |
raise ValueError( | |
f"Unsupported chat history format: {type(dialogue_turn)}." | |
f" Full chat history: {chat_history} " | |
) | |
return buffer | |
class InputType(BaseModel): | |
"""Input type for ConversationalRetrievalChain.""" | |
question: str | |
"""The question to answer.""" | |
chat_history: List[CHAT_TURN_TYPE] = Field(default_factory=list) | |
"""The chat history to use for retrieval.""" | |
class BaseConversationalRetrievalChain(Chain): | |
"""Chain for chatting with an index.""" | |
combine_docs_chain: BaseCombineDocumentsChain | |
"""The chain used to combine any retrieved documents.""" | |
question_generator: LLMChain | |
"""The chain used to generate a new question for the sake of retrieval. | |
This chain will take in the current question (with variable `question`) | |
and any chat history (with variable `chat_history`) and will produce | |
a new standalone question to be used later on.""" | |
output_key: str = "answer" | |
"""The output key to return the final answer of this chain in.""" | |
rephrase_question: bool = True | |
"""Whether or not to pass the new generated question to the combine_docs_chain. | |
If True, will pass the new generated question along. | |
If False, will only use the new generated question for retrieval and pass the | |
original question along to the combine_docs_chain.""" | |
return_source_documents: bool = False | |
"""Return the retrieved source documents as part of the final result.""" | |
return_generated_question: bool = False | |
"""Return the generated question as part of the final result.""" | |
get_chat_history: Optional[Callable[[List[CHAT_TURN_TYPE]], str]] = None | |
"""An optional function to get a string of the chat history. | |
If None is provided, will use a default.""" | |
response_if_no_docs_found: Optional[str] | |
"""If specified, the chain will return a fixed response if no docs | |
are found for the question. """ | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
arbitrary_types_allowed = True | |
allow_population_by_field_name = True | |
def input_keys(self) -> List[str]: | |
"""Input keys.""" | |
return ["question", "chat_history"] | |
def get_input_schema( | |
self, config: Optional[RunnableConfig] = None | |
) -> Type[BaseModel]: | |
return InputType | |
def output_keys(self) -> List[str]: | |
"""Return the output keys. | |
:meta private: | |
""" | |
_output_keys = [self.output_key] | |
if self.return_source_documents: | |
_output_keys = _output_keys + ["source_documents"] | |
if self.return_generated_question: | |
_output_keys = _output_keys + ["generated_question"] | |
return _output_keys | |
def _get_docs( | |
self, | |
question: str, | |
inputs: Dict[str, Any], | |
*, | |
run_manager: CallbackManagerForChainRun, | |
) -> List[Document]: | |
"""Get docs.""" | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
question = inputs["question"] | |
get_chat_history = self.get_chat_history or _get_chat_history | |
chat_history_str = get_chat_history(inputs["chat_history"]) | |
if chat_history_str: | |
callbacks = _run_manager.get_child() | |
new_question = self.question_generator.run( | |
question=question, chat_history=chat_history_str, callbacks=callbacks | |
) | |
else: | |
new_question = question | |
accepts_run_manager = ( | |
"run_manager" in inspect.signature(self._get_docs).parameters | |
) | |
if accepts_run_manager: | |
docs = self._get_docs(new_question, inputs, run_manager=_run_manager) | |
else: | |
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg] | |
output: Dict[str, Any] = {} | |
if self.response_if_no_docs_found is not None and len(docs) == 0: | |
output[self.output_key] = self.response_if_no_docs_found | |
else: | |
new_inputs = inputs.copy() | |
if self.rephrase_question: | |
new_inputs["question"] = new_question | |
new_inputs["chat_history"] = chat_history_str | |
answer = self.combine_docs_chain.run( | |
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs | |
) | |
output[self.output_key] = answer | |
if self.return_source_documents: | |
output["source_documents"] = docs | |
if self.return_generated_question: | |
output["generated_question"] = new_question | |
return output | |
async def _aget_docs( | |
self, | |
question: str, | |
inputs: Dict[str, Any], | |
*, | |
run_manager: AsyncCallbackManagerForChainRun, | |
) -> List[Document]: | |
"""Get docs.""" | |
async def _acall( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() | |
question = inputs["question"] | |
get_chat_history = self.get_chat_history or _get_chat_history | |
chat_history_str = get_chat_history(inputs["chat_history"]) | |
if chat_history_str: | |
callbacks = _run_manager.get_child() | |
new_question = await self.question_generator.arun( | |
question=question, chat_history=chat_history_str, callbacks=callbacks | |
) | |
else: | |
new_question = question | |
accepts_run_manager = ( | |
"run_manager" in inspect.signature(self._aget_docs).parameters | |
) | |
if accepts_run_manager: | |
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager) | |
else: | |
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg] | |
output: Dict[str, Any] = {} | |
if self.response_if_no_docs_found is not None and len(docs) == 0: | |
output[self.output_key] = self.response_if_no_docs_found | |
else: | |
new_inputs = inputs.copy() | |
if self.rephrase_question: | |
new_inputs["question"] = new_question | |
new_inputs["chat_history"] = chat_history_str | |
answer = await self.combine_docs_chain.arun( | |
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs | |
) | |
output[self.output_key] = answer | |
if self.return_source_documents: | |
output["source_documents"] = docs | |
if self.return_generated_question: | |
output["generated_question"] = new_question | |
return output | |
def save(self, file_path: Union[Path, str]) -> None: | |
if self.get_chat_history: | |
raise ValueError("Chain not saveable when `get_chat_history` is not None.") | |
super().save(file_path) | |
class ConversationalRetrievalChain(BaseConversationalRetrievalChain): | |
"""Chain for having a conversation based on retrieved documents. | |
This class is deprecated. See below for an example implementation using | |
`create_retrieval_chain`. Additional walkthroughs can be found at | |
https://python.langchain.com/docs/use_cases/question_answering/chat_history | |
.. code-block:: python | |
from langchain.chains import ( | |
create_history_aware_retriever, | |
create_retrieval_chain, | |
) | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_openai import ChatOpenAI | |
retriever = ... # Your retriever | |
llm = ChatOpenAI() | |
# Contextualize question | |
contextualize_q_system_prompt = ( | |
"Given a chat history and the latest user question " | |
"which might reference context in the chat history, " | |
"formulate a standalone question which can be understood " | |
"without the chat history. Do NOT answer the question, just " | |
"reformulate it if needed and otherwise return it as is." | |
) | |
contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
history_aware_retriever = create_history_aware_retriever( | |
llm, retriever, contextualize_q_prompt | |
) | |
# Answer question | |
qa_system_prompt = ( | |
"You are an assistant for question-answering tasks. Use " | |
"the following pieces of retrieved context to answer the " | |
"question. If you don't know the answer, just say that you " | |
"don't know. Use three sentences maximum and keep the answer " | |
"concise." | |
"\n\n" | |
"{context}" | |
) | |
qa_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", qa_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
# Below we use create_stuff_documents_chain to feed all retrieved context | |
# into the LLM. Note that we can also use StuffDocumentsChain and other | |
# instances of BaseCombineDocumentsChain. | |
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | |
rag_chain = create_retrieval_chain( | |
history_aware_retriever, question_answer_chain | |
) | |
# Usage: | |
chat_history = [] # Collect chat history here (a sequence of messages) | |
rag_chain.invoke({"input": query, "chat_history": chat_history}) | |
This chain takes in chat history (a list of messages) and new questions, | |
and then returns an answer to that question. | |
The algorithm for this chain consists of three parts: | |
1. Use the chat history and the new question to create a "standalone question". | |
This is done so that this question can be passed into the retrieval step to fetch | |
relevant documents. If only the new question was passed in, then relevant context | |
may be lacking. If the whole conversation was passed into retrieval, there may | |
be unnecessary information there that would distract from retrieval. | |
2. This new question is passed to the retriever and relevant documents are | |
returned. | |
3. The retrieved documents are passed to an LLM along with either the new question | |
(default behavior) or the original question and chat history to generate a final | |
response. | |
Example: | |
.. code-block:: python | |
from langchain.chains import ( | |
StuffDocumentsChain, LLMChain, ConversationalRetrievalChain | |
) | |
from langchain_core.prompts import PromptTemplate | |
from langchain_community.llms import OpenAI | |
combine_docs_chain = StuffDocumentsChain(...) | |
vectorstore = ... | |
retriever = vectorstore.as_retriever() | |
# This controls how the standalone question is generated. | |
# Should take `chat_history` and `question` as input variables. | |
template = ( | |
"Combine the chat history and follow up question into " | |
"a standalone question. Chat History: {chat_history}" | |
"Follow up question: {question}" | |
) | |
prompt = PromptTemplate.from_template(template) | |
llm = OpenAI() | |
question_generator_chain = LLMChain(llm=llm, prompt=prompt) | |
chain = ConversationalRetrievalChain( | |
combine_docs_chain=combine_docs_chain, | |
retriever=retriever, | |
question_generator=question_generator_chain, | |
) | |
""" | |
retriever: BaseRetriever | |
"""Retriever to use to fetch documents.""" | |
max_tokens_limit: Optional[int] = None | |
"""If set, enforces that the documents returned are less than this limit. | |
This is only enforced if `combine_docs_chain` is of type StuffDocumentsChain.""" | |
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]: | |
num_docs = len(docs) | |
if self.max_tokens_limit and isinstance( | |
self.combine_docs_chain, StuffDocumentsChain | |
): | |
tokens = [ | |
self.combine_docs_chain.llm_chain._get_num_tokens(doc.page_content) | |
for doc in docs | |
] | |
token_count = sum(tokens[:num_docs]) | |
while token_count > self.max_tokens_limit: | |
num_docs -= 1 | |
token_count -= tokens[num_docs] | |
return docs[:num_docs] | |
def _get_docs( | |
self, | |
question: str, | |
inputs: Dict[str, Any], | |
*, | |
run_manager: CallbackManagerForChainRun, | |
) -> List[Document]: | |
"""Get docs.""" | |
docs = self.retriever.invoke( | |
question, config={"callbacks": run_manager.get_child()} | |
) | |
return self._reduce_tokens_below_limit(docs) | |
async def _aget_docs( | |
self, | |
question: str, | |
inputs: Dict[str, Any], | |
*, | |
run_manager: AsyncCallbackManagerForChainRun, | |
) -> List[Document]: | |
"""Get docs.""" | |
docs = await self.retriever.ainvoke( | |
question, config={"callbacks": run_manager.get_child()} | |
) | |
return self._reduce_tokens_below_limit(docs) | |
def from_llm( | |
cls, | |
llm: BaseLanguageModel, | |
retriever: BaseRetriever, | |
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, | |
chain_type: str = "stuff", | |
verbose: bool = False, | |
condense_question_llm: Optional[BaseLanguageModel] = None, | |
combine_docs_chain_kwargs: Optional[Dict] = None, | |
callbacks: Callbacks = None, | |
**kwargs: Any, | |
) -> BaseConversationalRetrievalChain: | |
"""Convenience method to load chain from LLM and retriever. | |
This provides some logic to create the `question_generator` chain | |
as well as the combine_docs_chain. | |
Args: | |
llm: The default language model to use at every part of this chain | |
(eg in both the question generation and the answering) | |
retriever: The retriever to use to fetch relevant documents from. | |
condense_question_prompt: The prompt to use to condense the chat history | |
and new question into a standalone question. | |
chain_type: The chain type to use to create the combine_docs_chain, will | |
be sent to `load_qa_chain`. | |
verbose: Verbosity flag for logging to stdout. | |
condense_question_llm: The language model to use for condensing the chat | |
history and new question into a standalone question. If none is | |
provided, will default to `llm`. | |
combine_docs_chain_kwargs: Parameters to pass as kwargs to `load_qa_chain` | |
when constructing the combine_docs_chain. | |
callbacks: Callbacks to pass to all subchains. | |
**kwargs: Additional parameters to pass when initializing | |
ConversationalRetrievalChain | |
""" | |
combine_docs_chain_kwargs = combine_docs_chain_kwargs or {} | |
doc_chain = load_qa_chain( | |
llm, | |
chain_type=chain_type, | |
verbose=verbose, | |
callbacks=callbacks, | |
**combine_docs_chain_kwargs, | |
) | |
_llm = condense_question_llm or llm | |
condense_question_chain = LLMChain( | |
llm=_llm, | |
prompt=condense_question_prompt, | |
verbose=verbose, | |
callbacks=callbacks, | |
) | |
return cls( | |
retriever=retriever, | |
combine_docs_chain=doc_chain, | |
question_generator=condense_question_chain, | |
callbacks=callbacks, | |
**kwargs, | |
) | |
class ChatVectorDBChain(BaseConversationalRetrievalChain): | |
"""Chain for chatting with a vector database.""" | |
vectorstore: VectorStore = Field(alias="vectorstore") | |
top_k_docs_for_context: int = 4 | |
search_kwargs: dict = Field(default_factory=dict) | |
def _chain_type(self) -> str: | |
return "chat-vector-db" | |
def raise_deprecation(cls, values: Dict) -> Dict: | |
warnings.warn( | |
"`ChatVectorDBChain` is deprecated - " | |
"please use `from langchain.chains import ConversationalRetrievalChain`" | |
) | |
return values | |
def _get_docs( | |
self, | |
question: str, | |
inputs: Dict[str, Any], | |
*, | |
run_manager: CallbackManagerForChainRun, | |
) -> List[Document]: | |
"""Get docs.""" | |
vectordbkwargs = inputs.get("vectordbkwargs", {}) | |
full_kwargs = {**self.search_kwargs, **vectordbkwargs} | |
return self.vectorstore.similarity_search( | |
question, k=self.top_k_docs_for_context, **full_kwargs | |
) | |
async def _aget_docs( | |
self, | |
question: str, | |
inputs: Dict[str, Any], | |
*, | |
run_manager: AsyncCallbackManagerForChainRun, | |
) -> List[Document]: | |
"""Get docs.""" | |
raise NotImplementedError("ChatVectorDBChain does not support async") | |
def from_llm( | |
cls, | |
llm: BaseLanguageModel, | |
vectorstore: VectorStore, | |
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, | |
chain_type: str = "stuff", | |
combine_docs_chain_kwargs: Optional[Dict] = None, | |
callbacks: Callbacks = None, | |
**kwargs: Any, | |
) -> BaseConversationalRetrievalChain: | |
"""Load chain from LLM.""" | |
combine_docs_chain_kwargs = combine_docs_chain_kwargs or {} | |
doc_chain = load_qa_chain( | |
llm, | |
chain_type=chain_type, | |
callbacks=callbacks, | |
**combine_docs_chain_kwargs, | |
) | |
condense_question_chain = LLMChain( | |
llm=llm, prompt=condense_question_prompt, callbacks=callbacks | |
) | |
return cls( | |
vectorstore=vectorstore, | |
combine_docs_chain=doc_chain, | |
question_generator=condense_question_chain, | |
callbacks=callbacks, | |
**kwargs, | |
) | |