Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import logging | |
from langchain_community.chat_message_histories import StreamlitChatMessageHistory | |
from langchain_community.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings | |
from langchain_community.retrievers import BM25Retriever | |
from ensemble import ensemble_retriever_from_docs | |
from full_chain import create_full_chain, ask_question | |
from local_loader import load_data_files, load_file | |
from vector_store import EmbeddingProxy | |
from memory import clean_session_history | |
from pathlib import Path | |
import gradio as gr | |
from langchain.chat_models import ChatOpenAI | |
from langchain.schema import AIMessage, HumanMessage | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
def show_ui(message, history, request: gr.Request): | |
""" | |
Displays the Streamlit chat UI and handles user interactions. | |
Args: | |
qa: The LangChain chain for question answering. | |
prompt_to_user: The initial prompt to display to the user. | |
""" | |
global chain | |
session_id = request.session_hash | |
response = ask_question(chain, message, session_id) | |
# logging.info(f"Response: {response}") | |
return response.content | |
def get_retriever(openai_api_key=None): | |
""" | |
Creates and caches the document retriever. | |
Args: | |
openai_api_key: The OpenAI API key. | |
Returns: | |
An ensemble document retriever. | |
""" | |
try: | |
docs = load_data_files(data_dir="data") | |
# embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, model="text-embedding-3-small") | |
embeddings = HuggingFaceEmbeddings() | |
return ensemble_retriever_from_docs(docs, embeddings=embeddings) | |
except Exception as e: | |
logging.error(f"Error creating retriever: {e}") | |
logging.exception(f"message") | |
st.error("Error initializing the application. Please check the logs.") | |
st.stop() # Stop execution if retriever creation fails | |
def get_chain(openai_api_key=None, huggingfacehub_api_token=None): | |
""" | |
Creates the question answering chain. | |
Args: | |
openai_api_key: The OpenAI API key. | |
huggingfacehub_api_token: The Hugging Face Hub API token. | |
Returns: | |
A LangChain question answering chain. | |
""" | |
try: | |
ensemble_retriever = get_retriever(openai_api_key=openai_api_key) | |
chain = create_full_chain( | |
ensemble_retriever, | |
openai_api_key=openai_api_key, | |
) | |
return ensemble_retriever, chain | |
except Exception as e: | |
logging.error(f"Error creating chain: {e}") | |
logging.exception(f"message") | |
st.error("Error initializing the application. Please check the logs.") | |
st.stop() # Stop execution if chain creation fails | |
def get_secret_or_input(secret_key, secret_name, info_link=None): | |
""" | |
Retrieves a secret from Streamlit secrets or prompts the user for input. | |
Args: | |
secret_key: The key of the secret in Streamlit secrets. | |
secret_name: The user-friendly name of the secret. | |
info_link: An optional link to provide information about the secret. | |
Returns: | |
The secret value. | |
""" | |
if secret_key in st.secrets: | |
st.write("Found %s secret" % secret_key) | |
secret_value = st.secrets[secret_key] | |
else: | |
st.write(f"Please provide your {secret_name}") | |
secret_value = st.text_input(secret_name, key=f"input_{secret_key}", type="password") | |
if secret_value: | |
st.session_state[secret_key] = secret_value | |
if info_link: | |
st.markdown(f"[Get an {secret_name}]({info_link})") | |
return secret_value | |
def process_uploaded_file(uploaded_file): | |
""" | |
Processes the uploaded file and adds it to the vector database. | |
Args: | |
uploaded_file: The uploaded file object from Streamlit. | |
openai_api_key: The OpenAI API key for embedding generation. | |
""" | |
# try: | |
if uploaded_file is not None: | |
logging.info(f'run upload {uploaded_file}') | |
if isinstance(uploaded_file, str): | |
filename = uploaded_file | |
else: | |
filename = str(uploaded_file.name) | |
# Load the document using the saved file path | |
docs = load_file(Path(filename)) | |
global ensemble_retriever | |
global chain | |
all_docs = ensemble_retriever.retrievers[0].docs | |
all_docs.extend(docs) | |
ensemble_retriever.retrievers[1].add_documents(docs) | |
new_bm25 = BM25Retriever.from_texts([t.page_content for t in all_docs]) | |
ensemble_retriever.retrievers[0] = new_bm25 | |
chain = create_full_chain( | |
ensemble_retriever, | |
openai_api_key=open_api_key, | |
) | |
logging.info("File uploaded and added to the knowledge base!") | |
gr.Info('File uploaded and added to the knowledge base!', duration=3) | |
return None | |
# except Exception as e: | |
# logging.error(f"Error processing uploaded file: {e}") | |
# st.error("Error processing the file. Please check the logs.") | |
SUPPORTED_FORMATS = ['.txt', '.json', '.pdf'] | |
def activate(): | |
return gr.update(interactive=True) | |
def deactivate(): | |
return gr.update(interactive=False) | |
def reset(z, request: gr.Request): | |
session_id = request.session_hash | |
clean_session_history(session_id) | |
return [], [] | |
def main(): | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"# Equity Bank AI assistant \n" | |
"Ask questions about Equity Bank's products and services:" | |
) | |
with gr.Tab('Chat'): | |
clean_btn = gr.Button(value="Clean history", variant="secondary", size='sm', render=False) | |
bot = gr.Chatbot(elem_id="chatbot", render=False) | |
chat = gr.ChatInterface( | |
show_ui, | |
chatbot=bot, | |
undo_btn=None, | |
retry_btn=None, | |
clear_btn=clean_btn, | |
) | |
with gr.Tab('Documents'): | |
file_input = gr.File( | |
label=f'{", ".join([str(f) for f in SUPPORTED_FORMATS])}', | |
file_types=SUPPORTED_FORMATS, | |
) | |
submit_btn = gr.Button(value="Index file", variant="primary", interactive=False) | |
clean_btn.click(fn=reset, inputs=clean_btn, outputs=[bot, chat.chatbot_state]) | |
submit_btn.click( | |
fn=process_uploaded_file, | |
inputs=file_input, | |
outputs=file_input, | |
api_name="Index file" | |
) | |
file_input.upload(fn=activate, outputs=[submit_btn]) | |
file_input.clear(fn=deactivate, outputs=[submit_btn]) | |
demo.launch(share=True) | |
open_api_key = os.getenv('OPEN_API_KEY') | |
ensemble_retriever, chain = get_chain( | |
openai_api_key=open_api_key, | |
huggingfacehub_api_token=None | |
) | |
if __name__ == "__main__": | |
main() |