openfree commited on
Commit
66db587
ยท
verified ยท
1 Parent(s): 7ebafeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -296
app.py CHANGED
@@ -1,331 +1,86 @@
1
  import time
2
  import gradio as gr
3
  import torch
4
- from einops import rearrange, repeat
5
  from PIL import Image
6
  import numpy as np
7
- import spaces # Hugging Face Spaces ์ž„ํฌํŠธ ์ถ”๊ฐ€
8
 
9
- from flux.sampling import denoise, get_noise, get_schedule, prepare, rf_denoise, rf_inversion, unpack
10
- from flux.util import (
11
- SamplingOptions,
12
- load_ae,
13
- load_clip,
14
- load_flow_model,
15
- load_t5,
16
- )
17
- from pulid.pipeline_flux import PuLIDPipeline
18
- from pulid.utils import resize_numpy_image_long, seed_everything
19
-
20
- # ๊ฐ„๋‹จํ•œ ์ธ์šฉ ์ •๋ณด ์ถ”๊ฐ€
21
- _CITE_ = """PuLID: Person-under-Language Image Diffusion Model"""
22
-
23
- # GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์—ฌ๋ถ€ ํ™•์ธ ๋ฐ ์žฅ์น˜ ์„ค์ •
24
- def get_device():
25
- if torch.cuda.is_available():
26
- return torch.device('cuda')
27
- else:
28
- print("CUDA GPU๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. CPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.")
29
- return torch.device('cpu')
30
-
31
- def get_models(name: str, device, offload: bool):
32
- print(f"๋ชจ๋ธ์„ {device}์— ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.")
33
- t5 = load_t5(device, max_length=128)
34
- clip_model = load_clip(device)
35
- model = load_flow_model(name, device="cpu" if offload else device)
36
- model.eval()
37
- ae = load_ae(name, device="cpu" if offload else device)
38
- return model, ae, t5, clip_model
39
-
40
-
41
- class FluxGenerator:
42
  def __init__(self):
