openfree commited on
Commit
b982501
ยท
verified ยท
1 Parent(s): 985e1a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -153
app.py CHANGED
@@ -1,8 +1,5 @@
1
  import time
2
-
3
  import gradio as gr
4
- import spaces
5
- import numpy as np
6
  import torch
7
  from einops import rearrange, repeat
8
  from PIL import Image
@@ -13,12 +10,13 @@ from flux.util import (
13
  load_ae,
14
  load_clip,
15
  load_flow_model,
16
- load_flow_model_quintized,
17
  load_t5,
18
  )
19
  from pulid.pipeline_flux import PuLIDPipeline
20
  from pulid.utils import resize_numpy_image_long, seed_everything
21
 
 
 
22
 
23
  def get_models(name: str, device: torch.device, offload: bool):
24
  t5 = load_t5(device, max_length=128)
@@ -34,7 +32,7 @@ class FluxGenerator:
34
  self.device = torch.device('cuda')
35
  self.offload = False
36
  self.model_name = 'flux-dev'
37
- self.model, self.ae, self.t5, self.clip_model = get_models(
38
  self.model_name,
39
  device=self.device,
40
  offload=self.offload,
@@ -46,7 +44,6 @@ class FluxGenerator:
46
  flux_generator = FluxGenerator()
47
 
48
 
49
- @spaces.GPU(duration=120)
50
  @torch.inference_mode()
51
  def generate_image(
52
  prompt: str,
@@ -71,17 +68,9 @@ def generate_image(
71
  perform_editing: bool = True,
72
  inversion_true_cfg: float = 1.0,
73
  ):
74
- """
75
- Core function that performs the image generation.
76
- """
77
- # self.t5.to(self.device)
78
- # self.clip_model.to(self.device)
79
- # self.ae.to(self.device)
80
- # self.model.to(self.device)
81
-
82
  flux_generator.t5.max_length = max_sequence_length
83
 
84
- # If seed == -1, random
85
  seed = int(seed)
86
  if seed == -1:
87
  seed = None
@@ -99,14 +88,12 @@ def generate_image(
99
  opts.seed = torch.Generator(device="cpu").seed()
100
 
101
  seed_everything(opts.seed)
102
-
103
  print(f"Generating prompt: '{opts.prompt}' (seed={opts.seed})...")
104
  t0 = time.perf_counter()
105
 
106
  use_true_cfg = abs(true_cfg - 1.0) > 1e-6
107
 
108
-
109
- # 1) Prepare input noise
110
  noise = get_noise(
111
  num_samples=1,
112
  height=opts.height,
@@ -119,65 +106,42 @@ def generate_image(
119
  noise = rearrange(noise, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
120
  if noise.shape[0] == 1 and bs > 1:
121
  noise = repeat(noise, "1 ... -> bs ...", bs=bs)
122
- # Encode id_image directly here
123
- encode_t0 = time.perf_counter()
124
 
125
- # Resize image
 
126
  id_image = id_image.resize((opts.width, opts.height), resample=Image.LANCZOS)
127
-
128
- # Convert image to torch.Tensor and scale to [-1, 1]
129
- x = np.array(id_image).astype(np.float32)
130
- x = torch.from_numpy(x) # shape: (H, W, C)
131
- x = (x / 127.5) - 1.0 # now in [-1, 1]
132
- x = rearrange(x, "h w c -> 1 c h w") # shape: (1, C, H, W)
133
  x = x.to(flux_generator.device)
134
- # Encode with autocast
135
  with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
136
  x = flux_generator.ae.encode(x)
137
-
138
  x = x.to(torch.bfloat16)
139
 
140
- # Offload if needed
141
- if flux_generator.offload:
142
- flux_generator.ae.encoder.to("cpu")
143
- torch.cuda.empty_cache()
144
-
145
  encode_t1 = time.perf_counter()
146
  print(f"Encoded in {encode_t1 - encode_t0:.2f} seconds.")
147
 
148
  timesteps = get_schedule(opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=False)
149
 
150
- # 2) Prepare text embeddings
151
- if flux_generator.offload:
152
- flux_generator.t5 = flux_generator.t5.to(flux_generator.device)
153
- flux_generator.clip_model = flux_generator.clip_model.to(flux_generator.device)
154
-
155
  inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=opts.prompt)
156
  inp_inversion = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt="")
157
  inp_neg = None
158
  if use_true_cfg:
159
  inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=neg_prompt)
160
 
161
- # Offload text encoders, load ID detection to GPU
162
- if flux_generator.offload:
163
- flux_generator.t5 = flux_generator.t5.cpu()
164
- flux_generator.clip_model = flux_generator.clip_model.cpu()
165
- torch.cuda.empty_cache()
166
- flux_generator.pulid_model.components_to_device(torch.device("cuda"))
167
-
168
- # 3) ID Embeddings (optional)
169
  id_embeddings = None
170
  uncond_id_embeddings = None
171
  if id_image is not None:
172
  id_image = np.array(id_image)
173
  id_image = resize_numpy_image_long(id_image, 1024)
174
  id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
175
- else:
176
- id_embeddings = None
177
- uncond_id_embeddings = None
178
 
179
  y_0 = inp["img"].clone().detach()
180
 
 
181
  inverted = None
182
  if perform_inversion:
183
  inverted = rf_inversion(
@@ -198,10 +162,10 @@ def generate_image(
198
  y_1=noise,
199
  gamma=gamma
200
  )
201
-
202
  img = inverted
203
  else:
204
  img = noise
 
205
  inp["img"] = img
206
  inp_inversion["img"] = img
207
 
@@ -251,13 +215,7 @@ def generate_image(
251
  tau=tau,
252
  )
253
 
254
- # Offload flux model, load auto-decoder
255
- if flux_generator.offload:
256
- flux_generator.model.cpu()
257
- torch.cuda.empty_cache()
258
- flux_generator.ae.decoder.to(x.device)
259
-
260
- # 5) Decode latents
261
  if edited is not None:
262
  edited = unpack(edited.float(), opts.height, opts.width)
263
  with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
@@ -273,14 +231,10 @@ def generate_image(
273
  with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
274
  recon = flux_generator.ae.decode(recon)
275
 
276
- if flux_generator.offload:
277
- flux_generator.ae.decoder.cpu()
278
- torch.cuda.empty_cache()
279
-
280
  t1 = time.perf_counter()
281
  print(f"Done in {t1 - t0:.2f} seconds.")
282
 
283
- # Convert to PIL
284
  if edited is not None:
285
  edited = edited.clamp(-1, 1)
286
  edited = rearrange(edited[0], "c h w -> h w c")
@@ -299,84 +253,58 @@ def generate_image(
299
  return edited, str(opts.seed), flux_generator.pulid_model.debug_img_list
300
 
301
 
302
-
303
-
304
- def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu",
305
- offload: bool = False, aggressive_offload: bool = False):
306
-
307
- with gr.Blocks(theme = "apriel") as demo:
308
-
309
  with gr.Row():
310
  with gr.Column():
311
- prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic")
312
- id_image = gr.Image(label="ID Image", type="pil")
313
- id_weight = gr.Slider(0.0, 1.0, 0.4, step=0.05, label="id weight")
314
-
315
- width = gr.Slider(256, 1536, 1024, step=16, label="Width", visible=args.dev)
316
- height = gr.Slider(256, 1536, 1024, step=16, label="Height", visible=args.dev)
317
- num_steps = gr.Slider(1, 24, 16, step=1, label="Number of steps")
318
- guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance")
319
-
320
- with gr.Accordion("Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG", open=False): # noqa E501
321
- neg_prompt = gr.Textbox(
322
- label="Negative Prompt",
323
- value="")
324
- true_cfg = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="true CFG scale")
325
- timestep_to_start_cfg = gr.Slider(0, 20, 1, step=1, label="timestep to start cfg", visible=args.dev)
326
- start_step = gr.Slider(0, 10, 0, step=1, label="timestep to start inserting ID")
327
- seed = gr.Textbox(-1, label="Seed (-1 for random)")
328
- max_sequence_length = gr.Slider(128, 512, 128, step=128,
329
- label="max_sequence_length for prompt (T5), small will be faster")
330
- gr.Markdown("### RF Inversion Options")
331
- gamma = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="gamma")
332
- eta = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="eta")
333
- s = gr.Slider(0.0, 1.0, 0.0, step=0.1, label="s")
334
- tau = gr.Slider(0, 20, 2, step=1, label="tau")
335
-
336
- generate_btn = gr.Button("Generate")
337
 
338
  with gr.Column():
339
- output_image = gr.Image(label="Generated Image")
340
- seed_output = gr.Textbox(label="Used Seed")
341
- intermediate_output = gr.Gallery(label='Output', elem_id="gallery", visible=args.dev)
342
-
343
-
344
- with gr.Row(), gr.Column():
345
- gr.Markdown("## Examples")
346
- example_inps = [
347
- [
348
- 'a portrait of a clown',
349
- 'example_inputs/unsplash/lhon-karwan-11tbHtK5STE-unsplash.jpg',
350
- 0.5, 3.5, 42, 5.0, 0.7
351
- ],
352
- [
353
- 'a portrait of a zombie',
354
- 'example_inputs/unsplash/baruk-granda-cfLL_jHQ-Iw-unsplash.jpg',
355
- 0.4, 3.5, 42, 5.0, 0.7
356
- ],
357
- [
358
- 'a portrait of an elf',
359
- 'example_inputs/unsplash/masoud-razeghi--qsrZhXPius-unsplash.jpg',
360
- 0.5, 3.5, 42, 5.0, 0.7
361
- ],
362
- [
363
- 'a portrait of a demon',
364
- 'example_inputs/unsplash/marcin-sajur-nZdMgqvYPBY-unsplash.jpg',
365
- 0.3, 3.5, 42, 5.0, 0.7
366
- ],
367
- [
368
- 'a portrait of a superhero',
369
- 'example_inputs/unsplash/gus-tu-njana-Mf4MN7MZqcE-unsplash.jpg',
370
- 0.2, 3.5, 42, 5.0, 0.8
371
- ],
372
  ]
373
- gr.Examples(examples=example_inps, inputs=[prompt, id_image, id_weight, guidance, seed, true_cfg, eta])
 
374
 
375
  generate_btn.click(
376
  fn=generate_image,
377
- inputs=[prompt, id_image, width, height, num_steps, start_step, guidance, seed, id_weight, neg_prompt,
378
- true_cfg, timestep_to_start_cfg, max_sequence_length, gamma, eta, s, tau],
379
- outputs=[output_image, seed_output, intermediate_output],
380
  )
381
 
382
  return demo
@@ -384,30 +312,18 @@ def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_ava
384
 
385
  if __name__ == "__main__":
386
  import argparse
 
387
 
388
  parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
389
- parser.add_argument('--version', type=str, default='v0.9.1', help='version of the model', choices=['v0.9.0', 'v0.9.1'])
390
- parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'),
391
- help="currently only support flux-dev")
392
- parser.add_argument("--device", type=str, default="cuda", help="Device to use")
393
- parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
394
- parser.add_argument("--aggressive_offload", action="store_true", help="Offload model more aggressively to CPU when not in use, for 24G GPUs")
395
- parser.add_argument("--fp8", action="store_true", help="use flux-dev-fp8 model")
396
- parser.add_argument("--onnx_provider", type=str, default="gpu", choices=["gpu", "cpu"],
397
- help="set onnx_provider to cpu (default gpu) can help reduce RAM usage, and when combined with"
398
- "fp8 option, the peak RAM is under 15GB")
399
- parser.add_argument("--port", type=int, default=8080, help="Port to use")
400
- parser.add_argument("--dev", action='store_true', help="Development mode")
401
- parser.add_argument("--pretrained_model", type=str, help='for development')
402
  args = parser.parse_args()
403
 
404
- # args.fp8 = True
405
- if args.aggressive_offload:
406
- args.offload = True
407
-
408
  print(f"Using device: {args.device}")
409
- print(f"fp8: {args.fp8}")
410
  print(f"Offload: {args.offload}")
411
 
412
- demo = create_demo(args, args.name, args.device, args.offload, args.aggressive_offload)
413
- demo.launch(ssr_mode=False)
 
1
  import time
 
2
  import gradio as gr
 
 
3
  import torch
4
  from einops import rearrange, repeat
5
  from PIL import Image
 
10
  load_ae,
11
  load_clip,
12
  load_flow_model,
 
13
  load_t5,
14
  )
