Spaces:
Running
Running
Fangrui Liu
commited on
Commit
·
0b449a5
1
Parent(s):
b73f599
add selective db / feat / lang
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import numpy as np
|
|
3 |
import base64
|
4 |
from io import BytesIO
|
5 |
from multilingual_clip import pt_multilingual_clip
|
6 |
-
from transformers import CLIPTokenizerFast, AutoTokenizer
|
7 |
import torch
|
8 |
import logging
|
9 |
from os import environ
|
@@ -12,30 +12,22 @@ environ['TOKENIZERS_PARALLELISM'] = 'true'
|
|
12 |
|
13 |
|
14 |
db_name_map = {
|
15 |
-
"Unsplash Photos 25K": "mqdb_demo.
|
16 |
-
"RSICD: Remote Sensing Images 11K": "mqdb_demo.
|
17 |
}
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
DB_NAME = "mqdb_demo.unsplash_25k_clip_indexer"
|
20 |
-
MODEL_ID = 'M-CLIP/XLM-Roberta-Large-Vit-B-32'
|
21 |
DIMS = 512
|
22 |
# Ignore some bad links (broken in the dataset already)
|
23 |
BAD_IDS = {'9_9hzZVjV8s', 'RDs0THr4lGs', 'vigsqYux_-8',
|
24 |
'rsJtMXn3p_c', 'AcG-unN00gw', 'r1R_0ZNUcx0'}
|
25 |
|
26 |
|
27 |
-
@st.experimental_singleton(show_spinner=False)
|
28 |
-
def init_clip():
|
29 |
-
""" Initialize CLIP Model
|
30 |
-
|
31 |
-
Returns:
|
32 |
-
Tokenizer: CLIPTokenizerFast (which convert words into embeddings)
|
33 |
-
"""
|
34 |
-
clip = pt_multilingual_clip.MultilingualCLIP.from_pretrained(MODEL_ID)
|
35 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
36 |
-
return tokenizer, clip
|
37 |
-
|
38 |
-
|
39 |
@st.experimental_singleton(show_spinner=False)
|
40 |
def init_db():
|
41 |
""" Initialize the Database Connection
|
@@ -82,15 +74,15 @@ def query(xq, top_k=10):
|
|
82 |
# Using PREWHERE allows you to do column filter before vector search
|
83 |
xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
|
84 |
distance('topK={top_k}')(vector, {xq_s}) AS dist\
|
85 |
-
FROM {db_name_map[st.session_state.db_name_ref]} \
|
86 |
PREWHERE id NOT IN ({exclude_list})")
|
87 |
else:
|
88 |
xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
|
89 |
distance('topK={top_k}')(vector, {xq_s}) AS dist\
|
90 |
-
FROM {db_name_map[st.session_state.db_name_ref]}")
|
91 |
real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
|
92 |
distance('topK={top_k}')(vector, {xq_s}) AS dist\
|
93 |
-
FROM {db_name_map[st.session_state.db_name_ref]}")
|
94 |
top_k = real_xc
|
95 |
xc = [xi for xi in xc if xi['id'] not in st.session_state.meta or
|
96 |
st.session_state.meta[xi['id']] < 1]
|
@@ -166,38 +158,6 @@ class NormalizingLayer(torch.nn.Module):
|
|
166 |
return x / torch.norm(x, dim=-1, keepdim=True)
|
167 |
|
168 |
|
169 |
-
def prompt2vec(prompt: str):
|
170 |
-
""" Convert prompt into a computational vector
|
171 |
-
|
172 |
-
Args:
|
173 |
-
prompt (str): Text to be tokenized
|
174 |
-
|
175 |
-
Returns:
|
176 |
-
xq: vector from the tokenizer, representing the original prompt
|
177 |
-
"""
|
178 |
-
# inputs = tokenizer(prompt, return_tensors='pt')
|
179 |
-
# out = clip.get_text_features(**inputs)
|
180 |
-
out = clip.forward(prompt, tokenizer)
|
181 |
-
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
182 |
-
return xq
|
183 |
-
|
184 |
-
|
185 |
-
def pil_to_bytes(img):
|
186 |
-
""" Convert a Pillow image into base64
|
187 |
-
|
188 |
-
Args:
|
189 |
-
img (PIL.Image): Pillow (PIL) Image
|
190 |
-
|
191 |
-
Returns:
|
192 |
-
img_bin: image in base64 format
|
193 |
-
"""
|
194 |
-
with BytesIO() as buf:
|
195 |
-
img.save(buf, format='jpeg')
|
196 |
-
img_bin = buf.getvalue()
|
197 |
-
img_bin = base64.b64encode(img_bin).decode('utf-8')
|
198 |
-
return img_bin
|
199 |
-
|
200 |
-
|
201 |
def card(i, url):
|
202 |
return f'<img id="img{i}" src="{url}" width="200px;">'
|
203 |
|
@@ -286,6 +246,63 @@ def delete_element(element):
|
|
286 |
del element
|
287 |
|
288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
st.markdown("""
|
290 |
<link
|
291 |
rel="stylesheet"
|
@@ -323,13 +340,23 @@ messages = [
|
|
323 |
"""
|
324 |
]
|
325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
with st.spinner("Connecting DB..."):
|
327 |
st.session_state.meta, st.session_state.index = init_db()
|
328 |
|
329 |
with st.spinner("Loading Models..."):
|
330 |
# Initialize CLIP model
|
331 |
if 'xq' not in st.session_state:
|
332 |
-
|
|
|
|
|
333 |
st.session_state.query_num = 0
|
334 |
|
335 |
if 'xq' not in st.session_state:
|
@@ -347,8 +374,15 @@ if 'xq' not in st.session_state:
|
|
347 |
start = [st.empty(), st.empty(), st.empty(), st.empty(),
|
348 |
st.empty(), st.empty(), st.empty()]
|
349 |
start[0].info(msg)
|
350 |
-
|
351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
prompt = start[2].text_input(
|
353 |
"Prompt:", value="", placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...")
|
354 |
if len(prompt) > 0:
|
@@ -388,7 +422,8 @@ if 'xq' not in st.session_state:
|
|
388 |
else:
|
389 |
print(f"Input prompt is {prompt}")
|
390 |
# Tokenize the vectors
|
391 |
-
|
|
|
392 |
st.session_state.xq = xq
|
393 |
st.session_state.orig_xq = xq
|
394 |
_ = [elem.empty() for elem in start]
|
|
|
3 |
import base64
|
4 |
from io import BytesIO
|
5 |
from multilingual_clip import pt_multilingual_clip
|
6 |
+
from transformers import CLIPTokenizerFast, AutoTokenizer, CLIPModel
|
7 |
import torch
|
8 |
import logging
|
9 |
from os import environ
|
|
|
12 |
|
13 |
|
14 |
db_name_map = {
|
15 |
+
"Unsplash Photos 25K": lambda feat: f"mqdb_demo.unsplash_25k_{feat}_indexer",
|
16 |
+
"RSICD: Remote Sensing Images 11K": lambda feat: f"mqdb_demo.rsicd_{feat}_b_32",
|
17 |
}
|
18 |
+
feat_name_map = {
|
19 |
+
'Vanilla CLIP': "clip",
|
20 |
+
'CLIP finetuned on RSICD': "cliprsicd"
|
21 |
+
}
|
22 |
+
|
23 |
|
24 |
DB_NAME = "mqdb_demo.unsplash_25k_clip_indexer"
|
|
|
25 |
DIMS = 512
|
26 |
# Ignore some bad links (broken in the dataset already)
|
27 |
BAD_IDS = {'9_9hzZVjV8s', 'RDs0THr4lGs', 'vigsqYux_-8',
|
28 |
'rsJtMXn3p_c', 'AcG-unN00gw', 'r1R_0ZNUcx0'}
|
29 |
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
@st.experimental_singleton(show_spinner=False)
|
32 |
def init_db():
|
33 |
""" Initialize the Database Connection
|
|
|
74 |
# Using PREWHERE allows you to do column filter before vector search
|
75 |
xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
|
76 |
distance('topK={top_k}')(vector, {xq_s}) AS dist\
|
77 |
+
FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])} \
|
78 |
PREWHERE id NOT IN ({exclude_list})")
|
79 |
else:
|
80 |
xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
|
81 |
distance('topK={top_k}')(vector, {xq_s}) AS dist\
|
82 |
+
FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])}")
|
83 |
real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
|
84 |
distance('topK={top_k}')(vector, {xq_s}) AS dist\
|
85 |
+
FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])}")
|
86 |
top_k = real_xc
|
87 |
xc = [xi for xi in xc if xi['id'] not in st.session_state.meta or
|
88 |
st.session_state.meta[xi['id']] < 1]
|
|
|
158 |
return x / torch.norm(x, dim=-1, keepdim=True)
|
159 |
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
def card(i, url):
|
162 |
return f'<img id="img{i}" src="{url}" width="200px;">'
|
163 |
|
|
|
246 |
del element
|
247 |
|
248 |
|
249 |
+
@st.experimental_singleton(show_spinner=False)
|
250 |
+
def init_clip_mlang():
|
251 |
+
""" Initialize CLIP Model
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
Tokenizer: CLIPTokenizerFast (which convert words into embeddings)
|
255 |
+
"""
|
256 |
+
MODEL_ID = 'M-CLIP/XLM-Roberta-Large-Vit-B-32'
|
257 |
+
clip = pt_multilingual_clip.MultilingualCLIP.from_pretrained(MODEL_ID)
|
258 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
259 |
+
return tokenizer, clip
|
260 |
+
|
261 |
+
@st.experimental_singleton(show_spinner=False)
|
262 |
+
def init_clip_vanilla():
|
263 |
+
""" Initialize CLIP Model
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
Tokenizer: CLIPTokenizerFast (which convert words into embeddings)
|
267 |
+
"""
|
268 |
+
MODEL_ID = "openai/clip-vit-base-patch32"
|
269 |
+
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
|
270 |
+
clip = CLIPModel.from_pretrained(MODEL_ID)
|
271 |
+
return tokenizer, clip
|
272 |
+
|
273 |
+
|
274 |
+
@st.experimental_singleton(show_spinner=False)
|
275 |
+
def init_clip_rsicd():
|
276 |
+
""" Initialize CLIP Model
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
Tokenizer: CLIPTokenizerFast (which convert words into embeddings)
|
280 |
+
"""
|
281 |
+
MODEL_ID = "flax-community/clip-rsicd"
|
282 |
+
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
|
283 |
+
clip = CLIPModel.from_pretrained(MODEL_ID)
|
284 |
+
return tokenizer, clip
|
285 |
+
|
286 |
+
|
287 |
+
def prompt2vec_mlang(prompt: str, tokenizer, clip):
|
288 |
+
""" Convert prompt into a computational vector
|
289 |
+
|
290 |
+
Args:
|
291 |
+
prompt (str): Text to be tokenized
|
292 |
+
|
293 |
+
Returns:
|
294 |
+
xq: vector from the tokenizer, representing the original prompt
|
295 |
+
"""
|
296 |
+
out = clip.forward(prompt, tokenizer)
|
297 |
+
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
298 |
+
return xq
|
299 |
+
|
300 |
+
def prompt2vec_vanilla(prompt: str, tokenizer, clip):
|
301 |
+
inputs = tokenizer(prompt, return_tensors='pt')
|
302 |
+
out = clip.get_text_features(**inputs)
|
303 |
+
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
304 |
+
return xq
|
305 |
+
|
306 |
st.markdown("""
|
307 |
<link
|
308 |
rel="stylesheet"
|
|
|
340 |
"""
|
341 |
]
|
342 |
|
343 |
+
text_model_map = {
|
344 |
+
'Multi Lingual': {'Vanilla CLIP': [prompt2vec_mlang, ]},
|
345 |
+
'English': {'Vanilla CLIP': [prompt2vec_vanilla, ],
|
346 |
+
'CLIP finetuned on RSICD': [prompt2vec_vanilla, ],
|
347 |
+
}
|
348 |
+
}
|
349 |
+
|
350 |
+
|
351 |
with st.spinner("Connecting DB..."):
|
352 |
st.session_state.meta, st.session_state.index = init_db()
|
353 |
|
354 |
with st.spinner("Loading Models..."):
|
355 |
# Initialize CLIP model
|
356 |
if 'xq' not in st.session_state:
|
357 |
+
text_model_map['Multi Lingual']['Vanilla CLIP'].append(init_clip_mlang())
|
358 |
+
text_model_map['English']['Vanilla CLIP'].append(init_clip_vanilla())
|
359 |
+
text_model_map['English']['CLIP finetuned on RSICD'].append(init_clip_rsicd())
|
360 |
st.session_state.query_num = 0
|
361 |
|
362 |
if 'xq' not in st.session_state:
|
|
|
374 |
start = [st.empty(), st.empty(), st.empty(), st.empty(),
|
375 |
st.empty(), st.empty(), st.empty()]
|
376 |
start[0].info(msg)
|
377 |
+
start_col = start[1].columns(3)
|
378 |
+
st.session_state.db_name_ref = start_col[0].selectbox("Select Database:", list(db_name_map.keys()))
|
379 |
+
st.session_state.lang = start_col[1].selectbox("Select Language:", list(text_model_map.keys()))
|
380 |
+
st.session_state.feat_name = start_col[2].selectbox("Select Image Feature:",
|
381 |
+
list(text_model_map[st.session_state.lang].keys()))
|
382 |
+
if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K":
|
383 |
+
st.warning('If you are searching for Remote Sensing Images, \
|
384 |
+
try to use prompt "An aerial photograph of <your-real-query>" \
|
385 |
+
to obtain best search experience!')
|
386 |
prompt = start[2].text_input(
|
387 |
"Prompt:", value="", placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...")
|
388 |
if len(prompt) > 0:
|
|
|
422 |
else:
|
423 |
print(f"Input prompt is {prompt}")
|
424 |
# Tokenize the vectors
|
425 |
+
p2v_func, args = text_model_map[st.session_state.lang][st.session_state.feat_name]
|
426 |
+
xq = p2v_func(prompt, *args)
|
427 |
st.session_state.xq = xq
|
428 |
st.session_state.orig_xq = xq
|
429 |
_ = [elem.empty() for elem in start]
|