43
- # GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์—ฌ๋ถ€์— ๋”ฐ๋ผ ์žฅ์น˜ ์„ค์ •
44
- self.device = get_device()
45
- self.offload = False
46
- self.model_name = 'flux-dev'
47
-
48
- # ๋ชจ๋ธ ๋กœ๋“œ ์‹œ๋„
49
- try:
50
- self.model, self.ae, self.t5, self.clip_model = get_models(
51
- self.model_name,
52
- device=self.device,
53
- offload=self.offload,
54
- )
55
- self.pulid_model = PuLIDPipeline(
56
- self.model,
57
- 'cuda' if torch.cuda.is_available() else 'cpu',
58
- weight_dtype=torch.bfloat16 if self.device.type == 'cuda' else torch.float32
59
- )
60
- self.pulid_model.load_pretrain()
61
- self.initialized = True
62
- except Exception as e:
63
- print(f"๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
64
- self.initialized = False
65
-
66
- # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹œ๋„
67
- try:
68
- flux_generator = FluxGenerator()
69
- model_initialized = flux_generator.initialized
70
- except Exception as e:
71
- print(f"FluxGenerator ์ดˆ๊ธฐํ™” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
72
- model_initialized = False
73
-
74
-
75
- # Spaces GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ์ถ”๊ฐ€ (120์ดˆ GPU ์‚ฌ์šฉ)
76
- @spaces.GPU(duration=120)
77
- @torch.inference_mode()
78
- def generate_image(
79
- prompt: str,
80
- id_image,
81
- num_steps: int,
82
- guidance: float,
83
- seed,
84
- id_weight: float,
85
- neg_prompt: str,
86
- true_cfg: float,
87
- gamma: float,
88
- eta: float,
89
- ):
90
- # ๋ชจ๋ธ์ด ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์•˜์œผ๋ฉด ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ๋ฐ˜ํ™˜
91
- if not model_initialized:
92
- return None, "GPU ์˜ค๋ฅ˜: CUDA GPU๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์–ด ๋ชจ๋ธ์„ ์ดˆ๊ธฐํ™”ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
93
 
94
- # ID ์ด๋ฏธ์ง€๊ฐ€ ์—†์œผ๋ฉด ์‹คํ–‰ ๋ถˆ๊ฐ€
95
- if id_image is None:
96
- return None, "์˜ค๋ฅ˜: ID ์ด๋ฏธ์ง€๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค."
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  try:
99
- # ๊ณ ์ • ๋งค๊ฐœ๋ณ€์ˆ˜
100
- width = 512
101
- height = 512
102
- start_step = 0
103
- timestep_to_start_cfg = 1
104
- max_sequence_length = 128
105
- s = 0
106
- tau = 5
107
-
108
- flux_generator.t5.max_length = max_sequence_length
109
-
110
- # ์‹œ๋“œ ์„ค์ •
111
- try:
112
- seed = int(seed)
113
- except:
114
- seed = -1
115
-
116
- if seed == -1:
117
- seed = None
118
-
119
- opts = SamplingOptions(
120
- prompt=prompt,
121
- width=width,
122
- height=height,
123
- num_steps=num_steps,
124
- guidance=guidance,
125
- seed=seed,
126
- )
127
-
128
- if opts.seed is None:
129
- opts.seed = torch.Generator(device="cpu").seed()
130
-
131
- seed_everything(opts.seed)
132
- print(f"Generating prompt: '{opts.prompt}' (seed={opts.seed})...")
133
- t0 = time.perf_counter()
134
-
135
- use_true_cfg = abs(true_cfg - 1.0) > 1e-6
136
-
137
- # 1) ์ž…๋ ฅ ๋…ธ์ด์ฆˆ ์ค€๋น„
138
- noise = get_noise(
139
- num_samples=1,
140
- height=opts.height,
141
- width=opts.width,
142
- device=flux_generator.device,
143
- dtype=torch.bfloat16 if flux_generator.device.type == 'cuda' else torch.float32,
144
- seed=opts.seed,
145
- )
146
- bs, c, h, w = noise.shape
147
- noise = rearrange(noise, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
148
- if noise.shape[0] == 1 and bs > 1:
149
- noise = repeat(noise, "1 ... -> bs ...", bs=bs)
150
-
151
- # ID ์ด๋ฏธ์ง€ ์ธ์ฝ”๋”ฉ
152
- encode_t0 = time.perf_counter()
153
- id_image = id_image.resize((opts.width, opts.height), resample=Image.LANCZOS)
154
- x = torch.from_numpy(np.array(id_image).astype(np.float32))
155
- x = (x / 127.5) - 1.0
156
- x = rearrange(x, "h w c -> 1 c h w")
157
- x = x.to(flux_generator.device)
158
-
159
- dtype = torch.bfloat16 if flux_generator.device.type == 'cuda' else torch.float32
160
- with torch.autocast(device_type=flux_generator.device.type, dtype=dtype):
161
- x = flux_generator.ae.encode(x)
162
- x = x.to(dtype)
163
-
164
- encode_t1 = time.perf_counter()
165
- print(f"Encoded in {encode_t1 - encode_t0:.2f} seconds.")
166
-
167
- timesteps = get_schedule(opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=False)
168
-
169
- # 2) ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ์ค€๋น„
170
- inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=opts.prompt)
171
- inp_inversion = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt="")
172
- inp_neg = None
173
- if use_true_cfg:
174
- inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=neg_prompt)
175
-
176
- # 3) ID ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
177
- id_embeddings = None
178
- uncond_id_embeddings = None
179
- if id_image is not None:
180
- id_image = np.array(id_image)
181
- id_image = resize_numpy_image_long(id_image, 1024)
182
- id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
183
-
184
- y_0 = inp["img"].clone().detach()
185
-
186
- # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ๊ณผ์ •
187
- inverted = rf_inversion(
188
- flux_generator.model,
189
- **inp_inversion,
190
- timesteps=timesteps,
191
- guidance=opts.guidance,
192
- id=id_embeddings,
193
- id_weight=id_weight,
194
- start_step=start_step,
195
- uncond_id=uncond_id_embeddings,
196
- true_cfg=true_cfg,
197
- timestep_to_start_cfg=timestep_to_start_cfg,
198
- neg_txt=inp_neg["txt"] if use_true_cfg else None,
199
- neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
200
- neg_vec=inp_neg["vec"] if use_true_cfg else None,
201
- aggressive_offload=False,
202
- y_1=noise,
203
- gamma=gamma
204
- )
205
-
206
- inp["img"] = inverted
207
- inp_inversion["img"] = inverted
208
-
209
- edited = rf_denoise(
210
- flux_generator.model,
211
- **inp,
212
- timesteps=timesteps,
213
- guidance=opts.guidance,
214
- id=id_embeddings,
215
- id_weight=id_weight,
216
- start_step=start_step,
217
- uncond_id=uncond_id_embeddings,
218
- true_cfg=true_cfg,
219
- timestep_to_start_cfg=timestep_to_start_cfg,
220
- neg_txt=inp_neg["txt"] if use_true_cfg else None,
221
- neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
222
- neg_vec=inp_neg["vec"] if use_true_cfg else None,
223
- aggressive_offload=False,
224
- y_0=y_0,
225
- eta=eta,
226
- s=s,
227
- tau=tau,
228
- )
229
-
230
- # ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ๋””์ฝ”๋”ฉ
231
- edited = unpack(edited.float(), opts.height, opts.width)
232
- with torch.autocast(device_type=flux_generator.device.type, dtype=dtype):
233
- edited = flux_generator.ae.decode(edited)
234
-
235
- t1 = time.perf_counter()
236
- print(f"Done in {t1 - t0:.2f} seconds.")
237
-
238
- # PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
239
- edited = edited.clamp(-1, 1)
240
- edited = rearrange(edited[0], "c h w -> h w c")
241
- edited = Image.fromarray((127.5 * (edited + 1.0)).cpu().byte().numpy())
242
-
243
- return edited, str(opts.seed)
244
-
245
  except Exception as e:
246
  import traceback
247
- error_msg = f"์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
248
  print(error_msg)
249
  return None, error_msg
250
 
251
-
252
  def create_demo():
253
  with gr.Blocks() as demo:
254
- gr.Markdown("# PuLID: ์ธ๋ฌผ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ๋„๊ตฌ")
255
 
256
- if not model_initialized:
257
- gr.Markdown("## โš ๏ธ ์˜ค๋ฅ˜: CUDA GPU๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค")
258
- gr.Markdown("์ด ์‘์šฉ ํ”„๋กœ๊ทธ๋žจ์€ CUDA ์ง€์› GPU๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. CPU์—์„œ๋Š” ์‹คํ–‰ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
259
- return demo
260
-
261
  with gr.Row():
262
  with gr.Column():
263
- prompt = gr.Textbox(label="ํ”„๋กฌํ”„ํŠธ", value="portrait, color, cinematic")
264
- id_image = gr.Image(label="ID ์ด๋ฏธ์ง€", type="pil")
265
- id_weight = gr.Slider(0.0, 1.0, 0.4, step=0.05, label="ID ๊ฐ€์ค‘์น˜")
266
- num_steps = gr.Slider(1, 24, 16, step=1, label="๋‹จ๊ณ„ ์ˆ˜")
267
- guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="๊ฐ€์ด๋˜์Šค")
268
-
269
- with gr.Accordion("๊ณ ๊ธ‰ ์˜ต์…˜", open=False):
270
- neg_prompt = gr.Textbox(label="๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ", value="")
271
- true_cfg = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="CFG ์Šค์ผ€์ผ")
272
- seed = gr.Textbox(value="-1", label="์‹œ๋“œ (-1: ๋žœ๋ค)")
273
- gr.Markdown("### ๊ธฐํƒ€ ์˜ต์…˜")
274
- gamma = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="๊ฐ๋งˆ")
275
- eta = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="์—ํƒ€")
276
-
277
- generate_btn = gr.Button("์ด๋ฏธ์ง€ ์ƒ์„ฑ")
278
 
279
  with gr.Column():
280
- output_image = gr.Image(label="์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")
281
- seed_output = gr.Textbox(label="๊ฒฐ๊ณผ/์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€")
282
- gr.Markdown(_CITE_)
283
-
284
- # ์˜ˆ์ œ ์ถ”๊ฐ€
285
- with gr.Row():
286
- gr.Markdown("## ์˜ˆ์ œ")
287
- example_inps = [
288
- [
289
- 'a portrait of a clown',
290
- 'example_inputs/unsplash/lhon-karwan-11tbHtK5STE-unsplash.jpg',
291
- 16, 3.5, "-1", 0.4, "", 3.5, 0.5, 0.8
292
- ],
293
- [
294
- 'a portrait of a zombie',
295
- 'example_inputs/unsplash/baruk-granda-cfLL_jHQ-Iw-unsplash.jpg',
296
- 16, 3.5, "42", 0.4, "", 3.5, 0.5, 0.8
297
- ]
298
- ]
299
- gr.Examples(
300
- examples=example_inps,
301
- inputs=[prompt, id_image, num_steps, guidance, seed,
302
- id_weight, neg_prompt, true_cfg, gamma, eta]
303
- )
304
 
305
- # Gradio ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
306
  generate_btn.click(
307
  fn=generate_image,
308
- inputs=[
309
- prompt, id_image, num_steps, guidance, seed,
310
- id_weight, neg_prompt, true_cfg, gamma, eta
311
- ],
312
- outputs=[output_image, seed_output],
313
  )
314
 
315
- return demo
 
 
 
 
 
 
 
 
 
316
 
 
317
 
318
  if __name__ == "__main__":
319
  import argparse
320
-
321
- parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
322
- parser.add_argument('--version', type=str, default='v0.9.1')
323
- parser.add_argument("--name", type=str, default="flux-dev")
324
- parser.add_argument("--port", type=int, default=8080)
325
  args = parser.parse_args()
326
 
327
- print("Hugging Face Spaces ํ™˜๊ฒฝ์—์„œ ์‹คํ–‰ ์ค‘์ž…๋‹ˆ๋‹ค. GPU ํ• ๋‹น์„ ์š”์ฒญํ•ฉ๋‹ˆ๋‹ค.")
328
 
 
329
  demo = create_demo()
