rynmurdock commited on
Commit
ef6f6bd
·
1 Parent(s): d83af99

random calibration & they"re kept in the df

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -102,11 +102,17 @@ def sample_embs(prompt_embeds):
102
  def get_user_emb(embs, ys):
103
  positives = [e for e, ys in zip(embs, ys) if ys == 1]
104
  embs = random.sample(positives, min(8, len(positives)))
105
- positives = torch.stack(embs, 1)
 
 
 
106
 
107
  negs = [e for e, ys in zip(embs, ys) if ys == 0]
108
  negative_embs = random.sample(negs, min(8, len(negs)))
109
- negatives = torch.stack(negative_embs, 1)
 
 
 
110
 
111
  image_embeds = torch.stack([sample_embs(negatives), sample_embs(positives)])
112
 
@@ -202,10 +208,12 @@ def pluck_img(user_id):
202
 
203
  def next_image(calibrate_prompts, user_id):
204
  with torch.no_grad():
205
- if len(calibrate_prompts) > 0:
206
- cal_video = calibrate_prompts.pop(0)
 
207
  image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
208
  return image, calibrate_prompts
 
209
  else:
210
  image = pluck_img(user_id)
211
  return image, calibrate_prompts
@@ -330,7 +338,7 @@ Explore the latent space without text prompts based on your preferences. [rynmur
330
  ''', elem_id="description")
331
  user_id = gr.State()
332
  # calibration videos -- this is a misnomer now :D
333
- calibrate_prompts = gr.State( [l for l in random.sample(glob.glob('image_init/*'), k=8)] )
334
  def l():
335
  return None
336
 
@@ -428,8 +436,8 @@ for im in m_calibrate:
428
  tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
429
  tmp_df['user:rating'] = [{' ': ' '}]
430
  tmp_df['text'] = ['']
431
- # tmp_df['from_user_id'] = [0]
432
- # tmp_df['latest_user_to_rate'] = [0]
433
  prevs_df = pd.concat((prevs_df, tmp_df))
434
 
435
  glob_idx = 0
 
102
  def get_user_emb(embs, ys):
103
  positives = [e for e, ys in zip(embs, ys) if ys == 1]
104
  embs = random.sample(positives, min(8, len(positives)))
105
+ if len(embs) == 0:
106
+ positives = torch.zeros_like(im_emb)[None]
107
+ else:
108
+ positives = torch.stack(embs, 1)
109
 
110
  negs = [e for e, ys in zip(embs, ys) if ys == 0]
111
  negative_embs = random.sample(negs, min(8, len(negs)))
112
+ if len(negative_embs) == 0:
113
+ negatives = torch.zeros_like(im_emb)[None]
114
+ else:
115
+ negatives = torch.stack(negative_embs, 1)
116
 
117
  image_embeds = torch.stack([sample_embs(negatives), sample_embs(positives)])
118
 
 
208
 
209
  def next_image(calibrate_prompts, user_id):
210
  with torch.no_grad():
211
+ # once we've done so many random calibration prompts out of the full media
212
+ if len(m_calibrate) - len(calibrate_prompts) < 5:
213
+ cal_video = calibrate_prompts.pop(random.randint(0, len(calibrate_prompts)-1))
214
  image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
215
  return image, calibrate_prompts
216
+ # we switch to just getting media by similarity.
217
  else:
218
  image = pluck_img(user_id)
219
  return image, calibrate_prompts
 
338
  ''', elem_id="description")
339
  user_id = gr.State()
340
  # calibration videos -- this is a misnomer now :D
341
+ calibrate_prompts = gr.State( glob.glob('image_init/*') )
342
  def l():
343
  return None
344
 
 
436
  tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
437
  tmp_df['user:rating'] = [{' ': ' '}]
438
  tmp_df['text'] = ['']
439
+ tmp_df['from_user_id'] = [0]
440
+ tmp_df['latest_user_to_rate'] = [0]
441
  prevs_df = pd.concat((prevs_df, tmp_df))
442
 
443
  glob_idx = 0