boryasbora commited on
Commit
5e1e7c7
·
verified ·
1 Parent(s): d02ff36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -49
app.py CHANGED
@@ -8,14 +8,11 @@ from langchain_community.llms import HuggingFacePipeline
8
  from langchain.retrievers import ParentDocumentRetriever
9
  from langchain.storage import InMemoryStore
10
  from langchain_chroma import Chroma
11
- from langchain.llms import LlamaCpp
12
  from langchain_openai import ChatOpenAI
13
  from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
14
  from langchain_core.output_parsers import StrOutputParser
15
  from langchain_core.runnables import RunnableLambda
16
  from datetime import date
17
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
18
- import threading
19
  import time
20
  llm_list = ['Mistral-7B-Instruct-v0.2','Mixtral-8x7B-Instruct-v0.1','LLAMA3']
21
  blablador_base = "https://helmholtz-blablador.fz-juelich.de:8000/v1"
@@ -24,13 +21,6 @@ os.environ['LANGCHAIN_TRACING_V2'] = 'true'
24
  os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
25
  os.environ['LANGCHAIN_API_KEY'] = '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'
26
 
27
-
28
- @st.cache_resource
29
- def load_model():
30
- model_name = "EleutherAI/gpt-neo-125M"
31
- tokenizer = AutoTokenizer.from_pretrained(model_name)
32
- model = AutoModelForCausalLM.from_pretrained(model_name)
33
- return model, tokenizer
34
  def load_from_pickle(filename):
35
  with open(filename, "rb") as file:
36
  return pickle.load(file)
@@ -85,20 +75,7 @@ def get_chain(temperature,selected_model):
85
  openai_api_key=llm_api,
86
  openai_api_base=blablador_base,
87
  streaming=True)
88
- # model, tokenizer = load_model()
89
 
90
- # pipe = pipeline(
91
- # "text-generation",
92
- # model=model,
93
- # tokenizer=tokenizer,
94
- # max_length=1800,
95
- # max_new_tokens = 200,
96
- # temperature=temperature,
97
- # top_p=0.95,
98
- # repetition_penalty=1.15
99
- # )
100
-
101
- # llm = HuggingFacePipeline(pipeline=pipe)
102
 
103
 
104
  today = date.today()
@@ -134,24 +111,7 @@ def clear_chat_history():
134
  st.session_state.messages = []
135
  st.session_state.context_sources = []
136
  st.session_state.key = 0
137
- def run_with_timeout(func, args, timeout):
138
- result = [None]
139
- def worker():
140
- result[0] = func(*args)
141
- thread = threading.Thread(target=worker)
142
- thread.start()
143
- thread.join(timeout)
144
- if thread.is_alive():
145
- return None
146
- return result[0]
147
- # In your Streamlit app
148
- def generate_response(chain, query, context):
149
- timeout_seconds = 180
150
- result = chain.invoke, ({"question": query, "chat_history": st.session_state.messages},)
151
- if result is None:
152
- return result
153
- # return "I apologize, but I couldn't generate a response in time. The query might be too complex for me to process quickly. Could you try simplifying your question?"
154
- return result
155
  # Sidebar
156
  with st.sidebar:
157
  st.title("OHW Assistant")
@@ -206,17 +166,18 @@ if prompt := st.chat_input("How may I assist you today?"):
206
  st.markdown(prompt)
207
 
208
  with st.chat_message("assistant"):
209
- query = st.session_state.messages[-1]['content']
210
  tab1, tab2 = st.tabs(["Answer", "Sources"])
211
  with tab1:
212
- with st.spinner("Generating answer..."):
213
- start_time = time.time()
214
- full_answer = chain.invoke({"question": query, "chat_history":st.session_state.messages})# Context is handled within the chain
215
- end_time = time.time()
216
-
217
- st.markdown(full_answer,unsafe_allow_html=True)
 
 
218
  st.caption(f"Response time: {end_time - start_time:.2f} seconds")
219
-
220
  with tab2:
221
  if st.session_state.context_sources:
222
  for i, source in enumerate(st.session_state.context_sources):
@@ -226,6 +187,7 @@ if prompt := st.chat_input("How may I assist you today?"):
226
  else:
227
  st.write("No sources available for this query.")
228
 
 
229
  st.session_state.messages.append({"role": "assistant", "content": full_answer})
230
  st.session_state.messages[-1]['sources'] = st.session_state.context_sources
231
  st.session_state.messages[-1]['context'] = st.session_state.context_content
 
8
  from langchain.retrievers import ParentDocumentRetriever
9
  from langchain.storage import InMemoryStore
10
  from langchain_chroma import Chroma
 
11
  from langchain_openai import ChatOpenAI
12
  from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
13
  from langchain_core.output_parsers import StrOutputParser
14
  from langchain_core.runnables import RunnableLambda
15
  from datetime import date
 
 
16
  import time
17
  llm_list = ['Mistral-7B-Instruct-v0.2','Mixtral-8x7B-Instruct-v0.1','LLAMA3']
18
  blablador_base = "https://helmholtz-blablador.fz-juelich.de:8000/v1"
 
21
  os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
22
  os.environ['LANGCHAIN_API_KEY'] = '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'
23
 
 
 
 
 
 
 
 
24
  def load_from_pickle(filename):
25
  with open(filename, "rb") as file:
26
  return pickle.load(file)
 
75
  openai_api_key=llm_api,
76
  openai_api_base=blablador_base,
77
  streaming=True)
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
  today = date.today()
 
111
  st.session_state.messages = []
112
  st.session_state.context_sources = []
113
  st.session_state.key = 0
114
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  # Sidebar
116
  with st.sidebar:
117
  st.title("OHW Assistant")
 
166
  st.markdown(prompt)
167
 
168
  with st.chat_message("assistant"):
169
+ query=st.session_state.messages[-1]['content']
170
  tab1, tab2 = st.tabs(["Answer", "Sources"])
171
  with tab1:
172
+ start_time = time.time()
173
+ placeholder = st.empty() # Create a placeholder in Streamlit
174
+ full_answer = ""
175
+ for chunk in chain.stream({"question": query, "chat_history":st.session_state.messages}):
176
+
177
+ full_answer += chunk
178
+ placeholder.markdown(full_answer,unsafe_allow_html=True)
179
+ end_time = time.time()
180
  st.caption(f"Response time: {end_time - start_time:.2f} seconds")
 
181
  with tab2:
182
  if st.session_state.context_sources:
183
  for i, source in enumerate(st.session_state.context_sources):
 
187
  else:
188
  st.write("No sources available for this query.")
189
 
190
+
191
  st.session_state.messages.append({"role": "assistant", "content": full_answer})
192
  st.session_state.messages[-1]['sources'] = st.session_state.context_sources
193
  st.session_state.messages[-1]['context'] = st.session_state.context_content