File size: 22,961 Bytes
77b927d
 
 
 
 
 
 
 
 
 
 
 
b505cc3
eab6925
 
77b927d
eab6925
77b927d
56ce28d
09df805
e54b3e0
eab6925
 
5bf3195
 
6a2ae7a
 
77b927d
 
 
 
 
 
45a7d81
 
aac3522
 
eab6925
437c715
 
 
0c14e18
 
eab6925
0c14e18
 
 
 
 
 
eab6925
09df805
 
 
437c715
 
 
 
 
 
 
 
 
5bf3195
75a550e
5bf3195
b505cc3
77b927d
 
 
 
 
eab6925
d631df4
09df805
eab6925
5bf3195
09df805
 
b505cc3
5046d92
f4483df
09df805
 
 
 
 
3b9d5e1
4c4a1d7
77b927d
eab6925
5c9ea55
437c715
0c14e18
 
 
6a2ae7a
5bf3195
 
3a3acc2
 
 
 
 
 
a30e3b1
3a3acc2
 
 
 
 
 
 
a30e3b1
52b23da
a280e4d
0b0b1a7
3a3acc2
6a2ae7a
3a3acc2
6a2ae7a
 
a30e3b1
6a2ae7a
52b23da
437c715
 
 
 
 
 
 
 
 
 
3a3acc2
52b23da
 
6a2ae7a
 
 
 
52b23da
 
 
 
 
adb5688
6a2ae7a
 
5bf3195
b505cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77b927d
eab6925
 
 
 
 
 
 
0c14e18
b505cc3
 
 
 
 
 
 
eab6925
b505cc3
 
77b927d
b505cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77b927d
b505cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77b927d
b505cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a3acc2
b505cc3
 
 
 
 
56ce28d
b505cc3
 
3a3acc2
b505cc3
56ce28d
b505cc3
56ce28d
b505cc3
 
 
 
437c715
b505cc3
9f89884
b505cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba1f3e2
3a3acc2
5bf3195
09df805
 
 
0d3609a
 
 
 
 
 
 
6a2ae7a
 
 
 
aac3522
 
6a2ae7a
 
 
aac3522
6a2ae7a
 
 
 
 
 
 
 
 
aac3522
6a2ae7a
cd6bb4b
9f89884
 
 
ba1f3e2
 
 
 
 
6a2ae7a
 
 
 
aac3522
7c5594b
aac3522
09df805
aac3522
437c715
19e6802
aac3522
cd6bb4b
 
 
 
19e6802
 
52b23da
aac3522
0a198c1
 
19e6802
 
 
3a3acc2
 
19e6802
46a784b
adb5688
 
 
 
 
 
 
 
 
46a784b
 
 
 
 
 
 
 
6a2ae7a
09df805
 
3a3acc2
46a784b
 
3a3acc2
aac3522
437c715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f89884
 
5498ada
6a2ae7a
3a3acc2
cd6bb4b
09df805
 
b505cc3
 
09df805
 
 
77b927d
09df805
 
 
 
ba1f3e2
09df805
 
 
 
 
ba1f3e2
09df805
b505cc3
528a449
 
09df805
77b927d
5c7c2df
 
 
 
 
 
 
 
437c715
d54eee9
5c9ea55
 
 
 
b505cc3
 
 
43ed833
 
b505cc3
 
 
 
 
09df805
 
5bf3195
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
import asyncio
if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
    asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())

import json
import logging
from sys import exc_info
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

import streamlit as st

from googleai import send_message as google_send_message, init_googleai, DEFAULT_INSTRUCTIONS as google_default_instructions

from langchain.chains import RetrievalQA
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.schema import AIMessage, HumanMessage, SystemMessage
import pandas as pd
from PIL import Image
from streamlit.runtime.state import session_state
import openai
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer

import streamlit.components.v1 as components

# st.set_page_config(
#     layout="wide", 
#     initial_sidebar_state="collapsed",
#     page_title="RaizedAI Startup Discovery Assistant",
#     #page_icon=":robot:"
#     )


import utils
import openai_utils as oai 

