prasadnu commited on
Commit
ad41a02
·
1 Parent(s): 2e2dda5

RAG changes

Browse files
pages/Multimodal_Conversational_Search.py CHANGED
@@ -15,6 +15,7 @@ import random
15
  import string
16
  import rag_DocumentLoader
17
  import rag_DocumentSearcher
 
18
  import pandas as pd
19
  from PIL import Image
20
  import shutil
@@ -22,10 +23,9 @@ import base64
22
  import time
23
  import botocore
24
  #from langchain.callbacks.base import BaseCallbackHandler
25
- #import streamlit_nested_layout
26
  #from IPython.display import clear_output, display, display_markdown, Markdown
27
  from requests_aws4auth import AWS4Auth
28
- #import copali
29
  from requests.auth import HTTPBasicAuth
30
 
31
 
@@ -41,8 +41,8 @@ AI_ICON = "images/opensearch-twitter-card.png"
41
  REGENERATE_ICON = "images/regenerate.png"
42
  s3_bucket_ = "pdf-repo-uploads"
43
  #"pdf-repo-uploads"
44
- polly_client = boto3.client('polly',aws_access_key_id=st.secrets['user_access_key'],
45
- aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1')
46
 
47
  # Check if the user ID is already stored in the session state
48
  if 'user_id' in st.session_state:
@@ -69,6 +69,13 @@ if "chats" not in st.session_state:
69
 
70
  if "questions_" not in st.session_state:
71
  st.session_state.questions_ = []
 
 
 
 
 
 
 
72
 
73
  if "answers_" not in st.session_state:
74
  st.session_state.answers_ = []
@@ -78,6 +85,9 @@ if "input_index" not in st.session_state:
78
 
79
  if "input_is_rerank" not in st.session_state:
80
  st.session_state.input_is_rerank = True
 
 
 
81
 
82
  if "input_copali_rerank" not in st.session_state:
83
  st.session_state.input_copali_rerank = False
@@ -89,8 +99,8 @@ if "input_query" not in st.session_state:
89
  st.session_state.input_query="which city has the highest average housing price in UK ?"#"What is the projected energy percentage from renewable sources in future?"#"Which city in United Kingdom has the highest average housing price ?"#"How many aged above 85 years died due to covid ?"# What is the projected energy from renewable sources ?"
90
 
91
 
92
- # if "input_rag_searchType" not in st.session_state:
93
- # st.session_state.input_rag_searchType = ["Sparse Search"]
94
 
95
 
96
 
@@ -129,7 +139,7 @@ st.markdown("""
129
 
130
 
131
  credentials = boto3.Session().get_credentials()
132
- awsauth = HTTPBasicAuth('prasadnu',st.secrets['rag_shopping_assistant_os_api_access'])
133
  service = 'es'
134
 
135
 
@@ -156,7 +166,7 @@ service = 'es'
156
  def write_logo():
157
  col1, col2, col3 = st.columns([5, 1, 5])
158
  with col2:
159
- st.image(AI_ICON, use_container_width='always')
160
 
161
  def write_top_bar():
162
  col1, col2 = st.columns([77,23])
@@ -164,7 +174,7 @@ def write_top_bar():
164
  st.write("")
165
  st.header("Chat with your data",divider='rainbow')
166
 
167
- #st.image(AI_ICON, use_container_width='always')
168
 
169
  with col2:
170
  st.write("")
@@ -188,6 +198,8 @@ if clear:
188
 
189
 
190
  def handle_input():
 
 
191
  print("Question: "+st.session_state.input_query)
192
  print("-----------")
193
  print("\n\n")
@@ -208,7 +220,11 @@ def handle_input():
208
  'id': len(st.session_state.questions_)
209
  }
210
  st.session_state.questions_.append(question_with_id)
211
- out_ = rag_DocumentSearcher.query_(awsauth, inputs, st.session_state['session_id'],st.session_state.input_rag_searchType)
 
 
 
 
212
  st.session_state.answers_.append({
213
  'answer': out_['text'],
214
  'source':out_['source'],
@@ -248,7 +264,7 @@ def write_user_message(md):
248
  col1, col2 = st.columns([3,97])
249
 
250
  with col1:
251
- st.image(USER_ICON, use_container_width='always')
252
  with col2:
253
  #st.warning(md['question'])
254
 
@@ -261,7 +277,7 @@ def render_answer(question,answer,index,res_img):
261
 
262
  col1, col2, col_3 = st.columns([4,74,22])
263
  with col1:
264
- st.image(AI_ICON, use_container_width='always')
265
  with col2:
266
  ans_ = answer['answer']
267
  st.write(ans_)
@@ -317,38 +333,60 @@ def render_answer(question,answer,index,res_img):
317
  #st.write("")
318
  colu1,colu2,colu3 = st.columns([4,82,20])
319
  with colu2:
320
- #with st.expander("Relevant Sources:"):
321
- with st.container():
322
- if(len(res_img)>0):
323
- with st.expander("Relevant Sources:"):
324
- #with st.expander("Images:"):
325
- st.write("Images:")
326
- col3,col4,col5 = st.columns([33,33,33])
327
- cols = [col3,col4]
328
- idx = 0
329
- #print(res_img)
330
- for img_ in res_img:
331
- if(img_['file'].lower()!='none' and idx < 2):
332
- img = img_['file'].split(".")[0]
333
- caption = img_['caption']
334
 
335
- with cols[idx]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
- st.image(parent_dirname+"/figures/"+st.session_state.input_index+"/"+img+".jpg")
 
 
338
  #st.write(caption)
339
- idx = idx+1
 
 
340
  #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Sources from the document:</b></div>", unsafe_allow_html = True)
341
- if(len(answer["table"] )>0):
342
- #with st.expander("Table:"):
343
- st.write("Table:")
344
  df = pd.read_csv(answer["table"][0]['name'],skipinitialspace = True, on_bad_lines='skip',delimiter='`')
345
  df.fillna(method='pad', inplace=True)
346
  st.table(df)
347
- #with st.expander("Raw sources:"):
348
- st.write("Raw sources:")
349
  st.write(answer["source"])
350
-
351
-
352
 
353
  with col_3:
354
 
@@ -360,22 +398,25 @@ def render_answer(question,answer,index,res_img):
360
 
361
  rdn_key = ''.join([random.choice(string.ascii_letters)
362
  for _ in range(10)])
 
 
363
  currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
364
  oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
365
  #print("changing values-----------------")
366
  def on_button_click():
367
- # print("button clicked---------------")
368
- # print(currentValue)
369
- # print(oldValue)
370
  if(currentValue!=oldValue or 1==1):
371
- #print("----------regenerate----------------")
372
  st.session_state.input_query = st.session_state.questions_[-1]["question"]
373
  st.session_state.answers_.pop()
374
  st.session_state.questions_.pop()
375
 
376
- handle_input()
377
- with placeholder.container():
378
- render_all()
 
 
 
 
 
379
 
380
  if("currentValue" in st.session_state):
381
  del st.session_state["currentValue"]
@@ -385,16 +426,18 @@ def render_answer(question,answer,index,res_img):
385
  except:
386
  pass
387
 
388
- #print("------------------------")
389
- #print(st.session_state)
390
-
391
  placeholder__ = st.empty()
392
-
393
  placeholder__.button("🔄",key=rdn_key,on_click=on_button_click)
 
 
 
394
 
395
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
396
  def write_chat_message(md, q,index):
397
- res_img = md['image']
 
 
 
398
  #st.session_state['session_id'] = res['session_id'] to be added in memory
399
  chat = st.container()
400
  with chat:
@@ -425,7 +468,7 @@ with col_2:
425
  input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_query")
426
  with col_3:
427
  #hidden = st.button("RUN",disabled=True,key = "hidden")
428
- play = st.button("GO",on_click=handle_input,key = "play")
429
  with st.sidebar:
430
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
431
  st.subheader(":blue[Sample Data]")
@@ -454,16 +497,19 @@ with st.sidebar:
454
  }
455
  </style>
456
  """,unsafe_allow_html=True)
 
 
 
 
 
457
  # Initialize boto3 to use the S3 client.
458
- s3_client = boto3.resource('s3',aws_access_key_id=st.secrets['user_access_key'],
459
- aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1')
460
  bucket=s3_client.Bucket(s3_bucket_)
461
 
462
  objects = bucket.objects.filter(Prefix="sample_pdfs/")
463
  urls = []
464
 
465
- client = boto3.client('s3',aws_access_key_id=st.secrets['user_access_key'],
466
- aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1')
467
 
468
  for obj in objects:
469
  if obj.key.endswith('.pdf'):
@@ -510,18 +556,29 @@ with st.sidebar:
510
  # print('lambda done')
511
  # st.success('you can start searching on your PDF')
512
 
513
-
514
- # if(pdf_doc_ is None or pdf_doc_ == ""):
515
- # if(index_select == "Global Warming stats"):
516
- # st.session_state.input_index = "globalwarmingnew"
517
- # if(index_select == "Covid19 impacts on Ireland"):
518
- # st.session_state.input_index = "covid19ie"#"choosetheknnalgorithmforyourbillionscaleusecasewithopensearchawsbigdatablog"
519
- # if(index_select == "BEIR"):
520
- # st.session_state.input_index = "2104"
521
- # if(index_select == "UK Housing"):
522
- # st.session_state.input_index = "ukhousingstats"
523
-
524
-
 
 
 
 
 
 
 
 
 
 
 
525
  # custom_index = st.text_input("If uploaded the file already, enter the original file name", value = "")
526
  # if(custom_index!=""):
527
  # st.session_state.input_index = re.sub('[^A-Za-z0-9]+', '', (custom_index.lower().replace(".pdf","").split("/")[-1].split(".")[0]).lower())
@@ -534,18 +591,33 @@ with st.sidebar:
534
  'Vector Search',
535
  'Sparse Search',
536
  ],
537
- ['Sparse Search'],
538
 
539
  key = 'input_rag_searchType',
540
  help = "Select the type of Search, adding more than one search type will activate hybrid search"#\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers)"
541
  )
542
 
543
  re_rank = st.checkbox('Re-rank results', key = 'input_re_rank', disabled = False, value = True, help = "Checking this box will re-rank the results using a cross-encoder model")
544
-
545
  if(re_rank):
546
  st.session_state.input_is_rerank = True
547
  else:
548
  st.session_state.input_is_rerank = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
 
550
  # copali_rerank = st.checkbox("Search and Re-rank with Token level vectors",key = 'copali_rerank',help = "Enabling this option uses 'Copali' model's page level image embeddings to retrieve documents and MaxSim to re-rank the pages.\n\n Hugging Face Model: https://huggingface.co/vidore/colpali")
551
 
@@ -553,5 +625,8 @@ with st.sidebar:
553
  # st.session_state.input_copali_rerank = True
554
  # else:
555
  # st.session_state.input_copali_rerank = False
 
 
 
556
 
557
 
 
15
  import string
16
  import rag_DocumentLoader
17
  import rag_DocumentSearcher
18
+ #import colpali
19
  import pandas as pd
20
  from PIL import Image
21
  import shutil
 
23
  import time
24
  import botocore
25
  #from langchain.callbacks.base import BaseCallbackHandler
 
26
  #from IPython.display import clear_output, display, display_markdown, Markdown
27
  from requests_aws4auth import AWS4Auth
28
+ import colpali
29
  from requests.auth import HTTPBasicAuth
30
 
31
 
 
41
  REGENERATE_ICON = "images/regenerate.png"
42
  s3_bucket_ = "pdf-repo-uploads"
43
  #"pdf-repo-uploads"
44
+ polly_client = boto3.Session(
45
+ region_name='us-east-1').client('polly')
46
 
47
  # Check if the user ID is already stored in the session state
48
  if 'user_id' in st.session_state:
 
69
 
70
  if "questions_" not in st.session_state:
71
  st.session_state.questions_ = []
72
+
73
+
74
+ if "show_columns" not in st.session_state:
75
+ st.session_state.show_columns = False
76
+
77
+ if "answer_ready" not in st.session_state:
78
+ st.session_state.answer_ready = False
79
 
80
  if "answers_" not in st.session_state:
81
  st.session_state.answers_ = []
 
85
 
86
  if "input_is_rerank" not in st.session_state:
87
  st.session_state.input_is_rerank = True
88
+
89
+ if "input_is_colpali" not in st.session_state:
90
+ st.session_state.input_is_colpali = False
91
 
92
  if "input_copali_rerank" not in st.session_state:
93
  st.session_state.input_copali_rerank = False
 
99
  st.session_state.input_query="which city has the highest average housing price in UK ?"#"What is the projected energy percentage from renewable sources in future?"#"Which city in United Kingdom has the highest average housing price ?"#"How many aged above 85 years died due to covid ?"# What is the projected energy from renewable sources ?"
100
 
101
 
102
+ if "input_rag_searchType" not in st.session_state:
103
+ st.session_state.input_rag_searchType = ["Vector Search"]
104
 
105
 
106
 
 
139
 
140
 
141
  credentials = boto3.Session().get_credentials()
142
+ awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, 'us-west-2', service, session_token=credentials.token)
143
  service = 'es'
