Spaces:
Running
on
Zero
Running
on
Zero
import torch, types | |
import numpy as np | |
from PIL import Image | |
from einops import repeat | |
from typing import Optional, Union | |
from einops import rearrange | |
import numpy as np | |
from tqdm import tqdm | |
from typing import Optional | |
from typing_extensions import Literal | |
import imageio | |
import os | |
from typing import List | |
import cv2 | |
from utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner | |
from models import ModelManager, load_state_dict | |
from models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d | |
from models.wan_video_text_encoder import ( | |
WanTextEncoder, | |
T5RelativeEmbedding, | |
T5LayerNorm, | |
) | |
from models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample | |
from models.wan_video_image_encoder import WanImageEncoder | |
from models.wan_video_vace import VaceWanModel | |
from models.wan_video_motion_controller import WanMotionControllerModel | |
from schedulers.flow_match import FlowMatchScheduler | |
from prompters import WanPrompter | |
from vram_management import ( | |
enable_vram_management, | |
AutoWrappedModule, | |
AutoWrappedLinear, | |
WanAutoCastLayerNorm, | |
) | |
from lora import GeneralLoRALoader | |
def load_video_as_list(video_path: str) -> List[Image.Image]: | |
if not os.path.isfile(video_path): | |
raise FileNotFoundError(video_path) | |
reader = imageio.get_reader(video_path) | |
frames = [] | |
for i, frame_data in enumerate(reader): | |
pil_image = Image.fromarray(frame_data) | |
frames.append(pil_image) | |
reader.close() | |
return frames | |
class WanVideoPipeline_FaceSwap(BasePipeline): | |
def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None): | |
super().__init__( | |
device=device, | |
torch_dtype=torch_dtype, | |
height_division_factor=16, | |
width_division_factor=16, | |
time_division_factor=4, | |
time_division_remainder=1, | |
) | |
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) | |
self.prompter = WanPrompter(tokenizer_path=tokenizer_path) | |
self.text_encoder: WanTextEncoder = None | |
self.image_encoder: WanImageEncoder = None | |
self.dit: WanModel = None | |
self.dit2: WanModel = None | |
self.vae: WanVideoVAE = None | |
self.motion_controller: WanMotionControllerModel = None | |
self.vace: VaceWanModel = None | |
self.in_iteration_models = ("dit", "motion_controller", "vace") | |
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") | |
self.unit_runner = PipelineUnitRunner() | |
self.units = [ | |
WanVideoUnit_ShapeChecker(), | |
WanVideoUnit_NoiseInitializer(), | |
WanVideoUnit_InputVideoEmbedder(), | |
WanVideoUnit_PromptEmbedder(), | |
WanVideoUnit_ImageEmbedderVAE(), | |
WanVideoUnit_ImageEmbedderCLIP(), | |
WanVideoUnit_ImageEmbedderFused(), | |
WanVideoUnit_FunControl(), | |
WanVideoUnit_FunReference(), | |
WanVideoUnit_FunCameraControl(), | |
WanVideoUnit_SpeedControl(), | |
WanVideoUnit_VACE(), | |
WanVideoUnit_UnifiedSequenceParallel(), | |
WanVideoUnit_TeaCache(), | |
WanVideoUnit_CfgMerger(), | |
] | |
self.model_fn = model_fn_wan_video | |
def encode_ip_image(self, ip_image): | |
self.load_models_to_device(["vae"]) | |
ip_image = ( | |
torch.tensor(np.array(ip_image)).permute(2, 0, 1).float() / 255.0 | |
) # [3, H, W] | |
ip_image = ( | |
ip_image.unsqueeze(1).unsqueeze(0).to(dtype=self.torch_dtype) | |
) # [B, 3, 1, H, W] | |
ip_image = ip_image * 2 - 1 | |
ip_image_latent = self.vae.encode(ip_image, device=self.device, tiled=False) | |
return ip_image_latent | |
def load_lora(self, module, path, alpha=1): | |
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) | |
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) | |
loader.load(module, lora, alpha=alpha) | |
def training_loss(self, **inputs): | |
max_timestep_boundary = int( | |
inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps | |
) | |
min_timestep_boundary = int( | |
inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps | |
) | |
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) | |
timestep = self.scheduler.timesteps[timestep_id].to( | |
dtype=self.torch_dtype, device=self.device | |
) | |
inputs["latents"] = self.scheduler.add_noise( | |
inputs["input_latents"], inputs["noise"], timestep | |
) | |
training_target = self.scheduler.training_target( | |
inputs["input_latents"], inputs["noise"], timestep | |
) | |
noise_pred = self.model_fn(**inputs, timestep=timestep) | |
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) | |
loss = loss * self.scheduler.training_weight(timestep) | |
return loss | |
def enable_vram_management( | |
self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5 | |
): | |
self.vram_management_enabled = True | |
if num_persistent_param_in_dit is not None: | |
vram_limit = None | |
else: | |
if vram_limit is None: | |
vram_limit = self.get_vram() | |
vram_limit = vram_limit - vram_buffer | |
if self.text_encoder is not None: | |
dtype = next(iter(self.text_encoder.parameters())).dtype | |
enable_vram_management( | |
self.text_encoder, | |
module_map={ | |
torch.nn.Linear: AutoWrappedLinear, | |
torch.nn.Embedding: AutoWrappedModule, | |
T5RelativeEmbedding: AutoWrappedModule, | |
T5LayerNorm: AutoWrappedModule, | |
}, | |
module_config=dict( | |
offload_dtype=dtype, | |
offload_device="cpu", | |
onload_dtype=dtype, | |
onload_device="cpu", | |
computation_dtype=self.torch_dtype, | |
computation_device=self.device, | |
), | |
vram_limit=vram_limit, | |
) | |
if self.dit is not None: | |
dtype = next(iter(self.dit.parameters())).dtype | |
device = "cpu" if vram_limit is not None else self.device | |
enable_vram_management( | |
self.dit, | |
module_map={ | |
torch.nn.Linear: AutoWrappedLinear, | |
torch.nn.Conv3d: AutoWrappedModule, | |
torch.nn.LayerNorm: WanAutoCastLayerNorm, | |
RMSNorm: AutoWrappedModule, | |
torch.nn.Conv2d: AutoWrappedModule, | |
}, | |
module_config=dict( | |
offload_dtype=dtype, | |
offload_device="cpu", | |
onload_dtype=dtype, | |
onload_device=device, | |
computation_dtype=self.torch_dtype, | |
computation_device=self.device, | |
), | |
max_num_param=num_persistent_param_in_dit, | |
overflow_module_config=dict( | |
offload_dtype=dtype, | |
offload_device="cpu", | |
onload_dtype=dtype, | |
onload_device="cpu", | |
computation_dtype=self.torch_dtype, | |
computation_device=self.device, | |
), | |
vram_limit=vram_limit, | |
) | |
if self.dit2 is not None: | |
dtype = next(iter(self.dit2.parameters())).dtype | |
device = "cpu" if vram_limit is not None else self.device | |
enable_vram_management( | |
self.dit2, | |
module_map={ | |
torch.nn.Linear: AutoWrappedLinear, | |
torch.nn.Conv3d: AutoWrappedModule, | |
torch.nn.LayerNorm: WanAutoCastLayerNorm, | |
RMSNorm: AutoWrappedModule, | |
torch.nn.Conv2d: AutoWrappedModule, | |
}, | |
module_config=dict( | |
offload_dtype=dtype, | |
offload_device="cpu", | |
onload_dtype=dtype, | |
onload_device=device, | |
computation_dtype=self.torch_dtype, | |
computation_device=self.device, | |
), | |
max_num_param=num_persistent_param_in_dit, | |
overflow_module_config=dict( | |
offload_dtype=dtype, | |
offload_device="cpu", | |
onload_dtype=dtype, | |
onload_device="cpu", | |
computation_dtype=self.torch_dtype, | |
computation_device=self.device, | |
), | |
vram_limit=vram_limit, | |
) | |
if self.vae is not None: | |
dtype = next(iter(self.vae.parameters())).dtype | |
enable_vram_management( | |
self.vae, | |
module_map={ | |
torch.nn.Linear: AutoWrappedLinear, | |
torch.nn.Conv2d: AutoWrappedModule, | |
RMS_norm: AutoWrappedModule, | |
CausalConv3d: AutoWrappedModule, | |
Upsample: AutoWrappedModule, | |
torch.nn.SiLU: AutoWrappedModule, | |
torch.nn.Dropout: AutoWrappedModule, | |
}, | |
module_config=dict( | |
offload_dtype=dtype, | |
offload_device="cpu", | |
onload_dtype=dtype, | |
onload_device=self.device, | |
computation_dtype=self.torch_dtype, | |
computation_device=self.device, | |
), | |
) | |
if self.image_encoder is not None: | |
dtype = next(iter(self.image_encoder.parameters())).dtype | |
enable_vram_management( | |
self.image_encoder, | |
module_map={ | |
torch.nn.Linear: AutoWrappedLinear, | |
torch.nn.Conv2d: AutoWrappedModule, | |
torch.nn.LayerNorm: AutoWrappedModule, | |
}, | |
module_config=dict( | |
offload_dtype=dtype, | |
offload_device="cpu", | |
onload_dtype=dtype, | |
onload_device="cpu", | |
computation_dtype=dtype, | |
computation_device=self.device, | |
), | |
) | |
if self.motion_controller is not None: | |
dtype = next(iter(self.motion_controller.parameters())).dtype | |
enable_vram_management( | |
self.motion_controller, | |
module_map={ | |
torch.nn.Linear: AutoWrappedLinear, | |
}, | |
module_config=dict( | |
offload_dtype=dtype, | |
offload_device="cpu", | |
onload_dtype=dtype, | |
onload_device="cpu", | |
computation_dtype=dtype, | |
computation_device=self.device, | |
), | |
) | |
if self.vace is not None: | |
device = "cpu" if vram_limit is not None else self.device | |
enable_vram_management( | |
self.vace, | |
module_map={ | |
torch.nn.Linear: AutoWrappedLinear, | |
torch.nn.Conv3d: AutoWrappedModule, | |
torch.nn.LayerNorm: AutoWrappedModule, | |
RMSNorm: AutoWrappedModule, | |
}, | |
module_config=dict( | |
offload_dtype=dtype, | |
offload_device="cpu", | |
onload_dtype=dtype, | |
onload_device=device, | |
computation_dtype=self.torch_dtype, | |
computation_device=self.device, | |
), | |
vram_limit=vram_limit, | |
) | |
def initialize_usp(self): | |
import torch.distributed as dist | |
from xfuser.core.distributed import ( | |
initialize_model_parallel, | |
init_distributed_environment, | |
) | |
dist.init_process_group(backend="nccl", init_method="env://") | |
init_distributed_environment( | |
rank=dist.get_rank(), world_size=dist.get_world_size() | |
) | |
initialize_model_parallel( | |
sequence_parallel_degree=dist.get_world_size(), | |
ring_degree=1, | |
ulysses_degree=dist.get_world_size(), | |
) | |
torch.cuda.set_device(dist.get_rank()) | |
def enable_usp(self): | |
from xfuser.core.distributed import get_sequence_parallel_world_size | |
from distributed.xdit_context_parallel import ( | |
usp_attn_forward, | |
usp_dit_forward, | |
) | |
for block in self.dit.blocks: | |
block.self_attn.forward = types.MethodType( | |
usp_attn_forward, block.self_attn | |
) | |
self.dit.forward = types.MethodType(usp_dit_forward, self.dit) | |
if self.dit2 is not None: | |
for block in self.dit2.blocks: | |
block.self_attn.forward = types.MethodType( | |
usp_attn_forward, block.self_attn | |
) | |
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) | |
self.sp_size = get_sequence_parallel_world_size() | |
self.use_unified_sequence_parallel = True | |
def from_pretrained( | |
torch_dtype: torch.dtype = torch.bfloat16, | |
device: Union[str, torch.device] = "cuda", | |
model_configs: list[ModelConfig] = [], | |
tokenizer_config: ModelConfig = ModelConfig( | |
model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*" | |
), | |
redirect_common_files: bool = True, | |
use_usp=False, | |
): | |
# Redirect model path | |
if redirect_common_files: | |
redirect_dict = { | |
"models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B", | |
"Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B", | |
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P", | |
} | |
for model_config in model_configs: | |
if ( | |
model_config.origin_file_pattern is None | |
or model_config.model_id is None | |
): | |
continue | |
if ( | |
model_config.origin_file_pattern in redirect_dict | |
and model_config.model_id | |
!= redirect_dict[model_config.origin_file_pattern] | |
): | |
print( | |
f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection." | |
) | |
model_config.model_id = redirect_dict[ | |
model_config.origin_file_pattern | |
] | |
# Initialize pipeline | |
pipe = WanVideoPipeline_FaceSwap(device=device, torch_dtype=torch_dtype) | |
if use_usp: | |
pipe.initialize_usp() | |
# Download and load models | |
model_manager = ModelManager() | |
for model_config in model_configs: | |
model_config.download_if_necessary(use_usp=use_usp) | |
model_manager.load_model( | |
model_config.path, | |
device=model_config.offload_device or device, | |
torch_dtype=model_config.offload_dtype or torch_dtype, | |
) | |
# Load models | |
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") | |
dit = model_manager.fetch_model("wan_video_dit", index=2) | |
if isinstance(dit, list): | |
pipe.dit, pipe.dit2 = dit | |
else: | |
pipe.dit = dit | |
pipe.vae = model_manager.fetch_model("wan_video_vae") | |
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") | |
pipe.motion_controller = model_manager.fetch_model( | |
"wan_video_motion_controller" | |
) | |
pipe.vace = model_manager.fetch_model("wan_video_vace") | |
# Size division factor | |
if pipe.vae is not None: | |
pipe.height_division_factor = pipe.vae.upsampling_factor * 2 | |
pipe.width_division_factor = pipe.vae.upsampling_factor * 2 | |
# Initialize tokenizer | |
tokenizer_config.download_if_necessary(use_usp=use_usp) | |
pipe.prompter.fetch_models(pipe.text_encoder) | |
pipe.prompter.fetch_tokenizer(tokenizer_config.path) | |
# Unified Sequence Parallel | |
if use_usp: | |
pipe.enable_usp() | |
return pipe | |
def __call__( | |
self, | |
# Prompt | |
prompt: str, | |
negative_prompt: Optional[str] = "", | |
# Image-to-video | |
input_image: Optional[Image.Image] = None, | |
# First-last-frame-to-video | |
end_image: Optional[Image.Image] = None, | |
# Video-to-video | |
input_video: Optional[list[Image.Image]] = None, | |
denoising_strength: Optional[float] = 1, | |
# ControlNet | |
control_video: Optional[list[Image.Image]] = None, | |
reference_image: Optional[Image.Image] = None, | |
# Camera control | |
camera_control_direction: Optional[ | |
Literal[ | |
"Left", | |
"Right", | |
"Up", | |
"Down", | |
"LeftUp", | |
"LeftDown", | |
"RightUp", | |
"RightDown", | |
] | |
] = None, | |
camera_control_speed: Optional[float] = 1 / 54, | |
camera_control_origin: Optional[tuple] = ( | |
0, | |
0.532139961, | |
0.946026558, | |
0.5, | |
0.5, | |
0, | |
0, | |
1, | |
0, | |
0, | |
0, | |
0, | |
1, | |
0, | |
0, | |
0, | |
0, | |
1, | |
0, | |
), | |
# VACE | |
vace_video: Optional[list[Image.Image]] = None, | |
vace_video_mask: Optional[Image.Image] = None, | |
vace_reference_image: Optional[Image.Image] = None, | |
vace_scale: Optional[float] = 1.0, | |
# Randomness | |
seed: Optional[int] = None, | |
rand_device: Optional[str] = "cpu", | |
# Shape | |
height: Optional[int] = 480, | |
width: Optional[int] = 832, | |
num_frames=81, | |
# Classifier-free guidance | |
cfg_scale: Optional[float] = 5.0, | |
cfg_merge: Optional[bool] = False, | |
# Boundary | |
switch_DiT_boundary: Optional[float] = 0.875, | |
# Scheduler | |
num_inference_steps: Optional[int] = 50, | |
sigma_shift: Optional[float] = 5.0, | |
# Speed control | |
motion_bucket_id: Optional[int] = None, | |
# VAE tiling | |
tiled: Optional[bool] = True, | |
tile_size: Optional[tuple[int, int]] = (30, 52), | |
tile_stride: Optional[tuple[int, int]] = (15, 26), | |
# Sliding window | |
sliding_window_size: Optional[int] = None, | |
sliding_window_stride: Optional[int] = None, | |
# Teacache | |
tea_cache_l1_thresh: Optional[float] = None, | |
tea_cache_model_id: Optional[str] = "", | |
# progress_bar | |
progress_bar_cmd=tqdm, | |
# Stand-In | |
face_mask=None, | |
ip_image=None, | |
force_background_consistency=False | |
): | |
if ip_image is not None: | |
ip_image = self.encode_ip_image(ip_image) | |
# Scheduler | |
self.scheduler.set_timesteps( | |
num_inference_steps, | |
denoising_strength=denoising_strength, | |
shift=sigma_shift, | |
) | |
# Inputs | |
inputs_posi = { | |
"prompt": prompt, | |
"tea_cache_l1_thresh": tea_cache_l1_thresh, | |
"tea_cache_model_id": tea_cache_model_id, | |
"num_inference_steps": num_inference_steps, | |
} | |
inputs_nega = { | |
"negative_prompt": negative_prompt, | |
"tea_cache_l1_thresh": tea_cache_l1_thresh, | |
"tea_cache_model_id": tea_cache_model_id, | |
"num_inference_steps": num_inference_steps, | |
} | |
inputs_shared = { | |
"input_image": input_image, | |
"end_image": end_image, | |
"input_video": input_video, | |
"denoising_strength": denoising_strength, | |
"control_video": control_video, | |
"reference_image": reference_image, | |
"camera_control_direction": camera_control_direction, | |
"camera_control_speed": camera_control_speed, | |
"camera_control_origin": camera_control_origin, | |
"vace_video": vace_video, | |
"vace_video_mask": vace_video_mask, | |
"vace_reference_image": vace_reference_image, | |
"vace_scale": vace_scale, | |
"seed": seed, | |
"rand_device": rand_device, | |
"height": height, | |
"width": width, | |
"num_frames": num_frames, | |
"cfg_scale": cfg_scale, | |
"cfg_merge": cfg_merge, | |
"sigma_shift": sigma_shift, | |
"motion_bucket_id": motion_bucket_id, | |
"tiled": tiled, | |
"tile_size": tile_size, | |
"tile_stride": tile_stride, | |
"sliding_window_size": sliding_window_size, | |
"sliding_window_stride": sliding_window_stride, | |
"ip_image": ip_image, | |
} | |
for unit in self.units: | |
inputs_shared, inputs_posi, inputs_nega = self.unit_runner( | |
unit, self, inputs_shared, inputs_posi, inputs_nega | |
) | |
if face_mask is not None: | |
mask_processed = self.preprocess_video(face_mask) | |
mask_processed = mask_processed[:, 0:1, ...] | |
latent_mask = torch.nn.functional.interpolate( | |
mask_processed, | |
size=inputs_shared["latents"].shape[2:], | |
mode="nearest-exact", | |
) | |
# Denoise | |
self.load_models_to_device(self.in_iteration_models) | |
models = {name: getattr(self, name) for name in self.in_iteration_models} | |
for progress_id, timestep in enumerate( | |
progress_bar_cmd(self.scheduler.timesteps) | |
): | |
# Switch DiT if necessary | |
if ( | |
timestep.item() | |
< switch_DiT_boundary * self.scheduler.num_train_timesteps | |
and self.dit2 is not None | |
and not models["dit"] is self.dit2 | |
): | |
self.load_models_to_device(self.in_iteration_models_2) | |
models["dit"] = self.dit2 | |
# Timestep | |
timestep = timestep.unsqueeze(0).to( | |
dtype=self.torch_dtype, device=self.device | |
) | |
# Inference | |
noise_pred_posi = self.model_fn( | |
**models, **inputs_shared, **inputs_posi, timestep=timestep | |
) | |
inputs_shared["ip_image"] = None | |
if cfg_scale != 1.0: | |
if cfg_merge: | |
noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) | |
else: | |
noise_pred_nega = self.model_fn( | |
**models, **inputs_shared, **inputs_nega, timestep=timestep | |
) | |
noise_pred = noise_pred_nega + cfg_scale * ( | |
noise_pred_posi - noise_pred_nega | |
) | |
else: | |
noise_pred = noise_pred_posi | |
# Scheduler | |
inputs_shared["latents"] = self.scheduler.step( | |
noise_pred, | |
self.scheduler.timesteps[progress_id], | |
inputs_shared["latents"], | |
) | |
if force_background_consistency: | |
if ( | |
inputs_shared["input_latents"] is not None | |
and latent_mask is not None | |
): | |
if progress_id == len(self.scheduler.timesteps) - 1: | |
noised_original_latents = inputs_shared["input_latents"] | |
else: | |
next_timestep = self.scheduler.timesteps[progress_id + 1] | |
noised_original_latents = self.scheduler.add_noise( | |
inputs_shared["input_latents"], | |
inputs_shared["noise"], | |
timestep=next_timestep, | |
) | |
hard_mask = (latent_mask > 0.5).to( | |
dtype=inputs_shared["latents"].dtype | |
) | |
inputs_shared["latents"] = ( | |
1 - hard_mask | |
) * noised_original_latents + hard_mask * inputs_shared["latents"] | |
if "first_frame_latents" in inputs_shared: | |
inputs_shared["latents"][:, :, 0:1] = inputs_shared[ | |
"first_frame_latents" | |
] | |
if vace_reference_image is not None: | |
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:] | |
# Decode | |
self.load_models_to_device(["vae"]) | |
video = self.vae.decode( | |
inputs_shared["latents"], | |
device=self.device, | |
tiled=tiled, | |
tile_size=tile_size, | |
tile_stride=tile_stride, | |
) | |
video = self.vae_output_to_video(video) | |
self.load_models_to_device([]) | |
return video | |
class WanVideoUnit_ShapeChecker(PipelineUnit): | |
def __init__(self): | |
super().__init__(input_params=("height", "width", "num_frames")) | |
def process(self, pipe: WanVideoPipeline_FaceSwap, height, width, num_frames): | |
height, width, num_frames = pipe.check_resize_height_width( | |
height, width, num_frames | |
) | |
return {"height": height, "width": width, "num_frames": num_frames} | |
class WanVideoUnit_NoiseInitializer(PipelineUnit): | |
def __init__(self): | |
super().__init__( | |
input_params=( | |
"height", | |
"width", | |
"num_frames", | |
"seed", | |
"rand_device", | |
"vace_reference_image", | |
) | |
) | |
def process( | |
self, | |
pipe: WanVideoPipeline_FaceSwap, | |
height, | |
width, | |
num_frames, | |
seed, | |
rand_device, | |
vace_reference_image, | |
): | |
length = (num_frames - 1) // 4 + 1 | |
if vace_reference_image is not None: | |
length += 1 | |
shape = ( | |
1, | |
pipe.vae.model.z_dim, | |
length, | |
height // pipe.vae.upsampling_factor, | |
width // pipe.vae.upsampling_factor, | |
) | |
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) | |
if vace_reference_image is not None: | |
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2) | |
return {"noise": noise} | |
class WanVideoUnit_InputVideoEmbedder(PipelineUnit): | |
def __init__(self): | |
super().__init__( | |
input_params=( | |
"input_video", | |
"noise", | |
"tiled", | |
"tile_size", | |
"tile_stride", | |
"vace_reference_image", | |
), | |
onload_model_names=("vae",), | |
) | |
def process( | |
self, | |
pipe: WanVideoPipeline_FaceSwap, | |
input_video, | |
noise, | |
tiled, | |
tile_size, | |
tile_stride, | |
vace_reference_image, | |
): | |
if input_video is None: | |
return {"latents": noise} | |
pipe.load_models_to_device(["vae"]) | |
input_video = pipe.preprocess_video(input_video) | |
input_latents = pipe.vae.encode( | |
input_video, | |
device=pipe.device, | |
tiled=tiled, | |
tile_size=tile_size, | |
tile_stride=tile_stride, | |
).to(dtype=pipe.torch_dtype, device=pipe.device) | |
if vace_reference_image is not None: | |
vace_reference_image = pipe.preprocess_video([vace_reference_image]) | |
vace_reference_latents = pipe.vae.encode( | |
vace_reference_image, device=pipe.device | |
).to(dtype=pipe.torch_dtype, device=pipe.device) | |
input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) | |
if pipe.scheduler.training: | |
return {"latents": noise, "input_latents": input_latents} | |
else: | |
latents = pipe.scheduler.add_noise( | |
input_latents, noise, timestep=pipe.scheduler.timesteps[0] | |
) | |
return {"latents": latents, "input_latents": input_latents} | |
class WanVideoUnit_PromptEmbedder(PipelineUnit): | |
def __init__(self): | |
super().__init__( | |
seperate_cfg=True, | |
input_params_posi={"prompt": "prompt", "positive": "positive"}, | |
input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, | |
onload_model_names=("text_encoder",), | |
) | |
def process(self, pipe: WanVideoPipeline_FaceSwap, prompt, positive) -> dict: | |
pipe.load_models_to_device(self.onload_model_names) | |
prompt_emb = pipe.prompter.encode_prompt( | |
prompt, positive=positive, device=pipe.device | |
) | |
return {"context": prompt_emb} | |
class WanVideoUnit_ImageEmbedder(PipelineUnit): | |
""" | |
Deprecated | |
""" | |
def __init__(self): | |
super().__init__( | |
input_params=( | |
"input_image", | |
"end_image", | |
"num_frames", | |
"height", | |
"width", | |
"tiled", | |
"tile_size", | |
"tile_stride", | |
), | |
onload_model_names=("image_encoder", "vae"), | |
) | |
def process( | |
self, | |
pipe: WanVideoPipeline_FaceSwap, | |
input_image, | |
end_image, | |
num_frames, | |
height, | |
width, | |
tiled, | |
tile_size, | |
tile_stride, | |
): | |
if input_image is None or pipe.image_encoder is None: | |
return {} | |
pipe.load_models_to_device(self.onload_model_names) | |
image = pipe.preprocess_image(input_image.resize((width, height))).to( | |
pipe.device | |
) | |
clip_context = pipe.image_encoder.encode_image([image]) | |
msk = torch.ones(1, num_frames, height // 8, width // 8, device=pipe.device) | |
msk[:, 1:] = 0 | |
if end_image is not None: | |
end_image = pipe.preprocess_image(end_image.resize((width, height))).to( | |
pipe.device | |
) | |
vae_input = torch.concat( | |
[ | |
image.transpose(0, 1), | |
torch.zeros(3, num_frames - 2, height, width).to(image.device), | |
end_image.transpose(0, 1), | |
], | |
dim=1, | |
) | |
if pipe.dit.has_image_pos_emb: | |
clip_context = torch.concat( | |
[clip_context, pipe.image_encoder.encode_image([end_image])], dim=1 | |
) | |
msk[:, -1:] = 1 | |
else: | |
vae_input = torch.concat( | |
[ | |
image.transpose(0, 1), | |
torch.zeros(3, num_frames - 1, height, width).to(image.device), | |
], | |
dim=1, | |
) | |
msk = torch.concat( | |
[torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1 | |
) | |
msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8) | |
msk = msk.transpose(1, 2)[0] | |
y = pipe.vae.encode( | |
[vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], | |
device=pipe.device, | |
tiled=tiled, | |
tile_size=tile_size, | |
tile_stride=tile_stride, | |
)[0] | |
y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
y = torch.concat([msk, y]) | |
y = y.unsqueeze(0) | |
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) | |
y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
return {"clip_feature": clip_context, "y": y} | |
class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): | |
def __init__(self): | |
super().__init__( | |
input_params=("input_image", "end_image", "height", "width"), | |
onload_model_names=("image_encoder",), | |
) | |
def process( | |
self, pipe: WanVideoPipeline_FaceSwap, input_image, end_image, height, width | |
): | |
if ( | |
input_image is None | |
or pipe.image_encoder is None | |
or not pipe.dit.require_clip_embedding | |
): | |
return {} | |
pipe.load_models_to_device(self.onload_model_names) | |
image = pipe.preprocess_image(input_image.resize((width, height))).to( | |
pipe.device | |
) | |
clip_context = pipe.image_encoder.encode_image([image]) | |
if end_image is not None: | |
end_image = pipe.preprocess_image(end_image.resize((width, height))).to( | |
pipe.device | |
) | |
if pipe.dit.has_image_pos_emb: | |
clip_context = torch.concat( | |
[clip_context, pipe.image_encoder.encode_image([end_image])], dim=1 | |
) | |
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) | |
return {"clip_feature": clip_context} | |
class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): | |
def __init__(self): | |
super().__init__( | |
input_params=( | |
"input_image", | |
"end_image", | |
"num_frames", | |
"height", | |
"width", | |
"tiled", | |
"tile_size", | |
"tile_stride", | |
), | |
onload_model_names=("vae",), | |
) | |
def process( | |
self, | |
pipe: WanVideoPipeline_FaceSwap, | |
input_image, | |
end_image, | |
num_frames, | |
height, | |
width, | |
tiled, | |
tile_size, | |
tile_stride, | |
): | |
if input_image is None or not pipe.dit.require_vae_embedding: | |
return {} | |
pipe.load_models_to_device(self.onload_model_names) | |
image = pipe.preprocess_image(input_image.resize((width, height))).to( | |
pipe.device | |
) | |
msk = torch.ones(1, num_frames, height // 8, width // 8, device=pipe.device) | |
msk[:, 1:] = 0 | |
if end_image is not None: | |
end_image = pipe.preprocess_image(end_image.resize((width, height))).to( | |
pipe.device | |
) | |
vae_input = torch.concat( | |
[ | |
image.transpose(0, 1), | |
torch.zeros(3, num_frames - 2, height, width).to(image.device), | |
end_image.transpose(0, 1), | |
], | |
dim=1, | |
) | |
msk[:, -1:] = 1 | |
else: | |
vae_input = torch.concat( | |
[ | |
image.transpose(0, 1), | |
torch.zeros(3, num_frames - 1, height, width).to(image.device), | |
], | |
dim=1, | |
) | |
msk = torch.concat( | |
[torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1 | |
) | |
msk = msk.view(1, msk.shape[1] // 4, 4, height // 8, width // 8) | |
msk = msk.transpose(1, 2)[0] | |
y = pipe.vae.encode( | |
[vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], | |
device=pipe.device, | |
tiled=tiled, | |
tile_size=tile_size, | |
tile_stride=tile_stride, | |
)[0] | |
y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
y = torch.concat([msk, y]) | |
y = y.unsqueeze(0) | |
y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
return {"y": y} | |
class WanVideoUnit_ImageEmbedderFused(PipelineUnit): | |
""" | |
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. | |
""" | |
def __init__(self): | |
super().__init__( | |
input_params=( | |
"input_image", | |
"latents", | |
"height", | |
"width", | |
"tiled", | |
"tile_size", | |
"tile_stride", | |
), | |
onload_model_names=("vae",), | |
) | |
def process( | |
self, | |
pipe: WanVideoPipeline_FaceSwap, | |
input_image, | |
latents, | |
height, | |
width, | |
tiled, | |
tile_size, | |
tile_stride, | |
): | |
if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: | |
return {} | |
pipe.load_models_to_device(self.onload_model_names) | |
image = pipe.preprocess_image(input_image.resize((width, height))).transpose( | |
0, 1 | |
) | |
z = pipe.vae.encode( | |
[image], | |
device=pipe.device, | |
tiled=tiled, | |
tile_size=tile_size, | |
tile_stride=tile_stride, | |
) | |
latents[:, :, 0:1] = z | |
return { | |
"latents": latents, | |
"fuse_vae_embedding_in_latents": True, | |
"first_frame_latents": z, | |
} | |
class WanVideoUnit_FunControl(PipelineUnit): | |
def __init__(self): | |
super().__init__( | |
input_params=( | |
"control_video", | |
"num_frames", | |
"height", | |
"width", | |
"tiled", | |
"tile_size", | |
"tile_stride", | |
"clip_feature", | |
"y", | |
), | |
onload_model_names=("vae",), | |
) | |
def process( | |
self, | |
pipe: WanVideoPipeline_FaceSwap, | |
control_video, | |
num_frames, | |
height, | |
width, | |
tiled, | |
tile_size, | |
tile_stride, | |
clip_feature, | |
y, | |
): | |
if control_video is None: | |
return {} | |
pipe.load_models_to_device(self.onload_model_names) | |
control_video = pipe.preprocess_video(control_video) | |
control_latents = pipe.vae.encode( | |
control_video, | |
device=pipe.device, | |
tiled=tiled, | |
tile_size=tile_size, | |
tile_stride=tile_stride, | |
).to(dtype=pipe.torch_dtype, device=pipe.device) | |
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) | |
if clip_feature is None or y is None: | |
clip_feature = torch.zeros( | |
(1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device | |
) | |
y = torch.zeros( | |
(1, 16, (num_frames - 1) // 4 + 1, height // 8, width // 8), | |
dtype=pipe.torch_dtype, | |
device=pipe.device, | |
) | |
else: | |
y = y[:, -16:] | |
y = torch.concat([control_latents, y], dim=1) | |
return {"clip_feature": clip_feature, "y": y} | |
class WanVideoUnit_FunReference(PipelineUnit): | |
def __init__(self): | |
super().__init__( | |
input_params=("reference_image", "height", "width", "reference_image"), | |
onload_model_names=("vae",), | |
) | |
def process(self, pipe: WanVideoPipeline_FaceSwap, reference_image, height, width): | |
if reference_image is None: | |
return {} | |
pipe.load_models_to_device(["vae"]) | |
reference_image = reference_image.resize((width, height)) | |
reference_latents = pipe.preprocess_video([reference_image]) | |
reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) | |
clip_feature = pipe.preprocess_image(reference_image) | |
clip_feature = pipe.image_encoder.encode_image([clip_feature]) | |
return {"reference_latents": reference_latents, "clip_feature": clip_feature} | |
class WanVideoUnit_FunCameraControl(PipelineUnit): | |
def __init__(self): | |
super().__init__( | |
input_params=( | |
"height", | |
"width", | |
"num_frames", | |
"camera_control_direction", | |
"camera_control_speed", | |
"camera_control_origin", | |
"latents", | |
"input_image", | |
), | |
onload_model_names=("vae",), | |
) | |
def process( | |
self, | |
pipe: WanVideoPipeline_FaceSwap, | |
height, | |
width, | |
num_frames, | |
camera_control_direction, | |
camera_control_speed, | |
camera_control_origin, | |
latents, | |
input_image, | |
): | |
if camera_control_direction is None: | |
return {} | |
camera_control_plucker_embedding = ( | |
pipe.dit.control_adapter.process_camera_coordinates( | |
camera_control_direction, | |
num_frames, | |
height, | |
width, | |
camera_control_speed, | |
camera_control_origin, | |
) | |
) | |
control_camera_video = ( | |
camera_control_plucker_embedding[:num_frames] | |
.permute([3, 0, 1, 2]) | |
.unsqueeze(0) | |
) | |
control_camera_latents = torch.concat( | |
[ | |
torch.repeat_interleave( | |
control_camera_video[:, :, 0:1], repeats=4, dim=2 | |
), | |
control_camera_video[:, :, 1:], | |
], | |
dim=2, | |
).transpose(1, 2) | |
b, f, c, h, w = control_camera_latents.shape | |
control_camera_latents = ( | |
control_camera_latents.contiguous() | |
.view(b, f // 4, 4, c, h, w) | |
.transpose(2, 3) | |
) | |
control_camera_latents = ( | |
control_camera_latents.contiguous() | |
.view(b, f // 4, c * 4, h, w) | |
.transpose(1, 2) | |
) | |
control_camera_latents_input = control_camera_latents.to( | |
device=pipe.device, dtype=pipe.torch_dtype | |
) | |
input_image = input_image.resize((width, height)) | |
input_latents = pipe.preprocess_video([input_image]) | |
pipe.load_models_to_device(self.onload_model_names) | |
input_latents = pipe.vae.encode(input_latents, device=pipe.device) | |
y = torch.zeros_like(latents).to(pipe.device) | |
y[:, :, :1] = input_latents | |
y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
return {"control_camera_latents_input": control_camera_latents_input, "y": y} | |
class WanVideoUnit_SpeedControl(PipelineUnit): | |
def __init__(self): | |
super().__init__(input_params=("motion_bucket_id",)) | |
def process(self, pipe: WanVideoPipeline_FaceSwap, motion_bucket_id): | |
if motion_bucket_id is None: | |
return {} | |
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to( | |
dtype=pipe.torch_dtype, device=pipe.device | |
) | |
return {"motion_bucket_id": motion_bucket_id} | |
class WanVideoUnit_VACE(PipelineUnit): | |
def __init__(self): | |
super().__init__( | |
input_params=( | |
"vace_video", | |
"vace_video_mask", | |
"vace_reference_image", | |
"vace_scale", | |
"height", | |
"width", | |
"num_frames", | |
"tiled", | |
"tile_size", | |
"tile_stride", | |
), | |
onload_model_names=("vae",), | |
) | |
def process( | |
self, | |
pipe: WanVideoPipeline_FaceSwap, | |
vace_video, | |
vace_video_mask, | |
vace_reference_image, | |
vace_scale, | |
height, | |
width, | |
num_frames, | |
tiled, | |
tile_size, | |
tile_stride, | |
): | |
if ( | |
vace_video is not None | |
or vace_video_mask is not None | |
or vace_reference_image is not None | |
): | |
pipe.load_models_to_device(["vae"]) | |
if vace_video is None: | |
vace_video = torch.zeros( | |
(1, 3, num_frames, height, width), | |
dtype=pipe.torch_dtype, | |
device=pipe.device, | |
) | |
else: | |
vace_video = pipe.preprocess_video(vace_video) | |
if vace_video_mask is None: | |
vace_video_mask = torch.ones_like(vace_video) | |
else: | |
vace_video_mask = pipe.preprocess_video( | |
vace_video_mask, min_value=0, max_value=1 | |
) | |
inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask | |
reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) | |
inactive = pipe.vae.encode( | |
inactive, | |
device=pipe.device, | |
tiled=tiled, | |
tile_size=tile_size, | |
tile_stride=tile_stride, | |
).to(dtype=pipe.torch_dtype, device=pipe.device) | |
reactive = pipe.vae.encode( | |
reactive, | |
device=pipe.device, | |
tiled=tiled, | |
tile_size=tile_size, | |
tile_stride=tile_stride, | |
).to(dtype=pipe.torch_dtype, device=pipe.device) | |
vace_video_latents = torch.concat((inactive, reactive), dim=1) | |
vace_mask_latents = rearrange( | |
vace_video_mask[0, 0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8 | |
) | |
vace_mask_latents = torch.nn.functional.interpolate( | |
vace_mask_latents, | |
size=( | |
(vace_mask_latents.shape[2] + 3) // 4, | |
vace_mask_latents.shape[3], | |
vace_mask_latents.shape[4], | |
), | |
mode="nearest-exact", | |
) | |
if vace_reference_image is None: | |
pass | |
else: | |
vace_reference_image = pipe.preprocess_video([vace_reference_image]) | |
vace_reference_latents = pipe.vae.encode( | |
vace_reference_image, | |
device=pipe.device, | |
tiled=tiled, | |
tile_size=tile_size, | |
tile_stride=tile_stride, | |
).to(dtype=pipe.torch_dtype, device=pipe.device) | |
vace_reference_latents = torch.concat( | |
(vace_reference_latents, torch.zeros_like(vace_reference_latents)), | |
dim=1, | |
) | |
vace_video_latents = torch.concat( | |
(vace_reference_latents, vace_video_latents), dim=2 | |
) | |
vace_mask_latents = torch.concat( | |
(torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), | |
dim=2, | |
) | |
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) | |
return {"vace_context": vace_context, "vace_scale": vace_scale} | |
else: | |
return {"vace_context": None, "vace_scale": vace_scale} | |
class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): | |
def __init__(self): | |
super().__init__(input_params=()) | |
def process(self, pipe: WanVideoPipeline_FaceSwap): | |
if hasattr(pipe, "use_unified_sequence_parallel"): | |
if pipe.use_unified_sequence_parallel: | |
return {"use_unified_sequence_parallel": True} | |
return {} | |
class WanVideoUnit_TeaCache(PipelineUnit): | |
def __init__(self): | |
super().__init__( | |
seperate_cfg=True, | |
input_params_posi={ | |
"num_inference_steps": "num_inference_steps", | |
"tea_cache_l1_thresh": "tea_cache_l1_thresh", | |
"tea_cache_model_id": "tea_cache_model_id", | |
}, | |
input_params_nega={ | |
"num_inference_steps": "num_inference_steps", | |
"tea_cache_l1_thresh": "tea_cache_l1_thresh", | |
"tea_cache_model_id": "tea_cache_model_id", | |
}, | |
) | |
def process( | |
self, | |
pipe: WanVideoPipeline_FaceSwap, | |
num_inference_steps, | |
tea_cache_l1_thresh, | |
tea_cache_model_id, | |
): | |
if tea_cache_l1_thresh is None: | |
return {} | |
return { | |
"tea_cache": TeaCache( | |
num_inference_steps, | |
rel_l1_thresh=tea_cache_l1_thresh, | |
model_id=tea_cache_model_id, | |
) | |
} | |
class WanVideoUnit_CfgMerger(PipelineUnit): | |
def __init__(self): | |
super().__init__(take_over=True) | |
self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] | |
def process( | |
self, pipe: WanVideoPipeline_FaceSwap, inputs_shared, inputs_posi, inputs_nega | |
): | |
if not inputs_shared["cfg_merge"]: | |
return inputs_shared, inputs_posi, inputs_nega | |
for name in self.concat_tensor_names: | |
tensor_posi = inputs_posi.get(name) | |
tensor_nega = inputs_nega.get(name) | |
tensor_shared = inputs_shared.get(name) | |
if tensor_posi is not None and tensor_nega is not None: | |
inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) | |
elif tensor_shared is not None: | |
inputs_shared[name] = torch.concat( | |
(tensor_shared, tensor_shared), dim=0 | |
) | |
inputs_posi.clear() | |
inputs_nega.clear() | |
return inputs_shared, inputs_posi, inputs_nega | |
class TeaCache: | |
def __init__(self, num_inference_steps, rel_l1_thresh, model_id): | |
self.num_inference_steps = num_inference_steps | |
self.step = 0 | |
self.accumulated_rel_l1_distance = 0 | |
self.previous_modulated_input = None | |
self.rel_l1_thresh = rel_l1_thresh | |
self.previous_residual = None | |
self.previous_hidden_states = None | |
self.coefficients_dict = { | |
"Wan2.1-T2V-1.3B": [ | |
-5.21862437e04, | |
9.23041404e03, | |
-5.28275948e02, | |
1.36987616e01, | |
-4.99875664e-02, | |
], | |
"Wan2.1-T2V-14B": [ | |
-3.03318725e05, | |
4.90537029e04, | |
-2.65530556e03, | |
5.87365115e01, | |
-3.15583525e-01, | |
], | |
"Wan2.1-I2V-14B-480P": [ | |
2.57151496e05, | |
-3.54229917e04, | |
1.40286849e03, | |
-1.35890334e01, | |
1.32517977e-01, | |
], | |
"Wan2.1-I2V-14B-720P": [ | |
8.10705460e03, | |
2.13393892e03, | |
-3.72934672e02, | |
1.66203073e01, | |
-4.17769401e-02, | |
], | |
} | |
if model_id not in self.coefficients_dict: | |
supported_model_ids = ", ".join([i for i in self.coefficients_dict]) | |
raise ValueError( | |
f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids})." | |
) | |
self.coefficients = self.coefficients_dict[model_id] | |
def check(self, dit: WanModel, x, t_mod): | |
modulated_inp = t_mod.clone() | |
if self.step == 0 or self.step == self.num_inference_steps - 1: | |
should_calc = True | |
self.accumulated_rel_l1_distance = 0 | |
else: | |
coefficients = self.coefficients | |
rescale_func = np.poly1d(coefficients) | |
self.accumulated_rel_l1_distance += rescale_func( | |
( | |
(modulated_inp - self.previous_modulated_input).abs().mean() | |
/ self.previous_modulated_input.abs().mean() | |
) | |
.cpu() | |
.item() | |
) | |
if self.accumulated_rel_l1_distance < self.rel_l1_thresh: | |
should_calc = False | |
else: | |
should_calc = True | |
self.accumulated_rel_l1_distance = 0 | |
self.previous_modulated_input = modulated_inp | |
self.step += 1 | |
if self.step == self.num_inference_steps: | |
self.step = 0 | |
if should_calc: | |
self.previous_hidden_states = x.clone() | |
return not should_calc | |
def store(self, hidden_states): | |
self.previous_residual = hidden_states - self.previous_hidden_states | |
self.previous_hidden_states = None | |
def update(self, hidden_states): | |
hidden_states = hidden_states + self.previous_residual | |
return hidden_states | |
class TemporalTiler_BCTHW: | |
def __init__(self): | |
pass | |
def build_1d_mask(self, length, left_bound, right_bound, border_width): | |
x = torch.ones((length,)) | |
if not left_bound: | |
x[:border_width] = (torch.arange(border_width) + 1) / border_width | |
if not right_bound: | |
x[-border_width:] = torch.flip( | |
(torch.arange(border_width) + 1) / border_width, dims=(0,) | |
) | |
return x | |
def build_mask(self, data, is_bound, border_width): | |
_, _, T, _, _ = data.shape | |
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) | |
mask = repeat(t, "T -> 1 1 T 1 1") | |
return mask | |
def run( | |
self, | |
model_fn, | |
sliding_window_size, | |
sliding_window_stride, | |
computation_device, | |
computation_dtype, | |
model_kwargs, | |
tensor_names, | |
batch_size=None, | |
): | |
tensor_names = [ | |
tensor_name | |
for tensor_name in tensor_names | |
if model_kwargs.get(tensor_name) is not None | |
] | |
tensor_dict = { | |
tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names | |
} | |
B, C, T, H, W = tensor_dict[tensor_names[0]].shape | |
if batch_size is not None: | |
B *= batch_size | |
data_device, data_dtype = ( | |
tensor_dict[tensor_names[0]].device, | |
tensor_dict[tensor_names[0]].dtype, | |
) | |
value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) | |
weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) | |
for t in range(0, T, sliding_window_stride): | |
if ( | |
t - sliding_window_stride >= 0 | |
and t - sliding_window_stride + sliding_window_size >= T | |
): | |
continue | |
t_ = min(t + sliding_window_size, T) | |
model_kwargs.update( | |
{ | |
tensor_name: tensor_dict[tensor_name][:, :, t:t_:, :].to( | |
device=computation_device, dtype=computation_dtype | |
) | |
for tensor_name in tensor_names | |
} | |
) | |
model_output = model_fn(**model_kwargs).to( | |
device=data_device, dtype=data_dtype | |
) | |
mask = self.build_mask( | |
model_output, | |
is_bound=(t == 0, t_ == T), | |
border_width=(sliding_window_size - sliding_window_stride,), | |
).to(device=data_device, dtype=data_dtype) | |
value[:, :, t:t_, :, :] += model_output * mask | |
weight[:, :, t:t_, :, :] += mask | |
value /= weight | |
model_kwargs.update(tensor_dict) | |
return value | |
def model_fn_wan_video( | |
dit: WanModel, | |
motion_controller: WanMotionControllerModel = None, | |
vace: VaceWanModel = None, | |
latents: torch.Tensor = None, | |
timestep: torch.Tensor = None, | |
context: torch.Tensor = None, | |
clip_feature: Optional[torch.Tensor] = None, | |
y: Optional[torch.Tensor] = None, | |
reference_latents=None, | |
vace_context=None, | |
vace_scale=1.0, | |
tea_cache: TeaCache = None, | |
use_unified_sequence_parallel: bool = False, | |
motion_bucket_id: Optional[torch.Tensor] = None, | |
sliding_window_size: Optional[int] = None, | |
sliding_window_stride: Optional[int] = None, | |
cfg_merge: bool = False, | |
use_gradient_checkpointing: bool = False, | |
use_gradient_checkpointing_offload: bool = False, | |
control_camera_latents_input=None, | |
fuse_vae_embedding_in_latents: bool = False, | |
ip_image=None, | |
**kwargs, | |
): | |
if sliding_window_size is not None and sliding_window_stride is not None: | |
model_kwargs = dict( | |
dit=dit, | |
motion_controller=motion_controller, | |
vace=vace, | |
latents=latents, | |
timestep=timestep, | |
context=context, | |
clip_feature=clip_feature, | |
y=y, | |
reference_latents=reference_latents, | |
vace_context=vace_context, | |
vace_scale=vace_scale, | |
tea_cache=tea_cache, | |
use_unified_sequence_parallel=use_unified_sequence_parallel, | |
motion_bucket_id=motion_bucket_id, | |
) | |
return TemporalTiler_BCTHW().run( | |
model_fn_wan_video, | |
sliding_window_size, | |
sliding_window_stride, | |
latents.device, | |
latents.dtype, | |
model_kwargs=model_kwargs, | |
tensor_names=["latents", "y"], | |
batch_size=2 if cfg_merge else 1, | |
) | |
if use_unified_sequence_parallel: | |
import torch.distributed as dist | |
from xfuser.core.distributed import ( | |
get_sequence_parallel_rank, | |
get_sequence_parallel_world_size, | |
get_sp_group, | |
) | |
x_ip = None | |
t_mod_ip = None | |
# Timestep | |
if dit.seperated_timestep and fuse_vae_embedding_in_latents: | |
timestep = torch.concat( | |
[ | |
torch.zeros( | |
(1, latents.shape[3] * latents.shape[4] // 4), | |
dtype=latents.dtype, | |
device=latents.device, | |
), | |
torch.ones( | |
(latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), | |
dtype=latents.dtype, | |
device=latents.device, | |
) | |
* timestep, | |
] | |
).flatten() | |
t = dit.time_embedding( | |
sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0) | |
) | |
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) | |
else: | |
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) | |
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) | |
if ip_image is not None: | |
timestep_ip = torch.zeros_like(timestep) # [B] with 0s | |
t_ip = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep_ip)) | |
t_mod_ip = dit.time_projection(t_ip).unflatten(1, (6, dit.dim)) | |
# Motion Controller | |
if motion_bucket_id is not None and motion_controller is not None: | |
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) | |
context = dit.text_embedding(context) | |
x = latents | |
# Merged cfg | |
if x.shape[0] != context.shape[0]: | |
x = torch.concat([x] * context.shape[0], dim=0) | |
if timestep.shape[0] != context.shape[0]: | |
timestep = torch.concat([timestep] * context.shape[0], dim=0) | |
# Image Embedding | |
if y is not None and dit.require_vae_embedding: | |
x = torch.cat([x, y], dim=1) | |
if clip_feature is not None and dit.require_clip_embedding: | |
clip_embdding = dit.img_emb(clip_feature) | |
context = torch.cat([clip_embdding, context], dim=1) | |
# Add camera control | |
x, (f, h, w) = dit.patchify(x, control_camera_latents_input) | |
# Reference image | |
if reference_latents is not None: | |
if len(reference_latents.shape) == 5: | |
reference_latents = reference_latents[:, :, 0] | |
reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) | |
x = torch.concat([reference_latents, x], dim=1) | |
f += 1 | |
offset = 1 | |
freqs = ( | |
torch.cat( | |
[ | |
dit.freqs[0][offset : f + offset].view(f, 1, 1, -1).expand(f, h, w, -1), | |
dit.freqs[1][offset : h + offset].view(1, h, 1, -1).expand(f, h, w, -1), | |
dit.freqs[2][offset : w + offset].view(1, 1, w, -1).expand(f, h, w, -1), | |
], | |
dim=-1, | |
) | |
.reshape(f * h * w, 1, -1) | |
.to(x.device) | |
) | |
############################################################################################ | |
if ip_image is not None: | |
x_ip, (f_ip, h_ip, w_ip) = dit.patchify( | |
ip_image | |
) # x_ip [1, 1024, 5120] [B, N, D] f_ip = 1 h_ip = 32 w_ip = 32 | |
freqs_ip = ( | |
torch.cat( | |
[ | |
dit.freqs[0][0].view(f_ip, 1, 1, -1).expand(f_ip, h_ip, w_ip, -1), | |
dit.freqs[1][h + offset : h + offset + h_ip] | |
.view(1, h_ip, 1, -1) | |
.expand(f_ip, h_ip, w_ip, -1), | |
dit.freqs[2][w + offset : w + offset + w_ip] | |
.view(1, 1, w_ip, -1) | |
.expand(f_ip, h_ip, w_ip, -1), | |
], | |
dim=-1, | |
) | |
.reshape(f_ip * h_ip * w_ip, 1, -1) | |
.to(x_ip.device) | |
) | |
freqs_original = freqs | |
freqs = torch.cat([freqs, freqs_ip], dim=0) | |
############################################################################################ | |
else: | |
freqs_original = freqs | |
# TeaCache | |
if tea_cache is not None: | |
tea_cache_update = tea_cache.check(dit, x, t_mod) | |
else: | |
tea_cache_update = False | |
if vace_context is not None: | |
vace_hints = vace(x, vace_context, context, t_mod, freqs) | |
# blocks | |
if use_unified_sequence_parallel: | |
if dist.is_initialized() and dist.get_world_size() > 1: | |
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[ | |
get_sequence_parallel_rank() | |
] | |
if tea_cache_update: | |
x = tea_cache.update(x) | |
else: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
for block_id, block in enumerate(dit.blocks): | |
if use_gradient_checkpointing_offload: | |
with torch.autograd.graph.save_on_cpu(): | |
x, x_ip = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(block), | |
x, | |
context, | |
t_mod, | |
freqs, | |
x_ip=x_ip, | |
t_mod_ip=t_mod_ip, | |
use_reentrant=False, | |
) | |
elif use_gradient_checkpointing: | |
x, x_ip = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(block), | |
x, | |
context, | |
t_mod, | |
freqs, | |
x_ip=x_ip, | |
t_mod_ip=t_mod_ip, | |
use_reentrant=False, | |
) | |
else: | |
x, x_ip = block(x, context, t_mod, freqs, x_ip=x_ip, t_mod_ip=t_mod_ip) | |
if vace_context is not None and block_id in vace.vace_layers_mapping: | |
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] | |
if ( | |
use_unified_sequence_parallel | |
and dist.is_initialized() | |
and dist.get_world_size() > 1 | |
): | |
current_vace_hint = torch.chunk( | |
current_vace_hint, get_sequence_parallel_world_size(), dim=1 | |
)[get_sequence_parallel_rank()] | |
x = x + current_vace_hint * vace_scale | |
if tea_cache is not None: | |
tea_cache.store(x) | |
x = dit.head(x, t) | |
if use_unified_sequence_parallel: | |
if dist.is_initialized() and dist.get_world_size() > 1: | |
x = get_sp_group().all_gather(x, dim=1) | |
# Remove reference latents | |
if reference_latents is not None: | |
x = x[:, reference_latents.shape[1] :] | |
f -= 1 | |
x = dit.unpatchify(x, (f, h, w)) | |
return x | |