import asyncio if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) import json import logging from sys import exc_info logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) import streamlit as st from googleai import send_message as google_send_message, init_googleai, DEFAULT_INSTRUCTIONS as google_default_instructions from langchain.chains import RetrievalQA from langchain_community.embeddings import OpenAIEmbeddings from langchain.prompts import PromptTemplate from langchain.schema import AIMessage, HumanMessage, SystemMessage import pandas as pd from PIL import Image from streamlit.runtime.state import session_state import openai from transformers import AutoTokenizer from sentence_transformers import SentenceTransformer import streamlit.components.v1 as components # st.set_page_config( # layout="wide", # initial_sidebar_state="collapsed", # page_title="RaizedAI Startup Discovery Assistant", # #page_icon=":robot:" # ) import utils import openai_utils as oai from streamlit_extras.stylable_container import stylable_container # OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io #model_name = 'text-embedding-ada-002' # embed = OpenAIEmbeddings( # model=model_name, # openai_api_key=OPENAI_API_KEY # ) #"🤖", #st.image("resources/raized_logo.png") assistant_avatar = Image.open('resources/raized_logo.png') carddict = { "name": [], "company_id": [], "description": [], "country": [], "customer_problem": [], "target_customer": [], "business_model": [] } @st.cache_resource def init_models(): logger.debug("init_models") retriever = SentenceTransformer("msmarco-distilbert-base-v4") #model_name = "sentence-transformers/all-MiniLM-L6-v2" model_name = "sentence-transformers/msmarco-distilbert-base-v4" #retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") #reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering') tokenizer = AutoTokenizer.from_pretrained(model_name) return retriever, tokenizer#, vectorstore @st.cache_resource def init_openai(): logger.debug("init_openai") st.session_state.openai_client = oai.get_client() assistants = st.session_state.openai_client.beta.assistants.list( order="desc", limit="20", ) return assistants assistants = []#init_openai() retriever, tokenizer = init_models() st.session_state.retriever = retriever # AVATAR_PATHS = {"assistant": st.image("resources/raized_logo.png"), # "user": "👩‍⚖️"} #st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}] def card(company_id, name, description, score, data_type, region, country, metadata, is_debug): if 'Summary' in metadata: description = metadata['Summary'] customer_problem = metadata['Customer problem'] if 'Customer problem' in metadata else "" target_customer = metadata['Target customer'] if 'Target customer' in metadata else "" business_model = "" if 'Business model' in metadata: try: business_model = metadata['Business model'] #business_model = json.loads(metadata['Business model']) except Exception as e: print(f"An error occurred: {str(e)}") markdown = f"""
{name} (website).

{description}