144
 
145
 
 
166
  def write_logo():
167
  col1, col2, col3 = st.columns([5, 1, 5])
168
  with col2:
169
+ st.image(AI_ICON, use_column_width='always')
170
 
171
  def write_top_bar():
172
  col1, col2 = st.columns([77,23])
 
174
  st.write("")
175
  st.header("Chat with your data",divider='rainbow')
176
 
177
+ #st.image(AI_ICON, use_column_width='always')
178
 
179
  with col2:
180
  st.write("")
 
198
 
199
 
200
  def handle_input():
201
+ # st.session_state.answer_ready = True
202
+ # st.session_state.show_columns = False # reset column display
203
  print("Question: "+st.session_state.input_query)
204
  print("-----------")
205
  print("\n\n")
 
220
  'id': len(st.session_state.questions_)
221
  }
222
  st.session_state.questions_.append(question_with_id)
223
+ if(st.session_state.input_is_colpali):
224
+ out_ = colpali.colpali_search_rerank(st.session_state.input_query)
225
+ #print(out_)
226
+ else:
227
+ out_ = rag_DocumentSearcher.query_(awsauth, inputs, st.session_state['session_id'],st.session_state.input_rag_searchType)
228
  st.session_state.answers_.append({
229
  'answer': out_['text'],
230
  'source':out_['source'],
 
264
  col1, col2 = st.columns([3,97])
265
 
266
  with col1:
267
+ st.image(USER_ICON, use_column_width='always')
268
  with col2:
269
  #st.warning(md['question'])
270
 
 
277
 
278
  col1, col2, col_3 = st.columns([4,74,22])
279
  with col1:
280
+ st.image(AI_ICON, use_column_width='always')
281
  with col2:
282
  ans_ = answer['answer']
283
  st.write(ans_)
 
333
  #st.write("")
334
  colu1,colu2,colu3 = st.columns([4,82,20])
335
  with colu2:
336
+ with st.expander("Relevant Sources:"):
337
+ with st.container():
338
+ if(len(res_img)>0):
339
+ with st.expander("Images:"):
340
+
341
+ idx = 0
342
+ print(res_img)
343
+ for i in range(0,len(res_img)):
 
 
 
 
 
 
344
 
345
+ if(st.session_state.input_is_colpali):
346
+ if(st.session_state.show_columns == True):
347
+ cols_per_row = 3
348
+ st.session_state.image_placeholder=st.empty()
349
+ with st.session_state.image_placeholder.container():
350
+ row = st.columns(cols_per_row)
351
+ for j, item in enumerate(res_img[i:i+cols_per_row]):
352
+ with row[j]:
353
+ st.image(item['file'])
354
+
355
+ else:
356
+ st.session_state.image_placeholder = st.empty()
357
+ with st.session_state.image_placeholder.container():
358
+ col3_,col4_,col5_ = st.columns([33,33,33])
359
+ with col3_:
360
+ st.image(res_img[i]['file'])
361
+
362
+
363
+
364
+
365
+
366
+ else:
367
+ if(res_img[i]['file'].lower()!='none' and idx < 2):
368
+ col3,col4,col5 = st.columns([33,33,33])
369
+ cols = [col3,col4]
370
+ img = res_img[i]['file'].split(".")[0]
371
+ caption = res_img[i]['caption']
372
 
373
+ with cols[idx]:
374
+
375
+ st.image(parent_dirname+"/figures/"+st.session_state.input_index+"/"+img+".jpg")
376
  #st.write(caption)
377
+ idx = idx+1
378
+ if(st.session_state.show_columns == True):
379
+ st.session_state.show_columns = False
380
  #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Sources from the document:</b></div>", unsafe_allow_html = True)
381
+ if(len(answer["table"] )>0):
382
+ with st.expander("Table:"):
 
383
  df = pd.read_csv(answer["table"][0]['name'],skipinitialspace = True, on_bad_lines='skip',delimiter='`')
384
  df.fillna(method='pad', inplace=True)
385
  st.table(df)
386
+ with st.expander("Raw sources:"):
 
387
  st.write(answer["source"])
388
+
389
+
390
 
391
  with col_3:
392
 
 
398
 
399
  rdn_key = ''.join([random.choice(string.ascii_letters)
400
  for _ in range(10)])
401
+ rdn_key_1 = ''.join([random.choice(string.ascii_letters)
402
+ for _ in range(10)])
403
  currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
404
  oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
405
  #print("changing values-----------------")
406
  def on_button_click():
 
 
 
407
  if(currentValue!=oldValue or 1==1):
 
408
  st.session_state.input_query = st.session_state.questions_[-1]["question"]
409
  st.session_state.answers_.pop()
410
  st.session_state.questions_.pop()
411
 
412
+
413
+ def show_maxsim():
414
+ st.session_state.show_columns = True
415
+ st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
416
+ handle_input()
417
+ with placeholder.container():
418
+ render_all()
419
+
420
 
421
  if("currentValue" in st.session_state):
422
  del st.session_state["currentValue"]
 
426
  except:
427
  pass
428
 
 
 
 
429
  placeholder__ = st.empty()
 
430
  placeholder__.button("🔄",key=rdn_key,on_click=on_button_click)
431
+ placeholder__.button("Show similarity map",key=rdn_key_1,on_click = show_maxsim)
432
+
433
+
434
 
435
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
436
  def write_chat_message(md, q,index):
437
+ if(st.session_state.show_columns):
438
+ res_img = st.session_state.maxSimImages
439
+ else:
440
+ res_img = md['image']
441
  #st.session_state['session_id'] = res['session_id'] to be added in memory
442
  chat = st.container()
443
  with chat:
 
468
  input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_query")
469
  with col_3:
470
  #hidden = st.button("RUN",disabled=True,key = "hidden")
471
+ play = st.button("Go",on_click=handle_input,key = "play")
472
  with st.sidebar:
473
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
474
  st.subheader(":blue[Sample Data]")
 
497
  }
498
  </style>
499
  """,unsafe_allow_html=True)
500
+ with st.expander("Sample questions:"):
501
+ st.markdown("<span style = 'color:#FF9900;'>UK Housing</span> - which city has the highest average housing price in UK ?",unsafe_allow_html=True)
502
+ st.markdown("<span style = 'color:#FF9900;'>Global Warming stats</span> - What is the projected energy percentage from renewable sources in future?",unsafe_allow_html=True)
503
+ st.markdown("<span style = 'color:#FF9900;'>Covid19 impacts</span> - How many aged above 85 years died due to covid ?",unsafe_allow_html=True)
504
+
505
  # Initialize boto3 to use the S3 client.
506
+ s3_client = boto3.resource('s3')
 
507
  bucket=s3_client.Bucket(s3_bucket_)
508
 
509
  objects = bucket.objects.filter(Prefix="sample_pdfs/")
510
  urls = []
511
 
512
+ client = boto3.client('s3')
 
513
 
514
  for obj in objects:
515
  if obj.key.endswith('.pdf'):
 
556
  # print('lambda done')
557
  # st.success('you can start searching on your PDF')
558
 
559
+ ############## haystach demo temporary addition ############
560
+ # st.subheader(":blue[Multimodality]")
561
+ # colu1,colu2 = st.columns([50,50])
562
+ # with colu1:
563
+ # in_images = st.toggle('Images', key = 'in_images', disabled = False)
564
+ # with colu2:
565
+ # in_tables = st.toggle('Tables', key = 'in_tables', disabled = False)
566
+ # if(in_tables):
567
+ # st.session_state.input_table_with_sql = True
568
+ # else:
569
+ # st.session_state.input_table_with_sql = False
570
+
571
+ ############## haystach demo temporary addition ############
572
+ #if(pdf_doc_ is None or pdf_doc_ == ""):
573
+ if(index_select == "Global Warming stats"):
574
+ st.session_state.input_index = "globalwarmingnew"
575
+ if(index_select == "Covid19 impacts on Ireland"):
576
+ st.session_state.input_index = "covid19ie"#"choosetheknnalgorithmforyourbillionscaleusecasewithopensearchawsbigdatablog"
577
+ if(index_select == "BEIR"):
578
+ st.session_state.input_index = "2104"
579
+ if(index_select == "UK Housing"):
580
+ st.session_state.input_index = "hpijan2024hometrack"
581
+
582
  # custom_index = st.text_input("If uploaded the file already, enter the original file name", value = "")
583
  # if(custom_index!=""):
584
  # st.session_state.input_index = re.sub('[^A-Za-z0-9]+', '', (custom_index.lower().replace(".pdf","").split("/")[-1].split(".")[0]).lower())
 
591
  'Vector Search',
592
  'Sparse Search',
593
  ],
594
+ ['Vector Search'],
595
 
596
  key = 'input_rag_searchType',
597
  help = "Select the type of Search, adding more than one search type will activate hybrid search"#\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers)"
598
  )
599
 
600
  re_rank = st.checkbox('Re-rank results', key = 'input_re_rank', disabled = False, value = True, help = "Checking this box will re-rank the results using a cross-encoder model")
601
+
602
  if(re_rank):
603
  st.session_state.input_is_rerank = True
604
  else:
605
  st.session_state.input_is_rerank = False
606
+
607
+ st.subheader(":blue[Multi-vector retrieval]")
608
+
609
+ #st.write("Dataset indexed: https://huggingface.co/datasets/vespa-engine/gpfg-QA")
610
+ colpali_search_rerank = st.checkbox('Try Colpali multi-vector retrieval on the [sample dataset](https://huggingface.co/datasets/vespa-engine/gpfg-QA)', key = 'input_colpali', disabled = False, value = False, help = "Checking this box will use colpali as the embedding model and retrieval is performed using multi-vectors followed by re-ranking using MaxSim")
611
+
612
+ if(colpali_search_rerank):
613
+ st.session_state.input_is_colpali = True
614
+ #st.session_state.input_query = ""
615
+ else:
616
+ st.session_state.input_is_colpali = False
617
+
618
+ with st.expander("Sample questions for Colpali retriever:"):
619
+ st.write("1. Proportion of female new hires 2021-2023? \n\n 2. First-half 2021 return on unlisted real estate investments? \n\n 3. Trend of the fund's expected absolute volatility between January 2014 and January 2016? \n\n 4. Fund return percentage in 2017? \n\n 5. Annualized gross return of the fund from 1997 to 2008?")
620
+
621
 
622
  # copali_rerank = st.checkbox("Search and Re-rank with Token level vectors",key = 'copali_rerank',help = "Enabling this option uses 'Copali' model's page level image embeddings to retrieve documents and MaxSim to re-rank the pages.\n\n Hugging Face Model: https://huggingface.co/vidore/colpali")
623
 
 
625
  # st.session_state.input_copali_rerank = True
626
  # else:
627
  # st.session_state.input_copali_rerank = False
628
+
629
+
630
+
631
 
632