# Copyright (c) 2023 HuggingFace Team # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates. # SPDX-License-Identifier: Apache License, Version 2.0 (the "License") # # This file has been modified by ByteDance Ltd. and/or its affiliates. on 1st June 2025 # # Original file was released under Apache License, Version 2.0 (the "License"), with the full license text # available at http://www.apache.org/licenses/LICENSE-2.0. # # This modified file is released under the same license. from contextlib import nullcontext from typing import Literal, Optional, Tuple, Union import diffusers import torch import torch.nn as nn import torch.nn.functional as F from diffusers.models.attention_processor import Attention, SpatialNorm from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.downsampling import Downsample2D from diffusers.models.lora import LoRACompatibleConv from diffusers.models.modeling_outputs import AutoencoderKLOutput from diffusers.models.resnet import ResnetBlock2D from diffusers.models.unets.unet_2d_blocks import DownEncoderBlock2D, UpDecoderBlock2D from diffusers.models.upsampling import Upsample2D from diffusers.utils import is_torch_version from diffusers.utils.accelerate_utils import apply_forward_hook from einops import rearrange from common.distributed.advanced import get_sequence_parallel_world_size from common.logger import get_logger from models.video_vae_v3.modules.causal_inflation_lib import ( InflatedCausalConv3d, causal_norm_wrapper, init_causal_conv3d, remove_head, ) from models.video_vae_v3.modules.context_parallel_lib import ( causal_conv_gather_outputs, causal_conv_slice_inputs, ) from models.video_vae_v3.modules.global_config import set_norm_limit from models.video_vae_v3.modules.types import ( CausalAutoencoderOutput, CausalDecoderOutput, CausalEncoderOutput, MemoryState, _inflation_mode_t, _memory_device_t, _receptive_field_t, ) logger = get_logger(__name__) # pylint: disable=invalid-name class Upsample3D(Upsample2D): """A 3D upsampling layer with an optional convolution.""" def __init__( self, *args, inflation_mode: _inflation_mode_t = "tail", temporal_up: bool = False, spatial_up: bool = True, slicing: bool = False, **kwargs, ): super().__init__(*args, **kwargs) conv = self.conv if self.name == "conv" else self.Conv2d_0 assert type(conv) is not nn.ConvTranspose2d # Note: lora_layer is not passed into constructor in the original implementation. # So we make a simplification. conv = init_causal_conv3d( self.channels, self.out_channels, 3, padding=1, inflation_mode=inflation_mode, ) self.temporal_up = temporal_up self.spatial_up = spatial_up self.temporal_ratio = 2 if temporal_up else 1 self.spatial_ratio = 2 if spatial_up else 1 self.slicing = slicing assert not self.interpolate # [Override] MAGViT v2 implementation if not self.interpolate: upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio self.upscale_conv = nn.Conv3d( self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 ) identity = ( torch.eye(self.channels) .repeat(upscale_ratio, 1) .reshape_as(self.upscale_conv.weight) ) self.upscale_conv.weight.data.copy_(identity) nn.init.zeros_(self.upscale_conv.bias) if self.name == "conv": self.conv = conv else: self.Conv2d_0 = conv def forward( self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, memory_state: MemoryState = MemoryState.DISABLED, **kwargs, ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels if hasattr(self, "norm") and self.norm is not None: # [Overridden] change to causal norm. hidden_states = causal_norm_wrapper(self.norm, hidden_states) if self.use_conv_transpose: return self.conv(hidden_states) if self.slicing: split_size = hidden_states.size(2) // 2 hidden_states = list( hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) ) else: hidden_states = [hidden_states] for i in range(len(hidden_states)): hidden_states[i] = self.upscale_conv(hidden_states[i]) hidden_states[i] = rearrange( hidden_states[i], "b (x y z c) f h w -> b c (f z) (h x) (w y)", x=self.spatial_ratio, y=self.spatial_ratio, z=self.temporal_ratio, ) # [Overridden] For causal temporal conv if self.temporal_up and memory_state != MemoryState.ACTIVE: hidden_states[0] = remove_head(hidden_states[0]) if not self.slicing: hidden_states = hidden_states[0] if self.use_conv: if self.name == "conv": hidden_states = self.conv(hidden_states, memory_state=memory_state) else: hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state) if not self.slicing: return hidden_states else: return torch.cat(hidden_states, dim=2) class Downsample3D(Downsample2D): """A 3D downsampling layer with an optional convolution.""" def __init__( self, *args, inflation_mode: _inflation_mode_t = "tail", spatial_down: bool = False, temporal_down: bool = False, **kwargs, ): super().__init__(*args, **kwargs) conv = self.conv self.temporal_down = temporal_down self.spatial_down = spatial_down self.temporal_ratio = 2 if temporal_down else 1 self.spatial_ratio = 2 if spatial_down else 1 self.temporal_kernel = 3 if temporal_down else 1 self.spatial_kernel = 3 if spatial_down else 1 if type(conv) in [nn.Conv2d, LoRACompatibleConv]: # Note: lora_layer is not passed into constructor in the original implementation. # So we make a simplification. conv = init_causal_conv3d( self.channels, self.out_channels, kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), padding=( 1 if self.temporal_down else 0, self.padding if self.spatial_down else 0, self.padding if self.spatial_down else 0, ), inflation_mode=inflation_mode, ) elif type(conv) is nn.AvgPool2d: assert self.channels == self.out_channels conv = nn.AvgPool3d( kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), ) else: raise NotImplementedError if self.name == "conv": self.Conv2d_0 = conv self.conv = conv else: self.conv = conv def forward( self, hidden_states: torch.FloatTensor, memory_state: MemoryState = MemoryState.DISABLED, **kwargs, ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels if hasattr(self, "norm") and self.norm is not None: # [Overridden] change to causal norm. hidden_states = causal_norm_wrapper(self.norm, hidden_states) if self.use_conv and self.padding == 0 and self.spatial_down: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels hidden_states = self.conv(hidden_states, memory_state=memory_state) return hidden_states class ResnetBlock3D(ResnetBlock2D): def __init__( self, *args, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", slicing: bool = False, **kwargs, ): super().__init__(*args, **kwargs) self.conv1 = init_causal_conv3d( self.in_channels, self.out_channels, kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), stride=1, padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), inflation_mode=inflation_mode, ) self.conv2 = init_causal_conv3d( self.out_channels, self.conv2.out_channels, kernel_size=3, stride=1, padding=1, inflation_mode=inflation_mode, ) if self.up: assert type(self.upsample) is Upsample2D self.upsample = Upsample3D( self.in_channels, use_conv=False, inflation_mode=inflation_mode, slicing=slicing, ) elif self.down: assert type(self.downsample) is Downsample2D self.downsample = Downsample3D( self.in_channels, use_conv=False, padding=1, name="op", inflation_mode=inflation_mode, ) if self.use_in_shortcut: self.conv_shortcut = init_causal_conv3d( self.in_channels, self.conv_shortcut.out_channels, kernel_size=1, stride=1, padding=0, bias=(self.conv_shortcut.bias is not None), inflation_mode=inflation_mode, ) def forward( self, input_tensor, temb, memory_state: MemoryState = MemoryState.DISABLED, **kwargs ): hidden_states = input_tensor hidden_states = causal_norm_wrapper(self.norm1, hidden_states) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. # see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = self.upsample(input_tensor, memory_state=memory_state) hidden_states = self.upsample(hidden_states, memory_state=memory_state) elif self.downsample is not None: input_tensor = self.downsample(input_tensor, memory_state=memory_state) hidden_states = self.downsample(hidden_states, memory_state=memory_state) hidden_states = self.conv1(hidden_states, memory_state=memory_state) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) temb = self.time_emb_proj(temb)[:, :, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb hidden_states = causal_norm_wrapper(self.norm2, hidden_states) if temb is not None and self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, memory_state=memory_state) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor class DownEncoderBlock3D(DownEncoderBlock2D): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", temporal_down: bool = True, spatial_down: bool = True, ): super().__init__( in_channels=in_channels, out_channels=out_channels, dropout=dropout, num_layers=num_layers, resnet_eps=resnet_eps, resnet_time_scale_shift=resnet_time_scale_shift, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_pre_norm=resnet_pre_norm, output_scale_factor=output_scale_factor, add_downsample=add_downsample, downsample_padding=downsample_padding, ) resnets = [] temporal_modules = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( # [Override] Replace module. ResnetBlock3D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) ) temporal_modules.append(nn.Identity()) self.resnets = nn.ModuleList(resnets) self.temporal_modules = nn.ModuleList(temporal_modules) if add_downsample: self.downsamplers = nn.ModuleList( [ # [Override] Replace module. Downsample3D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op", temporal_down=temporal_down, spatial_down=spatial_down, inflation_mode=inflation_mode, ) ] ) else: self.downsamplers = None def forward( self, hidden_states: torch.FloatTensor, memory_state: MemoryState = MemoryState.DISABLED, **kwargs, ) -> torch.FloatTensor: for resnet, temporal in zip(self.resnets, self.temporal_modules): hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) hidden_states = temporal(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, memory_state=memory_state) return hidden_states class UpDecoderBlock3D(UpDecoderBlock2D): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_upsample: bool = True, temb_channels: Optional[int] = None, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", temporal_up: bool = True, spatial_up: bool = True, slicing: bool = False, ): super().__init__( in_channels=in_channels, out_channels=out_channels, dropout=dropout, num_layers=num_layers, resnet_eps=resnet_eps, resnet_time_scale_shift=resnet_time_scale_shift, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_pre_norm=resnet_pre_norm, output_scale_factor=output_scale_factor, add_upsample=add_upsample, temb_channels=temb_channels, ) resnets = [] temporal_modules = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( # [Override] Replace module. ResnetBlock3D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, slicing=slicing, ) ) temporal_modules.append(nn.Identity()) self.resnets = nn.ModuleList(resnets) self.temporal_modules = nn.ModuleList(temporal_modules) if add_upsample: # [Override] Replace module & use learnable upsample self.upsamplers = nn.ModuleList( [ Upsample3D( out_channels, use_conv=True, out_channels=out_channels, temporal_up=temporal_up, spatial_up=spatial_up, interpolate=False, inflation_mode=inflation_mode, slicing=slicing, ) ] ) else: self.upsamplers = None def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, memory_state: MemoryState = MemoryState.DISABLED, ) -> torch.FloatTensor: for resnet, temporal in zip(self.resnets, self.temporal_modules): hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) hidden_states = temporal(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, memory_state=memory_state) return hidden_states class UNetMidBlock3D(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, add_attention: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention # there is always at least one resnet resnets = [ # [Override] Replace module. ResnetBlock3D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) ] attentions = [] if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. " f"Defaulting `attention_head_dim` to `in_channels`: {in_channels}." ) attention_head_dim = in_channels for _ in range(num_layers): if self.add_attention: attentions.append( Attention( in_channels, heads=in_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=( resnet_groups if resnet_time_scale_shift == "default" else None ), spatial_norm_dim=( temb_channels if resnet_time_scale_shift == "spatial" else None ), residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) else: attentions.append(None) resnets.append( ResnetBlock3D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None, memory_state: MemoryState = MemoryState.DISABLED): video_length, frame_height, frame_width = hidden_states.size()[-3:] hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") hidden_states = attn(hidden_states, temb=temb) hidden_states = rearrange( hidden_states, "(b f) c h w -> b c f h w", f=video_length ) hidden_states = resnet(hidden_states, temb, memory_state=memory_state) return hidden_states class Encoder3D(nn.Module): r""" [Override] override most logics to support extra condition input and causal conv The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. double_z (`bool`, *optional*, defaults to `True`): Whether to double the number of output channels for the last block. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", double_z: bool = True, mid_block_add_attention=True, # [Override] add extra_cond_dim, temporal down num temporal_down_num: int = 2, extra_cond_dim: int = None, gradient_checkpoint: bool = False, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", ): super().__init__() self.layers_per_block = layers_per_block self.temporal_down_num = temporal_down_num self.conv_in = init_causal_conv3d( in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, inflation_mode=inflation_mode, ) self.mid_block = None self.down_blocks = nn.ModuleList([]) self.extra_cond_dim = extra_cond_dim self.conv_extra_cond = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 # [Override] to support temporal down block design is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 # Note: take the last ones assert down_block_type == "DownEncoderBlock3D" down_block = DownEncoderBlock3D( num_layers=self.layers_per_block, in_channels=input_channel, out_channels=output_channel, add_downsample=not is_final_block, resnet_eps=1e-6, downsample_padding=0, # Note: Don't know why set it as 0 resnet_act_fn=act_fn, resnet_groups=norm_num_groups, temporal_down=is_temporal_down_block, spatial_down=True, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) self.down_blocks.append(down_block) def zero_module(module): # Zero out the parameters of a module and return it. for p in module.parameters(): p.detach().zero_() return module self.conv_extra_cond.append( zero_module( nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) ) if self.extra_cond_dim is not None and self.extra_cond_dim > 0 else None ) # mid self.mid_block = UNetMidBlock3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, add_attention=mid_block_add_attention, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) # out self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 ) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = init_causal_conv3d( block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode ) self.gradient_checkpointing = gradient_checkpoint def forward( self, sample: torch.FloatTensor, extra_cond=None, memory_state: MemoryState = MemoryState.DISABLED, ) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" sample = self.conv_in(sample, memory_state=memory_state) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # down # [Override] add extra block and extra cond for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): sample = torch.utils.checkpoint.checkpoint( create_custom_forward(down_block), sample, memory_state, use_reentrant=False ) if extra_block is not None: sample = sample + F.interpolate(extra_block(extra_cond), size=sample.shape[2:]) # middle sample = self.mid_block(sample, memory_state=memory_state) # sample = torch.utils.checkpoint.checkpoint( # create_custom_forward(self.mid_block), sample, use_reentrant=False # ) else: # down # [Override] add extra block and extra cond for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): sample = down_block(sample, memory_state=memory_state) if extra_block is not None: sample = sample + F.interpolate(extra_block(extra_cond), size=sample.shape[2:]) # middle sample = self.mid_block(sample, memory_state=memory_state) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) sample = self.conv_out(sample, memory_state=memory_state) return sample class Decoder3D(nn.Module): r""" The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. norm_type (`str`, *optional*, defaults to `"group"`): The normalization type to use. Can be either `"group"` or `"spatial"`. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", norm_type: str = "group", # group, spatial mid_block_add_attention=True, # [Override] add temporal up block inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", temporal_up_num: int = 2, slicing_up_num: int = 0, gradient_checkpoint: bool = False, ): super().__init__() self.layers_per_block = layers_per_block self.temporal_up_num = temporal_up_num self.conv_in = init_causal_conv3d( in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1, inflation_mode=inflation_mode, ) self.mid_block = None self.up_blocks = nn.ModuleList([]) temb_channels = in_channels if norm_type == "spatial" else None # mid self.mid_block = UNetMidBlock3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default" if norm_type == "group" else norm_type, attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=temb_channels, add_attention=mid_block_add_attention, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] print(f"slicing_up_num: {slicing_up_num}") for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 is_temporal_up_block = i < self.temporal_up_num is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num # Note: Keep symmetric assert up_block_type == "UpDecoderBlock3D" up_block = UpDecoderBlock3D( num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, add_upsample=not is_final_block, resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, resnet_time_scale_shift=norm_type, temb_channels=temb_channels, temporal_up=is_temporal_up_block, slicing=is_slicing_up_block, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_type == "spatial": self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 ) self.conv_act = nn.SiLU() self.conv_out = init_causal_conv3d( block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode ) self.gradient_checkpointing = gradient_checkpoint # Note: Just copy from Decoder. def forward( self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None, memory_state: MemoryState = MemoryState.DISABLED, ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" sample = self.conv_in(sample, memory_state=memory_state) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), sample, latent_embeds, memory_state, use_reentrant=False, ) else: # middle sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), sample, latent_embeds, memory_state ) else: # middle sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = up_block(sample, latent_embeds, memory_state=memory_state) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) sample = self.conv_out(sample, memory_state=memory_state) return sample class AutoencoderKL(diffusers.AutoencoderKL): """ We simply inherit the model code from diffusers """ def __init__(self, attention: bool = True, *args, **kwargs): super().__init__(*args, **kwargs) # A hacky way to remove attention. if not attention: self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) def load_state_dict(self, state_dict, strict=True): # Newer version of diffusers changed the model keys, # causing incompatibility with old checkpoints. # They provided a method for conversion. We call conversion before loading state_dict. convert_deprecated_attention_blocks = getattr( self, "_convert_deprecated_attention_blocks", None ) if callable(convert_deprecated_attention_blocks): convert_deprecated_attention_blocks(state_dict) return super().load_state_dict(state_dict, strict) class VideoAutoencoderKL(diffusers.AutoencoderKL): """ We simply inherit the model code from diffusers """ def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str] = ("DownEncoderBlock3D",), up_block_types: Tuple[str] = ("UpDecoderBlock3D",), block_out_channels: Tuple[int] = (64,), layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, force_upcast: float = True, attention: bool = True, temporal_scale_num: int = 2, slicing_up_num: int = 0, gradient_checkpoint: bool = False, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "full", slicing_sample_min_size: int = 32, use_quant_conv: bool = True, use_post_quant_conv: bool = True, *args, **kwargs, ): extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None self.slicing_sample_min_size = slicing_sample_min_size self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) super().__init__( in_channels=in_channels, out_channels=out_channels, # [Override] make sure it can be normally initialized down_block_types=tuple( [down_block_type.replace("3D", "2D") for down_block_type in down_block_types] ), up_block_types=tuple( [up_block_type.replace("3D", "2D") for up_block_type in up_block_types] ), block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, latent_channels=latent_channels, norm_num_groups=norm_num_groups, sample_size=sample_size, scaling_factor=scaling_factor, force_upcast=force_upcast, *args, **kwargs, ) # pass init params to Encoder self.encoder = Encoder3D( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=True, extra_cond_dim=extra_cond_dim, # [Override] add temporal_down_num parameter temporal_down_num=temporal_scale_num, gradient_checkpoint=gradient_checkpoint, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) # pass init params to Decoder self.decoder = Decoder3D( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, # [Override] add temporal_up_num parameter temporal_up_num=temporal_scale_num, slicing_up_num=slicing_up_num, gradient_checkpoint=gradient_checkpoint, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) self.quant_conv = ( init_causal_conv3d( in_channels=2 * latent_channels, out_channels=2 * latent_channels, kernel_size=1, inflation_mode=inflation_mode, ) if use_quant_conv else None ) self.post_quant_conv = ( init_causal_conv3d( in_channels=latent_channels, out_channels=latent_channels, kernel_size=1, inflation_mode=inflation_mode, ) if use_post_quant_conv else None ) # A hacky way to remove attention. if not attention: self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) @apply_forward_hook def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: h = self.slicing_encode(x) posterior = DiagonalGaussianDistribution(h) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) @apply_forward_hook def decode( self, z: torch.Tensor, return_dict: bool = True ) -> Union[DecoderOutput, torch.Tensor]: decoded = self.slicing_decode(z) if not return_dict: return (decoded,) return DecoderOutput(sample=decoded) def _encode( self, x: torch.Tensor, memory_state: MemoryState = MemoryState.DISABLED ) -> torch.Tensor: _x = x.to(self.device) _x = causal_conv_slice_inputs(_x, self.slicing_sample_min_size, memory_state=memory_state) h = self.encoder(_x, memory_state=memory_state) if self.quant_conv is not None: output = self.quant_conv(h, memory_state=memory_state) else: output = h output = causal_conv_gather_outputs(output) return output.to(x.device) def _decode( self, z: torch.Tensor, memory_state: MemoryState = MemoryState.DISABLED ) -> torch.Tensor: _z = z.to(self.device) _z = causal_conv_slice_inputs(_z, self.slicing_latent_min_size, memory_state=memory_state) if self.post_quant_conv is not None: _z = self.post_quant_conv(_z, memory_state=memory_state) output = self.decoder(_z, memory_state=memory_state) output = causal_conv_gather_outputs(output) return output.to(z.device) def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: sp_size = get_sequence_parallel_world_size() if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) encoded_slices = [ self._encode( torch.cat((x[:, :, :1], x_slices[0]), dim=2), memory_state=MemoryState.INITIALIZING, ) ] for x_idx in range(1, len(x_slices)): encoded_slices.append( self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) ) return torch.cat(encoded_slices, dim=2) else: return self._encode(x) def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: sp_size = get_sequence_parallel_world_size() if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) decoded_slices = [ self._decode( torch.cat((z[:, :, :1], z_slices[0]), dim=2), memory_state=MemoryState.INITIALIZING, ) ] for z_idx in range(1, len(z_slices)): decoded_slices.append( self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) ) return torch.cat(decoded_slices, dim=2) else: return self._decode(z) def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: raise NotImplementedError def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: raise NotImplementedError def forward( self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs ): # x: [b c t h w] if mode == "encode": h = self.encode(x) return h.latent_dist elif mode == "decode": h = self.decode(x) return h.sample else: h = self.encode(x) h = self.decode(h.latent_dist.mode()) return h.sample def load_state_dict(self, state_dict, strict=False): # Newer version of diffusers changed the model keys, # causing incompatibility with old checkpoints. # They provided a method for conversion. # We call conversion before loading state_dict. convert_deprecated_attention_blocks = getattr( self, "_convert_deprecated_attention_blocks", None ) if callable(convert_deprecated_attention_blocks): convert_deprecated_attention_blocks(state_dict) return super().load_state_dict(state_dict, strict) class VideoAutoencoderKLWrapper(VideoAutoencoderKL): def __init__( self, *args, spatial_downsample_factor: int, temporal_downsample_factor: int, freeze_encoder: bool, **kwargs, ): self.spatial_downsample_factor = spatial_downsample_factor self.temporal_downsample_factor = temporal_downsample_factor self.freeze_encoder = freeze_encoder super().__init__(*args, **kwargs) def forward(self, x: torch.FloatTensor) -> CausalAutoencoderOutput: with torch.no_grad() if self.freeze_encoder else nullcontext(): z, p = self.encode(x) x = self.decode(z).sample return CausalAutoencoderOutput(x, z, p) def encode(self, x: torch.FloatTensor) -> CausalEncoderOutput: if x.ndim == 4: x = x.unsqueeze(2) p = super().encode(x).latent_dist z = p.sample().squeeze(2) return CausalEncoderOutput(z, p) def decode(self, z: torch.FloatTensor) -> CausalDecoderOutput: if z.ndim == 4: z = z.unsqueeze(2) x = super().decode(z).sample.squeeze(2) return CausalDecoderOutput(x) def preprocess(self, x: torch.Tensor): # x should in [B, C, T, H, W], [B, C, H, W] assert x.ndim == 4 or x.size(2) % 4 == 1 return x def postprocess(self, x: torch.Tensor): # x should in [B, C, T, H, W], [B, C, H, W] return x def set_causal_slicing( self, *, split_size: Optional[int], memory_device: _memory_device_t, ): assert ( split_size is None or memory_device is not None ), "if split_size is set, memory_device must not be None." if split_size is not None: self.enable_slicing() self.slicing_sample_min_size = split_size self.slicing_latent_min_size = split_size // self.temporal_downsample_factor else: self.disable_slicing() for module in self.modules(): if isinstance(module, InflatedCausalConv3d): module.set_memory_device(memory_device) def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): set_norm_limit(norm_max_mem) for m in self.modules(): if isinstance(m, InflatedCausalConv3d): m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))