nft-search / app.py
sean1's picture
Update app.py
eedd04d
import streamlit as st
import numpy as np
import torch
from datasets import load_dataset
from html import escape
from transformers import RobertaModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('volen/nft-text', use_auth_token=st.secrets["access_token"])
text_encoder = RobertaModel.from_pretrained('volen/nft-text', use_auth_token=st.secrets["access_token"]).eval()
image_embeddings = torch.load('image_embeddings.pt', map_location=torch.device('cpu'))
links = np.load('image_links.npy', allow_pickle=True)
@st.experimental_memo
def image_search(query, top_k=10):
with torch.no_grad():
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
_, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True)
return [links[i] for i in indices[:top_k]]
def get_html(url_list):
html = "<div style='margin-top: 50px; max-width: 1100px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for url in url_list:
html2 = f"<img style='height: 180px; margin: 2px' src='{escape(url)}'>"
html = html + html2
html += "</div>"
return html
description = '''
# nft search
- Enter your search and hit enter
- Note: So far we only support BAYC, cool cats, doodles and MAYC
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input('search box', value='cat beanie')
c.text("It'll take a few secs to load new images")
if len(query) > 0:
results = image_search(query)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main()