from streamlit_extras.stylable_container import stylable_container


# OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]  # app.pinecone.io
#model_name = 'text-embedding-ada-002'

# embed = OpenAIEmbeddings(
#     model=model_name,
#     openai_api_key=OPENAI_API_KEY
# )

#"🤖", 

#st.image("resources/raized_logo.png")
assistant_avatar = Image.open('resources/raized_logo.png')

carddict = {
        "name": [],
        "company_id": [],
        "description": [],
        "country": [],
        "customer_problem": [],
        "target_customer": [],
        "business_model": []
    }
    
@st.cache_resource
def init_models():
    logger.debug("init_models")    
    retriever = SentenceTransformer("msmarco-distilbert-base-v4")
    #model_name = "sentence-transformers/all-MiniLM-L6-v2"
    model_name = "sentence-transformers/msmarco-distilbert-base-v4"    
    #retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

    #reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    return retriever, tokenizer#, vectorstore

@st.cache_resource
def init_openai():
    logger.debug("init_openai")    
    st.session_state.openai_client = oai.get_client() 
    assistants = st.session_state.openai_client.beta.assistants.list(
        order="desc",
        limit="20",
    )
    return assistants

assistants = []#init_openai()


retriever, tokenizer = init_models()
st.session_state.retriever = retriever

# AVATAR_PATHS = {"assistant": st.image("resources/raized_logo.png"),
#                 "user": "👩‍⚖️"}

#st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]


def card(company_id, name, description, score, data_type, region, country, metadata, is_debug):

    if 'Summary' in metadata:
        description  = metadata['Summary']
            
    customer_problem = metadata['Customer problem'] if 'Customer problem' in metadata else ""
    target_customer = metadata['Target customer'] if 'Target customer' in metadata else ""
    business_model = ""    
    if 'Business model' in metadata:
        try:
            business_model = metadata['Business model']
            #business_model = json.loads(metadata['Business model'])
        except Exception as e:
            print(f"An error occurred: {str(e)}")
           
    markdown = f"""
        <div class="row align-items-start" style="padding-bottom:10px;">
             <div  class="col-md-8 col-sm-8">
                 <b>{name} (<a href='https://{company_id}'>website</a>).</b>
                 <p style="">{description}</p>
             </div>
             <div  class="col-md-1 col-sm-1"><span>{country}</span></div>
             <div  class="col-md-1 col-sm-1"><span>{customer_problem}</span></div>
             <div  class="col-md-1 col-sm-1"><span>{target_customer}</span></div>
             <div  class="col-md-1 col-sm-1"><span>{business_model}</span></div>
    """             
    business_model_str = ", ".join(business_model)
    company_id_url = "https://" + company_id

    carddict["name"].append(name)
    carddict["company_id"].append(company_id_url)
    carddict["description"].append(description)
    carddict["country"].append(country)
    carddict["customer_problem"].append(customer_problem)
    carddict["target_customer"].append(target_customer)
    carddict["business_model"].append(business_model_str)

    if is_debug:
        markdown = markdown + f"""
        <div  class="col-md-1 col-sm-1" style="display:none;">
            <button type='button' onclick="like_company({company_id});">Like</button>
            <button type='button' onclick="dislike_company({company_id});">DisLike</button>
        </div>
        <div  class="col-md-1 col-sm-1">
            <span>{data_type}</span>
            <span>[Score: {score}</span>
        </div>        
        """             
    markdown = markdown + "</div>"
    #print(f" markdown for {company_id}\n{markdown}")
    return markdown

def run_googleai(query, prompt):
    try:
        logger.debug(f"User: {query}")
        response = google_send_message(query, prompt)
        response = response['output']
        logger.debug(f"Agent: {response  }")
        content_container = st.container()  #, col_sidepanel = st.columns([4, 1], gap="small")
        with content_container:
            with st.chat_message(name = 'User'):
                st.write(query)
            with st.chat_message(name = 'Agent', avatar = assistant_avatar):
                st.write(response)
        st.session_state.messages.append({"role": "user", "content": query})
        st.session_state.messages.append({"role": "system", "content": response})
        render_history()        
    except Exception as e:
        logger.exception(f"Error processing user message", exc_info=e)
    st.session_state.last_user_query = query            