15
  from pulid.pipeline_flux import PuLIDPipeline
16
  from pulid.utils import resize_numpy_image_long, seed_everything
17
 
18
+ # ๊ฐ„๋‹จํ•œ ์ธ์šฉ ์ •๋ณด ์ถ”๊ฐ€
19
+ _CITE_ = """PuLID: Person-under-Language Image Diffusion Model"""
20
 
21
  def get_models(name: str, device: torch.device, offload: bool):
22
  t5 = load_t5(device, max_length=128)
 
32
  self.device = torch.device('cuda')
33
  self.offload = False
34
  self.model_name = 'flux-dev'
35
+ self.model, self.ae, self.t5, self.clip_model = get_models(
36
  self.model_name,
37
  device=self.device,
38
  offload=self.offload,
 
44
  flux_generator = FluxGenerator()
45
 
46
 
 
47
  @torch.inference_mode()
48
  def generate_image(
49
  prompt: str,
 
68
  perform_editing: bool = True,
69
  inversion_true_cfg: float = 1.0,
70
  ):
 
 
 
 
 
 
 
 
71
  flux_generator.t5.max_length = max_sequence_length
72
 
73
+ # ์‹œ๋“œ ์„ค์ •
74
  seed = int(seed)
75
  if seed == -1:
76
  seed = None
 
88
  opts.seed = torch.Generator(device="cpu").seed()
89
 
90
  seed_everything(opts.seed)
 
91
  print(f"Generating prompt: '{opts.prompt}' (seed={opts.seed})...")
92
  t0 = time.perf_counter()
93
 
94
  use_true_cfg = abs(true_cfg - 1.0) > 1e-6
95
 
96
+ # 1) ์ž…๋ ฅ ๋…ธ์ด์ฆˆ ์ค€๋น„
 
97
  noise = get_noise(
98
  num_samples=1,
99
  height=opts.height,
 
106
  noise = rearrange(noise, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
107
  if noise.shape[0] == 1 and bs > 1:
108
  noise = repeat(noise, "1 ... -> bs ...", bs=bs)
 
 
109
 
110
+ # ID ์ด๋ฏธ์ง€ ์ธ์ฝ”๋”ฉ
111
+ encode_t0 = time.perf_counter()
112
  id_image = id_image.resize((opts.width, opts.height), resample=Image.LANCZOS)
113
+ x = torch.from_numpy(np.array(id_image).astype(np.float32))
114
+ x = (x / 127.5) - 1.0
115
+ x = rearrange(x, "h w c -> 1 c h w")
 
 
 
116
  x = x.to(flux_generator.device)
117
+
118
  with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
119
  x = flux_generator.ae.encode(x)
 
120
  x = x.to(torch.bfloat16)
121
 
 
 
 
 
 
122
  encode_t1 = time.perf_counter()
123
  print(f"Encoded in {encode_t1 - encode_t0:.2f} seconds.")
124
 
125
  timesteps = get_schedule(opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=False)
126
 
127
+ # 2) ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ์ค€๋น„
 
 
 
 
128
  inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=opts.prompt)
129
  inp_inversion = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt="")
130
  inp_neg = None
131
  if use_true_cfg:
132
  inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=neg_prompt)