330
- # ํ ์„ค์ • ์ˆ˜์ •
331
- demo.queue().launch(server_name="0.0.0.0", server_port=args.port)
 
1
  import time
2
  import gradio as gr
3
  import torch
4
+ import spaces
5
  from PIL import Image
6
  import numpy as np
 
7
 
8
+ # ์ถ•์†Œ๋œ ๋ชจ๋ธ ๋กœ๋“œ - ๋ฌดํ•œ ์Šคํƒ€ํŒ… ๋ฌธ์ œ ํ•ด๊ฒฐ์„ ์œ„ํ•œ ๊ฐ„์†Œํ™”
9
+ class SimpleModel:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def __init__(self):
11
+ self.initialized = True
12
+ print("๊ฐ„์†Œํ™”๋œ ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ def process_image(self, image, prompt, strength):
15
+ print(f"์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์ค‘: {prompt}, ๊ฐ•๋„: {strength}")
16
+ # ์›๋ณธ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ - ์‹ค์ œ ๋ชจ๋ธ ์—†์ด ๊ฐ„๋‹จํ•œ ์˜ˆ์‹œ ๊ตฌํ˜„
17
+ img_array = np.array(image).astype(np.float32)
18
+ # ๊ฐ„๋‹จํ•œ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ (์ƒ‰์ƒ ๋ฐ˜์ „)
19
+ modified = 255 - img_array
20
+ return Image.fromarray(modified.astype('uint8'))
21
+
22
+ # ๊ฐ„์†Œํ™”๋œ ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
23
+ model = SimpleModel()
24
+
25
+ # Spaces GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ
26
+ @spaces.GPU(duration=60)
27
+ def generate_image(prompt, image, strength=0.5):
28
+ if image is None:
29
+ return None, "์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”."
30
 
31
  try:
32
+ # ๊ธฐ๋ณธ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
33
+ result = model.process_image(image, prompt, strength)
34
+ return result, f"์ƒ์„ฑ ์™„๋ฃŒ: {prompt}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  except Exception as e:
36
  import traceback
37
+ error_msg = f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
38
  print(error_msg)
39
  return None, error_msg
40
 
 
41
  def create_demo():
42
  with gr.Blocks() as demo:
43
+ gr.Markdown("# ๊ฐ„์†Œํ™”๋œ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜ ๋„๊ตฌ")
44
 
 
 
 
 
 
45
  with gr.Row():
46
  with gr.Column():
47
+ prompt = gr.Textbox(label="๋ณ€ํ™˜ ํ”„๋กฌํ”„ํŠธ", value="artistic portrait")
48
+ image = gr.Image(label="์›๋ณธ ์ด๋ฏธ์ง€", type="pil")
49
+ strength = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="๋ณ€ํ™˜ ๊ฐ•๋„")
50
+ generate_btn = gr.Button("๋ณ€ํ™˜ ์‹œ์ž‘")
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  with gr.Column():
53
+ output_image = gr.Image(label="๋ณ€ํ™˜๋œ ์ด๋ฏธ์ง€")
54
+ output_text = gr.Textbox(label="๊ฒฐ๊ณผ ๋ฉ”์‹œ์ง€")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ # ๋ฒ„ํŠผ ํด๋ฆญ ์ด๋ฒคํŠธ
57
  generate_btn.click(
58
  fn=generate_image,
59
+ inputs=[prompt, image, strength],
60
+ outputs=[output_image, output_text],
 
 
 
61
  )
62
 
63
+ # ์˜ˆ์ œ ์ด๋ฏธ์ง€ (Hugging Face Spaces์˜ ์˜ˆ์ œ ํด๋”์— ์ด๋ฏธ์ง€๊ฐ€ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •)
64
+ example_inputs = [
65
+ ["portrait in the style of van gogh", "examples/face.jpg", 0.7],
66
+ ["cyberpunk character", "examples/face.jpg", 0.9]
67
+ ]
68
+
69
+ gr.Examples(
70
+ examples=example_inputs,
71
+ inputs=[prompt, image, strength]
72
+ )
73
 
74
+ return demo
75
 
76
  if __name__ == "__main__":
77
  import argparse
78
+ parser = argparse.ArgumentParser(description="Simple Image Transformer")
79
+ parser.add_argument("--port", type=int, default=7860)
 
 
 
80
  args = parser.parse_args()
81
 
82
+ print("๊ฐ„์†Œํ™”๋œ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์‹œ์ž‘ ์ค‘...")
83
 
84
+ # ๋ฐ๋ชจ ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ ๋ฐ ์‹คํ–‰
85
  demo = create_demo()
86
+ demo.launch(debug=True) # ๋””๋ฒ„๊ทธ ๋ชจ๋“œ ํ™œ์„ฑํ™”