root commited on
Commit
df122aa
·
1 Parent(s): de8dacc
Files changed (1) hide show
  1. app.py +84 -82
app.py CHANGED
@@ -13,12 +13,12 @@ from wan_pipeline import WanPipeline
13
  from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
14
  from PIL import Image
15
  from diffusers.utils import export_to_video
16
-
17
-
18
  from huggingface_hub import login
19
- login(token=os.getenv('HF_TOKEN'))
20
 
 
 
21
 
 
22
  def set_seed(seed):
23
  random.seed(seed)
24
  os.environ['PYTHONHASHSEED'] = str(seed)
@@ -33,20 +33,17 @@ model_paths = {
33
  "wan-t2v": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
34
  }
35
 
36
- # Global variable for current model
37
  current_model = None
38
-
39
- # Folder to save video outputs
40
  OUTPUT_DIR = "generated_videos"
41
  os.makedirs(OUTPUT_DIR, exist_ok=True)
42
 
43
  def load_model(model_name):
44
  global current_model
45
  if current_model is not None:
46
- del current_model # Delete the old model
47
- torch.cuda.empty_cache() # Free GPU memory
48
- gc.collect() # Force garbage collection
49
-
50
  if "wan-t2v" in model_name:
51
  vae = AutoencoderKLWan.from_pretrained(model_paths[model_name], subfolder="vae", torch_dtype=torch.bfloat16)
52
  scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=8.0)
@@ -54,9 +51,8 @@ def load_model(model_name):
54
  current_model.scheduler = scheduler
55
  else:
56
  current_model = StableDiffusion3Pipeline.from_pretrained(model_paths[model_name], torch_dtype=torch.bfloat16).to("cuda")
57
-
58
- return current_model.to('cuda')
59
 
 
60
 
61
  @spaces.GPU(duration=500)
62
  def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps=50, use_cfg_zero_star=True, use_zero_init=True, zero_steps=0, seed=None, compare_mode=False):
@@ -68,52 +64,26 @@ def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps
68
  is_video_model = "wan-t2v" in model_name
69
 
70
  if is_video_model:
71
- if True:
72
- negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
73
- set_seed(seed)
74
- video1_frames = model(
75
- prompt=prompt,
76
- negative_prompt=negative_prompt,
77
- height=480,
78
- width=832,
79
- num_frames=81,
80
- num_inference_steps=num_inference_steps,
81
- guidance_scale=guidance_scale,
82
- use_cfg_zero_star=True,
83
- use_zero_init=True,
84
- zero_steps=0
85
- ).frames[0]
86
- video1_path = os.path.join(OUTPUT_DIR, f"{seed}_CFG-Zero-Star.mp4")
87
- export_to_video(video1_frames, video1_path, fps=16)
88
-
89
- return None, None, video1_path, seed
90
-
91
-
92
- # set_seed(seed)
93
- # video2_frames = model(
94
- # prompt=prompt,
95
- # guidance_scale=guidance_scale,
96
- # num_frames=81,
97
- # use_cfg_zero_star=False,
98
- # use_zero_init=use_zero_init,
99
- # zero_steps=zero_steps
100
- # ).frames[0]
101
- # video2_path = os.path.join(OUTPUT_DIR, f"{seed}_CFG.mp4")
102
- # export_to_video(video2_frames, video2_path, fps=16)
103
-
104
- # return None, None, video1_path, video2_path, seed
105
- # else:
106
- # video_frames = model(
107
- # prompt=prompt,
108
- # guidance_scale=guidance_scale,
109
- # num_frames=81,
110
- # use_cfg_zero_star=use_cfg_zero_star,
111
- # use_zero_init=use_zero_init,
112
- # zero_steps=zero_steps
113
- # ).frames[0]
114
- # video_path = save_video(video_frames, f"{seed}.mp4")
115
- # return None, None, video_path, None, seed
116
- print('prompt: ',prompt)
117
  if compare_mode:
118
  set_seed(seed)
