prasadnu commited on
Commit
49a0db0
ยท
1 Parent(s): 3ee5d85

rerank model

Browse files
pages/Multimodal_Conversational_Search.py CHANGED
@@ -34,9 +34,6 @@ st.set_page_config(
34
  layout="wide",
35
  page_icon="images/opensearch_mark_default.png"
36
  )
37
- if "trigger_search" not in st.session_state:
38
- st.session_state.trigger_search = False
39
-
40
 
41
  parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
42
  USER_ICON = "images/user.png"
@@ -45,11 +42,22 @@ REGENERATE_ICON = "images/regenerate.png"
45
  s3_bucket_ = "pdf-repo-uploads"
46
  #"pdf-repo-uploads"
47
 
48
- polly_client = boto3.client('polly',aws_access_key_id=st.secrets['user_access_key'],
49
- aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1')
 
 
 
 
 
 
 
50
 
51
 
52
  # Check if the user ID is already stored in the session state
 
 
 
 
53
  if 'user_id' in st.session_state:
54
  user_id = st.session_state['user_id']
55
  #print(f"User ID: {user_id}")
@@ -103,12 +111,6 @@ if "input_rag_searchType" not in st.session_state:
103
  st.session_state.input_rag_searchType = ["Vector Search"]
104
 
105
 
106
-
107
- region = 'us-east-1'
108
- bedrock_runtime_client = boto3.client('bedrock-runtime',region_name=region)
109
- output = []
110
- service = 'es'
111
-
112
  st.markdown("""
113
  <style>
114
  [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
@@ -154,7 +156,6 @@ if clear:
154
 
155
 
156
  def handle_input():
157
- st.session_state.trigger_search = True
158
  print("Question: "+st.session_state.input_query)
159
  print("-----------")
160
  print("\n\n")
@@ -207,29 +208,33 @@ def render_answer(question,answer,index,res_img):
207
  ans_ = answer['answer']
208
  st.write(ans_)
209
 
210
- polly_response = polly_client.synthesize_speech(VoiceId='Joanna',
211
- OutputFormat='ogg_vorbis',
212
- Text = ans_,
213
- Engine = 'neural')
214
 
215
- audio_col1, audio_col2 = st.columns([50,50])
216
- with audio_col1:
217
- st.audio(polly_response['AudioStream'].read(), format="audio/ogg")
218
- rdn_key_1 = ''.join([random.choice(string.ascii_letters)
219
- for _ in range(10)])
220
  # def show_maxsim():
221
  # st.session_state.show_columns = True
222
  # st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
 
223
  # with placeholder.container():
224
- # if st.session_state.trigger_search:
225
- # handle_input()
226
- # render_all()
227
- # #render_all()
228
  # if(st.session_state.input_is_colpali):
229
  # st.button("Show similarity map",key=rdn_key_1,on_click = show_maxsim)
230
 
231
  colu1,colu2,colu3 = st.columns([4,82,20])
232
  with colu2:
 
 
 
 
 
 
233
  with st.expander("Relevant Sources:"):
234
  with st.container():
235
  if(len(res_img)>0):
@@ -265,45 +270,42 @@ def render_answer(question,answer,index,res_img):
265
  with cols[idx]:
266
 
267
  st.image(parent_dirname+"/figures/"+st.session_state.input_index+"/"+img+".jpg")
268
- #st.write(caption)
269
  idx = idx+1
270
  if(st.session_state.show_columns == True):
271
  st.session_state.show_columns = False
272
- #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Sources from the document:</b></div>", unsafe_allow_html = True)
273
  if(len(answer["table"] )>0):
274
  #with st.expander("Table:"):
275
- df = pd.read_csv(answer["table"][0]['name'],skipinitialspace = True, on_bad_lines='skip',delimiter='`')
276
- df.fillna(method='pad', inplace=True)
277
  st.table(df)
278
  #with st.expander("Raw sources:"):
279
  st.write(answer["source"])
280
 
281
 
282
- with col_3:
283
- if(index == len(st.session_state.questions_)):
284
-
285
- rdn_key = ''.join([random.choice(string.ascii_letters)
286
- for _ in range(10)])
287
- currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
288
- oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
289
- # def on_button_click():
290
- # if(currentValue!=oldValue or 1==1):
291
- # st.session_state.input_query = st.session_state.questions_[-1]["question"]
292
- # st.session_state.answers_.pop()
293
- # st.session_state.questions_.pop()
294
 
295
- # handle_input()
296
- # with placeholder.container():
297
- # render_all()
298
- # if("currentValue" in st.session_state):
299
- # del st.session_state["currentValue"]
300
-
301
- # try:
302
- # del regenerate
303
- # except:
304
- # pass
305
- # placeholder__ = st.empty()
306
- # placeholder__.button("๐Ÿ”„",key=rdn_key,on_click=on_button_click)
307
 
308
 
309
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
@@ -326,19 +328,21 @@ def render_all():
326
 
327
  placeholder = st.empty()
328
  with placeholder.container():
329
- if st.session_state.trigger_search:
330
- handle_input()
 
331
  render_all()
332
-
333
 
334
  st.markdown("")
335
  col_2, col_3 = st.columns([75,20])
336
  with col_2:
337
- #st.markdown("")
338
  input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_query")
339
  with col_3:
340
- #hidden = st.button("RUN",disabled=True,key = "hidden")
341
- play = st.button("Go",on_click=handle_input,key = "play")
 
 
342
  with st.sidebar:
343
  st.page_link("app.py", label=":orange[Home]", icon="๐Ÿ ")
344
  st.subheader(":blue[Sample Data]")
@@ -411,10 +415,5 @@ with st.sidebar:
411
  with st.expander("Sample questions for Colpali retriever:"):
412
  st.write("1. Proportion of female new hires 2021-2023? \n\n 2. First-half 2021 return on unlisted real estate investments? \n\n 3. Trend of the fund's expected absolute volatility between January 2014 and January 2016? \n\n 4. Fund return percentage in 2017? \n\n 5. Annualized gross return of the fund from 1997 to 2008?")
413
 
414
- run = st.sidebar.button("๐Ÿ” Run Search")
415
-
416
- if run:
417
- st.session_state.trigger_search = True
418
 
419
-
420
 
 
34
  layout="wide",
35
  page_icon="images/opensearch_mark_default.png"
36
  )
 
 
 
37
 
38
  parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
39
  USER_ICON = "images/user.png"
 
42
  s3_bucket_ = "pdf-repo-uploads"
43
  #"pdf-repo-uploads"
44
 
45
+ # @st.cache_resource
46
+ # def get_polly_client():
47
+ # return boto3.client('polly',
48
+ # aws_access_key_id=st.secrets['user_access_key'],
49
+ # aws_secret_access_key=st.secrets['user_secret_key'],
50
+ # region_name='us-east-1'
51
+ # )
52
+
53
+ # polly_client = get_polly_client()
54
 
55
 
56
  # Check if the user ID is already stored in the session state
57
+
58
+ if "trigger_search" not in st.session_state:
59
+ st.session_state.trigger_search = False
60
+
61
  if 'user_id' in st.session_state:
62
  user_id = st.session_state['user_id']
63
  #print(f"User ID: {user_id}")
 
111
  st.session_state.input_rag_searchType = ["Vector Search"]
112
 
113
 
 
 
 
 
 
 
114
  st.markdown("""
115
  <style>
116
  [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
 
156
 
157
 
158
  def handle_input():
 
159
  print("Question: "+st.session_state.input_query)
160
  print("-----------")
161
  print("\n\n")
 
208
  ans_ = answer['answer']
209
  st.write(ans_)
210
 
211
+ # polly_response = polly_client.synthesize_speech(VoiceId='Joanna',
212
+ # OutputFormat='ogg_vorbis',
213
+ # Text = ans_,
214
+ # Engine = 'neural')
215
 
216
+ # audio_col1, audio_col2 = st.columns([50,50])
217
+ # with audio_col1:
218
+ # st.audio(polly_response['AudioStream'].read(), format="audio/ogg")
219
+ # rdn_key_1 = ''.join([random.choice(string.ascii_letters)
220
+ # for _ in range(10)])
221
  # def show_maxsim():
222
  # st.session_state.show_columns = True
223
  # st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
224
+ # handle_input()
225
  # with placeholder.container():
226
+ # render_all()
 
 
 
227
  # if(st.session_state.input_is_colpali):
228
  # st.button("Show similarity map",key=rdn_key_1,on_click = show_maxsim)
229
 
230
  colu1,colu2,colu3 = st.columns([4,82,20])
231
  with colu2:
232
+ @st.cache_data
233
+ def load_table_from_file(filepath):
234
+ df = pd.read_csv(filepath, skipinitialspace=True, on_bad_lines='skip', delimiter='`')
235
+ df.fillna(method='pad', inplace=True)
236
+ return df
237
+
238
  with st.expander("Relevant Sources:"):
239
  with st.container():
240
  if(len(res_img)>0):
 
270
  with cols[idx]:
271
 
272
  st.image(parent_dirname+"/figures/"+st.session_state.input_index+"/"+img+".jpg")
 
273
  idx = idx+1
274
  if(st.session_state.show_columns == True):
275
  st.session_state.show_columns = False
 
276
  if(len(answer["table"] )>0):
277
  #with st.expander("Table:"):
278
+ df = load_table_from_file(answer["table"][0]['name'])
 
279
  st.table(df)
280
  #with st.expander("Raw sources:"):
281
  st.write(answer["source"])
282
 
283
 
284
+ # with col_3:
285
+ # if(index == len(st.session_state.questions_)):
286
+
287
+ # rdn_key = ''.join([random.choice(string.ascii_letters)
288
+ # for _ in range(10)])
289
+ # currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
290
+ # oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
291
+ # def on_button_click():
292
+ # if(currentValue!=oldValue or 1==1):
293
+ # st.session_state.input_query = st.session_state.questions_[-1]["question"]
294
+ # st.session_state.answers_.pop()
295
+ # st.session_state.questions_.pop()
296
 
297
+ # handle_input()
298
+ # with placeholder.container():
299
+ # render_all()
300
+ # if("currentValue" in st.session_state):
301
+ # del st.session_state["currentValue"]
302
+
303
+ # try:
304
+ # del regenerate
305
+ # except:
306
+ # pass
307
+ # placeholder__ = st.empty()
308
+ # placeholder__.button("๐Ÿ”„",key=rdn_key,on_click=on_button_click)
309
 
310
 
311
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
 
328
 
329
  placeholder = st.empty()
330
  with placeholder.container():
331
+ if st.session_state.trigger_search:
332
+ with st.spinner("Running search..."):
333
+ handle_input()
334
  render_all()
335
+ st.session_state.trigger_search = False # reset
336
 
337
  st.markdown("")
338
  col_2, col_3 = st.columns([75,20])
339
  with col_2:
 
340
  input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_query")
341
  with col_3:
342
+ #play = st.button("Go",on_click=handle_input,key = "play")
343
+ play = st.button("Go", key="play")
344
+ if play:
345
+ st.session_state.trigger_search = True
346
  with st.sidebar:
347
  st.page_link("app.py", label=":orange[Home]", icon="๐Ÿ ")
348
  st.subheader(":blue[Sample Data]")
 
415
  with st.expander("Sample questions for Colpali retriever:"):
416
  st.write("1. Proportion of female new hires 2021-2023? \n\n 2. First-half 2021 return on unlisted real estate investments? \n\n 3. Trend of the fund's expected absolute volatility between January 2014 and January 2016? \n\n 4. Fund return percentage in 2017? \n\n 5. Annualized gross return of the fund from 1997 to 2008?")
417
 
 
 
 
 
418
 
 
419