Spaces:
Sleeping
Sleeping
import os | |
from typing import List, Iterable, Any | |
from dotenv import load_dotenv | |
from langchain.memory import ChatMessageHistory | |
from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
from langchain_core.chat_history import BaseChatMessageHistory | |
from langchain_core.documents import Document | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.retrievers import BaseRetriever | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from basic_chain import get_model | |
from rag_chain import make_rag_chain | |
import logging | |
store = {} | |
import json | |
def create_memory_chain(llm, base_chain): | |
contextualize_q_system_prompt = """Given a chat history and the latest user question \ | |
which might reference context in the chat history, formulate a standalone question \ | |
which can be understood without the chat history. Do NOT answer the question, \ | |
just reformulate it if needed and otherwise return it as is.""" | |
contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder(variable_name="chat_history"), | |
("human", "{question}"), | |
] | |
) | |
runnable = contextualize_q_prompt | llm | base_chain | |
def get_session_history(session_id: str) -> BaseChatMessageHistory: | |
if session_id not in store: | |
store[session_id] = ChatMessageHistory() | |
return store[session_id] | |
with_message_history = RunnableWithMessageHistory( | |
runnable, | |
get_session_history, | |
input_messages_key="question", | |
history_messages_key="chat_history", | |
) | |
return with_message_history | |
def clean_session_history(session_id): | |
global store | |
store[session_id] = ChatMessageHistory() | |
class SimpleTextRetriever(BaseRetriever): | |
docs: List[Document] | |
"""Documents.""" | |
def from_texts( | |
cls, | |
texts: Iterable[str], | |
**kwargs: Any, | |
): | |
docs = [Document(page_content=t) for t in texts] | |
return cls(docs=docs, **kwargs) | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
return self.docs | |
def main(): | |
load_dotenv() | |
model = get_model("ChatGPT") | |
chat_memory = ChatMessageHistory() | |
system_prompt = "You are a helpful AI assistant for busy professionals trying to improve their health." | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_prompt), | |
MessagesPlaceholder(variable_name="chat_history"), | |
("human", "{question}"), | |
] | |
) | |
text_path = "examples/equity_faq.txt" # Updated path | |
text = open(text_path, "r").read() | |
retriever = SimpleTextRetriever.from_texts([text]) | |
rag_chain = make_rag_chain(model, retriever, rag_prompt=None) | |
chain = create_memory_chain(model, rag_chain, chat_memory) | StrOutputParser() | |
queries = [ | |
"What do I need to get from the grocery store besides milk?", | |
"Which of these items can I find at a farmer's market?", | |
] | |
for query in queries: | |
print(f"\nQuestion: {query}") | |
response = chain.invoke( | |
{"question": query}, | |
config={"configurable": {"session_id": "foo"}} | |
) | |
print(f"Answer: {response}") | |
if __name__ == "__main__": | |
# this is to quite parallel tokenizers warning. | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
main() | |