Spaces:
Sleeping
Sleeping
Commit
·
ef6f6bd
1
Parent(s):
d83af99
random calibration & they"re kept in the df
Browse files
app.py
CHANGED
@@ -102,11 +102,17 @@ def sample_embs(prompt_embeds):
|
|
102 |
def get_user_emb(embs, ys):
|
103 |
positives = [e for e, ys in zip(embs, ys) if ys == 1]
|
104 |
embs = random.sample(positives, min(8, len(positives)))
|
105 |
-
|
|
|
|
|
|
|
106 |
|
107 |
negs = [e for e, ys in zip(embs, ys) if ys == 0]
|
108 |
negative_embs = random.sample(negs, min(8, len(negs)))
|
109 |
-
|
|
|
|
|
|
|
110 |
|
111 |
image_embeds = torch.stack([sample_embs(negatives), sample_embs(positives)])
|
112 |
|
@@ -202,10 +208,12 @@ def pluck_img(user_id):
|
|
202 |
|
203 |
def next_image(calibrate_prompts, user_id):
|
204 |
with torch.no_grad():
|
205 |
-
|
206 |
-
|
|
|
207 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
208 |
return image, calibrate_prompts
|
|
|
209 |
else:
|
210 |
image = pluck_img(user_id)
|
211 |
return image, calibrate_prompts
|
@@ -330,7 +338,7 @@ Explore the latent space without text prompts based on your preferences. [rynmur
|
|
330 |
''', elem_id="description")
|
331 |
user_id = gr.State()
|
332 |
# calibration videos -- this is a misnomer now :D
|
333 |
-
calibrate_prompts = gr.State(
|
334 |
def l():
|
335 |
return None
|
336 |
|
@@ -428,8 +436,8 @@ for im in m_calibrate:
|
|
428 |
tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
|
429 |
tmp_df['user:rating'] = [{' ': ' '}]
|
430 |
tmp_df['text'] = ['']
|
431 |
-
|
432 |
-
|
433 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
434 |
|
435 |
glob_idx = 0
|
|
|
102 |
def get_user_emb(embs, ys):
|
103 |
positives = [e for e, ys in zip(embs, ys) if ys == 1]
|
104 |
embs = random.sample(positives, min(8, len(positives)))
|
105 |
+
if len(embs) == 0:
|
106 |
+
positives = torch.zeros_like(im_emb)[None]
|
107 |
+
else:
|
108 |
+
positives = torch.stack(embs, 1)
|
109 |
|
110 |
negs = [e for e, ys in zip(embs, ys) if ys == 0]
|
111 |
negative_embs = random.sample(negs, min(8, len(negs)))
|
112 |
+
if len(negative_embs) == 0:
|
113 |
+
negatives = torch.zeros_like(im_emb)[None]
|
114 |
+
else:
|
115 |
+
negatives = torch.stack(negative_embs, 1)
|
116 |
|
117 |
image_embeds = torch.stack([sample_embs(negatives), sample_embs(positives)])
|
118 |
|
|
|
208 |
|
209 |
def next_image(calibrate_prompts, user_id):
|
210 |
with torch.no_grad():
|
211 |
+
# once we've done so many random calibration prompts out of the full media
|
212 |
+
if len(m_calibrate) - len(calibrate_prompts) < 5:
|
213 |
+
cal_video = calibrate_prompts.pop(random.randint(0, len(calibrate_prompts)-1))
|
214 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
215 |
return image, calibrate_prompts
|
216 |
+
# we switch to just getting media by similarity.
|
217 |
else:
|
218 |
image = pluck_img(user_id)
|
219 |
return image, calibrate_prompts
|
|
|
338 |
''', elem_id="description")
|
339 |
user_id = gr.State()
|
340 |
# calibration videos -- this is a misnomer now :D
|
341 |
+
calibrate_prompts = gr.State( glob.glob('image_init/*') )
|
342 |
def l():
|
343 |
return None
|
344 |
|
|
|
436 |
tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
|
437 |
tmp_df['user:rating'] = [{' ': ' '}]
|
438 |
tmp_df['text'] = ['']
|
439 |
+
tmp_df['from_user_id'] = [0]
|
440 |
+
tmp_df['latest_user_to_rate'] = [0]
|
441 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
442 |
|
443 |
glob_idx = 0
|