WeichenFan commited on
Commit
9bdc30c
·
1 Parent(s): a4f2074

Add application file

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -49,7 +49,7 @@ def load_model(model_name):
49
  if "wan-t2v" in model_name:
50
  vae = AutoencoderKLWan.from_pretrained(model_paths[model_name], subfolder="vae", torch_dtype=torch.float32)
51
  scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=8.0)
52
- current_model = WanPipeline.from_pretrained(model_paths[model_name], vae=vae, torch_dtype=torch.bfloat16).to("cuda")
53
  current_model.scheduler = scheduler
54
  else:
55
  current_model = StableDiffusion3Pipeline.from_pretrained(model_paths[model_name], torch_dtype=torch.bfloat16).to("cuda")
 
49
  if "wan-t2v" in model_name:
50
  vae = AutoencoderKLWan.from_pretrained(model_paths[model_name], subfolder="vae", torch_dtype=torch.float32)
51
  scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=8.0)
52
+ current_model = WanPipeline.from_pretrained(model_paths[model_name], vae=vae, torch_dtype=torch.float16).to("cuda")
53
  current_model.scheduler = scheduler
54
  else:
55
  current_model = StableDiffusion3Pipeline.from_pretrained(model_paths[model_name], torch_dtype=torch.bfloat16).to("cuda")