Spaces:
Sleeping
Sleeping
File size: 7,111 Bytes
820f884 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import os
import streamlit as st
import logging
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain_community.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
from ensemble import ensemble_retriever_from_docs
from full_chain import create_full_chain, ask_question
from local_loader import load_data_files, load_file
from vector_store import EmbeddingProxy
from memory import clean_session_history
from pathlib import Path
import gradio as gr
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def show_ui(message, history, request: gr.Request):
"""
Displays the Streamlit chat UI and handles user interactions.
Args:
qa: The LangChain chain for question answering.
prompt_to_user: The initial prompt to display to the user.
"""
global chain
session_id = request.session_hash
response = ask_question(chain, message, session_id)
# logging.info(f"Response: {response}")
return response.content
def get_retriever(openai_api_key=None):
"""
Creates and caches the document retriever.
Args:
openai_api_key: The OpenAI API key.
Returns:
An ensemble document retriever.
"""
try:
docs = load_data_files(data_dir="data")
# embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, model="text-embedding-3-small")
embeddings = HuggingFaceEmbeddings()
return ensemble_retriever_from_docs(docs, embeddings=embeddings)
except Exception as e:
logging.error(f"Error creating retriever: {e}")
logging.exception(f"message")
st.error("Error initializing the application. Please check the logs.")
st.stop() # Stop execution if retriever creation fails
def get_chain(openai_api_key=None, huggingfacehub_api_token=None):
"""
Creates the question answering chain.
Args:
openai_api_key: The OpenAI API key.
huggingfacehub_api_token: The Hugging Face Hub API token.
Returns:
A LangChain question answering chain.
"""
try:
ensemble_retriever = get_retriever(openai_api_key=openai_api_key)
chain = create_full_chain(
ensemble_retriever,
openai_api_key=openai_api_key,
)
return ensemble_retriever, chain
except Exception as e:
logging.error(f"Error creating chain: {e}")
logging.exception(f"message")
st.error("Error initializing the application. Please check the logs.")
st.stop() # Stop execution if chain creation fails
def get_secret_or_input(secret_key, secret_name, info_link=None):
"""
Retrieves a secret from Streamlit secrets or prompts the user for input.
Args:
secret_key: The key of the secret in Streamlit secrets.
secret_name: The user-friendly name of the secret.
info_link: An optional link to provide information about the secret.
Returns:
The secret value.
"""
if secret_key in st.secrets:
st.write("Found %s secret" % secret_key)
secret_value = st.secrets[secret_key]
else:
st.write(f"Please provide your {secret_name}")
secret_value = st.text_input(secret_name, key=f"input_{secret_key}", type="password")
if secret_value:
st.session_state[secret_key] = secret_value
if info_link:
st.markdown(f"[Get an {secret_name}]({info_link})")
return secret_value
def process_uploaded_file(uploaded_file):
"""
Processes the uploaded file and adds it to the vector database.
Args:
uploaded_file: The uploaded file object from Streamlit.
openai_api_key: The OpenAI API key for embedding generation.
"""
# try:
if uploaded_file is not None:
logging.info(f'run upload {uploaded_file}')
if isinstance(uploaded_file, str):
filename = uploaded_file
else:
filename = str(uploaded_file.name)
# Load the document using the saved file path
docs = load_file(Path(filename))
global ensemble_retriever
global chain
all_docs = ensemble_retriever.retrievers[0].docs
all_docs.extend(docs)
ensemble_retriever.retrievers[1].add_documents(docs)
new_bm25 = BM25Retriever.from_texts([t.page_content for t in all_docs])
ensemble_retriever.retrievers[0] = new_bm25
chain = create_full_chain(
ensemble_retriever,
openai_api_key=open_api_key,
)
logging.info("File uploaded and added to the knowledge base!")
gr.Info('File uploaded and added to the knowledge base!', duration=3)
return None
# except Exception as e:
# logging.error(f"Error processing uploaded file: {e}")
# st.error("Error processing the file. Please check the logs.")
SUPPORTED_FORMATS = ['.txt', '.json', '.pdf']
def activate():
return gr.update(interactive=True)
def deactivate():
return gr.update(interactive=False)
def reset(z, request: gr.Request):
session_id = request.session_hash
clean_session_history(session_id)
return [], []
def main():
with gr.Blocks() as demo:
gr.Markdown(
"# Equity Bank AI assistant \n"
"Ask questions about Equity Bank's products and services:"
)
with gr.Tab('Chat'):
clean_btn = gr.Button(value="Clean history", variant="secondary", size='sm', render=False)
bot = gr.Chatbot(elem_id="chatbot", render=False)
chat = gr.ChatInterface(
show_ui,
chatbot=bot,
undo_btn=None,
retry_btn=None,
clear_btn=clean_btn,
)
with gr.Tab('Documents'):
file_input = gr.File(
label=f'{", ".join([str(f) for f in SUPPORTED_FORMATS])}',
file_types=SUPPORTED_FORMATS,
)
submit_btn = gr.Button(value="Index file", variant="primary", interactive=False)
clean_btn.click(fn=reset, inputs=clean_btn, outputs=[bot, chat.chatbot_state])
submit_btn.click(
fn=process_uploaded_file,
inputs=file_input,
outputs=file_input,
api_name="Index file"
)
file_input.upload(fn=activate, outputs=[submit_btn])
file_input.clear(fn=deactivate, outputs=[submit_btn])
demo.launch(share=True)
open_api_key = os.getenv('OPEN_API_KEY')
ensemble_retriever, chain = get_chain(
openai_api_key=open_api_key,
huggingfacehub_api_token=None
)
if __name__ == "__main__":
main() |