File size: 5,055 Bytes
af8db98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
from typing import Any, Dict, List, Tuple
from langchain_chroma import Chroma
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
import pandas as pd

class CustomHandler(BaseCallbackHandler):
    def __init__(self):
        self.prompt = ""

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> Any:
        formatted_prompts = "\n".join(prompts)
        self.prompt = formatted_prompts

class CustomRetriever(BaseRetriever):
    vectorstore: Chroma
    comments: pd.DataFrame

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:   
        docs = self.vectorstore.similarity_search(query)
        matching_documents = []
        for doc in docs:
            post_id = int(doc.metadata['source'])
            comment = self.comments.loc[self.comments['Post_ID'] == post_id, 'Comment_content'].values
            query = doc.page_content.replace("Content: ", "User: ")
            content = f"{query}\nAssistant: {comment[0]}"
            matching_documents.append(
                Document(
                    page_content=content,
                    metadata=doc.metadata
                )
            )

        print(matching_documents)
        return matching_documents

class ChatBot:
    def __init__(self, is_debug=False):
        self.is_debug = is_debug
        self.model = ChatOpenAI()
        self.handler = CustomHandler()
        self.embedding_function = OpenAIEmbeddings()
        self.vectorstore = Chroma(
            embedding_function=self.embedding_function,
            collection_name="documents",
            persist_directory="chroma", 
        )
        self.comments = pd.read_csv("data/comments.csv")
        self.retriever = CustomRetriever(vectorstore=self.vectorstore, comments=self.comments)
    
    def create_chain(self):
        qa_system_prompt = """
        You are a helpful and joyous mental therapy assistant. Always answer as helpfully and cheerfully as possible, while being safe.
        Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
        Please ensure that your responses are socially unbiased and positive in nature.
        If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. 
        If you don't know the answer to a question, please don't share false information.

        Here are a few examples of answers:   
        {context}
        
        """
        prompt = ChatPromptTemplate.from_messages([
            ("system", qa_system_prompt),
            MessagesPlaceholder(variable_name="chat_history"),
            ("human", "{input}")
        ])

        chain = create_stuff_documents_chain(
            llm=self.model,
            prompt=prompt
        )

        retriever_prompt = ChatPromptTemplate.from_messages([
            MessagesPlaceholder(variable_name="chat_history"),
            ("human", "{input}"),
            ("human", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation")
        ])
        history_aware_retriever = create_history_aware_retriever(
            llm=self.model,
            retriever=self.retriever,
            prompt=retriever_prompt
        )

        retrieval_chain = create_retrieval_chain(
            # retriever, Replace with History Aware Retriever
            history_aware_retriever,
            chain
        )

        return retrieval_chain

    def process_chat_history(self, chat_history):
        history = []
        for (query, response) in chat_history:
            history.append(HumanMessage(content=query))
            history.append(AIMessage(content=response))
        return history

    def generate_response(self, query, chat_history):
        if not input:
            raise gr.Error("Please enter a question.")

        history = self.process_chat_history(chat_history)
        conversational_chain = self.create_chain()
        response = conversational_chain.invoke(
            {
                "input": query,
                "chat_history": history,
            }, 
            config={"callbacks": [self.handler]}
        )["answer"]

        references = self.handler.prompt if self.is_debug else "This is for debugging purposes only."
        return response, references