rynmurdock commited on
Commit
de9a113
·
1 Parent(s): fc2fdc7
app.py CHANGED
@@ -8,7 +8,9 @@ import glob
8
  import config
9
  from model import get_model_and_tokenizer
10
 
11
- model, model.prior_pipe.image_encoder = get_model_and_tokenizer(config.model_path,
 
 
12
  'cuda', torch.bfloat16)
13
 
14
  # TODO unify/merge origin and this
@@ -16,6 +18,7 @@ model, model.prior_pipe.image_encoder = get_model_and_tokenizer(config.model_pat
16
 
17
  device = "cuda"
18
 
 
19
 
20
  import spaces
21
  import matplotlib.pyplot as plt
@@ -51,14 +54,14 @@ def generate_gpu(in_im_embs, prompt='the scene'):
51
  with torch.no_grad():
52
  in_im_embs = in_im_embs.to('cuda')
53
 
54
- negative_image_embeds = in_im_embs[0] # model.prior_pipe.get_zero_embed()
55
  positive_image_embeds = in_im_embs[1]
56
 
57
  images = model.kandinsky_pipe(
58
  num_inference_steps=50,
59
  image_embeds=positive_image_embeds,
60
  negative_image_embeds=negative_image_embeds,
61
- guidance_scale=15,
62
  ).images[0]
63
  cond = (
64
  model.prior_pipe.image_processor(images, return_tensors="pt")
@@ -91,11 +94,10 @@ def generate(in_im_embs, ):
91
  @spaces.GPU()
92
  def sample_embs(prompt_embeds):
93
  latent = torch.randn(prompt_embeds.shape[0], 1, prompt_embeds.shape[-1])
94
- if prompt_embeds.shape[1] < 8: # TODO grab as `k` arg from config
95
- prompt_embeds = torch.nn.functional.pad(prompt_embeds, [0, 0, 0, 8-prompt_embeds.shape[1]])
96
- assert prompt_embeds.shape[1] == 8, f"The model is set to take `k`` cond image embeds but is shape {prompt_embeds.shape}"
97
  image_embeds = model(latent.to('cuda'), prompt_embeds.to('cuda')).predicted_image_embedding
98
-
99
  return image_embeds
100
 
101
  @spaces.GPU()
@@ -113,6 +115,8 @@ def get_user_emb(embs, ys):
113
  else:
114
  negative_embs = random.sample(negs, min(4, len(negs))) + negs[-4:]
115
  negatives = torch.stack(negative_embs, 1)
 
 
116
 
117
  image_embeds = torch.stack([sample_embs(negatives), sample_embs(positives)])
118
 
@@ -175,6 +179,7 @@ def background_next_image():
175
  prevs_df = pd.concat((prevs_df.iloc[:6], prevs_df.iloc[7:]))
176
 
177
  def pluck_img(user_id):
 
178
  rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) is not None for i in prevs_df.iterrows()]]
179
  ems = rated_rows['embeddings'].to_list()
180
  ys = [i[user_id][0] for i in rated_rows['user:rating'].to_list()]
 
8
  import config
9
  from model import get_model_and_tokenizer
10
 
11
+ torch.set_float32_matmul_precision('high')
12
+
13
+ model, model.prior_pipe.image_encoder = get_model_and_tokenizer(config.model_path,
14
  'cuda', torch.bfloat16)
15
 
16
  # TODO unify/merge origin and this
 
18
 
19
  device = "cuda"
20
 
21
+ k = config.k
22
 
23
  import spaces
24
  import matplotlib.pyplot as plt
 
54
  with torch.no_grad():
55
  in_im_embs = in_im_embs.to('cuda')
56
 
57
+ negative_image_embeds = in_im_embs[0]# if random.random() < .3 else model.prior_pipe.get_zero_embed()
58
  positive_image_embeds = in_im_embs[1]
59
 
60
  images = model.kandinsky_pipe(
61
  num_inference_steps=50,
62
  image_embeds=positive_image_embeds,
63
  negative_image_embeds=negative_image_embeds,
64
+ guidance_scale=8,
65
  ).images[0]
