File size: 3,568 Bytes
ea8b3bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
from langchain_community.chat_models import ChatOpenAI

from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain

import streamlit as st
from streamlit_chat import message


@st.cache_data()
def load_docs():
    documents = []
    for file in os.listdir('docs'):
        if file.endswith('.pdf'):
            pdf_path = "./docs/"+file
            loader = PyPDFLoader(pdf_path)
            documents.extend(loader.load())
        elif file.endswith('.docx') or file.endswith('.doc'):
            doc_path = './docs/'+file
            loader = Docx2txtLoader(doc_path)
            documents.extend(loader.load())
        elif file.endswith('.txt'):
            text_path = '.docs/'+file
            loader = TextLoader(text_path)
            documents.extend(loader.load())
    
    return documents

os.environ["OPENAI_API_KEY"] = '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'

llm_model = "gpt-3.5-turbo"
llm = ChatOpenAI(temperature=.7, model=llm_model)
#======================================================================================================================
# Load documents
documents = load_docs()
chat_history = []

# 1. Text splitter
text_splitter = CharacterTextSplitter(
    chunk_size = 100,
    chunk_overlap = 20,
    length_function = len
)

# 2. Embedding
embeddings = OpenAIEmbeddings()

docs = text_splitter.split_documents(documents)

#=====================================================================================================================
# 3. Storage
vector_store = Chroma.from_documents(
    documents=docs,
    embedding=embeddings,
    persist_directory='./data'
)
vector_store.persist()
# ====================================================================================================================
# 4. Retrieve
retriever = vector_store.as_retriever(search_kwargs={"k":6})
# docs = retriever.get_relevant_documents("Tell me more about Data Science")

# Make a chain to answer questions
qa_chain = ConversationalRetrievalChain.from_llm(
    llm, 
    vector_store.as_retriever(search_kwargs={'k':6}),
    return_source_documents=True,
    verbose=False
)
 

# cite sources - helper function to prettyfy responses
def process_llm_response(llm_response):
    print(llm_response['result'])
    print('\n\nSources:')
    for source in llm_response['source_documents']:
        print(source.metadata['source'])
        
#==============================FRONTEND=======================================
st.title("ViTo chatbot👠")
st.header("Ask anything about ViTo company...")

if 'generated' not in st.session_state:
    st.session_state['generated'] = []

if 'past' not in st.session_state:
    st.session_state['past'] = []
    
def get_query():
    input_text = st.chat_input("Ask a question about your documents...")
    return input_text

# retrieve the user input
user_input = get_query()
if user_input:
    result = qa_chain({'question': user_input, 'chat_history': chat_history})
    st.session_state.past.append(user_input)
    st.session_state.generated.append(result['answer'])
    
if st.session_state['generated']:
     for i in range(len(st.session_state['generated'])):
         message(st.session_state['past'][i], is_user=True, key=str(i)+'_user')
         message(st.session_state['generated'][i], key=str(i))