133
 
134
+ # 3) ID ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
 
 
 
 
 
 
 
135
  id_embeddings = None
136
  uncond_id_embeddings = None
137
  if id_image is not None:
138
  id_image = np.array(id_image)
139
  id_image = resize_numpy_image_long(id_image, 1024)
140
  id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
 
 
 
141
 
142
  y_0 = inp["img"].clone().detach()
143
 
144
+ # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ๊ณผ์ •
145
  inverted = None
146
  if perform_inversion:
147
  inverted = rf_inversion(
 
162
  y_1=noise,
163
  gamma=gamma
164
  )
 
165
  img = inverted
166
  else:
167
  img = noise
168
+
169
  inp["img"] = img
170
  inp_inversion["img"] = img
171
 
 
215
  tau=tau,
216
  )
217
 
218
+ # ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ๋””์ฝ”๋”ฉ
 
 
 
 
 
 
219
  if edited is not None:
220
  edited = unpack(edited.float(), opts.height, opts.width)
221
  with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
 
231
  with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
232
  recon = flux_generator.ae.decode(recon)
233
 
 
 
 
 
234
  t1 = time.perf_counter()
235
  print(f"Done in {t1 - t0:.2f} seconds.")
236
 
237
+ # PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
238
  if edited is not None:
239
  edited = edited.clamp(-1, 1)
