# Copyright (c) 2023 HuggingFace Team # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates. # SPDX-License-Identifier: Apache License, Version 2.0 (the "License") # # This file has been modified by ByteDance Ltd. and/or its affiliates. on 1st June 2025 # # Original file was released under Apache License, Version 2.0 (the "License"), with the full license text # available at http://www.apache.org/licenses/LICENSE-2.0. # # This modified file is released under the same license. from contextlib import nullcontext from typing import Optional, Tuple, Literal, Callable, Union import torch import torch.nn as nn import torch.nn.functional as F from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution from einops import rearrange from common.distributed.advanced import get_sequence_parallel_world_size from common.logger import get_logger from models.video_vae_v3.modules.causal_inflation_lib import ( InflatedCausalConv3d, causal_norm_wrapper, init_causal_conv3d, remove_head, ) from models.video_vae_v3.modules.context_parallel_lib import ( causal_conv_gather_outputs, causal_conv_slice_inputs, ) from models.video_vae_v3.modules.global_config import set_norm_limit from models.video_vae_v3.modules.types import ( CausalAutoencoderOutput, CausalDecoderOutput, CausalEncoderOutput, MemoryState, _inflation_mode_t, _memory_device_t, _receptive_field_t, _selective_checkpointing_t, ) logger = get_logger(__name__) # pylint: disable=invalid-name # Fake func, no checkpointing is required for inference def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs): return module(*args, **kwargs) class ResnetBlock2D(nn.Module): r""" A Resnet block. Parameters: in_channels (`int`): The number of channels in the input. out_channels (`int`, *optional*, default to be `None`): The number of output channels for the first conv2d layer. If None, same as `in_channels`. dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. """ def __init__( self, *, in_channels: int, out_channels: Optional[int] = None, dropout: float = 0.0 ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.nonlinearity = nn.SiLU() self.norm1 = torch.nn.GroupNorm( num_groups=32, num_channels=in_channels, eps=1e-6, affine=True ) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = torch.nn.GroupNorm( num_groups=32, num_channels=out_channels, eps=1e-6, affine=True ) self.dropout = torch.nn.Dropout(dropout) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.use_in_shortcut = self.in_channels != out_channels self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: hidden = input_tensor hidden = self.norm1(hidden) hidden = self.nonlinearity(hidden) hidden = self.conv1(hidden) hidden = self.norm2(hidden) hidden = self.nonlinearity(hidden) hidden = self.dropout(hidden) hidden = self.conv2(hidden) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = input_tensor + hidden return output_tensor class Upsample3D(nn.Module): """A 3D upsampling layer.""" def __init__( self, channels: int, inflation_mode: _inflation_mode_t = "tail", temporal_up: bool = False, spatial_up: bool = True, slicing: bool = False, ): super().__init__() self.channels = channels self.conv = init_causal_conv3d( self.channels, self.channels, kernel_size=3, padding=1, inflation_mode=inflation_mode ) self.temporal_up = temporal_up self.spatial_up = spatial_up self.temporal_ratio = 2 if temporal_up else 1 self.spatial_ratio = 2 if spatial_up else 1 self.slicing = slicing upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio self.upscale_conv = nn.Conv3d( self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 ) identity = ( torch.eye(self.channels).repeat(upscale_ratio, 1).reshape_as(self.upscale_conv.weight) ) self.upscale_conv.weight.data.copy_(identity) nn.init.zeros_(self.upscale_conv.bias) self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, memory_state: MemoryState, ) -> torch.FloatTensor: return gradient_checkpointing( self.custom_forward, hidden_states, memory_state, enabled=self.training and self.gradient_checkpointing, ) def custom_forward( self, hidden_states: torch.FloatTensor, memory_state: MemoryState, ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels if self.slicing: split_size = hidden_states.size(2) // 2 hidden_states = list( hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) ) else: hidden_states = [hidden_states] for i in range(len(hidden_states)): hidden_states[i] = self.upscale_conv(hidden_states[i]) hidden_states[i] = rearrange( hidden_states[i], "b (x y z c) f h w -> b c (f z) (h x) (w y)", x=self.spatial_ratio, y=self.spatial_ratio, z=self.temporal_ratio, ) # [Overridden] For causal temporal conv if self.temporal_up and memory_state != MemoryState.ACTIVE: hidden_states[0] = remove_head(hidden_states[0]) if self.slicing: hidden_states = self.conv(hidden_states, memory_state=memory_state) return torch.cat(hidden_states, dim=2) else: return self.conv(hidden_states[0], memory_state=memory_state) class Downsample3D(nn.Module): """A 3D downsampling layer.""" def __init__( self, channels: int, inflation_mode: _inflation_mode_t = "tail", temporal_down: bool = False, spatial_down: bool = True, ): super().__init__() self.channels = channels self.temporal_down = temporal_down self.spatial_down = spatial_down self.temporal_ratio = 2 if temporal_down else 1 self.spatial_ratio = 2 if spatial_down else 1 self.temporal_kernel = 3 if temporal_down else 1 self.spatial_kernel = 3 if spatial_down else 1 self.conv = init_causal_conv3d( self.channels, self.channels, kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), padding=((1 if self.temporal_down else 0), 0, 0), inflation_mode=inflation_mode, ) self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, memory_state: MemoryState, ) -> torch.FloatTensor: return gradient_checkpointing( self.custom_forward, hidden_states, memory_state, enabled=self.training and self.gradient_checkpointing, ) def custom_forward( self, hidden_states: torch.FloatTensor, memory_state: MemoryState, ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels if self.spatial_down: hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) hidden_states = self.conv(hidden_states, memory_state=memory_state) return hidden_states class ResnetBlock3D(ResnetBlock2D): def __init__( self, *args, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", **kwargs, ): super().__init__(*args, **kwargs) self.conv1 = init_causal_conv3d( self.in_channels, self.out_channels, kernel_size=3, stride=1, padding=1, inflation_mode=inflation_mode, ) self.conv2 = init_causal_conv3d( self.out_channels, self.out_channels, kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), stride=1, padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), inflation_mode=inflation_mode, ) if self.use_in_shortcut: self.conv_shortcut = init_causal_conv3d( self.in_channels, self.out_channels, kernel_size=1, stride=1, padding=0, bias=(self.conv_shortcut.bias is not None), inflation_mode=inflation_mode, ) self.gradient_checkpointing = False def forward(self, input_tensor: torch.Tensor, memory_state: MemoryState = MemoryState.UNSET): return gradient_checkpointing( self.custom_forward, input_tensor, memory_state, enabled=self.training and self.gradient_checkpointing, ) def custom_forward( self, input_tensor: torch.Tensor, memory_state: MemoryState = MemoryState.UNSET ): assert memory_state != MemoryState.UNSET hidden_states = input_tensor hidden_states = causal_norm_wrapper(self.norm1, hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states, memory_state=memory_state) hidden_states = causal_norm_wrapper(self.norm2, hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, memory_state=memory_state) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) output_tensor = input_tensor + hidden_states return output_tensor class DownEncoderBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, add_downsample: bool = True, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", temporal_down: bool = True, spatial_down: bool = True, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=in_channels, out_channels=out_channels, dropout=dropout, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) ) self.resnets = nn.ModuleList(resnets) self.downsamplers = None if add_downsample: # Todo: Refactor this line before V5 Image VAE Training. self.downsamplers = nn.ModuleList( [ Downsample3D( channels=out_channels, inflation_mode=inflation_mode, temporal_down=temporal_down, spatial_down=spatial_down, ) ] ) def forward( self, hidden_states: torch.FloatTensor, memory_state: MemoryState ) -> torch.FloatTensor: for resnet in self.resnets: hidden_states = resnet(hidden_states, memory_state=memory_state) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, memory_state=memory_state) return hidden_states class UpDecoderBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, add_upsample: bool = True, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", temporal_up: bool = True, spatial_up: bool = True, slicing: bool = False, ): super().__init__() resnets = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=input_channels, out_channels=out_channels, dropout=dropout, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) ) self.resnets = nn.ModuleList(resnets) self.upsamplers = None # Todo: Refactor this line before V5 Image VAE Training. if add_upsample: self.upsamplers = nn.ModuleList( [ Upsample3D( channels=out_channels, inflation_mode=inflation_mode, temporal_up=temporal_up, spatial_up=spatial_up, slicing=slicing, ) ] ) def forward( self, hidden_states: torch.FloatTensor, memory_state: MemoryState ) -> torch.FloatTensor: for resnet in self.resnets: hidden_states = resnet(hidden_states, memory_state=memory_state) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, memory_state=memory_state) return hidden_states class UNetMidBlock3D(nn.Module): def __init__( self, channels: int, dropout: float = 0.0, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", ): super().__init__() self.resnets = nn.ModuleList( [ ResnetBlock3D( in_channels=channels, out_channels=channels, dropout=dropout, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ), ResnetBlock3D( in_channels=channels, out_channels=channels, dropout=dropout, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ), ] ) def forward(self, hidden_states: torch.Tensor, memory_state: MemoryState): for resnet in self.resnets: hidden_states = resnet(hidden_states, memory_state) return hidden_states class Encoder3D(nn.Module): r""" The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, double_z: bool = True, temporal_down_num: int = 2, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), ): super().__init__() self.layers_per_block = layers_per_block self.temporal_down_num = temporal_down_num self.conv_in = init_causal_conv3d( in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, inflation_mode=inflation_mode, ) self.down_blocks = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i in range(len(block_out_channels)): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 # Note: take the last one down_block = DownEncoderBlock3D( num_layers=self.layers_per_block, in_channels=input_channel, out_channels=output_channel, add_downsample=not is_final_block, temporal_down=is_temporal_down_block, spatial_down=True, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock3D( channels=block_out_channels[-1], inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) # out self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[-1], num_groups=32, eps=1e-6 ) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = init_causal_conv3d( block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode ) assert len(selective_checkpointing) == len(self.down_blocks) self.set_gradient_checkpointing(selective_checkpointing) def set_gradient_checkpointing(self, checkpointing_types): gradient_checkpointing = [] for down_block, sac_type in zip(self.down_blocks, checkpointing_types): if sac_type == "coarse": gradient_checkpointing.append(True) elif sac_type == "fine": for n, m in down_block.named_modules(): if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = True logger.debug(f"set gradient_checkpointing: {n}") gradient_checkpointing.append(False) else: gradient_checkpointing.append(False) self.gradient_checkpointing = gradient_checkpointing logger.info(f"[Encoder3D] gradient_checkpointing: {checkpointing_types}") def forward(self, sample: torch.FloatTensor, memory_state: MemoryState) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" sample = self.conv_in(sample, memory_state=memory_state) # down for down_block, sac in zip(self.down_blocks, self.gradient_checkpointing): sample = gradient_checkpointing( down_block, sample, memory_state=memory_state, enabled=self.training and sac, ) # middle sample = self.mid_block(sample, memory_state=memory_state) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) sample = self.conv_out(sample, memory_state=memory_state) return sample class Decoder3D(nn.Module): r""" The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", temporal_up_num: int = 2, slicing_up_num: int = 0, selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), ): super().__init__() self.layers_per_block = layers_per_block self.temporal_up_num = temporal_up_num self.conv_in = init_causal_conv3d( in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1, inflation_mode=inflation_mode, ) self.up_blocks = nn.ModuleList([]) # mid self.mid_block = UNetMidBlock3D( channels=block_out_channels[-1], inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i in range(len(reversed_block_out_channels)): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 is_temporal_up_block = i < self.temporal_up_num is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num # Note: Keep symmetric up_block = UpDecoderBlock3D( num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, add_upsample=not is_final_block, temporal_up=is_temporal_up_block, slicing=is_slicing_up_block, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) self.up_blocks.append(up_block) # out self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=32, eps=1e-6 ) self.conv_act = nn.SiLU() self.conv_out = init_causal_conv3d( block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode ) assert len(selective_checkpointing) == len(self.up_blocks) self.set_gradient_checkpointing(selective_checkpointing) def set_gradient_checkpointing(self, checkpointing_types): gradient_checkpointing = [] for up_block, sac_type in zip(self.up_blocks, checkpointing_types): if sac_type == "coarse": gradient_checkpointing.append(True) elif sac_type == "fine": for n, m in up_block.named_modules(): if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = True logger.debug(f"set gradient_checkpointing: {n}") gradient_checkpointing.append(False) else: gradient_checkpointing.append(False) self.gradient_checkpointing = gradient_checkpointing logger.info(f"[Decoder3D] gradient_checkpointing: {checkpointing_types}") def forward(self, sample: torch.FloatTensor, memory_state: MemoryState) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" sample = self.conv_in(sample, memory_state=memory_state) # middle sample = self.mid_block(sample, memory_state=memory_state) # up for up_block, sac in zip(self.up_blocks, self.gradient_checkpointing): sample = gradient_checkpointing( up_block, sample, memory_state=memory_state, enabled=self.training and sac, ) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) sample = self.conv_out(sample, memory_state=memory_state) return sample class VideoAutoencoderKL(nn.Module): def __init__( self, in_channels: int = 3, out_channels: int = 3, block_out_channels: Tuple[int] = (64,), layers_per_block: int = 1, latent_channels: int = 4, use_quant_conv: bool = True, use_post_quant_conv: bool = True, enc_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), dec_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), temporal_scale_num: int = 3, slicing_up_num: int = 0, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", slicing_sample_min_size: int = None, spatial_downsample_factor: int = 16, temporal_downsample_factor: int = 8, freeze_encoder: bool = False, ): super().__init__() self.spatial_downsample_factor = spatial_downsample_factor self.temporal_downsample_factor = temporal_downsample_factor self.freeze_encoder = freeze_encoder if slicing_sample_min_size is None: slicing_sample_min_size = temporal_downsample_factor self.slicing_sample_min_size = slicing_sample_min_size self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) # pass init params to Encoder self.encoder = Encoder3D( in_channels=in_channels, out_channels=latent_channels, block_out_channels=block_out_channels, layers_per_block=layers_per_block, double_z=True, temporal_down_num=temporal_scale_num, selective_checkpointing=enc_selective_checkpointing, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) # pass init params to Decoder self.decoder = Decoder3D( in_channels=latent_channels, out_channels=out_channels, block_out_channels=block_out_channels, layers_per_block=layers_per_block, # [Override] add temporal_up_num parameter temporal_up_num=temporal_scale_num, slicing_up_num=slicing_up_num, selective_checkpointing=dec_selective_checkpointing, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) self.quant_conv = ( init_causal_conv3d( in_channels=2 * latent_channels, out_channels=2 * latent_channels, kernel_size=1, inflation_mode=inflation_mode, ) if use_quant_conv else None ) self.post_quant_conv = ( init_causal_conv3d( in_channels=latent_channels, out_channels=latent_channels, kernel_size=1, inflation_mode=inflation_mode, ) if use_post_quant_conv else None ) self.use_slicing = False def enable_slicing(self): self.use_slicing = True def disable_slicing(self): self.use_slicing = False def encode(self, x: torch.FloatTensor) -> CausalEncoderOutput: if x.ndim == 4: x = x.unsqueeze(2) h = self.slicing_encode(x) p = DiagonalGaussianDistribution(h) z = p.sample() return CausalEncoderOutput(z, p) def decode(self, z: torch.FloatTensor) -> CausalDecoderOutput: if z.ndim == 4: z = z.unsqueeze(2) x = self.slicing_decode(z) return CausalDecoderOutput(x) def _encode(self, x: torch.Tensor, memory_state: MemoryState) -> torch.Tensor: x = causal_conv_slice_inputs(x, self.slicing_sample_min_size, memory_state=memory_state) h = self.encoder(x, memory_state=memory_state) h = self.quant_conv(h, memory_state=memory_state) if self.quant_conv is not None else h h = causal_conv_gather_outputs(h) return h def _decode(self, z: torch.Tensor, memory_state: MemoryState) -> torch.Tensor: z = causal_conv_slice_inputs(z, self.slicing_latent_min_size, memory_state=memory_state) z = ( self.post_quant_conv(z, memory_state=memory_state) if self.post_quant_conv is not None else z ) x = self.decoder(z, memory_state=memory_state) x = causal_conv_gather_outputs(x) return x def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: sp_size = get_sequence_parallel_world_size() if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) encoded_slices = [ self._encode( torch.cat((x[:, :, :1], x_slices[0]), dim=2), memory_state=MemoryState.INITIALIZING, ) ] for x_idx in range(1, len(x_slices)): encoded_slices.append( self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) ) return torch.cat(encoded_slices, dim=2) else: return self._encode(x, memory_state=MemoryState.DISABLED) def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: sp_size = get_sequence_parallel_world_size() if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) decoded_slices = [ self._decode( torch.cat((z[:, :, :1], z_slices[0]), dim=2), memory_state=MemoryState.INITIALIZING, ) ] for z_idx in range(1, len(z_slices)): decoded_slices.append( self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) ) return torch.cat(decoded_slices, dim=2) else: return self._decode(z, memory_state=MemoryState.DISABLED) def forward(self, x: torch.FloatTensor) -> CausalAutoencoderOutput: with torch.no_grad() if self.freeze_encoder else nullcontext(): z, p = self.encode(x) x = self.decode(z).sample return CausalAutoencoderOutput(x, z, p) def preprocess(self, x: torch.Tensor): # x should in [B, C, T, H, W], [B, C, H, W] assert x.ndim == 4 or x.size(2) % self.temporal_downsample_factor == 1 return x def postprocess(self, x: torch.Tensor): # x should in [B, C, T, H, W], [B, C, H, W] return x def set_causal_slicing( self, *, split_size: Optional[int], memory_device: _memory_device_t, ): assert ( split_size is None or memory_device is not None ), "if split_size is set, memory_device must not be None." if split_size is not None: self.enable_slicing() self.slicing_sample_min_size = split_size self.slicing_latent_min_size = split_size // self.temporal_downsample_factor else: self.disable_slicing() for module in self.modules(): if isinstance(module, InflatedCausalConv3d): module.set_memory_device(memory_device) def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): set_norm_limit(norm_max_mem) for m in self.modules(): if isinstance(m, InflatedCausalConv3d): m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) class VideoAutoencoderKLWrapper(VideoAutoencoderKL): def __init__( self, *args, spatial_downsample_factor: int, temporal_downsample_factor: int, **kwargs ): self.spatial_downsample_factor = spatial_downsample_factor self.temporal_downsample_factor = temporal_downsample_factor super().__init__(*args, **kwargs) def forward(self, x) -> CausalAutoencoderOutput: z, _, p = self.encode(x) x, _ = self.decode(z) return CausalAutoencoderOutput(x, z, None, p) def encode(self, x) -> CausalEncoderOutput: if x.ndim == 4: x = x.unsqueeze(2) p = super().encode(x).latent_dist z = p.sample().squeeze(2) return CausalEncoderOutput(z, None, p) def decode(self, z) -> CausalDecoderOutput: if z.ndim == 4: z = z.unsqueeze(2) x = super().decode(z).sample.squeeze(2) return CausalDecoderOutput(x, None) def preprocess(self, x): # x should in [B, C, T, H, W], [B, C, H, W] assert x.ndim == 4 or x.size(2) % 4 == 1 return x def postprocess(self, x): # x should in [B, C, T, H, W], [B, C, H, W] return x def set_causal_slicing( self, *, split_size: Optional[int], memory_device: Optional[Literal["cpu", "same"]], ): assert ( split_size is None or memory_device is not None ), "if split_size is set, memory_device must not be None." if split_size is not None: self.enable_slicing() else: self.disable_slicing() self.slicing_sample_min_size = split_size if split_size is not None: self.slicing_latent_min_size = split_size // self.temporal_downsample_factor for module in self.modules(): if isinstance(module, InflatedCausalConv3d): module.set_memory_device(memory_device)