File size: 5,055 Bytes
af8db98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr
from typing import Any, Dict, List, Tuple
from langchain_chroma import Chroma
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
import pandas as pd
class CustomHandler(BaseCallbackHandler):
def __init__(self):
self.prompt = ""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
formatted_prompts = "\n".join(prompts)
self.prompt = formatted_prompts
class CustomRetriever(BaseRetriever):
vectorstore: Chroma
comments: pd.DataFrame
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
docs = self.vectorstore.similarity_search(query)
matching_documents = []
for doc in docs:
post_id = int(doc.metadata['source'])
comment = self.comments.loc[self.comments['Post_ID'] == post_id, 'Comment_content'].values
query = doc.page_content.replace("Content: ", "User: ")
content = f"{query}\nAssistant: {comment[0]}"
matching_documents.append(
Document(
page_content=content,
metadata=doc.metadata
)
)
print(matching_documents)
return matching_documents
class ChatBot:
def __init__(self, is_debug=False):
self.is_debug = is_debug
self.model = ChatOpenAI()
self.handler = CustomHandler()
self.embedding_function = OpenAIEmbeddings()
self.vectorstore = Chroma(
embedding_function=self.embedding_function,
collection_name="documents",
persist_directory="chroma",
)
self.comments = pd.read_csv("data/comments.csv")
self.retriever = CustomRetriever(vectorstore=self.vectorstore, comments=self.comments)
def create_chain(self):
qa_system_prompt = """
You are a helpful and joyous mental therapy assistant. Always answer as helpfully and cheerfully as possible, while being safe.
Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
If you don't know the answer to a question, please don't share false information.
Here are a few examples of answers:
{context}
"""
prompt = ChatPromptTemplate.from_messages([
("system", qa_system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}")
])
chain = create_stuff_documents_chain(
llm=self.model,
prompt=prompt
)
retriever_prompt = ChatPromptTemplate.from_messages([
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
("human", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation")
])
history_aware_retriever = create_history_aware_retriever(
llm=self.model,
retriever=self.retriever,
prompt=retriever_prompt
)
retrieval_chain = create_retrieval_chain(
# retriever, Replace with History Aware Retriever
history_aware_retriever,
chain
)
return retrieval_chain
def process_chat_history(self, chat_history):
history = []
for (query, response) in chat_history:
history.append(HumanMessage(content=query))
history.append(AIMessage(content=response))
return history
def generate_response(self, query, chat_history):
if not input:
raise gr.Error("Please enter a question.")
history = self.process_chat_history(chat_history)
conversational_chain = self.create_chain()
response = conversational_chain.invoke(
{
"input": query,
"chat_history": history,
},
config={"callbacks": [self.handler]}
)["answer"]
references = self.handler.prompt if self.is_debug else "This is for debugging purposes only."
return response, references |