240
  edited = rearrange(edited[0], "c h w -> h w c")
 
253
  return edited, str(opts.seed), flux_generator.pulid_model.debug_img_list
254
 
255
 
256
+ def create_demo(args):
257
+ with gr.Blocks(theme="apriel") as demo:
 
 
 
 
 
258
  with gr.Row():
259
  with gr.Column():
260
+ prompt = gr.Textbox(label="ํ”„๋กฌํ”„ํŠธ", value="portrait, color, cinematic")
261
+ id_image = gr.Image(label="ID ์ด๋ฏธ์ง€", type="pil")
262
+ id_weight = gr.Slider(0.0, 1.0, 0.4, step=0.05, label="ID ๊ฐ€์ค‘์น˜")
263
+ num_steps = gr.Slider(1, 24, 16, step=1, label="๋‹จ๊ณ„ ์ˆ˜")
264
+ guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="๊ฐ€์ด๋˜์Šค")
265
+
266
+ with gr.Accordion("๊ณ ๊ธ‰ ์˜ต์…˜", open=False):
267
+ neg_prompt = gr.Textbox(label="๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ", value="")
268
+ true_cfg = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="CFG ์Šค์ผ€์ผ")
269
+ seed = gr.Textbox(-1, label="์‹œ๋“œ (-1: ๋žœ๋ค)")
270
+ gr.Markdown("### ๊ธฐํƒ€ ์˜ต์…˜")
271
+ gamma = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="๊ฐ๋งˆ")
272
+ eta = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="์—ํƒ€")
273
+
274
+ generate_btn = gr.Button("์ด๋ฏธ์ง€ ์ƒ์„ฑ")
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  with gr.Column():
277
+ output_image = gr.Image(label="์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")
278
+ seed_output = gr.Textbox(label="์‚ฌ์šฉ๋œ ์‹œ๋“œ")
279
+ gr.Markdown(_CITE_)
280
+
281
+ # ์˜ˆ์ œ ์ž…๋ ฅ
282
+ with gr.Row():
283
+ gr.Markdown("## ์˜ˆ์ œ")
284
+ example_inps = [
285
+ [
286
+ 'a portrait of a clown',
287
+ 'example_inputs/unsplash/lhon-karwan-11tbHtK5STE-unsplash.jpg',
288
+ 0.5, 3.5, 42
289
+ ],
290
+ [
291
+ 'a portrait of a zombie',
292
+ 'example_inputs/unsplash/baruk-granda-cfLL_jHQ-Iw-unsplash.jpg',
293
+ 0.4, 3.5, 42
294
+ ],
295
+ [
296
+ 'a portrait of an elf',
297
+ 'example_inputs/unsplash/masoud-razeghi--qsrZhXPius-unsplash.jpg',
298
+ 0.5, 3.5, 42
 
 
 
 
 
 
 
 
 
 
 
299
  ]
300
+ ]
301
+ gr.Examples(examples=example_inps, inputs=[prompt, id_image, id_weight, guidance, seed])
302
 
303
  generate_btn.click(
304
  fn=generate_image,
305
+ inputs=[prompt, id_image, 512, 512, num_steps, 0, guidance, seed, id_weight, neg_prompt,
306
+ true_cfg, 1, 128, gamma, eta, 0, 5],
307
+ outputs=[output_image, seed_output],
308
  )
309
 
310
  return demo
 
312
 
313
  if __name__ == "__main__":
314
  import argparse
315
+ import numpy as np
316
 
317
  parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
318
+ parser.add_argument('--version', type=str, default='v0.9.1')
319
+ parser.add_argument("--name", type=str, default="flux-dev")
320
+ parser.add_argument("--device", type=str, default="cuda")
321
+ parser.add_argument("--offload", action="store_true")
322
+ parser.add_argument("--port", type=int, default=8080)
 
 
 
 
 
 
 
 
323
  args = parser.parse_args()
324
 
 
 
 
 
325
  print(f"Using device: {args.device}")
 
326
  print(f"Offload: {args.offload}")
327
 
328
+ demo = create_demo(args)
329
+ demo.launch(ssr_mode=False)