Spaces:
Running
Running
import os | |
import time | |
import random | |
import functools | |
from typing import List, Optional, Tuple, Union | |
from pathlib import Path | |
from einops import rearrange | |
import torch | |
import torch.distributed as dist | |
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V | |
from hyvideo.vae import load_vae | |
from hyvideo.modules import load_model | |
from hyvideo.text_encoder import TextEncoder | |
from hyvideo.utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list | |
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new | |
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler | |
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline | |
from hyvideo.diffusion.pipelines import HunyuanVideoAudioPipeline | |
from PIL import Image | |
import numpy as np | |
import torchvision.transforms as transforms | |
import cv2 | |
from wan.utils.utils import resize_lanczos, calculate_new_dimensions | |
from hyvideo.data_kits.audio_preprocessor import encode_audio, get_facemask | |
from transformers import WhisperModel | |
from transformers import AutoFeatureExtractor | |
from hyvideo.data_kits.face_align import AlignImage | |
import librosa | |
def get_audio_feature(feature_extractor, audio_path, duration): | |
audio_input, sampling_rate = librosa.load(audio_path, duration=duration, sr=16000) | |
assert sampling_rate == 16000 | |
audio_features = [] | |
window = 750*640 | |
for i in range(0, len(audio_input), window): | |
audio_feature = feature_extractor(audio_input[i:i+window], | |
sampling_rate=sampling_rate, | |
return_tensors="pt", | |
device="cuda" | |
).input_features | |
audio_features.append(audio_feature) | |
audio_features = torch.cat(audio_features, dim=-1) | |
return audio_features, len(audio_input) // 640 | |
def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1): | |
crop_h, crop_w = crop_img.shape[:2] | |
target_w, target_h = size | |
scale_h, scale_w = target_h / crop_h, target_w / crop_w | |
if scale_w > scale_h: | |
resize_h = int(target_h*resize_ratio) | |
resize_w = int(crop_w / crop_h * resize_h) | |
else: | |
resize_w = int(target_w*resize_ratio) | |
resize_h = int(crop_h / crop_w * resize_w) | |
crop_img = cv2.resize(crop_img, (resize_w, resize_h)) | |
pad_left = (target_w - resize_w) // 2 | |
pad_top = (target_h - resize_h) // 2 | |
pad_right = target_w - resize_w - pad_left | |
pad_bottom = target_h - resize_h - pad_top | |
crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color) | |
return crop_img | |
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): | |
num_images, num_image_patches, embed_dim = image_features.shape | |
batch_size, sequence_length = input_ids.shape | |
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) | |
# 1. Create a mask to know where special image tokens are | |
special_image_token_mask = input_ids == self.config.image_token_index | |
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) | |
# Compute the maximum embed dimension | |
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length | |
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) | |
# 2. Compute the positions where text should be written | |
# Calculate new positions for text tokens in merged image-text sequence. | |
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. | |
# `torch.cumsum` computes how each image token shifts subsequent text token positions. | |
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. | |
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 | |
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] | |
if left_padding: | |
new_token_positions += nb_image_pad[:, None] # offset for left padding | |
text_to_overwrite = new_token_positions[batch_indices, non_image_indices] | |
# 3. Create the full embedding, already padded to the maximum position | |
final_embedding = torch.zeros( | |
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device | |
) | |
final_attention_mask = torch.zeros( | |
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device | |
) | |
if labels is not None: | |
final_labels = torch.full( | |
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device | |
) | |
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually | |
# set the corresponding tensors into their correct target device. | |
target_device = inputs_embeds.device | |
batch_indices, non_image_indices, text_to_overwrite = ( | |
batch_indices.to(target_device), | |
non_image_indices.to(target_device), | |
text_to_overwrite.to(target_device), | |
) | |
attention_mask = attention_mask.to(target_device) | |
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] | |
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features | |
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] | |
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] | |
if labels is not None: | |
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] | |
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) | |
image_to_overwrite = torch.full( | |
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device | |
) | |
image_to_overwrite[batch_indices, text_to_overwrite] = False | |
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) | |
if image_to_overwrite.sum() != image_features.shape[:-1].numel(): | |
raise ValueError( | |
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" | |
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." | |
) | |
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) | |
final_attention_mask |= image_to_overwrite | |
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) | |
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. | |
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) | |
indices_to_mask = new_token_positions[batch_indices, pad_indices] | |
final_embedding[batch_indices, indices_to_mask] = 0 | |
if labels is None: | |
final_labels = None | |
return final_embedding, final_attention_mask, final_labels, position_ids | |
def patched_llava_forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
pixel_values: torch.FloatTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
vision_feature_layer: Optional[int] = None, | |
vision_feature_select_strategy: Optional[str] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
num_logits_to_keep: int = 0, | |
): | |
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
vision_feature_layer = ( | |
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer | |
) | |
vision_feature_select_strategy = ( | |
vision_feature_select_strategy | |
if vision_feature_select_strategy is not None | |
else self.config.vision_feature_select_strategy | |
) | |
if (input_ids is None) ^ (inputs_embeds is not None): | |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") | |
if pixel_values is not None and inputs_embeds is not None: | |
raise ValueError( | |
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" | |
) | |
if inputs_embeds is None: | |
inputs_embeds = self.get_input_embeddings()(input_ids) | |
image_features = None | |
if pixel_values is not None: | |
image_features = self.get_image_features( | |
pixel_values=pixel_values, | |
vision_feature_layer=vision_feature_layer, | |
vision_feature_select_strategy=vision_feature_select_strategy, | |
) | |
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( | |
image_features, inputs_embeds, input_ids, attention_mask, labels | |
) | |
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) | |
outputs = self.language_model( | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
cache_position=cache_position, | |
num_logits_to_keep=num_logits_to_keep, | |
) | |
logits = outputs[0] | |
loss = None | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return (loss,) + output if loss is not None else output | |
return LlavaCausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
image_hidden_states=image_features if pixel_values is not None else None, | |
) | |
def adapt_model(model, audio_block_name): | |
modules_dict= { k: m for k, m in model.named_modules()} | |
for model_layer, avatar_layer in model.double_stream_map.items(): | |
module = modules_dict[f"{audio_block_name}.{avatar_layer}"] | |
target = modules_dict[f"double_blocks.{model_layer}"] | |
setattr(target, "audio_adapter", module ) | |
delattr(model, audio_block_name) | |
class DataPreprocess(object): | |
def __init__(self): | |
self.llava_size = (336, 336) | |
self.llava_transform = transforms.Compose( | |
[ | |
transforms.Resize(self.llava_size, interpolation=transforms.InterpolationMode.BILINEAR), | |
transforms.ToTensor(), | |
transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)), | |
] | |
) | |
def get_batch(self, image , size, pad = False): | |
image = np.asarray(image) | |
if pad: | |
llava_item_image = pad_image(image.copy(), self.llava_size) | |
else: | |
llava_item_image = image.copy() | |
uncond_llava_item_image = np.ones_like(llava_item_image) * 255 | |
if pad: | |
cat_item_image = pad_image(image.copy(), size) | |
else: | |
cat_item_image = image.copy() | |
llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8))) | |
uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image)) | |
cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0 | |
# batch = { | |
# "pixel_value_llava": llava_item_tensor.unsqueeze(0), | |
# "uncond_pixel_value_llava": uncond_llava_item_tensor.unsqueeze(0), | |
# 'pixel_value_ref': cat_item_tensor.unsqueeze(0), | |
# } | |
return llava_item_tensor.unsqueeze(0), uncond_llava_item_tensor.unsqueeze(0), cat_item_tensor.unsqueeze(0) | |
class Inference(object): | |
def __init__( | |
self, | |
i2v, | |
custom, | |
avatar, | |
enable_cfg, | |
vae, | |
vae_kwargs, | |
text_encoder, | |
model, | |
text_encoder_2=None, | |
pipeline=None, | |
feature_extractor=None, | |
wav2vec=None, | |
align_instance=None, | |
device=None, | |
): | |
self.i2v = i2v | |
self.custom = custom | |
self.avatar = avatar | |
self.enable_cfg = enable_cfg | |
self.vae = vae | |
self.vae_kwargs = vae_kwargs | |
self.text_encoder = text_encoder | |
self.text_encoder_2 = text_encoder_2 | |
self.model = model | |
self.pipeline = pipeline | |
self.feature_extractor=feature_extractor | |
self.wav2vec=wav2vec | |
self.align_instance=align_instance | |
self.device = "cuda" | |
def from_pretrained(cls, model_filepath, model_type, base_model_type, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , quantizeTransformer = False, save_quantized = False, **kwargs): | |
device = "cuda" | |
import transformers | |
transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47) | |
transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features | |
torch.set_grad_enabled(False) | |
text_len = 512 | |
latent_channels = 16 | |
precision = "bf16" | |
vae_precision = "fp32" if VAE_dtype == torch.float32 else "bf16" | |
embedded_cfg_scale = 6 | |
filepath = model_filepath[0] | |
i2v_condition_type = None | |
i2v_mode = False | |
custom = False | |
custom_audio = False | |
avatar = False | |
if base_model_type == "hunyuan_i2v": | |
model_id = "HYVideo-T/2" | |
i2v_condition_type = "token_replace" | |
i2v_mode = True | |
elif base_model_type == "hunyuan_custom": | |
model_id = "HYVideo-T/2-custom" | |
custom = True | |
elif base_model_type == "hunyuan_custom_audio": | |
model_id = "HYVideo-T/2-custom-audio" | |
custom_audio = True | |
custom = True | |
elif base_model_type == "hunyuan_custom_edit": | |
model_id = "HYVideo-T/2-custom-edit" | |
custom = True | |
elif base_model_type == "hunyuan_avatar": | |
model_id = "HYVideo-T/2-avatar" | |
text_len = 256 | |
avatar = True | |
else: | |
model_id = "HYVideo-T/2-cfgdistill" | |
if i2v_mode and i2v_condition_type == "latent_concat": | |
in_channels = latent_channels * 2 + 1 | |
image_embed_interleave = 2 | |
elif i2v_mode and i2v_condition_type == "token_replace": | |
in_channels = latent_channels | |
image_embed_interleave = 4 | |
else: | |
in_channels = latent_channels | |
image_embed_interleave = 1 | |
out_channels = latent_channels | |
pinToMemory = kwargs.pop("pinToMemory", False) | |
partialPinning = kwargs.pop("partialPinning", False) | |
factor_kwargs = kwargs | {"device": "meta", "dtype": PRECISION_TO_TYPE[precision]} | |
if embedded_cfg_scale and i2v_mode: | |
factor_kwargs["guidance_embed"] = True | |
model = load_model( | |
model = model_id, | |
i2v_condition_type = i2v_condition_type, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
factor_kwargs=factor_kwargs, | |
) | |
from mmgp import offload | |
# model = Inference.load_state_dict(args, model, model_filepath) | |
# model_filepath ="c:/temp/hc/mp_rank_00_model_states_video.pt" | |
offload.load_model_data(model, model_filepath, do_quantize= quantizeTransformer and not save_quantized, pinToMemory = pinToMemory, partialPinning = partialPinning) | |
pass | |
# offload.save_model(model, "hunyuan_video_avatar_edit_720_bf16.safetensors") | |
# offload.save_model(model, "hunyuan_video_avatar_edit_720_quanto_bf16_int8.safetensors", do_quantize= True) | |
if save_quantized: | |
from wgp import save_quantized_model | |
save_quantized_model(model, model_type, filepath, dtype, None) | |
model.mixed_precision = mixed_precision_transformer | |
if model.mixed_precision : | |
model._lock_dtype = torch.float32 | |
model.lock_layers_dtypes(torch.float32) | |
model.eval() | |
# ============================= Build extra models ======================== | |
# VAE | |
if custom or avatar: | |
vae_configpath = "ckpts/hunyuan_video_custom_VAE_config.json" | |
vae_filepath = "ckpts/hunyuan_video_custom_VAE_fp32.safetensors" | |
# elif avatar: | |
# vae_configpath = "ckpts/config_vae_avatar.json" | |
# vae_filepath = "ckpts/vae_avatar.pt" | |
else: | |
vae_configpath = "ckpts/hunyuan_video_VAE_config.json" | |
vae_filepath = "ckpts/hunyuan_video_VAE_fp32.safetensors" | |
# config = AutoencoderKLCausal3D.load_config("ckpts/hunyuan_video_VAE_config.json") | |
# config = AutoencoderKLCausal3D.load_config("c:/temp/hvae/config_vae.json") | |
vae, _, s_ratio, t_ratio = load_vae( "884-16c-hy", vae_path= vae_filepath, vae_config_path= vae_configpath, vae_precision= vae_precision, device= "cpu", ) | |
vae._model_dtype = torch.float32 if VAE_dtype == torch.float32 else (torch.float16 if avatar else torch.bfloat16) | |
vae._model_dtype = torch.float32 if VAE_dtype == torch.float32 else torch.bfloat16 | |
vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} | |
enable_cfg = False | |
# Text encoder | |
if i2v_mode: | |
text_encoder = "llm-i2v" | |
tokenizer = "llm-i2v" | |
prompt_template = "dit-llm-encode-i2v" | |
prompt_template_video = "dit-llm-encode-video-i2v" | |
elif custom or avatar : | |
text_encoder = "llm-i2v" | |
tokenizer = "llm-i2v" | |
prompt_template = "dit-llm-encode" | |
prompt_template_video = "dit-llm-encode-video" | |
enable_cfg = True | |
else: | |
text_encoder = "llm" | |
tokenizer = "llm" | |
prompt_template = "dit-llm-encode" | |
prompt_template_video = "dit-llm-encode-video" | |
if prompt_template_video is not None: | |
crop_start = PROMPT_TEMPLATE[prompt_template_video].get( "crop_start", 0 ) | |
elif prompt_template is not None: | |
crop_start = PROMPT_TEMPLATE[prompt_template].get("crop_start", 0) | |
else: | |
crop_start = 0 | |
max_length = text_len + crop_start | |
# prompt_template | |
prompt_template = PROMPT_TEMPLATE[prompt_template] if prompt_template is not None else None | |
# prompt_template_video | |
prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] if prompt_template_video is not None else None | |
text_encoder = TextEncoder( | |
text_encoder_type=text_encoder, | |
max_length=max_length, | |
text_encoder_precision="fp16", | |
tokenizer_type=tokenizer, | |
i2v_mode=i2v_mode, | |
prompt_template=prompt_template, | |
prompt_template_video=prompt_template_video, | |
hidden_state_skip_layer=2, | |
apply_final_norm=False, | |
reproduce=True, | |
device="cpu", | |
image_embed_interleave=image_embed_interleave, | |
text_encoder_path = text_encoder_filepath | |
) | |
text_encoder_2 = TextEncoder( | |
text_encoder_type="clipL", | |
max_length=77, | |
text_encoder_precision="fp16", | |
tokenizer_type="clipL", | |
reproduce=True, | |
device="cpu", | |
) | |
feature_extractor = None | |
wav2vec = None | |
align_instance = None | |
if avatar or custom_audio: | |
feature_extractor = AutoFeatureExtractor.from_pretrained("ckpts/whisper-tiny/") | |
wav2vec = WhisperModel.from_pretrained("ckpts/whisper-tiny/").to(device="cpu", dtype=torch.float32) | |
wav2vec._model_dtype = torch.float32 | |
wav2vec.requires_grad_(False) | |
if avatar: | |
align_instance = AlignImage("cuda", det_path="ckpts/det_align/detface.pt") | |
align_instance.facedet.model.to("cpu") | |
adapt_model(model, "audio_adapter_blocks") | |
elif custom_audio: | |
adapt_model(model, "audio_models") | |
return cls( | |
i2v=i2v_mode, | |
custom=custom, | |
avatar=avatar, | |
enable_cfg = enable_cfg, | |
vae=vae, | |
vae_kwargs=vae_kwargs, | |
text_encoder=text_encoder, | |
text_encoder_2=text_encoder_2, | |
model=model, | |
feature_extractor=feature_extractor, | |
wav2vec=wav2vec, | |
align_instance=align_instance, | |
device=device, | |
) | |
class HunyuanVideoSampler(Inference): | |
def __init__( | |
self, | |
i2v, | |
custom, | |
avatar, | |
enable_cfg, | |
vae, | |
vae_kwargs, | |
text_encoder, | |
model, | |
text_encoder_2=None, | |
pipeline=None, | |
feature_extractor=None, | |
wav2vec=None, | |
align_instance=None, | |
device=0, | |
): | |
super().__init__( | |
i2v, | |
custom, | |
avatar, | |
enable_cfg, | |
vae, | |
vae_kwargs, | |
text_encoder, | |
model, | |
text_encoder_2=text_encoder_2, | |
pipeline=pipeline, | |
feature_extractor=feature_extractor, | |
wav2vec=wav2vec, | |
align_instance=align_instance, | |
device=device, | |
) | |
self.i2v_mode = i2v | |
self.enable_cfg = enable_cfg | |
self.pipeline = self.load_diffusion_pipeline( | |
avatar = self.avatar, | |
vae=self.vae, | |
text_encoder=self.text_encoder, | |
text_encoder_2=self.text_encoder_2, | |
model=self.model, | |
device=self.device, | |
) | |
if self.i2v_mode: | |
self.default_negative_prompt = NEGATIVE_PROMPT_I2V | |
else: | |
self.default_negative_prompt = NEGATIVE_PROMPT | |
def _interrupt(self): | |
return self.pipeline._interrupt | |
def _interrupt(self, value): | |
self.pipeline._interrupt =value | |
def load_diffusion_pipeline( | |
self, | |
avatar, | |
vae, | |
text_encoder, | |
text_encoder_2, | |
model, | |
scheduler=None, | |
device=None, | |
progress_bar_config=None, | |
#data_type="video", | |
): | |
"""Load the denoising scheduler for inference.""" | |
if scheduler is None: | |
scheduler = FlowMatchDiscreteScheduler( | |
shift=6.0, | |
reverse=True, | |
solver="euler", | |
) | |
if avatar: | |
pipeline = HunyuanVideoAudioPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
text_encoder_2=text_encoder_2, | |
transformer=model, | |
scheduler=scheduler, | |
progress_bar_config=progress_bar_config, | |
) | |
else: | |
pipeline = HunyuanVideoPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
text_encoder_2=text_encoder_2, | |
transformer=model, | |
scheduler=scheduler, | |
progress_bar_config=progress_bar_config, | |
) | |
return pipeline | |
def get_rotary_pos_embed_new(self, video_length, height, width, concat_dict={}, enable_riflex = False): | |
target_ndim = 3 | |
ndim = 5 - 2 | |
latents_size = [(video_length-1)//4+1 , height//8, width//8] | |
if isinstance(self.model.patch_size, int): | |
assert all(s % self.model.patch_size == 0 for s in latents_size), \ | |
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ | |
f"but got {latents_size}." | |
rope_sizes = [s // self.model.patch_size for s in latents_size] | |
elif isinstance(self.model.patch_size, list): | |
assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \ | |
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ | |
f"but got {latents_size}." | |
rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)] | |
if len(rope_sizes) != target_ndim: | |
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis | |
head_dim = self.model.hidden_size // self.model.heads_num | |
rope_dim_list = self.model.rope_dim_list | |
if rope_dim_list is None: | |
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] | |
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" | |
freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list, | |
rope_sizes, | |
theta=256, | |
use_real=True, | |
theta_rescale_factor=1, | |
concat_dict=concat_dict, | |
L_test = (video_length - 1) // 4 + 1, | |
enable_riflex = enable_riflex | |
) | |
return freqs_cos, freqs_sin | |
def get_rotary_pos_embed(self, video_length, height, width, enable_riflex = False): | |
target_ndim = 3 | |
ndim = 5 - 2 | |
# 884 | |
vae = "884-16c-hy" | |
if "884" in vae: | |
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] | |
elif "888" in vae: | |
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8] | |
else: | |
latents_size = [video_length, height // 8, width // 8] | |
if isinstance(self.model.patch_size, int): | |
assert all(s % self.model.patch_size == 0 for s in latents_size), ( | |
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " | |
f"but got {latents_size}." | |
) | |
rope_sizes = [s // self.model.patch_size for s in latents_size] | |
elif isinstance(self.model.patch_size, list): | |
assert all( | |
s % self.model.patch_size[idx] == 0 | |
for idx, s in enumerate(latents_size) | |
), ( | |
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " | |
f"but got {latents_size}." | |
) | |
rope_sizes = [ | |
s // self.model.patch_size[idx] for idx, s in enumerate(latents_size) | |
] | |
if len(rope_sizes) != target_ndim: | |
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis | |
head_dim = self.model.hidden_size // self.model.heads_num | |
rope_dim_list = self.model.rope_dim_list | |
if rope_dim_list is None: | |
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] | |
assert ( | |
sum(rope_dim_list) == head_dim | |
), "sum(rope_dim_list) should equal to head_dim of attention layer" | |
freqs_cos, freqs_sin = get_nd_rotary_pos_embed( | |
rope_dim_list, | |
rope_sizes, | |
theta=256, | |
use_real=True, | |
theta_rescale_factor=1, | |
L_test = (video_length - 1) // 4 + 1, | |
enable_riflex = enable_riflex | |
) | |
return freqs_cos, freqs_sin | |
def generate( | |
self, | |
input_prompt, | |
input_ref_images = None, | |
audio_guide = None, | |
input_frames = None, | |
input_masks = None, | |
input_video = None, | |
fps = 24, | |
height=192, | |
width=336, | |
frame_num=129, | |
seed=None, | |
n_prompt=None, | |
sampling_steps=50, | |
guide_scale=1.0, | |
shift=5.0, | |
embedded_guidance_scale=6.0, | |
batch_size=1, | |
num_videos_per_prompt=1, | |
i2v_resolution="720p", | |
image_start=None, | |
enable_RIFLEx = False, | |
i2v_condition_type: str = "token_replace", | |
i2v_stability=True, | |
VAE_tile_size = None, | |
joint_pass = False, | |
cfg_star_switch = False, | |
fit_into_canvas = True, | |
conditioning_latents_size = 0, | |
**kwargs, | |
): | |
if VAE_tile_size != None: | |
self.vae.tile_sample_min_tsize = VAE_tile_size["tile_sample_min_tsize"] | |
self.vae.tile_latent_min_tsize = VAE_tile_size["tile_latent_min_tsize"] | |
self.vae.tile_sample_min_size = VAE_tile_size["tile_sample_min_size"] | |
self.vae.tile_latent_min_size = VAE_tile_size["tile_latent_min_size"] | |
self.vae.tile_overlap_factor = VAE_tile_size["tile_overlap_factor"] | |
self.vae.enable_tiling() | |
i2v_mode= self.i2v_mode | |
if not self.enable_cfg: | |
guide_scale=1.0 | |
# ======================================================================== | |
# Arguments: seed | |
# ======================================================================== | |
if isinstance(seed, torch.Tensor): | |
seed = seed.tolist() | |
if seed is None: | |
seeds = [ | |
random.randint(0, 1_000_000) | |
for _ in range(batch_size * num_videos_per_prompt) | |
] | |
elif isinstance(seed, int): | |
seeds = [ | |
seed + i | |
for _ in range(batch_size) | |
for i in range(num_videos_per_prompt) | |
] | |
elif isinstance(seed, (list, tuple)): | |
if len(seed) == batch_size: | |
seeds = [ | |
int(seed[i]) + j | |
for i in range(batch_size) | |
for j in range(num_videos_per_prompt) | |
] | |
elif len(seed) == batch_size * num_videos_per_prompt: | |
seeds = [int(s) for s in seed] | |
else: | |
raise ValueError( | |
f"Length of seed must be equal to number of prompt(batch_size) or " | |
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}." | |
) | |
else: | |
raise ValueError( | |
f"Seed must be an integer, a list of integers, or None, got {seed}." | |
) | |
from wan.utils.utils import seed_everything | |
seed_everything(seed) | |
generator = [torch.Generator("cuda").manual_seed(seed) for seed in seeds] | |
# generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] | |
# ======================================================================== | |
# Arguments: target_width, target_height, target_frame_num | |
# ======================================================================== | |
if width <= 0 or height <= 0 or frame_num <= 0: | |
raise ValueError( | |
f"`height` and `width` and `frame_num` must be positive integers, got height={height}, width={width}, frame_num={frame_num}" | |
) | |
if (frame_num - 1) % 4 != 0: | |
raise ValueError( | |
f"`frame_num-1` must be a multiple of 4, got {frame_num}" | |
) | |
target_height = align_to(height, 16) | |
target_width = align_to(width, 16) | |
target_frame_num = frame_num | |
audio_strength = 1 | |
if input_ref_images != None: | |
# ip_cfg_scale = 3.0 | |
ip_cfg_scale = 0 | |
denoise_strength = 1 | |
# guide_scale=7.5 | |
# shift=13 | |
name = "person" | |
input_ref_images = input_ref_images[0] | |
# ======================================================================== | |
# Arguments: prompt, new_prompt, negative_prompt | |
# ======================================================================== | |
if not isinstance(input_prompt, str): | |
raise TypeError(f"`prompt` must be a string, but got {type(input_prompt)}") | |
input_prompt = [input_prompt.strip()] | |
# negative prompt | |
if n_prompt is None or n_prompt == "": | |
n_prompt = self.default_negative_prompt | |
if guide_scale == 1.0: | |
n_prompt = "" | |
if not isinstance(n_prompt, str): | |
raise TypeError( | |
f"`negative_prompt` must be a string, but got {type(n_prompt)}" | |
) | |
n_prompt = [n_prompt.strip()] | |
# ======================================================================== | |
# Scheduler | |
# ======================================================================== | |
scheduler = FlowMatchDiscreteScheduler( | |
shift=shift, | |
reverse=True, | |
solver="euler" | |
) | |
self.pipeline.scheduler = scheduler | |
# --------------------------------- | |
# Reference condition | |
# --------------------------------- | |
img_latents = None | |
semantic_images = None | |
denoise_strength = 0 | |
ip_cfg_scale = 0 | |
if i2v_mode: | |
if i2v_resolution == "720p": | |
bucket_hw_base_size = 960 | |
elif i2v_resolution == "540p": | |
bucket_hw_base_size = 720 | |
elif i2v_resolution == "360p": | |
bucket_hw_base_size = 480 | |
else: | |
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]") | |
# semantic_images = [Image.open(i2v_image_path).convert('RGB')] | |
semantic_images = [image_start.convert('RGB')] # | |
origin_size = semantic_images[0].size | |
h, w = origin_size | |
h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas) | |
closest_size = (w, h) | |
# crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32) | |
# aspect_ratios = np.array([round(float(h)/float(w), 5) for h, w in crop_size_list]) | |
# closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list) | |
ref_image_transform = transforms.Compose([ | |
transforms.Resize(closest_size), | |
transforms.CenterCrop(closest_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]) | |
]) | |
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images] | |
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device) | |
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True): | |
img_latents = self.pipeline.vae.encode(semantic_image_pixel_values).latent_dist.mode() # B, C, F, H, W | |
img_latents.mul_(self.pipeline.vae.config.scaling_factor) | |
target_height, target_width = closest_size | |
# ======================================================================== | |
# Build Rope freqs | |
# ======================================================================== | |
if input_ref_images == None: | |
freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx) | |
else: | |
if self.avatar: | |
w, h = input_ref_images.size | |
target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas) | |
if target_width != w or target_height != h: | |
input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS) | |
concat_dict = {'mode': 'timecat', 'bias': -1} | |
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict) | |
else: | |
if input_frames != None: | |
target_height, target_width = input_frames.shape[-3:-1] | |
elif input_video != None: | |
target_height, target_width = input_video.shape[-2:] | |
concat_dict = {'mode': 'timecat-w', 'bias': -1} | |
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(target_frame_num, target_height, target_width, concat_dict, enable_RIFLEx) | |
n_tokens = freqs_cos.shape[0] | |
callback = kwargs.pop("callback", None) | |
callback_steps = kwargs.pop("callback_steps", None) | |
# ======================================================================== | |
# Pipeline inference | |
# ======================================================================== | |
pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = None, None, None | |
if input_ref_images == None: | |
name = None | |
else: | |
pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = DataPreprocess().get_batch(input_ref_images, (target_width, target_height), pad = self.custom) | |
ref_latents, uncond_audio_prompts, audio_prompts, face_masks, motion_exp, motion_pose = None, None, None, None, None, None | |
bg_latents = None | |
if input_video != None: | |
pixel_value_bg = input_video.unsqueeze(0) | |
pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0) | |
if input_frames != None: | |
pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float() | |
pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float() | |
pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.) | |
if input_video != None: | |
pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2) | |
pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2) | |
else: | |
pixel_value_bg = pixel_value_video_bg | |
pixel_value_mask = pixel_value_video_mask | |
pixel_value_video_mask, pixel_value_video_bg = None, None | |
if input_video != None or input_frames != None: | |
if pixel_value_bg.shape[2] < frame_num: | |
padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:]) | |
pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2) | |
pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) | |
bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample() | |
pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.) | |
mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample() | |
bg_latents = torch.cat([bg_latents, mask_latents], dim=1) | |
bg_latents.mul_(self.vae.config.scaling_factor) | |
if self.avatar: | |
if n_prompt == None or len(n_prompt) == 0: | |
n_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes" | |
uncond_pixel_value_llava = pixel_value_llava.clone() | |
pixel_value_ref = pixel_value_ref.unsqueeze(0) | |
self.align_instance.facedet.model.to("cuda") | |
face_masks = get_facemask(pixel_value_ref.to("cuda")*255, self.align_instance, area=3.0) | |
# iii = (face_masks.squeeze(0).squeeze(0).permute(1,2,0).repeat(1,1,3)*255).cpu().numpy().astype(np.uint8) | |
# image = Image.fromarray(iii) | |
# image.save("mask.png") | |
# jjj = (pixel_value_ref.squeeze(0).squeeze(0).permute(1,2,0)*255).cpu().numpy().astype(np.uint8) | |
self.align_instance.facedet.model.to("cpu") | |
# pixel_value_ref = pixel_value_ref.clone().repeat(1,129,1,1,1) | |
pixel_value_ref = pixel_value_ref.repeat(1,1+4*2,1,1,1) | |
pixel_value_ref = pixel_value_ref * 2 - 1 | |
pixel_value_ref_for_vae = rearrange(pixel_value_ref, "b f c h w -> b c f h w") | |
vae_dtype = self.vae.dtype | |
with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32): | |
ref_latents = self.vae.encode(pixel_value_ref_for_vae).latent_dist.sample() | |
ref_latents = torch.cat( [ref_latents[:,:, :1], ref_latents[:,:, 1:2].repeat(1,1,31,1,1), ref_latents[:,:, -1:]], dim=2) | |
pixel_value_ref, pixel_value_ref_for_vae = None, None | |
if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor: | |
ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor) | |
else: | |
ref_latents.mul_(self.vae.config.scaling_factor) | |
# out_latents= ref_latents / self.vae.config.scaling_factor | |
# image = self.vae.decode(out_latents, return_dict=False, generator=generator)[0] | |
# image = image.clamp(-1, 1) | |
# from wan.utils.utils import cache_video | |
# cache_video( tensor=image, save_file="decode.mp4", fps=25, nrow=1, normalize=True, value_range=(-1, 1)) | |
motion_pose = np.array([25] * 4) | |
motion_exp = np.array([30] * 4) | |
motion_pose = torch.from_numpy(motion_pose).unsqueeze(0) | |
motion_exp = torch.from_numpy(motion_exp).unsqueeze(0) | |
face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2), | |
(ref_latents.shape[-2], | |
ref_latents.shape[-1]), | |
mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype) | |
if audio_guide != None: | |
audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_guide, duration = frame_num/fps ) | |
audio_prompts = audio_input[0] | |
weight_dtype = audio_prompts.dtype | |
if self.custom: | |
audio_len = min(audio_len, frame_num) | |
audio_input = audio_input[:, :audio_len] | |
audio_prompts = encode_audio(self.wav2vec, audio_prompts.to(dtype=self.wav2vec.dtype), fps, num_frames=audio_len) | |
audio_prompts = audio_prompts.to(self.model.dtype) | |
segment_size = 129 if self.avatar else frame_num | |
if audio_prompts.shape[1] <= segment_size: | |
audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,segment_size-audio_prompts.shape[1], 1, 1, 1)], dim=1) | |
else: | |
audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1) | |
uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129]) | |
samples = self.pipeline( | |
prompt=input_prompt, | |
height=target_height, | |
width=target_width, | |
video_length=target_frame_num, | |
num_inference_steps=sampling_steps, | |
guidance_scale=guide_scale, | |
negative_prompt=n_prompt, | |
num_videos_per_prompt=num_videos_per_prompt, | |
generator=generator, | |
output_type="pil", | |
name = name, | |
pixel_value_ref = pixel_value_ref, | |
ref_latents=ref_latents, # [1, 16, 1, h//8, w//8] | |
pixel_value_llava=pixel_value_llava, # [1, 3, 336, 336] | |
uncond_pixel_value_llava=uncond_pixel_value_llava, | |
face_masks=face_masks, # [b f h w] | |
audio_prompts=audio_prompts, | |
uncond_audio_prompts=uncond_audio_prompts, | |
motion_exp=motion_exp, | |
motion_pose=motion_pose, | |
fps= torch.from_numpy(np.array(fps)), | |
bg_latents = bg_latents, | |
audio_strength = audio_strength, | |
denoise_strength=denoise_strength, | |
ip_cfg_scale=ip_cfg_scale, | |
freqs_cis=(freqs_cos, freqs_sin), | |
n_tokens=n_tokens, | |
embedded_guidance_scale=embedded_guidance_scale, | |
data_type="video" if target_frame_num > 1 else "image", | |
is_progress_bar=True, | |
vae_ver="884-16c-hy", | |
enable_tiling=True, | |
i2v_mode=i2v_mode, | |
i2v_condition_type=i2v_condition_type, | |
i2v_stability=i2v_stability, | |
img_latents=img_latents, | |
semantic_images=semantic_images, | |
joint_pass = joint_pass, | |
cfg_star_rescale = cfg_star_switch, | |
callback = callback, | |
callback_steps = callback_steps, | |
)[0] | |
if samples == None: | |
return None | |
samples = samples.squeeze(0) | |
return samples | |