File size: 7,111 Bytes
820f884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import os
import streamlit as st
import logging
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain_community.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever

from ensemble import ensemble_retriever_from_docs
from full_chain import create_full_chain, ask_question
from local_loader import load_data_files, load_file
from vector_store import EmbeddingProxy 
from memory import clean_session_history
from pathlib import Path

import gradio as gr
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage


# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def show_ui(message, history, request: gr.Request):
    """

    Displays the Streamlit chat UI and handles user interactions.



    Args:

        qa: The LangChain chain for question answering.

        prompt_to_user: The initial prompt to display to the user.

    """
    global chain
    session_id = request.session_hash
    response = ask_question(chain, message, session_id)
    # logging.info(f"Response: {response}")
    return response.content


def get_retriever(openai_api_key=None):
    """

    Creates and caches the document retriever.



    Args:

        openai_api_key: The OpenAI API key.



    Returns:

        An ensemble document retriever.

    """
    try:
        docs = load_data_files(data_dir="data")  
        # embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, model="text-embedding-3-small")
        embeddings = HuggingFaceEmbeddings()
        return ensemble_retriever_from_docs(docs, embeddings=embeddings)
    except Exception as e:
        logging.error(f"Error creating retriever: {e}")
        logging.exception(f"message")
        st.error("Error initializing the application. Please check the logs.")
        st.stop()  # Stop execution if retriever creation fails


def get_chain(openai_api_key=None, huggingfacehub_api_token=None):
    """

    Creates the question answering chain.



    Args:

        openai_api_key: The OpenAI API key.

        huggingfacehub_api_token: The Hugging Face Hub API token.



    Returns:

        A LangChain question answering chain.

    """
    try:
        ensemble_retriever = get_retriever(openai_api_key=openai_api_key)
        chain = create_full_chain(
            ensemble_retriever,
            openai_api_key=openai_api_key,
        )
        return ensemble_retriever, chain
    except Exception as e:
        logging.error(f"Error creating chain: {e}")
        logging.exception(f"message")
        st.error("Error initializing the application. Please check the logs.")
        st.stop()  # Stop execution if chain creation fails

def get_secret_or_input(secret_key, secret_name, info_link=None):
    """

    Retrieves a secret from Streamlit secrets or prompts the user for input.



    Args:

        secret_key: The key of the secret in Streamlit secrets.

        secret_name: The user-friendly name of the secret.

        info_link: An optional link to provide information about the secret.



    Returns:

        The secret value.

    """
    if secret_key in st.secrets:
        st.write("Found %s secret" % secret_key)
        secret_value = st.secrets[secret_key]
    else:
        st.write(f"Please provide your {secret_name}")
        secret_value = st.text_input(secret_name, key=f"input_{secret_key}", type="password")
        if secret_value:
            st.session_state[secret_key] = secret_value
        if info_link:
            st.markdown(f"[Get an {secret_name}]({info_link})")
    return secret_value

def process_uploaded_file(uploaded_file):
    """

    Processes the uploaded file and adds it to the vector database.



    Args:

        uploaded_file: The uploaded file object from Streamlit.

        openai_api_key: The OpenAI API key for embedding generation.

    """
    # try:
    if uploaded_file is not None:
        logging.info(f'run upload {uploaded_file}')

        if isinstance(uploaded_file, str):
            filename = uploaded_file
        else:
            filename = str(uploaded_file.name)

        # Load the document using the saved file path
        docs = load_file(Path(filename))

        global ensemble_retriever
        global chain

        all_docs = ensemble_retriever.retrievers[0].docs
        all_docs.extend(docs)

        ensemble_retriever.retrievers[1].add_documents(docs)

        new_bm25 = BM25Retriever.from_texts([t.page_content for t in all_docs])

        ensemble_retriever.retrievers[0] = new_bm25

        chain = create_full_chain(
            ensemble_retriever,
            openai_api_key=open_api_key,
        )

        logging.info("File uploaded and added to the knowledge base!")
        gr.Info('File uploaded and added to the knowledge base!', duration=3)
    
    return None
        
    # except Exception as e:
    #     logging.error(f"Error processing uploaded file: {e}")
    #     st.error("Error processing the file. Please check the logs.")

SUPPORTED_FORMATS = ['.txt', '.json', '.pdf']

def activate():
    return gr.update(interactive=True)

def deactivate():
    return gr.update(interactive=False)

def reset(z, request: gr.Request):
    session_id = request.session_hash
    clean_session_history(session_id)
    return [], []

def main():
    with gr.Blocks() as demo:
        gr.Markdown(
            "# Equity Bank AI assistant \n"
            "Ask questions about Equity Bank's products and services:"
        )
        with gr.Tab('Chat'):
            clean_btn = gr.Button(value="Clean history", variant="secondary", size='sm', render=False)
            bot = gr.Chatbot(elem_id="chatbot", render=False)

            chat = gr.ChatInterface(
                show_ui,
                chatbot=bot,
                undo_btn=None,
                retry_btn=None,
                clear_btn=clean_btn,
            )
        with gr.Tab('Documents'):
            file_input = gr.File(
                label=f'{", ".join([str(f) for f in SUPPORTED_FORMATS])}',
                file_types=SUPPORTED_FORMATS,
            )
            submit_btn = gr.Button(value="Index file", variant="primary", interactive=False)

        clean_btn.click(fn=reset, inputs=clean_btn, outputs=[bot, chat.chatbot_state])

        submit_btn.click(
                fn=process_uploaded_file,
                inputs=file_input,
                outputs=file_input,
                api_name="Index file"
            )
        
        file_input.upload(fn=activate, outputs=[submit_btn])
        file_input.clear(fn=deactivate, outputs=[submit_btn])

    demo.launch(share=True)


open_api_key = os.getenv('OPEN_API_KEY')

ensemble_retriever, chain = get_chain(
    openai_api_key=open_api_key,
    huggingfacehub_api_token=None
)



if __name__ == "__main__":
    main()