def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model, default_prompt):

    #Summarize the results
    # prompt_txt = """
    # You are a venture capitalist analyst. Below are descriptions of startup companies that are relevant to the user with their relevancy score. 
    # Create a summarized report focusing on the top3 companies.
    # For every company find its uniqueness over the other companies. Use only information from the descriptions.
    # """
    content_container = st.container()  #, col_sidepanel = st.columns([4, 1], gap="small")
    if report_type=="guided":
        prompt_txt = utils.query_finetune_prompt + """        
        User query: {query} 
        """        
        prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"])
        prompt = prompt_template.format(query = query)
        m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False)
    
        print(f"Keywords: {m_text}")    
        results = utils.search_index(m_text, top_k, regions, countries, retriever, index_namespace)
        
        descriptions = "\n".join([f"Description of company \"{res['name']}\":  {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
        ntokens = len(descriptions.split(" "))

        print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")    

        prompt_txt = utils.summarization_prompt + """        
        User query: {query} 
        Company descriptions: {descriptions}
        """    
        prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
        prompt = prompt_template.format(descriptions = descriptions, query = query)
        print(f"==============================\nPrompt:\n{prompt}\n==============================\n")

        m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
        m_text
    elif report_type=="company_list":  # or st.session_state.new_conversation:
        results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
        descriptions = "\n".join([f"Description of company \"{res['name']}\":  {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
    elif report_type=="assistant":
        #results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
        #descriptions = "\n".join([f"Description of company \"{res['name']}\":  {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
        messages = oai.call_assistant(query, engine=openai_model)
        st.session_state.messages = messages
        results = st.session_state.db_search_results
        if not messages is None:        
            with content_container:
                for message in list(messages)[::-1]:
                    if hasattr(message, 'role'):
                        # print(f"\n-----\nMessage: {message}\n")
                        # with st.chat_message(name = message.role):
                        #     st.write(message.content[0].text.value)
                        if message.role == "assistant":
                            with st.chat_message(name = message.role, avatar = assistant_avatar):
                                st.write(message.content[0].text.value)
                        else:                        
                            with st.chat_message(name = message.role):
                                st.write(message.content[0].text.value)
        # st.session_state.messages.append({"role": "user", "content": query})
        # st.session_state.messages.append({"role": "system", "content": m_text})
        
    else:
        st.session_state.new_conversation = False

        results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
        descriptions = "\n".join([f"Description of company \"{res['name']}\":  {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
        ntokens = len(descriptions.split(" "))

        print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")    
        prompt = utils.clustering_prompt if report_type=="clustered" else utils.default_prompt
        prompt_txt = prompt + """        
        User query: {query} 
        Company descriptions: {descriptions}
        """    
        prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
        prompt = prompt_template.format(descriptions = descriptions, query = query)
        print(f"==============================\nPrompt:\n{prompt[:1000]}\n==============================\n")

        m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
        m_text
        st.session_state.messages.append({"role": "user", "content": query})
        i = m_text.find("-----")
        i = 0 if i<0 else i
        st.session_state.messages.append({"role": "system", "content": m_text[:i]})



    #render_history()
    # for message in st.session_state.messages:
    #     with st.chat_message(message["role"]):
    #         st.markdown(message["content"])
#    print(f"History: \n {st.session_state.messages}")
    
    sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)

    names = []
    # list_html = """
    # <h2>Companies list</h2>
    # <div class="container-fluid">
    #     <div class="row align-items-start" style="padding-bottom:10px;">
    #          <div  class="col-md-8 col-sm-8">
    #                 <span>Company</span>
    #          </div>
    #          <div  class="col-md-1 col-sm-1">
    #                 <span>Country</span>
    #          </div>
    #          <div  class="col-md-1 col-sm-1">
    #                 <span>Customer Problem</span>
    #          </div>
    #          <div  class="col-md-1 col-sm-1">
    #                 <span>Business Model</span>
    #          </div>
    #          <div  class="col-md-1 col-sm-1">
    #                 Actions
    #          </div>
    #     </div>             
    # """
    list_html = "<div class='container-fluid'>"

    locations = set()
    for r in sorted_results:
        company_name = r["name"]
        if company_name in names:
            continue
        else:
            names.append(company_name)
        description = r["description"]  #.replace(company_name, f"<mark>{company_name}</mark>")
        if description is None or len(description.strip())<10:
            continue
        
        score = round(r["score"], 4)
        data_type = r["metadata"]["type"] if "type" in r["metadata"] else ""
        region = r["metadata"]["region"]
        country = r["metadata"]["country"]
        company_id = r["metadata"]["company_id"]

        locations.add(country)
        list_html = list_html + card(company_id, company_name, description, score, data_type, region, country, r['data'], is_debug)

    list_html = list_html + '</div>'

    #pins = country_geo[country_geo['name'].isin(locations)].loc[:, ['latitude', 'longitude']]
    
    # if len(pins)>0:
    #     with st.expander("Map view"):
    #         st.map(pins)
    #st.markdown(list_html, unsafe_allow_html=True)

    df = pd.DataFrame.from_dict(carddict, orient="columns")
    
    if len(df)>0:
        df.index += 1
        with content_container:
            st.dataframe(df,
                hide_index=False,
                column_config ={
                    "name": st.column_config.TextColumn("Name"),
                    "company_id": st.column_config.LinkColumn("Link"),
                    "description": st.column_config.TextColumn("Description"),
                    "country": st.column_config.TextColumn("Country", width="small"),
                    "customer_problem": st.column_config.TextColumn("Customer problem"),
                    "target_customer": st.column_config.TextColumn(label="Target customer", width="small"),
                    "business_model": st.column_config.TextColumn(label="Business model")
                },
                use_container_width=True)
    st.session_state.last_user_query = query            


def query_sent():
    st.session_state.user_query = ""

def find_default_assistant_idx(assistants):
    default_assistant_id = 'asst_8aSvGL075pmE1r8GAjjymu85' #startup discovery 3 steps
    for idx, assistant in enumerate(assistants):
        if assistant.id == default_assistant_id:
            return idx
    return 0

def render_history():
    with st.session_state.history_container:

        s = f"""
            <div style='overflow: hidden; padding:10px 0px;'>
                <div id="chat_history" style='overflow-y: scroll;height: 200px;'>
        """
        for m in st.session_state.messages:
            #print(f"Printing message\t {m['role']}: {m['content']}")        
            s = s + f"<div class='chat_message'><b>{m['role']}</b>: {m['content']}</div>"
        
        s = s + f"""</div>
            </div>
            <script>
                var el = document.getElementById("chat_history");
                el.scrollTop = el.scrollHeight;
            </script>            
        """

        components.html(s, height=220)
        #st.markdown(s, unsafe_allow_html=True)

if not 'submitted_query' in st.session_state:
    st.session_state.submitted_query = ''

if not 'messages' in st.session_state:
    st.session_state.messages = []
if not 'last_user_query' in st.session_state:
    st.session_state.last_user_query = ''

if utils.check_password():

    st.markdown("<script language='javascript'>console.log('scrolling');</script>", unsafe_allow_html=True)

    if st.sidebar.button("New Conversation") or "messages" not in st.session_state:
        st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()    
        st.session_state.new_conversation = True
        st.session_state.messages = []        

    st.markdown("<h1 style='text-align: center; color: red; position: relative; top: -3rem;'>Raized.AI – Startups discovery demo</h1>", unsafe_allow_html=True)

    #st.write("Search for a company in free text. Describe the type of company you are looking for, the problem they solve and the solution they provide. You can also copy in the description of a similar company to kick off the search.")

    st.markdown("""
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
    """, unsafe_allow_html=True)
    with open("data/countries.json", "r") as f:
        countries = json.load(f)['countries']
    header = st.sidebar.markdown("Filters")
    #new_conversation = st.sidebar.button("New Conversation", key="new_conversation")    
    report_type = st.sidebar.selectbox(label="Response Type", options=["gemini", "assistant", "standard", "guided", "company_list", "clustered"], index=0)

    countries_selectbox = st.sidebar.multiselect("Country", countries, default=[])
    all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America')
    region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions)
    all_bizmodels = ('B2B', 'B2C', 'eCommerce & Marketplace', 'Manufacturing', 'SaaS', 'Advertising', 'Commission', 'Subscription')    
    bizmodel_selectbox = st.sidebar.multiselect("Business Model", all_bizmodels, default=all_bizmodels)

    st.markdown(
        '''
            <script>
                function like_company(company_id) {
                    console.log("Company " + company_id + " Liked!");
                    }            
                function dislike_company(company_id) {
                    console.log("Company " + company_id + " Disliked!");
                    }            
            </script>                    
            <style>
                .sidebar .sidebar-content {{
                    width: 375px;
                }}
            </style>
        ''',
        unsafe_allow_html=True
    )

    #tab_search, tab_advanced = st.tabs(["Search", "Settings"])
    tab_search = st.container()


    with tab_search:    
        #report_type = st.multiselect("Report Type", utils.get_prompts(), key="search_prompts_multiselect")
        st.session_state.history_container = st.container()

        with stylable_container(
            key="query_panel",
            css_styles="""
               .stTextInput {
                    position: fixed;
                    bottom: 0px;
                    background: white;
                    z-index: 1000;
                    padding-bottom: 2rem;
                    padding-left: 1rem;
                    padding-right: 1rem;
                    padding-top: 1rem;
                    border-top: 1px solid whitesmoke;
                    height: 8rem;
                    border-radius: 8px 8px 8px 8px;
                    box-shadow: 0 -3px 3px whitesmoke;
                }
            """,
        ):
            query = st.text_input(key="user_query", 
                                  label="Enter your query", 
                                  placeholder="Tell me what startups you are looking for", label_visibility="collapsed")
        #cluster = st.checkbox("Cluster the results", value = False, key = "cluster")
        #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))        

    tab_advanced = st.sidebar.expander("Settings")
    with tab_advanced:
        gemini_prompt = st.text_area("Gemini Prompt", value = google_default_instructions, height=400, key="advanced_gemini_prompt_content")
        default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
        #prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", )
        #prompt_title_editable = st.text_input("Title", key="prompt_title_editable")
        #assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value = "asst_NHoxEosVlemDY7y5TYg8ftku") #value="asst_fkZtxo127nxKOCcwrwznuCs2")
        assistant_id = st.selectbox(label="OpenAI Assistant", options = [f"{a.id}|||{a.name}" for a in assistants], index = find_default_assistant_idx(assistants))
        #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
        #prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable))
        #prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt))
        #scrape_boost = st.number_input('Web to API content ratio', value=1.)
        top_k = st.number_input('# Top Results', value=30)
        is_debug = st.checkbox("Debug output", value = False, key="debug")
        openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
        index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
        liked_companies = st.text_input(label="liked companies", key='liked_companies')
        disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
        clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")

    if report_type == "assistant" and not "assistant_thread" in st.session_state:
        st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()    


    if query != "" and (not 'new_conversation' in st.session_state or not st.session_state.new_conversation):
        # if report_type=="standard":
        #     prompt = default_prompt
        # elif report_type=="clustered":
        #     prompt = clustering_prompt 
        # elif report_type=="guided":
        #     prompt = "guided"
        # else:
        #     prompt = ""
        #oai.start_conversation()
        st.session_state.report_type = report_type
        st.session_state.top_k = top_k
        st.session_state.index_namespace = index_namespace
        st.session_state.region = region_selectbox
        st.session_state.country = countries_selectbox
        if report_type=="gemini":
            run_googleai(query, gemini_prompt)
        else:            
            i = assistant_id.index("|||")
            st.session_state.assistant_id = assistant_id[:i]
            run_query(query, report_type, top_k, 
                  region_selectbox, countries_selectbox, is_debug, 
                  index_namespace, openai_model, 
                  default_prompt,
                  gemini_prompt)
    else:
        st.session_state.new_conversation = False