forum-rag / src /chatbot.py
Chu Thi Thanh
Upload files
af8db98
import gradio as gr
from typing import Any, Dict, List, Tuple
from langchain_chroma import Chroma
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
import pandas as pd
class CustomHandler(BaseCallbackHandler):
def __init__(self):
self.prompt = ""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
formatted_prompts = "\n".join(prompts)
self.prompt = formatted_prompts
class CustomRetriever(BaseRetriever):
vectorstore: Chroma
comments: pd.DataFrame
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
docs = self.vectorstore.similarity_search(query)
matching_documents = []
for doc in docs:
post_id = int(doc.metadata['source'])
comment = self.comments.loc[self.comments['Post_ID'] == post_id, 'Comment_content'].values
query = doc.page_content.replace("Content: ", "User: ")
content = f"{query}\nAssistant: {comment[0]}"
matching_documents.append(
Document(
page_content=content,
metadata=doc.metadata
)
)
print(matching_documents)
return matching_documents
class ChatBot:
def __init__(self, is_debug=False):
self.is_debug = is_debug
self.model = ChatOpenAI()
self.handler = CustomHandler()
self.embedding_function = OpenAIEmbeddings()
self.vectorstore = Chroma(
embedding_function=self.embedding_function,
collection_name="documents",
persist_directory="chroma",
)
self.comments = pd.read_csv("data/comments.csv")
self.retriever = CustomRetriever(vectorstore=self.vectorstore, comments=self.comments)
def create_chain(self):
qa_system_prompt = """
You are a helpful and joyous mental therapy assistant. Always answer as helpfully and cheerfully as possible, while being safe.
Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
If you don't know the answer to a question, please don't share false information.
Here are a few examples of answers:
{context}
"""
prompt = ChatPromptTemplate.from_messages([
("system", qa_system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}")
])
chain = create_stuff_documents_chain(
llm=self.model,
prompt=prompt
)
retriever_prompt = ChatPromptTemplate.from_messages([
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
("human", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation")
])
history_aware_retriever = create_history_aware_retriever(
llm=self.model,
retriever=self.retriever,
prompt=retriever_prompt
)
retrieval_chain = create_retrieval_chain(
# retriever, Replace with History Aware Retriever
history_aware_retriever,
chain
)
return retrieval_chain
def process_chat_history(self, chat_history):
history = []
for (query, response) in chat_history:
history.append(HumanMessage(content=query))
history.append(AIMessage(content=response))
return history
def generate_response(self, query, chat_history):
if not input:
raise gr.Error("Please enter a question.")
history = self.process_chat_history(chat_history)
conversational_chain = self.create_chain()
response = conversational_chain.invoke(
{
"input": query,
"chat_history": history,
},
config={"callbacks": [self.handler]}
)["answer"]
references = self.handler.prompt if self.is_debug else "This is for debugging purposes only."
return response, references