File size: 3,822 Bytes
358dcca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
022d1be
358dcca
 
 
 
 
 
 
 
 
 
022d1be
358dcca
 
 
022d1be
 
 
 
 
 
358dcca
 
 
 
3b7f280
358dcca
33ec7f4
022d1be
2a42dca
33ec7f4
 
 
 
 
022d1be
33ec7f4
358dcca
022d1be
 
 
 
358dcca
 
33ec7f4
022d1be
 
 
 
358dcca
 
 
 
022d1be
 
358dcca
 
 
 
 
 
022d1be
358dcca
022d1be
358dcca
 
 
 
022d1be
0674733
358dcca
 
 
 
 
 
 
 
022d1be
33ec7f4
358dcca
 
 
 
 
 
1f67ded
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import streamlit as st

## dockerized streamlit app wants to read from os.getenv(), otherwise use st.secrets
import os
api_key = os.getenv("LITELLM_KEY")
if api_key is None:
    api_key = st.secrets["LITELLM_KEY"]
cirrus_key = os.getenv("CIRRUS_KEY")
if cirrus_key is None:
    cirrus_key = st.secrets["CIRRUS_KEY"]        

st.title("HWC LLM Testing")


'''
(Demo will take a while to load first while processing all data!  Will be pre-processed in future...)
'''

import requests
import zipfile
def download_and_unzip(url, output_dir):
    response = requests.get(url)
    zip_file_path = os.path.basename(url)
    with open(zip_file_path, 'wb') as f:
        f.write(response.content)
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(output_dir)
    os.remove(zip_file_path)


import pathlib
from langchain_community.document_loaders import PyPDFLoader
@st.cache_data
def pdf_loader(path):
    all_documents = []
    docs_dir = pathlib.Path(path)
    for file in docs_dir.iterdir():
        loader = PyPDFLoader(file)
        documents = loader.load()
        all_documents.extend(documents)
    return all_documents

download_and_unzip("https://minio.carlboettiger.info/public-data/hwc.zip", "hwc")
docs = pdf_loader('hwc/')


from langchain_openai import OpenAIEmbeddings
embedding = OpenAIEmbeddings(
             model = "cirrus",
             api_key = cirrus_key, 
             base_url = "https://llm.cirrus.carlboettiger.info/v1",
)


# Build a retrival agent
from langchain_text_splitters import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=500)
splits = text_splitter.split_documents(docs)

from langchain_core.vectorstores import InMemoryVectorStore
@st.cache_resource
def vector_store(_splits):
    vectorstore = InMemoryVectorStore.from_documents(documents=_splits, embedding=embedding)
    retriever = vectorstore.as_retriever()
    return retriever

# here we go, slow part:
retriever = vector_store(splits)

# Set up the language model
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model = "llama3", api_key = api_key, base_url = "https://llm.nrp-nautilus.io",  temperature=0)
## Cirrus instead:
system_prompt = (
    "You are an assistant for question-answering tasks. "
    "Use the following scientific articles as the retrieved context to answer "
    "the question. Appropriately cite the articles from the context on which your answer is based. "
    "Do not attempt to cite articles that are not in the context."
    "If you don't know the answer, say that you don't know."
    "Use up to five sentences maximum and keep the "
    "answer concise."
    "\n\n"
    "{context}"
)

from langchain_core.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{input}"),
    ]
)
from langchain.chains.combine_documents import create_stuff_documents_chain
question_answer_chain = create_stuff_documents_chain(llm, prompt)
from langchain.chains import create_retrieval_chain
rag_chain = create_retrieval_chain(retriever, question_answer_chain)



# Place agent inside a streamlit application:
if prompt := st.chat_input("What are the most cost-effective prevention methods for elephants raiding my crops?"):
    with st.chat_message("user"):
        st.markdown(prompt)

    with st.chat_message("assistant"):
        results = rag_chain.invoke({"input": prompt})
        st.write(results['answer'])

        with st.expander("See context matched"):
            # FIXME parse results dict and display in pretty format
            st.write(results['context'])


# adapt for memory / multi-question interaction with:
# https://python.langchain.com/docs/tutorials/qa_chat_history/

# Also see structured outputs.