Spaces:
Sleeping
Sleeping
File size: 22,961 Bytes
77b927d b505cc3 eab6925 77b927d eab6925 77b927d 56ce28d 09df805 e54b3e0 eab6925 5bf3195 6a2ae7a 77b927d 45a7d81 aac3522 eab6925 437c715 0c14e18 eab6925 0c14e18 eab6925 09df805 437c715 5bf3195 75a550e 5bf3195 b505cc3 77b927d eab6925 d631df4 09df805 eab6925 5bf3195 09df805 b505cc3 5046d92 f4483df 09df805 3b9d5e1 4c4a1d7 77b927d eab6925 5c9ea55 437c715 0c14e18 6a2ae7a 5bf3195 3a3acc2 a30e3b1 3a3acc2 a30e3b1 52b23da a280e4d 0b0b1a7 3a3acc2 6a2ae7a 3a3acc2 6a2ae7a a30e3b1 6a2ae7a 52b23da 437c715 3a3acc2 52b23da 6a2ae7a 52b23da adb5688 6a2ae7a 5bf3195 b505cc3 77b927d eab6925 0c14e18 b505cc3 eab6925 b505cc3 77b927d b505cc3 77b927d b505cc3 77b927d b505cc3 3a3acc2 b505cc3 56ce28d b505cc3 3a3acc2 b505cc3 56ce28d b505cc3 56ce28d b505cc3 437c715 b505cc3 9f89884 b505cc3 ba1f3e2 3a3acc2 5bf3195 09df805 0d3609a 6a2ae7a aac3522 6a2ae7a aac3522 6a2ae7a aac3522 6a2ae7a cd6bb4b 9f89884 ba1f3e2 6a2ae7a aac3522 7c5594b aac3522 09df805 aac3522 437c715 19e6802 aac3522 cd6bb4b 19e6802 52b23da aac3522 0a198c1 19e6802 3a3acc2 19e6802 46a784b adb5688 46a784b 6a2ae7a 09df805 3a3acc2 46a784b 3a3acc2 aac3522 437c715 9f89884 5498ada 6a2ae7a 3a3acc2 cd6bb4b 09df805 b505cc3 09df805 77b927d 09df805 ba1f3e2 09df805 ba1f3e2 09df805 b505cc3 528a449 09df805 77b927d 5c7c2df 437c715 d54eee9 5c9ea55 b505cc3 43ed833 b505cc3 09df805 5bf3195 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 |
import asyncio
if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import json
import logging
from sys import exc_info
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
import streamlit as st
from googleai import send_message as google_send_message, init_googleai, DEFAULT_INSTRUCTIONS as google_default_instructions
from langchain.chains import RetrievalQA
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.schema import AIMessage, HumanMessage, SystemMessage
import pandas as pd
from PIL import Image
from streamlit.runtime.state import session_state
import openai
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import streamlit.components.v1 as components
# st.set_page_config(
# layout="wide",
# initial_sidebar_state="collapsed",
# page_title="RaizedAI Startup Discovery Assistant",
# #page_icon=":robot:"
# )
import utils
import openai_utils as oai
from streamlit_extras.stylable_container import stylable_container
# OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io
#model_name = 'text-embedding-ada-002'
# embed = OpenAIEmbeddings(
# model=model_name,
# openai_api_key=OPENAI_API_KEY
# )
#"🤖",
#st.image("resources/raized_logo.png")
assistant_avatar = Image.open('resources/raized_logo.png')
carddict = {
"name": [],
"company_id": [],
"description": [],
"country": [],
"customer_problem": [],
"target_customer": [],
"business_model": []
}
@st.cache_resource
def init_models():
logger.debug("init_models")
retriever = SentenceTransformer("msmarco-distilbert-base-v4")
#model_name = "sentence-transformers/all-MiniLM-L6-v2"
model_name = "sentence-transformers/msmarco-distilbert-base-v4"
#retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
#reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
tokenizer = AutoTokenizer.from_pretrained(model_name)
return retriever, tokenizer#, vectorstore
@st.cache_resource
def init_openai():
logger.debug("init_openai")
st.session_state.openai_client = oai.get_client()
assistants = st.session_state.openai_client.beta.assistants.list(
order="desc",
limit="20",
)
return assistants
assistants = []#init_openai()
retriever, tokenizer = init_models()
st.session_state.retriever = retriever
# AVATAR_PATHS = {"assistant": st.image("resources/raized_logo.png"),
# "user": "👩⚖️"}
#st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]
def card(company_id, name, description, score, data_type, region, country, metadata, is_debug):
if 'Summary' in metadata:
description = metadata['Summary']
customer_problem = metadata['Customer problem'] if 'Customer problem' in metadata else ""
target_customer = metadata['Target customer'] if 'Target customer' in metadata else ""
business_model = ""
if 'Business model' in metadata:
try:
business_model = metadata['Business model']
#business_model = json.loads(metadata['Business model'])
except Exception as e:
print(f"An error occurred: {str(e)}")
markdown = f"""
<div class="row align-items-start" style="padding-bottom:10px;">
<div class="col-md-8 col-sm-8">
<b>{name} (<a href='https://{company_id}'>website</a>).</b>
<p style="">{description}</p>
</div>
<div class="col-md-1 col-sm-1"><span>{country}</span></div>
<div class="col-md-1 col-sm-1"><span>{customer_problem}</span></div>
<div class="col-md-1 col-sm-1"><span>{target_customer}</span></div>
<div class="col-md-1 col-sm-1"><span>{business_model}</span></div>
"""
business_model_str = ", ".join(business_model)
company_id_url = "https://" + company_id
carddict["name"].append(name)
carddict["company_id"].append(company_id_url)
carddict["description"].append(description)
carddict["country"].append(country)
carddict["customer_problem"].append(customer_problem)
carddict["target_customer"].append(target_customer)
carddict["business_model"].append(business_model_str)
if is_debug:
markdown = markdown + f"""
<div class="col-md-1 col-sm-1" style="display:none;">
<button type='button' onclick="like_company({company_id});">Like</button>
<button type='button' onclick="dislike_company({company_id});">DisLike</button>
</div>
<div class="col-md-1 col-sm-1">
<span>{data_type}</span>
<span>[Score: {score}</span>
</div>
"""
markdown = markdown + "</div>"
#print(f" markdown for {company_id}\n{markdown}")
return markdown
def run_googleai(query, prompt):
try:
logger.debug(f"User: {query}")
response = google_send_message(query, prompt)
response = response['output']
logger.debug(f"Agent: {response }")
content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small")
with content_container:
with st.chat_message(name = 'User'):
st.write(query)
with st.chat_message(name = 'Agent', avatar = assistant_avatar):
st.write(response)
st.session_state.messages.append({"role": "user", "content": query})
st.session_state.messages.append({"role": "system", "content": response})
render_history()
except Exception as e:
logger.exception(f"Error processing user message", exc_info=e)
st.session_state.last_user_query = query
def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model, default_prompt):
#Summarize the results
# prompt_txt = """
# You are a venture capitalist analyst. Below are descriptions of startup companies that are relevant to the user with their relevancy score.
# Create a summarized report focusing on the top3 companies.
# For every company find its uniqueness over the other companies. Use only information from the descriptions.
# """
content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small")
if report_type=="guided":
prompt_txt = utils.query_finetune_prompt + """
User query: {query}
"""
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"])
prompt = prompt_template.format(query = query)
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False)
print(f"Keywords: {m_text}")
results = utils.search_index(m_text, top_k, regions, countries, retriever, index_namespace)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
ntokens = len(descriptions.split(" "))
print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")
prompt_txt = utils.summarization_prompt + """
User query: {query}
Company descriptions: {descriptions}
"""
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
prompt = prompt_template.format(descriptions = descriptions, query = query)
print(f"==============================\nPrompt:\n{prompt}\n==============================\n")
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
m_text
elif report_type=="company_list": # or st.session_state.new_conversation:
results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
elif report_type=="assistant":
#results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
#descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
messages = oai.call_assistant(query, engine=openai_model)
st.session_state.messages = messages
results = st.session_state.db_search_results
if not messages is None:
with content_container:
for message in list(messages)[::-1]:
if hasattr(message, 'role'):
# print(f"\n-----\nMessage: {message}\n")
# with st.chat_message(name = message.role):
# st.write(message.content[0].text.value)
if message.role == "assistant":
with st.chat_message(name = message.role, avatar = assistant_avatar):
st.write(message.content[0].text.value)
else:
with st.chat_message(name = message.role):
st.write(message.content[0].text.value)
# st.session_state.messages.append({"role": "user", "content": query})
# st.session_state.messages.append({"role": "system", "content": m_text})
else:
st.session_state.new_conversation = False
results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
ntokens = len(descriptions.split(" "))
print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")
prompt = utils.clustering_prompt if report_type=="clustered" else utils.default_prompt
prompt_txt = prompt + """
User query: {query}
Company descriptions: {descriptions}
"""
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
prompt = prompt_template.format(descriptions = descriptions, query = query)
print(f"==============================\nPrompt:\n{prompt[:1000]}\n==============================\n")
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
m_text
st.session_state.messages.append({"role": "user", "content": query})
i = m_text.find("-----")
i = 0 if i<0 else i
st.session_state.messages.append({"role": "system", "content": m_text[:i]})
#render_history()
# for message in st.session_state.messages:
# with st.chat_message(message["role"]):
# st.markdown(message["content"])
# print(f"History: \n {st.session_state.messages}")
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
names = []
# list_html = """
# <h2>Companies list</h2>
# <div class="container-fluid">
# <div class="row align-items-start" style="padding-bottom:10px;">
# <div class="col-md-8 col-sm-8">
# <span>Company</span>
# </div>
# <div class="col-md-1 col-sm-1">
# <span>Country</span>
# </div>
# <div class="col-md-1 col-sm-1">
# <span>Customer Problem</span>
# </div>
# <div class="col-md-1 col-sm-1">
# <span>Business Model</span>
# </div>
# <div class="col-md-1 col-sm-1">
# Actions
# </div>
# </div>
# """
list_html = "<div class='container-fluid'>"
locations = set()
for r in sorted_results:
company_name = r["name"]
if company_name in names:
continue
else:
names.append(company_name)
description = r["description"] #.replace(company_name, f"<mark>{company_name}</mark>")
if description is None or len(description.strip())<10:
continue
score = round(r["score"], 4)
data_type = r["metadata"]["type"] if "type" in r["metadata"] else ""
region = r["metadata"]["region"]
country = r["metadata"]["country"]
company_id = r["metadata"]["company_id"]
locations.add(country)
list_html = list_html + card(company_id, company_name, description, score, data_type, region, country, r['data'], is_debug)
list_html = list_html + '</div>'
#pins = country_geo[country_geo['name'].isin(locations)].loc[:, ['latitude', 'longitude']]
# if len(pins)>0:
# with st.expander("Map view"):
# st.map(pins)
#st.markdown(list_html, unsafe_allow_html=True)
df = pd.DataFrame.from_dict(carddict, orient="columns")
if len(df)>0:
df.index += 1
with content_container:
st.dataframe(df,
hide_index=False,
column_config ={
"name": st.column_config.TextColumn("Name"),
"company_id": st.column_config.LinkColumn("Link"),
"description": st.column_config.TextColumn("Description"),
"country": st.column_config.TextColumn("Country", width="small"),
"customer_problem": st.column_config.TextColumn("Customer problem"),
"target_customer": st.column_config.TextColumn(label="Target customer", width="small"),
"business_model": st.column_config.TextColumn(label="Business model")
},
use_container_width=True)
st.session_state.last_user_query = query
def query_sent():
st.session_state.user_query = ""
def find_default_assistant_idx(assistants):
default_assistant_id = 'asst_8aSvGL075pmE1r8GAjjymu85' #startup discovery 3 steps
for idx, assistant in enumerate(assistants):
if assistant.id == default_assistant_id:
return idx
return 0
def render_history():
with st.session_state.history_container:
s = f"""
<div style='overflow: hidden; padding:10px 0px;'>
<div id="chat_history" style='overflow-y: scroll;height: 200px;'>
"""
for m in st.session_state.messages:
#print(f"Printing message\t {m['role']}: {m['content']}")
s = s + f"<div class='chat_message'><b>{m['role']}</b>: {m['content']}</div>"
s = s + f"""</div>
</div>
<script>
var el = document.getElementById("chat_history");
el.scrollTop = el.scrollHeight;
</script>
"""
components.html(s, height=220)
#st.markdown(s, unsafe_allow_html=True)
if not 'submitted_query' in st.session_state:
st.session_state.submitted_query = ''
if not 'messages' in st.session_state:
st.session_state.messages = []
if not 'last_user_query' in st.session_state:
st.session_state.last_user_query = ''
if utils.check_password():
st.markdown("<script language='javascript'>console.log('scrolling');</script>", unsafe_allow_html=True)
if st.sidebar.button("New Conversation") or "messages" not in st.session_state:
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
st.session_state.new_conversation = True
st.session_state.messages = []
st.markdown("<h1 style='text-align: center; color: red; position: relative; top: -3rem;'>Raized.AI – Startups discovery demo</h1>", unsafe_allow_html=True)
#st.write("Search for a company in free text. Describe the type of company you are looking for, the problem they solve and the solution they provide. You can also copy in the description of a similar company to kick off the search.")
st.markdown("""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
""", unsafe_allow_html=True)
with open("data/countries.json", "r") as f:
countries = json.load(f)['countries']
header = st.sidebar.markdown("Filters")
#new_conversation = st.sidebar.button("New Conversation", key="new_conversation")
report_type = st.sidebar.selectbox(label="Response Type", options=["gemini", "assistant", "standard", "guided", "company_list", "clustered"], index=0)
countries_selectbox = st.sidebar.multiselect("Country", countries, default=[])
all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America')
region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions)
all_bizmodels = ('B2B', 'B2C', 'eCommerce & Marketplace', 'Manufacturing', 'SaaS', 'Advertising', 'Commission', 'Subscription')
bizmodel_selectbox = st.sidebar.multiselect("Business Model", all_bizmodels, default=all_bizmodels)
st.markdown(
'''
<script>
function like_company(company_id) {
console.log("Company " + company_id + " Liked!");
}
function dislike_company(company_id) {
console.log("Company " + company_id + " Disliked!");
}
</script>
<style>
.sidebar .sidebar-content {{
width: 375px;
}}
</style>
''',
unsafe_allow_html=True
)
#tab_search, tab_advanced = st.tabs(["Search", "Settings"])
tab_search = st.container()
with tab_search:
#report_type = st.multiselect("Report Type", utils.get_prompts(), key="search_prompts_multiselect")
st.session_state.history_container = st.container()
with stylable_container(
key="query_panel",
css_styles="""
.stTextInput {
position: fixed;
bottom: 0px;
background: white;
z-index: 1000;
padding-bottom: 2rem;
padding-left: 1rem;
padding-right: 1rem;
padding-top: 1rem;
border-top: 1px solid whitesmoke;
height: 8rem;
border-radius: 8px 8px 8px 8px;
box-shadow: 0 -3px 3px whitesmoke;
}
""",
):
query = st.text_input(key="user_query",
label="Enter your query",
placeholder="Tell me what startups you are looking for", label_visibility="collapsed")
#cluster = st.checkbox("Cluster the results", value = False, key = "cluster")
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
tab_advanced = st.sidebar.expander("Settings")
with tab_advanced:
gemini_prompt = st.text_area("Gemini Prompt", value = google_default_instructions, height=400, key="advanced_gemini_prompt_content")
default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
#prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", )
#prompt_title_editable = st.text_input("Title", key="prompt_title_editable")
#assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value = "asst_NHoxEosVlemDY7y5TYg8ftku") #value="asst_fkZtxo127nxKOCcwrwznuCs2")
assistant_id = st.selectbox(label="OpenAI Assistant", options = [f"{a.id}|||{a.name}" for a in assistants], index = find_default_assistant_idx(assistants))
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
#prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable))
#prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt))
#scrape_boost = st.number_input('Web to API content ratio', value=1.)
top_k = st.number_input('# Top Results', value=30)
is_debug = st.checkbox("Debug output", value = False, key="debug")
openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
liked_companies = st.text_input(label="liked companies", key='liked_companies')
disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
if report_type == "assistant" and not "assistant_thread" in st.session_state:
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
if query != "" and (not 'new_conversation' in st.session_state or not st.session_state.new_conversation):
# if report_type=="standard":
# prompt = default_prompt
# elif report_type=="clustered":
# prompt = clustering_prompt
# elif report_type=="guided":
# prompt = "guided"
# else:
# prompt = ""
#oai.start_conversation()
st.session_state.report_type = report_type
st.session_state.top_k = top_k
st.session_state.index_namespace = index_namespace
st.session_state.region = region_selectbox
st.session_state.country = countries_selectbox
if report_type=="gemini":
run_googleai(query, gemini_prompt)
else:
i = assistant_id.index("|||")
st.session_state.assistant_id = assistant_id[:i]
run_query(query, report_type, top_k,
region_selectbox, countries_selectbox, is_debug,
index_namespace, openai_model,
default_prompt,
gemini_prompt)
else:
st.session_state.new_conversation = False
|