Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
import numpy as np | |
from tqdm import tqdm | |
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D | |
class DownsampleCausal3D(nn.Module): | |
def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2): | |
super().__init__() | |
self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias) | |
def forward(self, hidden_states): | |
hidden_states = self.conv(hidden_states) | |
return hidden_states | |
class DownEncoderBlockCausal3D(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
dropout=0.0, | |
num_layers=1, | |
eps=1e-6, | |
num_groups=32, | |
add_downsample=True, | |
downsample_stride=2, | |
): | |
super().__init__() | |
resnets = [] | |
for i in range(num_layers): | |
cur_in_channel = in_channels if i == 0 else out_channels | |
resnets.append( | |
ResnetBlockCausal3D( | |
in_channels=cur_in_channel, | |
out_channels=out_channels, | |
groups=num_groups, | |
dropout=dropout, | |
eps=eps, | |
)) | |
self.resnets = nn.ModuleList(resnets) | |
self.downsamplers = None | |
if add_downsample: | |
self.downsamplers = nn.ModuleList([DownsampleCausal3D( | |
out_channels, | |
out_channels, | |
stride=downsample_stride, | |
)]) | |
def forward(self, hidden_states): | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states) | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
return hidden_states | |
class EncoderCausal3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int = 3, | |
out_channels: int = 16, | |
eps=1e-6, | |
dropout=0.0, | |
block_out_channels=[128, 256, 512, 512], | |
layers_per_block=2, | |
num_groups=32, | |
time_compression_ratio: int = 4, | |
spatial_compression_ratio: int = 8, | |
gradient_checkpointing=False, | |
): | |
super().__init__() | |
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) | |
self.down_blocks = nn.ModuleList([]) | |
# down | |
output_channel = block_out_channels[0] | |
for i in range(len(block_out_channels)): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) | |
num_time_downsample_layers = int(np.log2(time_compression_ratio)) | |
add_spatial_downsample = bool(i < num_spatial_downsample_layers) | |
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block) | |
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) | |
downsample_stride_T = (2,) if add_time_downsample else (1,) | |
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) | |
down_block = DownEncoderBlockCausal3D( | |
in_channels=input_channel, | |
out_channels=output_channel, | |
dropout=dropout, | |
num_layers=layers_per_block, | |
eps=eps, | |
num_groups=num_groups, | |
add_downsample=bool(add_spatial_downsample or add_time_downsample), | |
downsample_stride=downsample_stride, | |
) | |
self.down_blocks.append(down_block) | |
# mid | |
self.mid_block = UNetMidBlockCausal3D( | |
in_channels=block_out_channels[-1], | |
dropout=dropout, | |
eps=eps, | |
num_groups=num_groups, | |
attention_head_dim=block_out_channels[-1], | |
) | |
# out | |
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps) | |
self.conv_act = nn.SiLU() | |
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3) | |
self.gradient_checkpointing = gradient_checkpointing | |
def forward(self, hidden_states): | |
hidden_states = self.conv_in(hidden_states) | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
# down | |
for down_block in self.down_blocks: | |
torch.utils.checkpoint.checkpoint( | |
create_custom_forward(down_block), | |
hidden_states, | |
use_reentrant=False, | |
) | |
# middle | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(self.mid_block), | |
hidden_states, | |
use_reentrant=False, | |
) | |
else: | |
# down | |
for down_block in self.down_blocks: | |
hidden_states = down_block(hidden_states) | |
# middle | |
hidden_states = self.mid_block(hidden_states) | |
# post-process | |
hidden_states = self.conv_norm_out(hidden_states) | |
hidden_states = self.conv_act(hidden_states) | |
hidden_states = self.conv_out(hidden_states) | |
return hidden_states | |
class HunyuanVideoVAEEncoder(nn.Module): | |
def __init__( | |
self, | |
in_channels=3, | |
out_channels=16, | |
eps=1e-6, | |
dropout=0.0, | |
block_out_channels=[128, 256, 512, 512], | |
layers_per_block=2, | |
num_groups=32, | |
time_compression_ratio=4, | |
spatial_compression_ratio=8, | |
gradient_checkpointing=False, | |
): | |
super().__init__() | |
self.encoder = EncoderCausal3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
eps=eps, | |
dropout=dropout, | |
block_out_channels=block_out_channels, | |
layers_per_block=layers_per_block, | |
num_groups=num_groups, | |
time_compression_ratio=time_compression_ratio, | |
spatial_compression_ratio=spatial_compression_ratio, | |
gradient_checkpointing=gradient_checkpointing, | |
) | |
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1) | |
self.scaling_factor = 0.476986 | |
def forward(self, images): | |
latents = self.encoder(images) | |
latents = self.quant_conv(latents) | |
latents = latents[:, :16] | |
latents = latents * self.scaling_factor | |
return latents | |
def build_1d_mask(self, length, left_bound, right_bound, border_width): | |
x = torch.ones((length,)) | |
if not left_bound: | |
x[:border_width] = (torch.arange(border_width) + 1) / border_width | |
if not right_bound: | |
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) | |
return x | |
def build_mask(self, data, is_bound, border_width): | |
_, _, T, H, W = data.shape | |
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) | |
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1]) | |
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2]) | |
t = repeat(t, "T -> T H W", T=T, H=H, W=W) | |
h = repeat(h, "H -> T H W", T=T, H=H, W=W) | |
w = repeat(w, "W -> T H W", T=T, H=H, W=W) | |
mask = torch.stack([t, h, w]).min(dim=0).values | |
mask = rearrange(mask, "T H W -> 1 1 T H W") | |
return mask | |
def tile_forward(self, hidden_states, tile_size, tile_stride): | |
B, C, T, H, W = hidden_states.shape | |
size_t, size_h, size_w = tile_size | |
stride_t, stride_h, stride_w = tile_stride | |
# Split tasks | |
tasks = [] | |
for t in range(0, T, stride_t): | |
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue | |
for h in range(0, H, stride_h): | |
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue | |
for w in range(0, W, stride_w): | |
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue | |
t_, h_, w_ = t + size_t, h + size_h, w + size_w | |
tasks.append((t, t_, h, h_, w, w_)) | |
# Run | |
torch_dtype = self.quant_conv.weight.dtype | |
data_device = hidden_states.device | |
computation_device = self.quant_conv.weight.device | |
weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device) | |
values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device) | |
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): | |
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device) | |
hidden_states_batch = self.forward(hidden_states_batch).to(data_device) | |
if t > 0: | |
hidden_states_batch = hidden_states_batch[:, :, 1:] | |
mask = self.build_mask( | |
hidden_states_batch, | |
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W), | |
border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8) | |
).to(dtype=torch_dtype, device=data_device) | |
target_t = 0 if t==0 else t // 4 + 1 | |
target_h = h // 8 | |
target_w = w // 8 | |
values[ | |
:, | |
:, | |
target_t: target_t + hidden_states_batch.shape[2], | |
target_h: target_h + hidden_states_batch.shape[3], | |
target_w: target_w + hidden_states_batch.shape[4], | |
] += hidden_states_batch * mask | |
weight[ | |
:, | |
:, | |
target_t: target_t + hidden_states_batch.shape[2], | |
target_h: target_h + hidden_states_batch.shape[3], | |
target_w: target_w + hidden_states_batch.shape[4], | |
] += mask | |
return values / weight | |
def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)): | |
latents = latents.to(self.quant_conv.weight.dtype) | |
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride) | |
def state_dict_converter(): | |
return HunyuanVideoVAEEncoderStateDictConverter() | |
class HunyuanVideoVAEEncoderStateDictConverter: | |
def __init__(self): | |
pass | |
def from_diffusers(self, state_dict): | |
state_dict_ = {} | |
for name in state_dict: | |
if name.startswith('encoder.') or name.startswith('quant_conv.'): | |
state_dict_[name] = state_dict[name] | |
return state_dict_ | |