{country}
{customer_problem}
{target_customer}
{business_model}
""" business_model_str = ", ".join(business_model) company_id_url = "https://" + company_id carddict["name"].append(name) carddict["company_id"].append(company_id_url) carddict["description"].append(description) carddict["country"].append(country) carddict["customer_problem"].append(customer_problem) carddict["target_customer"].append(target_customer) carddict["business_model"].append(business_model_str) if is_debug: markdown = markdown + f"""
{data_type} [Score: {score}
""" markdown = markdown + "
" #print(f" markdown for {company_id}\n{markdown}") return markdown def run_googleai(query, prompt): try: logger.debug(f"User: {query}") response = google_send_message(query, prompt) response = response['output'] logger.debug(f"Agent: {response }") content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small") with content_container: with st.chat_message(name = 'User'): st.write(query) with st.chat_message(name = 'Agent', avatar = assistant_avatar): st.write(response) st.session_state.messages.append({"role": "user", "content": query}) st.session_state.messages.append({"role": "system", "content": response}) render_history() except Exception as e: logger.exception(f"Error processing user message", exc_info=e) st.session_state.last_user_query = query def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model, default_prompt): #Summarize the results # prompt_txt = """ # You are a venture capitalist analyst. Below are descriptions of startup companies that are relevant to the user with their relevancy score. # Create a summarized report focusing on the top3 companies. # For every company find its uniqueness over the other companies. Use only information from the descriptions. # """ content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small") if report_type=="guided": prompt_txt = utils.query_finetune_prompt + """ User query: {query} """ prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"]) prompt = prompt_template.format(query = query) m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False) print(f"Keywords: {m_text}") results = utils.search_index(m_text, top_k, regions, countries, retriever, index_namespace) descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) ntokens = len(descriptions.split(" ")) print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}") prompt_txt = utils.summarization_prompt + """ User query: {query} Company descriptions: {descriptions} """ prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"]) prompt = prompt_template.format(descriptions = descriptions, query = query) print(f"==============================\nPrompt:\n{prompt}\n==============================\n") m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0) m_text elif report_type=="company_list": # or st.session_state.new_conversation: results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace) descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) elif report_type=="assistant": #results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace) #descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) messages = oai.call_assistant(query, engine=openai_model) st.session_state.messages = messages results = st.session_state.db_search_results if not messages is None: with content_container: for message in list(messages)[::-1]: if hasattr(message, 'role'): # print(f"\n-----\nMessage: {message}\n") # with st.chat_message(name = message.role): # st.write(message.content[0].text.value) if message.role == "assistant": with st.chat_message(name = message.role, avatar = assistant_avatar): st.write(message.content[0].text.value) else: with st.chat_message(name = message.role): st.write(message.content[0].text.value) # st.session_state.messages.append({"role": "user", "content": query}) # st.session_state.messages.append({"role": "system", "content": m_text}) else: st.session_state.new_conversation = False results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace) descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) ntokens = len(descriptions.split(" ")) print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}") prompt = utils.clustering_prompt if report_type=="clustered" else utils.default_prompt prompt_txt = prompt + """ User query: {query} Company descriptions: {descriptions} """ prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"]) prompt = prompt_template.format(descriptions = descriptions, query = query) print(f"==============================\nPrompt:\n{prompt[:1000]}\n==============================\n") m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0) m_text st.session_state.messages.append({"role": "user", "content": query}) i = m_text.find("-----") i = 0 if i<0 else i st.session_state.messages.append({"role": "system", "content": m_text[:i]}) #render_history() # for message in st.session_state.messages: # with st.chat_message(message["role"]): # st.markdown(message["content"]) # print(f"History: \n {st.session_state.messages}") sorted_results = sorted(results, key=lambda x: x['score'], reverse=True) names = [] # list_html = """ #

Companies list

#
#
#
# Company #
#
# Country #
#
# Customer Problem #
#
# Business Model #
#
# Actions #
#
# """ list_html = "
" locations = set() for r in sorted_results: company_name = r["name"] if company_name in names: continue else: names.append(company_name) description = r["description"] #.replace(company_name, f"{company_name}") if description is None or len(description.strip())<10: continue score = round(r["score"], 4) data_type = r["metadata"]["type"] if "type" in r["metadata"] else "" region = r["metadata"]["region"] country = r["metadata"]["country"] company_id = r["metadata"]["company_id"] locations.add(country) list_html = list_html + card(company_id, company_name, description, score, data_type, region, country, r['data'], is_debug) list_html = list_html + '
' #pins = country_geo[country_geo['name'].isin(locations)].loc[:, ['latitude', 'longitude']] # if len(pins)>0: # with st.expander("Map view"): # st.map(pins) #st.markdown(list_html, unsafe_allow_html=True) df = pd.DataFrame.from_dict(carddict, orient="columns") if len(df)>0: df.index += 1 with content_container: st.dataframe(df, hide_index=False, column_config ={ "name": st.column_config.TextColumn("Name"), "company_id": st.column_config.LinkColumn("Link"), "description": st.column_config.TextColumn("Description"), "country": st.column_config.TextColumn("Country", width="small"), "customer_problem": st.column_config.TextColumn("Customer problem"), "target_customer": st.column_config.TextColumn(label="Target customer", width="small"), "business_model": st.column_config.TextColumn(label="Business model") }, use_container_width=True) st.session_state.last_user_query = query def query_sent(): st.session_state.user_query = "" def find_default_assistant_idx(assistants): default_assistant_id = 'asst_8aSvGL075pmE1r8GAjjymu85' #startup discovery 3 steps for idx, assistant in enumerate(assistants): if assistant.id == default_assistant_id: return idx return 0 def render_history(): with st.session_state.history_container: s = f"""
""" for m in st.session_state.messages: #print(f"Printing message\t {m['role']}: {m['content']}") s = s + f"
{m['role']}: {m['content']}
" s = s + f"""
""" components.html(s, height=220) #st.markdown(s, unsafe_allow_html=True) if not 'submitted_query' in st.session_state: st.session_state.submitted_query = '' if not 'messages' in st.session_state: st.session_state.messages = [] if not 'last_user_query' in st.session_state: st.session_state.last_user_query = '' if utils.check_password(): st.markdown("", unsafe_allow_html=True) if st.sidebar.button("New Conversation") or "messages" not in st.session_state: st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create() st.session_state.new_conversation = True st.session_state.messages = [] st.markdown("

Raized.AI – Startups discovery demo

", unsafe_allow_html=True) #st.write("Search for a company in free text. Describe the type of company you are looking for, the problem they solve and the solution they provide. You can also copy in the description of a similar company to kick off the search.") st.markdown(""" """, unsafe_allow_html=True) with open("data/countries.json", "r") as f: countries = json.load(f)['countries'] header = st.sidebar.markdown("Filters") #new_conversation = st.sidebar.button("New Conversation", key="new_conversation") report_type = st.sidebar.selectbox(label="Response Type", options=["gemini", "assistant", "standard", "guided", "company_list", "clustered"], index=0) countries_selectbox = st.sidebar.multiselect("Country", countries, default=[]) all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America') region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions) all_bizmodels = ('B2B', 'B2C', 'eCommerce & Marketplace', 'Manufacturing', 'SaaS', 'Advertising', 'Commission', 'Subscription') bizmodel_selectbox = st.sidebar.multiselect("Business Model", all_bizmodels, default=all_bizmodels) st.markdown( ''' ''', unsafe_allow_html=True ) #tab_search, tab_advanced = st.tabs(["Search", "Settings"]) tab_search = st.container() with tab_search: #report_type = st.multiselect("Report Type", utils.get_prompts(), key="search_prompts_multiselect") st.session_state.history_container = st.container() with stylable_container( key="query_panel", css_styles=""" .stTextInput { position: fixed; bottom: 0px; background: white; z-index: 1000; padding-bottom: 2rem; padding-left: 1rem; padding-right: 1rem; padding-top: 1rem; border-top: 1px solid whitesmoke; height: 8rem; border-radius: 8px 8px 8px 8px; box-shadow: 0 -3px 3px whitesmoke; } """, ): query = st.text_input(key="user_query", label="Enter your query", placeholder="Tell me what startups you are looking for", label_visibility="collapsed") #cluster = st.checkbox("Cluster the results", value = False, key = "cluster") #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt)) tab_advanced = st.sidebar.expander("Settings") with tab_advanced: gemini_prompt = st.text_area("Gemini Prompt", value = google_default_instructions, height=400, key="advanced_gemini_prompt_content") default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content") #prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", ) #prompt_title_editable = st.text_input("Title", key="prompt_title_editable") #assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value = "asst_NHoxEosVlemDY7y5TYg8ftku") #value="asst_fkZtxo127nxKOCcwrwznuCs2") assistant_id = st.selectbox(label="OpenAI Assistant", options = [f"{a.id}|||{a.name}" for a in assistants], index = find_default_assistant_idx(assistants)) #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt)) #prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable)) #prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt)) #scrape_boost = st.number_input('Web to API content ratio', value=1.) top_k = st.number_input('# Top Results', value=30) is_debug = st.checkbox("Debug output", value = False, key="debug") openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model") index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0) liked_companies = st.text_input(label="liked companies", key='liked_companies') disliked_companies = st.text_input(label="disliked companies", key='disliked_companies') clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content") if report_type == "assistant" and not "assistant_thread" in st.session_state: st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create() if query != "" and (not 'new_conversation' in st.session_state or not st.session_state.new_conversation): # if report_type=="standard": # prompt = default_prompt # elif report_type=="clustered": # prompt = clustering_prompt # elif report_type=="guided": # prompt = "guided" # else: # prompt = "" #oai.start_conversation() st.session_state.report_type = report_type st.session_state.top_k = top_k st.session_state.index_namespace = index_namespace st.session_state.region = region_selectbox st.session_state.country = countries_selectbox if report_type=="gemini": run_googleai(query, gemini_prompt) else: i = assistant_id.index("|||") st.session_state.assistant_id = assistant_id[:i] run_query(query, report_type, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model, default_prompt, gemini_prompt) else: st.session_state.new_conversation = False