import gradio as gr from typing import Any, Dict, List, Tuple from langchain_chroma import Chroma from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain.chains.history_aware_retriever import create_history_aware_retriever from langchain.chains.retrieval import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever import pandas as pd class CustomHandler(BaseCallbackHandler): def __init__(self): self.prompt = "" def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> Any: formatted_prompts = "\n".join(prompts) self.prompt = formatted_prompts class CustomRetriever(BaseRetriever): vectorstore: Chroma comments: pd.DataFrame def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: docs = self.vectorstore.similarity_search(query) matching_documents = [] for doc in docs: post_id = int(doc.metadata['source']) comment = self.comments.loc[self.comments['Post_ID'] == post_id, 'Comment_content'].values query = doc.page_content.replace("Content: ", "User: ") content = f"{query}\nAssistant: {comment[0]}" matching_documents.append( Document( page_content=content, metadata=doc.metadata ) ) print(matching_documents) return matching_documents class ChatBot: def __init__(self, is_debug=False): self.is_debug = is_debug self.model = ChatOpenAI() self.handler = CustomHandler() self.embedding_function = OpenAIEmbeddings() self.vectorstore = Chroma( embedding_function=self.embedding_function, collection_name="documents", persist_directory="chroma", ) self.comments = pd.read_csv("data/comments.csv") self.retriever = CustomRetriever(vectorstore=self.vectorstore, comments=self.comments) def create_chain(self): qa_system_prompt = """ You are a helpful and joyous mental therapy assistant. Always answer as helpfully and cheerfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. Here are a few examples of answers: {context} """ prompt = ChatPromptTemplate.from_messages([ ("system", qa_system_prompt), MessagesPlaceholder(variable_name="chat_history"), ("human", "{input}") ]) chain = create_stuff_documents_chain( llm=self.model, prompt=prompt ) retriever_prompt = ChatPromptTemplate.from_messages([ MessagesPlaceholder(variable_name="chat_history"), ("human", "{input}"), ("human", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation") ]) history_aware_retriever = create_history_aware_retriever( llm=self.model, retriever=self.retriever, prompt=retriever_prompt ) retrieval_chain = create_retrieval_chain( # retriever, Replace with History Aware Retriever history_aware_retriever, chain ) return retrieval_chain def process_chat_history(self, chat_history): history = [] for (query, response) in chat_history: history.append(HumanMessage(content=query)) history.append(AIMessage(content=response)) return history def generate_response(self, query, chat_history): if not input: raise gr.Error("Please enter a question.") history = self.process_chat_history(chat_history) conversational_chain = self.create_chain() response = conversational_chain.invoke( { "input": query, "chat_history": history, }, config={"callbacks": [self.handler]} )["answer"] references = self.handler.prompt if self.is_debug else "This is for debugging purposes only." return response, references