Spaces:
Sleeping
Sleeping
import asyncio | |
if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): | |
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) | |
import json | |
import logging | |
from sys import exc_info | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
import streamlit as st | |
from googleai import send_message as google_send_message, init_googleai, DEFAULT_INSTRUCTIONS as google_default_instructions | |
from langchain.chains import RetrievalQA | |
from langchain_community.embeddings import OpenAIEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.schema import AIMessage, HumanMessage, SystemMessage | |
import pandas as pd | |
from PIL import Image | |
from streamlit.runtime.state import session_state | |
import openai | |
from transformers import AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
import streamlit.components.v1 as components | |
# st.set_page_config( | |
# layout="wide", | |
# initial_sidebar_state="collapsed", | |
# page_title="RaizedAI Startup Discovery Assistant", | |
# #page_icon=":robot:" | |
# ) | |
import utils | |
import openai_utils as oai | |
from streamlit_extras.stylable_container import stylable_container | |
# OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io | |
#model_name = 'text-embedding-ada-002' | |
# embed = OpenAIEmbeddings( | |
# model=model_name, | |
# openai_api_key=OPENAI_API_KEY | |
# ) | |
#"🤖", | |
#st.image("resources/raized_logo.png") | |
assistant_avatar = Image.open('resources/raized_logo.png') | |
carddict = { | |
"name": [], | |
"company_id": [], | |
"description": [], | |
"country": [], | |
"customer_problem": [], | |
"target_customer": [], | |
"business_model": [] | |
} | |
def init_models(): | |
logger.debug("init_models") | |
retriever = SentenceTransformer("msmarco-distilbert-base-v4") | |
#model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
model_name = "sentence-transformers/msmarco-distilbert-base-v4" | |
#retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
#reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering') | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
return retriever, tokenizer#, vectorstore | |
def init_openai(): | |
logger.debug("init_openai") | |
st.session_state.openai_client = oai.get_client() | |
assistants = st.session_state.openai_client.beta.assistants.list( | |
order="desc", | |
limit="20", | |
) | |
return assistants | |
assistants = []#init_openai() | |
retriever, tokenizer = init_models() | |
st.session_state.retriever = retriever | |
# AVATAR_PATHS = {"assistant": st.image("resources/raized_logo.png"), | |
# "user": "👩⚖️"} | |
#st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}] | |
def card(company_id, name, description, score, data_type, region, country, metadata, is_debug): | |
if 'Summary' in metadata: | |
description = metadata['Summary'] | |
customer_problem = metadata['Customer problem'] if 'Customer problem' in metadata else "" | |
target_customer = metadata['Target customer'] if 'Target customer' in metadata else "" | |
business_model = "" | |
if 'Business model' in metadata: | |
try: | |
business_model = metadata['Business model'] | |
#business_model = json.loads(metadata['Business model']) | |
except Exception as e: | |
print(f"An error occurred: {str(e)}") | |
markdown = f""" | |
<div class="row align-items-start" style="padding-bottom:10px;"> | |
<div class="col-md-8 col-sm-8"> | |
<b>{name} (<a href='https://{company_id}'>website</a>).</b> | |
<p style="">{description}</p> | |
</div> | |
<div class="col-md-1 col-sm-1"><span>{country}</span></div> | |
<div class="col-md-1 col-sm-1"><span>{customer_problem}</span></div> | |
<div class="col-md-1 col-sm-1"><span>{target_customer}</span></div> | |
<div class="col-md-1 col-sm-1"><span>{business_model}</span></div> | |
""" | |
business_model_str = ", ".join(business_model) | |
company_id_url = "https://" + company_id | |
carddict["name"].append(name) | |
carddict["company_id"].append(company_id_url) | |
carddict["description"].append(description) | |
carddict["country"].append(country) | |
carddict["customer_problem"].append(customer_problem) | |
carddict["target_customer"].append(target_customer) | |
carddict["business_model"].append(business_model_str) | |
if is_debug: | |
markdown = markdown + f""" | |
<div class="col-md-1 col-sm-1" style="display:none;"> | |
<button type='button' onclick="like_company({company_id});">Like</button> | |
<button type='button' onclick="dislike_company({company_id});">DisLike</button> | |
</div> | |
<div class="col-md-1 col-sm-1"> | |
<span>{data_type}</span> | |
<span>[Score: {score}</span> | |
</div> | |
""" | |
markdown = markdown + "</div>" | |
#print(f" markdown for {company_id}\n{markdown}") | |
return markdown | |
def run_googleai(query, prompt): | |
try: | |
logger.debug(f"User: {query}") | |
response = google_send_message(query, prompt) | |
response = response['output'] | |
logger.debug(f"Agent: {response }") | |
content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small") | |
with content_container: | |
with st.chat_message(name = 'User'): | |
st.write(query) | |
with st.chat_message(name = 'Agent', avatar = assistant_avatar): | |
st.write(response) | |
st.session_state.messages.append({"role": "user", "content": query}) | |
st.session_state.messages.append({"role": "system", "content": response}) | |
render_history() | |
except Exception as e: | |
logger.exception(f"Error processing user message", exc_info=e) | |
st.session_state.last_user_query = query | |
def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model, default_prompt): | |
#Summarize the results | |
# prompt_txt = """ | |
# You are a venture capitalist analyst. Below are descriptions of startup companies that are relevant to the user with their relevancy score. | |
# Create a summarized report focusing on the top3 companies. | |
# For every company find its uniqueness over the other companies. Use only information from the descriptions. | |
# """ | |
content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small") | |
if report_type=="guided": | |
prompt_txt = utils.query_finetune_prompt + """ | |
User query: {query} | |
""" | |
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"]) | |
prompt = prompt_template.format(query = query) | |
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False) | |
print(f"Keywords: {m_text}") | |
results = utils.search_index(m_text, top_k, regions, countries, retriever, index_namespace) | |
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) | |
ntokens = len(descriptions.split(" ")) | |
print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}") | |
prompt_txt = utils.summarization_prompt + """ | |
User query: {query} | |
Company descriptions: {descriptions} | |
""" | |
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"]) | |
prompt = prompt_template.format(descriptions = descriptions, query = query) | |
print(f"==============================\nPrompt:\n{prompt}\n==============================\n") | |
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0) | |
m_text | |
elif report_type=="company_list": # or st.session_state.new_conversation: | |
results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace) | |
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) | |
elif report_type=="assistant": | |
#results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace) | |
#descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) | |
messages = oai.call_assistant(query, engine=openai_model) | |
st.session_state.messages = messages | |
results = st.session_state.db_search_results | |
if not messages is None: | |
with content_container: | |
for message in list(messages)[::-1]: | |
if hasattr(message, 'role'): | |
# print(f"\n-----\nMessage: {message}\n") | |
# with st.chat_message(name = message.role): | |
# st.write(message.content[0].text.value) | |
if message.role == "assistant": | |
with st.chat_message(name = message.role, avatar = assistant_avatar): | |
st.write(message.content[0].text.value) | |
else: | |
with st.chat_message(name = message.role): | |
st.write(message.content[0].text.value) | |
# st.session_state.messages.append({"role": "user", "content": query}) | |
# st.session_state.messages.append({"role": "system", "content": m_text}) | |
else: | |
st.session_state.new_conversation = False | |
results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace) | |
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) | |
ntokens = len(descriptions.split(" ")) | |
print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}") | |
prompt = utils.clustering_prompt if report_type=="clustered" else utils.default_prompt | |
prompt_txt = prompt + """ | |
User query: {query} | |
Company descriptions: {descriptions} | |
""" | |
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"]) | |
prompt = prompt_template.format(descriptions = descriptions, query = query) | |
print(f"==============================\nPrompt:\n{prompt[:1000]}\n==============================\n") | |
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0) | |
m_text | |
st.session_state.messages.append({"role": "user", "content": query}) | |
i = m_text.find("-----") | |
i = 0 if i<0 else i | |
st.session_state.messages.append({"role": "system", "content": m_text[:i]}) | |
#render_history() | |
# for message in st.session_state.messages: | |
# with st.chat_message(message["role"]): | |
# st.markdown(message["content"]) | |
# print(f"History: \n {st.session_state.messages}") | |
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True) | |
names = [] | |
# list_html = """ | |
# <h2>Companies list</h2> | |
# <div class="container-fluid"> | |
# <div class="row align-items-start" style="padding-bottom:10px;"> | |
# <div class="col-md-8 col-sm-8"> | |
# <span>Company</span> | |
# </div> | |
# <div class="col-md-1 col-sm-1"> | |
# <span>Country</span> | |
# </div> | |
# <div class="col-md-1 col-sm-1"> | |
# <span>Customer Problem</span> | |
# </div> | |
# <div class="col-md-1 col-sm-1"> | |
# <span>Business Model</span> | |
# </div> | |
# <div class="col-md-1 col-sm-1"> | |
# Actions | |
# </div> | |
# </div> | |
# """ | |
list_html = "<div class='container-fluid'>" | |
locations = set() | |
for r in sorted_results: | |
company_name = r["name"] | |
if company_name in names: | |
continue | |
else: | |
names.append(company_name) | |
description = r["description"] #.replace(company_name, f"<mark>{company_name}</mark>") | |
if description is None or len(description.strip())<10: | |
continue | |
score = round(r["score"], 4) | |
data_type = r["metadata"]["type"] if "type" in r["metadata"] else "" | |
region = r["metadata"]["region"] | |
country = r["metadata"]["country"] | |
company_id = r["metadata"]["company_id"] | |
locations.add(country) | |
list_html = list_html + card(company_id, company_name, description, score, data_type, region, country, r['data'], is_debug) | |
list_html = list_html + '</div>' | |
#pins = country_geo[country_geo['name'].isin(locations)].loc[:, ['latitude', 'longitude']] | |
# if len(pins)>0: | |
# with st.expander("Map view"): | |
# st.map(pins) | |
#st.markdown(list_html, unsafe_allow_html=True) | |
df = pd.DataFrame.from_dict(carddict, orient="columns") | |
if len(df)>0: | |
df.index += 1 | |
with content_container: | |
st.dataframe(df, | |
hide_index=False, | |
column_config ={ | |
"name": st.column_config.TextColumn("Name"), | |
"company_id": st.column_config.LinkColumn("Link"), | |
"description": st.column_config.TextColumn("Description"), | |
"country": st.column_config.TextColumn("Country", width="small"), | |
"customer_problem": st.column_config.TextColumn("Customer problem"), | |
"target_customer": st.column_config.TextColumn(label="Target customer", width="small"), | |
"business_model": st.column_config.TextColumn(label="Business model") | |
}, | |
use_container_width=True) | |
st.session_state.last_user_query = query | |
def query_sent(): | |
st.session_state.user_query = "" | |
def find_default_assistant_idx(assistants): | |
default_assistant_id = 'asst_8aSvGL075pmE1r8GAjjymu85' #startup discovery 3 steps | |
for idx, assistant in enumerate(assistants): | |
if assistant.id == default_assistant_id: | |
return idx | |
return 0 | |
def render_history(): | |
with st.session_state.history_container: | |
s = f""" | |
<div style='overflow: hidden; padding:10px 0px;'> | |
<div id="chat_history" style='overflow-y: scroll;height: 200px;'> | |
""" | |
for m in st.session_state.messages: | |
#print(f"Printing message\t {m['role']}: {m['content']}") | |
s = s + f"<div class='chat_message'><b>{m['role']}</b>: {m['content']}</div>" | |
s = s + f"""</div> | |
</div> | |
<script> | |
var el = document.getElementById("chat_history"); | |
el.scrollTop = el.scrollHeight; | |
</script> | |
""" | |
components.html(s, height=220) | |
#st.markdown(s, unsafe_allow_html=True) | |
if not 'submitted_query' in st.session_state: | |
st.session_state.submitted_query = '' | |
if not 'messages' in st.session_state: | |
st.session_state.messages = [] | |
if not 'last_user_query' in st.session_state: | |
st.session_state.last_user_query = '' | |
if utils.check_password(): | |
st.markdown("<script language='javascript'>console.log('scrolling');</script>", unsafe_allow_html=True) | |
if st.sidebar.button("New Conversation") or "messages" not in st.session_state: | |
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create() | |
st.session_state.new_conversation = True | |
st.session_state.messages = [] | |
st.markdown("<h1 style='text-align: center; color: red; position: relative; top: -3rem;'>Raized.AI – Startups discovery demo</h1>", unsafe_allow_html=True) | |
#st.write("Search for a company in free text. Describe the type of company you are looking for, the problem they solve and the solution they provide. You can also copy in the description of a similar company to kick off the search.") | |
st.markdown(""" | |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"> | |
""", unsafe_allow_html=True) | |
with open("data/countries.json", "r") as f: | |
countries = json.load(f)['countries'] | |
header = st.sidebar.markdown("Filters") | |
#new_conversation = st.sidebar.button("New Conversation", key="new_conversation") | |
report_type = st.sidebar.selectbox(label="Response Type", options=["gemini", "assistant", "standard", "guided", "company_list", "clustered"], index=0) | |
countries_selectbox = st.sidebar.multiselect("Country", countries, default=[]) | |
all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America') | |
region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions) | |
all_bizmodels = ('B2B', 'B2C', 'eCommerce & Marketplace', 'Manufacturing', 'SaaS', 'Advertising', 'Commission', 'Subscription') | |
bizmodel_selectbox = st.sidebar.multiselect("Business Model", all_bizmodels, default=all_bizmodels) | |
st.markdown( | |
''' | |
<script> | |
function like_company(company_id) { | |
console.log("Company " + company_id + " Liked!"); | |
} | |
function dislike_company(company_id) { | |
console.log("Company " + company_id + " Disliked!"); | |
} | |
</script> | |
<style> | |
.sidebar .sidebar-content {{ | |
width: 375px; | |
}} | |
</style> | |
''', | |
unsafe_allow_html=True | |
) | |
#tab_search, tab_advanced = st.tabs(["Search", "Settings"]) | |
tab_search = st.container() | |
with tab_search: | |
#report_type = st.multiselect("Report Type", utils.get_prompts(), key="search_prompts_multiselect") | |
st.session_state.history_container = st.container() | |
with stylable_container( | |
key="query_panel", | |
css_styles=""" | |
.stTextInput { | |
position: fixed; | |
bottom: 0px; | |
background: white; | |
z-index: 1000; | |
padding-bottom: 2rem; | |
padding-left: 1rem; | |
padding-right: 1rem; | |
padding-top: 1rem; | |
border-top: 1px solid whitesmoke; | |
height: 8rem; | |
border-radius: 8px 8px 8px 8px; | |
box-shadow: 0 -3px 3px whitesmoke; | |
} | |
""", | |
): | |
query = st.text_input(key="user_query", | |
label="Enter your query", | |
placeholder="Tell me what startups you are looking for", label_visibility="collapsed") | |
#cluster = st.checkbox("Cluster the results", value = False, key = "cluster") | |
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt)) | |
tab_advanced = st.sidebar.expander("Settings") | |
with tab_advanced: | |
gemini_prompt = st.text_area("Gemini Prompt", value = google_default_instructions, height=400, key="advanced_gemini_prompt_content") | |
default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content") | |
#prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", ) | |
#prompt_title_editable = st.text_input("Title", key="prompt_title_editable") | |
#assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value = "asst_NHoxEosVlemDY7y5TYg8ftku") #value="asst_fkZtxo127nxKOCcwrwznuCs2") | |
assistant_id = st.selectbox(label="OpenAI Assistant", options = [f"{a.id}|||{a.name}" for a in assistants], index = find_default_assistant_idx(assistants)) | |
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt)) | |
#prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable)) | |
#prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt)) | |
#scrape_boost = st.number_input('Web to API content ratio', value=1.) | |
top_k = st.number_input('# Top Results', value=30) | |
is_debug = st.checkbox("Debug output", value = False, key="debug") | |
openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model") | |
index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0) | |
liked_companies = st.text_input(label="liked companies", key='liked_companies') | |
disliked_companies = st.text_input(label="disliked companies", key='disliked_companies') | |
clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content") | |
if report_type == "assistant" and not "assistant_thread" in st.session_state: | |
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create() | |
if query != "" and (not 'new_conversation' in st.session_state or not st.session_state.new_conversation): | |
# if report_type=="standard": | |
# prompt = default_prompt | |
# elif report_type=="clustered": | |
# prompt = clustering_prompt | |
# elif report_type=="guided": | |
# prompt = "guided" | |
# else: | |
# prompt = "" | |
#oai.start_conversation() | |
st.session_state.report_type = report_type | |
st.session_state.top_k = top_k | |
st.session_state.index_namespace = index_namespace | |
st.session_state.region = region_selectbox | |
st.session_state.country = countries_selectbox | |
if report_type=="gemini": | |
run_googleai(query, gemini_prompt) | |
else: | |
i = assistant_id.index("|||") | |
st.session_state.assistant_id = assistant_id[:i] | |
run_query(query, report_type, top_k, | |
region_selectbox, countries_selectbox, is_debug, | |
index_namespace, openai_model, | |
default_prompt, | |
gemini_prompt) | |
else: | |
st.session_state.new_conversation = False | |