119
  image1 = model(
@@ -134,8 +104,8 @@ def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps
134
  use_zero_init=use_zero_init,
135
  zero_steps=zero_steps
136
  ).images[0]
 
137
  return image1, image2, None, seed
138
- #return image1, image2, None, None, seed
139
  else:
140
  image = model(
141
  prompt,
@@ -145,14 +115,11 @@ def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps
145
  use_zero_init=use_zero_init,
146
  zero_steps=zero_steps
147
  ).images[0]
 
148
  if use_cfg_zero_star:
149
  return image, None, None, seed
150
  else:
151
  return None, image, None, seed
152
- # if use_cfg_zero_star:
153
- # return image, None, None, None, seed
154
- # else:
155
- # return None, image, None, None, seed
156
 
157
  # Gradio UI
158
  with gr.Blocks() as demo:
@@ -166,28 +133,63 @@ with gr.Blocks() as demo:
166
  </div>
167
  """)
168
 
169
- gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  fn=generate_content,
171
  inputs=[
172
- gr.Textbox(value="A spooky haunted mansion on a hill silhouetted by a full moon.", label="Enter your prompt"),
173
- gr.Dropdown(choices=list(model_paths.keys()), label="Choose Model"),
174
- gr.Slider(1, 20, value=4.0, step=0.5, label="Guidance Scale"),
175
- gr.Slider(10, 100, value=28, step=5, label="Inference Steps"),
176
- gr.Checkbox(value=True, label="Use Optimized-Scale"),
177
- gr.Checkbox(value=True, label="Use Zero Init"),
178
- gr.Slider(0, 20, value=0, step=1, label="Zero out steps"),
179
- gr.Number(value=42, label="Seed (Leave blank for random)"),
180
- gr.Checkbox(value=True, label="Compare Mode")
181
  ],
182
- outputs=[
183
- gr.Image(type="pil", label="CFG-Zero* Image"),
184
- gr.Image(type="pil", label="CFG Image"),
185
- gr.Video(label="Video"),
186
- gr.Textbox(label="Used Seed")
187
- ],
188
- #title="CFG-Zero*: Improved Classifier-Free Guidance for Flow Matching Models",
189
- live=False # optional
190
  )
191
 
192
  demo.launch(ssr_mode=False)
193
-
 
13
  from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
14
  from PIL import Image
15
  from diffusers.utils import export_to_video
 
 
16
  from huggingface_hub import login
 
17
 
18
+ # Authenticate with HF
19
+ login(token=os.getenv('HF_TOKEN'))
20
 
21
+ # Set seed
22
  def set_seed(seed):
23
  random.seed(seed)
24
  os.environ['PYTHONHASHSEED'] = str(seed)
 
33
  "wan-t2v": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
34
  }
35
 
 
36
  current_model = None
 
 
37
  OUTPUT_DIR = "generated_videos"
38
  os.makedirs(OUTPUT_DIR, exist_ok=True)
39
 
40
  def load_model(model_name):
41
  global current_model
42
  if current_model is not None:
43
+ del current_model
44
+ torch.cuda.empty_cache()
45
+ gc.collect()
46
+
47
  if "wan-t2v" in model_name:
48
  vae = AutoencoderKLWan.from_pretrained(model_paths[model_name], subfolder="vae", torch_dtype=torch.bfloat16)
49
  scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=8.0)
 
51
  current_model.scheduler = scheduler
52
  else:
53
  current_model = StableDiffusion3Pipeline.from_pretrained(model_paths[model_name], torch_dtype=torch.bfloat16).to("cuda")
 
 
54
 
55
+ return current_model.to("cuda")
56
 
57
  @spaces.GPU(duration=500)
58
  def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps=50, use_cfg_zero_star=True, use_zero_init=True, zero_steps=0, seed=None, compare_mode=False):
 
64
  is_video_model = "wan-t2v" in model_name
65
 
66
  if is_video_model:
67
+ negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
68
+ video1_frames = model(
69
+ prompt=prompt,
70
+ negative_prompt=negative_prompt,
71
+ height=480,
72
+ width=832,
73
+ num_frames=81,
74
+ num_inference_steps=num_inference_steps,
75
+ guidance_scale=guidance_scale,
76
+ use_cfg_zero_star=True,
77
+ use_zero_init=True,
78
+ zero_steps=0
79
+ ).frames[0]
80
+ video1_path = os.path.join(OUTPUT_DIR, f"{seed}_CFG-Zero-Star.mp4")
81
+ export_to_video(video1_frames, video1_path, fps=16)
82
+
83
+ return None, None, video1_path, seed
84
+
85
+ print("prompt:", prompt)
86
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  if compare_mode:
88
  set_seed(seed)
89
  image1 = model(
 
104
  use_zero_init=use_zero_init,
105
  zero_steps=zero_steps
106
  ).images[0]
107
+
108
  return image1, image2, None, seed
 
109
  else:
110
  image = model(
111
  prompt,
 
115
  use_zero_init=use_zero_init,
116
  zero_steps=zero_steps
117
  ).images[0]
118
+
119
  if use_cfg_zero_star:
120
  return image, None, None, seed
121
  else:
122
  return None, image, None, seed
 
 
 
 
123
 
124
  # Gradio UI
125
  with gr.Blocks() as demo:
 
133
  </div>
134
  """)
