semsearch / app.py
hanch's picture
bug fix
e661aac verified
raw
history blame
23 kB
import asyncio
if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import json
import logging
from sys import exc_info
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
import streamlit as st
from googleai import send_message as google_send_message, init_googleai, DEFAULT_INSTRUCTIONS as google_default_instructions
from langchain.chains import RetrievalQA
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.schema import AIMessage, HumanMessage, SystemMessage
import pandas as pd
from PIL import Image
from streamlit.runtime.state import session_state
import openai
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import streamlit.components.v1 as components
# st.set_page_config(
# layout="wide",
# initial_sidebar_state="collapsed",
# page_title="RaizedAI Startup Discovery Assistant",
# #page_icon=":robot:"
# )
import utils
import openai_utils as oai
from streamlit_extras.stylable_container import stylable_container
# OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io
#model_name = 'text-embedding-ada-002'
# embed = OpenAIEmbeddings(
# model=model_name,
# openai_api_key=OPENAI_API_KEY
# )
#"🤖",
#st.image("resources/raized_logo.png")
assistant_avatar = Image.open('resources/raized_logo.png')
carddict = {
"name": [],
"company_id": [],
"description": [],
"country": [],
"customer_problem": [],
"target_customer": [],
"business_model": []
}
@st.cache_resource
def init_models():
logger.debug("init_models")
retriever = SentenceTransformer("msmarco-distilbert-base-v4")
#model_name = "sentence-transformers/all-MiniLM-L6-v2"
model_name = "sentence-transformers/msmarco-distilbert-base-v4"
#retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
#reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
tokenizer = AutoTokenizer.from_pretrained(model_name)
return retriever, tokenizer#, vectorstore
@st.cache_resource
def init_openai():
logger.debug("init_openai")
st.session_state.openai_client = oai.get_client()
assistants = st.session_state.openai_client.beta.assistants.list(
order="desc",
limit="20",
)
return assistants
assistants = []#init_openai()
retriever, tokenizer = init_models()
st.session_state.retriever = retriever
# AVATAR_PATHS = {"assistant": st.image("resources/raized_logo.png"),
# "user": "👩‍⚖️"}
#st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]
def card(company_id, name, description, score, data_type, region, country, metadata, is_debug):
if 'Summary' in metadata:
description = metadata['Summary']
customer_problem = metadata['Customer problem'] if 'Customer problem' in metadata else ""
target_customer = metadata['Target customer'] if 'Target customer' in metadata else ""
business_model = ""
if 'Business model' in metadata:
try:
business_model = metadata['Business model']
#business_model = json.loads(metadata['Business model'])
except Exception as e:
print(f"An error occurred: {str(e)}")
markdown = f"""
<div class="row align-items-start" style="padding-bottom:10px;">
<div class="col-md-8 col-sm-8">
<b>{name} (<a href='https://{company_id}'>website</a>).</b>
<p style="">{description}</p>
</div>
<div class="col-md-1 col-sm-1"><span>{country}</span></div>
<div class="col-md-1 col-sm-1"><span>{customer_problem}</span></div>
<div class="col-md-1 col-sm-1"><span>{target_customer}</span></div>
<div class="col-md-1 col-sm-1"><span>{business_model}</span></div>
"""
business_model_str = ", ".join(business_model)
company_id_url = "https://" + company_id
carddict["name"].append(name)
carddict["company_id"].append(company_id_url)
carddict["description"].append(description)
carddict["country"].append(country)
carddict["customer_problem"].append(customer_problem)
carddict["target_customer"].append(target_customer)
carddict["business_model"].append(business_model_str)
if is_debug:
markdown = markdown + f"""
<div class="col-md-1 col-sm-1" style="display:none;">
<button type='button' onclick="like_company({company_id});">Like</button>
<button type='button' onclick="dislike_company({company_id});">DisLike</button>
</div>
<div class="col-md-1 col-sm-1">
<span>{data_type}</span>
<span>[Score: {score}</span>
</div>
"""
markdown = markdown + "</div>"
#print(f" markdown for {company_id}\n{markdown}")
return markdown
def run_googleai(query, prompt):
try:
logger.debug(f"User: {query}")
response = google_send_message(query, prompt)
response = response['output']
logger.debug(f"Agent: {response }")
content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small")
with content_container:
with st.chat_message(name = 'User'):
st.write(query)
with st.chat_message(name = 'Agent', avatar = assistant_avatar):
st.write(response)
st.session_state.messages.append({"role": "user", "content": query})
st.session_state.messages.append({"role": "system", "content": response})
render_history()
except Exception as e:
logger.exception(f"Error processing user message", exc_info=e)
st.session_state.last_user_query = query
def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model, default_prompt):
#Summarize the results
# prompt_txt = """
# You are a venture capitalist analyst. Below are descriptions of startup companies that are relevant to the user with their relevancy score.
# Create a summarized report focusing on the top3 companies.
# For every company find its uniqueness over the other companies. Use only information from the descriptions.
# """
content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small")
if report_type=="guided":
prompt_txt = utils.query_finetune_prompt + """
User query: {query}
"""
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"])
prompt = prompt_template.format(query = query)
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False)
print(f"Keywords: {m_text}")
results = utils.search_index(m_text, top_k, regions, countries, retriever, index_namespace)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
ntokens = len(descriptions.split(" "))
print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")
prompt_txt = utils.summarization_prompt + """
User query: {query}
Company descriptions: {descriptions}
"""
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
prompt = prompt_template.format(descriptions = descriptions, query = query)
print(f"==============================\nPrompt:\n{prompt}\n==============================\n")
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
m_text
elif report_type=="company_list": # or st.session_state.new_conversation:
results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
elif report_type=="assistant":
#results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
#descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
messages = oai.call_assistant(query, engine=openai_model)
st.session_state.messages = messages
results = st.session_state.db_search_results
if not messages is None:
with content_container:
for message in list(messages)[::-1]:
if hasattr(message, 'role'):
# print(f"\n-----\nMessage: {message}\n")
# with st.chat_message(name = message.role):
# st.write(message.content[0].text.value)
if message.role == "assistant":
with st.chat_message(name = message.role, avatar = assistant_avatar):
st.write(message.content[0].text.value)
else:
with st.chat_message(name = message.role):
st.write(message.content[0].text.value)
# st.session_state.messages.append({"role": "user", "content": query})
# st.session_state.messages.append({"role": "system", "content": m_text})
else:
st.session_state.new_conversation = False
results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
ntokens = len(descriptions.split(" "))
print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")
prompt = utils.clustering_prompt if report_type=="clustered" else utils.default_prompt
prompt_txt = prompt + """
User query: {query}
Company descriptions: {descriptions}
"""
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
prompt = prompt_template.format(descriptions = descriptions, query = query)
print(f"==============================\nPrompt:\n{prompt[:1000]}\n==============================\n")
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
m_text
st.session_state.messages.append({"role": "user", "content": query})
i = m_text.find("-----")
i = 0 if i<0 else i
st.session_state.messages.append({"role": "system", "content": m_text[:i]})
#render_history()
# for message in st.session_state.messages:
# with st.chat_message(message["role"]):
# st.markdown(message["content"])
# print(f"History: \n {st.session_state.messages}")
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
names = []
# list_html = """
# <h2>Companies list</h2>
# <div class="container-fluid">
# <div class="row align-items-start" style="padding-bottom:10px;">
# <div class="col-md-8 col-sm-8">
# <span>Company</span>
# </div>
# <div class="col-md-1 col-sm-1">
# <span>Country</span>
# </div>
# <div class="col-md-1 col-sm-1">
# <span>Customer Problem</span>
# </div>
# <div class="col-md-1 col-sm-1">
# <span>Business Model</span>
# </div>
# <div class="col-md-1 col-sm-1">
# Actions
# </div>
# </div>
# """
list_html = "<div class='container-fluid'>"
locations = set()
for r in sorted_results:
company_name = r["name"]
if company_name in names:
continue
else:
names.append(company_name)
description = r["description"] #.replace(company_name, f"<mark>{company_name}</mark>")
if description is None or len(description.strip())<10:
continue
score = round(r["score"], 4)
data_type = r["metadata"]["type"] if "type" in r["metadata"] else ""
region = r["metadata"]["region"]
country = r["metadata"]["country"]
company_id = r["metadata"]["company_id"]
locations.add(country)
list_html = list_html + card(company_id, company_name, description, score, data_type, region, country, r['data'], is_debug)
list_html = list_html + '</div>'
#pins = country_geo[country_geo['name'].isin(locations)].loc[:, ['latitude', 'longitude']]
# if len(pins)>0:
# with st.expander("Map view"):
# st.map(pins)
#st.markdown(list_html, unsafe_allow_html=True)
df = pd.DataFrame.from_dict(carddict, orient="columns")
if len(df)>0:
df.index += 1
with content_container:
st.dataframe(df,
hide_index=False,
column_config ={
"name": st.column_config.TextColumn("Name"),
"company_id": st.column_config.LinkColumn("Link"),
"description": st.column_config.TextColumn("Description"),
"country": st.column_config.TextColumn("Country", width="small"),
"customer_problem": st.column_config.TextColumn("Customer problem"),
"target_customer": st.column_config.TextColumn(label="Target customer", width="small"),
"business_model": st.column_config.TextColumn(label="Business model")
},
use_container_width=True)
st.session_state.last_user_query = query
def query_sent():
st.session_state.user_query = ""
def find_default_assistant_idx(assistants):
default_assistant_id = 'asst_8aSvGL075pmE1r8GAjjymu85' #startup discovery 3 steps
for idx, assistant in enumerate(assistants):
if assistant.id == default_assistant_id:
return idx
return 0
def render_history():
with st.session_state.history_container:
s = f"""
<div style='overflow: hidden; padding:10px 0px;'>
<div id="chat_history" style='overflow-y: scroll;height: 200px;'>
"""
for m in st.session_state.messages:
#print(f"Printing message\t {m['role']}: {m['content']}")
s = s + f"<div class='chat_message'><b>{m['role']}</b>: {m['content']}</div>"
s = s + f"""</div>
</div>
<script>
var el = document.getElementById("chat_history");
el.scrollTop = el.scrollHeight;
</script>
"""
components.html(s, height=220)
#st.markdown(s, unsafe_allow_html=True)
if not 'submitted_query' in st.session_state:
st.session_state.submitted_query = ''
if not 'messages' in st.session_state:
st.session_state.messages = []
if not 'last_user_query' in st.session_state:
st.session_state.last_user_query = ''
if utils.check_password():
st.markdown("<script language='javascript'>console.log('scrolling');</script>", unsafe_allow_html=True)
if st.sidebar.button("New Conversation") or "messages" not in st.session_state:
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
st.session_state.new_conversation = True
st.session_state.messages = []
st.markdown("<h1 style='text-align: center; color: red; position: relative; top: -3rem;'>Raized.AI – Startups discovery demo</h1>", unsafe_allow_html=True)
#st.write("Search for a company in free text. Describe the type of company you are looking for, the problem they solve and the solution they provide. You can also copy in the description of a similar company to kick off the search.")
st.markdown("""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
""", unsafe_allow_html=True)
with open("data/countries.json", "r") as f:
countries = json.load(f)['countries']
header = st.sidebar.markdown("Filters")
#new_conversation = st.sidebar.button("New Conversation", key="new_conversation")
report_type = st.sidebar.selectbox(label="Response Type", options=["gemini", "assistant", "standard", "guided", "company_list", "clustered"], index=0)
countries_selectbox = st.sidebar.multiselect("Country", countries, default=[])
all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America')
region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions)
all_bizmodels = ('B2B', 'B2C', 'eCommerce & Marketplace', 'Manufacturing', 'SaaS', 'Advertising', 'Commission', 'Subscription')
bizmodel_selectbox = st.sidebar.multiselect("Business Model", all_bizmodels, default=all_bizmodels)
st.markdown(
'''
<script>
function like_company(company_id) {
console.log("Company " + company_id + " Liked!");
}
function dislike_company(company_id) {
console.log("Company " + company_id + " Disliked!");
}
</script>
<style>
.sidebar .sidebar-content {{
width: 375px;
}}
</style>
''',
unsafe_allow_html=True
)
#tab_search, tab_advanced = st.tabs(["Search", "Settings"])
tab_search = st.container()
with tab_search:
#report_type = st.multiselect("Report Type", utils.get_prompts(), key="search_prompts_multiselect")
st.session_state.history_container = st.container()
with stylable_container(
key="query_panel",
css_styles="""
.stTextInput {
position: fixed;
bottom: 0px;
background: white;
z-index: 1000;
padding-bottom: 2rem;
padding-left: 1rem;
padding-right: 1rem;
padding-top: 1rem;
border-top: 1px solid whitesmoke;
height: 8rem;
border-radius: 8px 8px 8px 8px;
box-shadow: 0 -3px 3px whitesmoke;
}
""",
):
query = st.text_input(key="user_query",
label="Enter your query",
placeholder="Tell me what startups you are looking for", label_visibility="collapsed")
#cluster = st.checkbox("Cluster the results", value = False, key = "cluster")
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
tab_advanced = st.sidebar.expander("Settings")
with tab_advanced:
gemini_prompt = st.text_area("Gemini Prompt", value = google_default_instructions, height=400, key="advanced_gemini_prompt_content")
default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
#prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", )
#prompt_title_editable = st.text_input("Title", key="prompt_title_editable")
#assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value = "asst_NHoxEosVlemDY7y5TYg8ftku") #value="asst_fkZtxo127nxKOCcwrwznuCs2")
assistant_id = st.selectbox(label="OpenAI Assistant", options = [f"{a.id}|||{a.name}" for a in assistants], index = find_default_assistant_idx(assistants))
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
#prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable))
#prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt))
#scrape_boost = st.number_input('Web to API content ratio', value=1.)
top_k = st.number_input('# Top Results', value=30)
is_debug = st.checkbox("Debug output", value = False, key="debug")
openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
liked_companies = st.text_input(label="liked companies", key='liked_companies')
disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
if report_type == "assistant" and not "assistant_thread" in st.session_state:
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
if query != "" and (not 'new_conversation' in st.session_state or not st.session_state.new_conversation):
# if report_type=="standard":
# prompt = default_prompt
# elif report_type=="clustered":
# prompt = clustering_prompt
# elif report_type=="guided":
# prompt = "guided"
# else:
# prompt = ""
#oai.start_conversation()
st.session_state.report_type = report_type
st.session_state.top_k = top_k
st.session_state.index_namespace = index_namespace
st.session_state.region = region_selectbox
st.session_state.country = countries_selectbox
if report_type=="gemini":
run_googleai(query, gemini_prompt)
else:
i = assistant_id.index("|||")
st.session_state.assistant_id = assistant_id[:i]
run_query(query, report_type, top_k,
region_selectbox, countries_selectbox, is_debug,
index_namespace, openai_model,
default_prompt,
gemini_prompt)
else:
st.session_state.new_conversation = False