|
import gradio as gr |
|
from typing import Any, Dict, List, Tuple |
|
from langchain_chroma import Chroma |
|
from langchain_core.callbacks import BaseCallbackHandler |
|
from langchain_core.messages import AIMessage, HumanMessage |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
|
from langchain.chains.history_aware_retriever import create_history_aware_retriever |
|
from langchain.chains.retrieval import create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun |
|
from langchain_core.documents import Document |
|
from langchain_core.retrievers import BaseRetriever |
|
import pandas as pd |
|
|
|
class CustomHandler(BaseCallbackHandler): |
|
def __init__(self): |
|
self.prompt = "" |
|
|
|
def on_llm_start( |
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any |
|
) -> Any: |
|
formatted_prompts = "\n".join(prompts) |
|
self.prompt = formatted_prompts |
|
|
|
class CustomRetriever(BaseRetriever): |
|
vectorstore: Chroma |
|
comments: pd.DataFrame |
|
|
|
def _get_relevant_documents( |
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun |
|
) -> List[Document]: |
|
docs = self.vectorstore.similarity_search(query) |
|
matching_documents = [] |
|
for doc in docs: |
|
post_id = int(doc.metadata['source']) |
|
comment = self.comments.loc[self.comments['Post_ID'] == post_id, 'Comment_content'].values |
|
query = doc.page_content.replace("Content: ", "User: ") |
|
content = f"{query}\nAssistant: {comment[0]}" |
|
matching_documents.append( |
|
Document( |
|
page_content=content, |
|
metadata=doc.metadata |
|
) |
|
) |
|
|
|
print(matching_documents) |
|
return matching_documents |
|
|
|
class ChatBot: |
|
def __init__(self, is_debug=False): |
|
self.is_debug = is_debug |
|
self.model = ChatOpenAI() |
|
self.handler = CustomHandler() |
|
self.embedding_function = OpenAIEmbeddings() |
|
self.vectorstore = Chroma( |
|
embedding_function=self.embedding_function, |
|
collection_name="documents", |
|
persist_directory="chroma", |
|
) |
|
self.comments = pd.read_csv("data/comments.csv") |
|
self.retriever = CustomRetriever(vectorstore=self.vectorstore, comments=self.comments) |
|
|
|
def create_chain(self): |
|
qa_system_prompt = """ |
|
You are a helpful and joyous mental therapy assistant. Always answer as helpfully and cheerfully as possible, while being safe. |
|
Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. |
|
Please ensure that your responses are socially unbiased and positive in nature. |
|
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. |
|
If you don't know the answer to a question, please don't share false information. |
|
|
|
Here are a few examples of answers: |
|
{context} |
|
|
|
""" |
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", qa_system_prompt), |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("human", "{input}") |
|
]) |
|
|
|
chain = create_stuff_documents_chain( |
|
llm=self.model, |
|
prompt=prompt |
|
) |
|
|
|
retriever_prompt = ChatPromptTemplate.from_messages([ |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("human", "{input}"), |
|
("human", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation") |
|
]) |
|
history_aware_retriever = create_history_aware_retriever( |
|
llm=self.model, |
|
retriever=self.retriever, |
|
prompt=retriever_prompt |
|
) |
|
|
|
retrieval_chain = create_retrieval_chain( |
|
|
|
history_aware_retriever, |
|
chain |
|
) |
|
|
|
return retrieval_chain |
|
|
|
def process_chat_history(self, chat_history): |
|
history = [] |
|
for (query, response) in chat_history: |
|
history.append(HumanMessage(content=query)) |
|
history.append(AIMessage(content=response)) |
|
return history |
|
|
|
def generate_response(self, query, chat_history): |
|
if not input: |
|
raise gr.Error("Please enter a question.") |
|
|
|
history = self.process_chat_history(chat_history) |
|
conversational_chain = self.create_chain() |
|
response = conversational_chain.invoke( |
|
{ |
|
"input": query, |
|
"chat_history": history, |
|
}, |
|
config={"callbacks": [self.handler]} |
|
)["answer"] |
|
|
|
references = self.handler.prompt if self.is_debug else "This is for debugging purposes only." |
|
return response, references |