135
 
136
+ with gr.Row():
137
+ prompt = gr.Textbox(value="A spooky haunted mansion on a hill silhouetted by a full moon.", label="Enter your prompt")
138
+ model_choice = gr.Dropdown(choices=list(model_paths.keys()), label="Choose Model")
139
+
140
+ with gr.Row():
141
+ guidance_scale = gr.Slider(1, 20, value=4.0, step=0.5, label="Guidance Scale")
142
+ inference_steps = gr.Slider(10, 100, value=28, step=5, label="Inference Steps")
143
+
144
+ with gr.Row():
145
+ use_opt_scale = gr.Checkbox(value=True, label="Use Optimized-Scale")
146
+ use_zero_init = gr.Checkbox(value=True, label="Use Zero Init")
147
+ zero_steps = gr.Slider(0, 20, value=0, step=1, label="Zero out steps")
148
+
149
+ with gr.Row():
150
+ seed = gr.Number(value=42, label="Seed (Leave blank for random)")
151
+ compare_mode = gr.Checkbox(value=True, label="Compare Mode")
152
+
153
+ with gr.Row():
154
+ out1 = gr.Image(type="pil", label="CFG-Zero* Image")
155
+ out2 = gr.Image(type="pil", label="CFG Image")
156
+ video = gr.Video(label="Video")
157
+ used_seed = gr.Textbox(label="Used Seed")
158
+
159
+ generate_btn = gr.Button("Generate")
160
+
161
+ # Change logic for when "wan-t2v" is selected
162
+ def update_params(model_name):
163
+ if model_name == "wan-t2v":
164
+ return (
165
+ gr.update(value=5),
166
+ gr.update(value=50),
167
+ gr.update(value=True),
168
+ gr.update(value=True),
169
+ gr.update(value=1)
170
+ )
171
+ else:
172
+ return (
173
+ gr.update(value=4.0),
174
+ gr.update(value=28),
175
+ gr.update(value=True),
176
+ gr.update(value=True),
177
+ gr.update(value=0)
178
+ )
179
+
180
+ model_choice.change(
181
+ fn=update_params,
182
+ inputs=[model_choice],
183
+ outputs=[guidance_scale, inference_steps, use_opt_scale, use_zero_init, zero_steps]
184
+ )
185
+
186
+ generate_btn.click(
187
  fn=generate_content,
188
  inputs=[
189
+ prompt, model_choice, guidance_scale, inference_steps,
190
+ use_opt_scale, use_zero_init, zero_steps, seed, compare_mode
 
 
 
 
 
 
 
191
  ],
192
+ outputs=[out1, out2, video, used_seed]
 
 
 
 
 
 
 
193
  )
194
 
195
  demo.launch(ssr_mode=False)