66
  cond = (
67
  model.prior_pipe.image_processor(images, return_tensors="pt")
 
94
  @spaces.GPU()
95
  def sample_embs(prompt_embeds):
96
  latent = torch.randn(prompt_embeds.shape[0], 1, prompt_embeds.shape[-1])
97
+ if prompt_embeds.shape[1] < k:
98
+ prompt_embeds = torch.nn.functional.pad(prompt_embeds, [0, 0, 0, k-prompt_embeds.shape[1]])
99
+ assert prompt_embeds.shape[1] == k, f"The model is set to take `k`` cond image embeds but is shape {prompt_embeds.shape}"
100
  image_embeds = model(latent.to('cuda'), prompt_embeds.to('cuda')).predicted_image_embedding
 
101
  return image_embeds
102
 
103
  @spaces.GPU()
 
115
  else:
116
  negative_embs = random.sample(negs, min(4, len(negs))) + negs[-4:]
117
  negatives = torch.stack(negative_embs, 1)
118
+ # if random.random() < .5:
119
+ # negatives = torch.zeros_like(negatives)
120
 
121
  image_embeds = torch.stack([sample_embs(negatives), sample_embs(positives)])
122
 
 
179
  prevs_df = pd.concat((prevs_df.iloc[:6], prevs_df.iloc[7:]))
180
 
181
  def pluck_img(user_id):
182
+ # TODO pluck images based on similarity but also based on diversity by cluster every few times.
183
  rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) is not None for i in prevs_df.iterrows()]]
184
  ems = rated_rows['embeddings'].to_list()
185
  ys = [i[user_id][0] for i in rated_rows['user:rating'].to_list()]
config.py CHANGED
@@ -12,5 +12,5 @@ batch_size = 16
12
  number_k_clip_embed = 16 # divide by this to determine bundling together of sequences -> CLIP
13
  num_workers = 32
14
  seed = 107
15
-
16
  # TODO config option to swap to diffusion?
 
12
  number_k_clip_embed = 16 # divide by this to determine bundling together of sequences -> CLIP
13
  num_workers = 32
14
  seed = 107
15
+ k = 8
16
  # TODO config option to swap to diffusion?
last_epoch_ckpt/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d65a902c101345526b244420a5e6f495a947909db28015840afa9bacd557936b
3
  size 136790920
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae34b5c319b9c804e1e82c93f78821b880553d2ac60ff628003175334ee9066d
3
  size 136790920
prior/pipeline_kandinsky_prior.py CHANGED
@@ -498,14 +498,13 @@ class KandinskyPriorPipeline(DiffusionPipeline):
498
  if negative_prompt is None:
499
  # zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
500
 
501
- # using the same hidden states or different hidden states?
502
-
503
- hidden_states = torch.randn(
504
- (batch_size, prompt_embeds.shape[-1]),
505
- device=prompt_embeds.device,
506
- dtype=prompt_embeds.dtype,
507
- generator=generator,
508
- )
509
 
510
  latents = self.prior(
511
  hidden_states,
@@ -541,7 +540,7 @@ class KandinskyPriorPipeline(DiffusionPipeline):
541
 
542
  if not return_dict:
543
  return (image_embeddings, zero_embeds)
544
-
545
  return KandinskyPriorPipelineOutput(
546
  image_embeds=image_embeddings, negative_image_embeds=zero_embeds
547
  )
 
498
  if negative_prompt is None:
499
  # zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
500
 
501
+ # use the same hidden states or different hidden states?
502
+ # hidden_states = torch.randn(
503
+ # (batch_size, prompt_embeds.shape[-1]),
504
+ # device=prompt_embeds.device,
505
+ # dtype=prompt_embeds.dtype,
506
+ # generator=generator,
507
+ # )
 
508
 
509
  latents = self.prior(
510
  hidden_states,
 
540
 
541
  if not return_dict:
542
  return (image_embeddings, zero_embeds)
543
+
544
  return KandinskyPriorPipelineOutput(
545
  image_embeds=image_embeddings, negative_image_embeds=zero_embeds
546
  )