SeedVR2-3B / models /video_vae_v3 /modules /causal_inflation_lib.py
IceClear
upload files
42f2c22
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
import math
from contextlib import contextmanager
from typing import List, Optional, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from diffusers.models.normalization import RMSNorm
from einops import rearrange
from torch import Tensor, nn
from torch.nn import Conv3d
from common.distributed.advanced import (
get_next_sequence_parallel_rank,
get_prev_sequence_parallel_rank,
get_sequence_parallel_group,
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
)
from common.logger import get_logger
from models.video_vae_v3.modules.context_parallel_lib import cache_send_recv, get_cache_size
from models.video_vae_v3.modules.global_config import get_norm_limit
from models.video_vae_v3.modules.types import MemoryState, _inflation_mode_t, _memory_device_t
logger = get_logger(__name__)
@contextmanager
def ignore_padding(model):
orig_padding = model.padding
model.padding = (0, 0, 0)
try:
yield
finally:
model.padding = orig_padding
class InflatedCausalConv3d(Conv3d):
def __init__(
self,
*args,
inflation_mode: _inflation_mode_t,
memory_device: _memory_device_t = "same",
**kwargs,
):
self.inflation_mode = inflation_mode
self.memory = None
super().__init__(*args, **kwargs)
self.temporal_padding = self.padding[0]
self.memory_device = memory_device
self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal.
self.memory_limit = float("inf")
def set_memory_limit(self, value: float):
self.memory_limit = value
def set_memory_device(self, memory_device: _memory_device_t):
self.memory_device = memory_device
def memory_limit_conv(
self,
x,
*,
split_dim=3,
padding=(0, 0, 0, 0, 0, 0),
prev_cache=None,
):
# Compatible with no limit.
if math.isinf(self.memory_limit):
if prev_cache is not None:
x = torch.cat([prev_cache, x], dim=split_dim - 1)
return super().forward(x)
# Compute tensor shape after concat & padding.
shape = torch.tensor(x.size())
if prev_cache is not None:
shape[split_dim - 1] += prev_cache.size(split_dim - 1)
shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0)
memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB
logger.debug(
f"x:{(shape, x.dtype)} {memory_occupy:.3f}GiB "
f"prev_cache:{prev_cache.shape if prev_cache is not None else None}"
)
if memory_occupy < self.memory_limit or split_dim == x.ndim:
if prev_cache is not None:
x = torch.cat([prev_cache, x], dim=split_dim - 1)
x = F.pad(x, padding, value=0.0)
with ignore_padding(self):
return super().forward(x)
logger.debug(
f"Exceed memory limit {memory_occupy} > {self.memory_limit}, split dim {split_dim}"
)
# Split input (& prev_cache).
num_splits = math.ceil(memory_occupy / self.memory_limit)
size_per_split = x.size(split_dim) // num_splits
split_sizes = [size_per_split] * (num_splits - 1)
split_sizes += [x.size(split_dim) - sum(split_sizes)]
x = list(x.split(split_sizes, dim=split_dim))
logger.debug(f"Conv inputs: {[inp.size() for inp in x]} {x[0].dtype}")
if prev_cache is not None:
prev_cache = list(prev_cache.split(split_sizes, dim=split_dim))
# Loop Fwd.
cache = None
for idx in range(len(x)):
# Concat prev cache from last dim
if prev_cache is not None:
x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1)
# Get padding pattern.
lpad_dim = (x[idx].ndim - split_dim - 1) * 2
rpad_dim = lpad_dim + 1
padding = list(padding)
padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0
padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0
pad_len = padding[lpad_dim] + padding[rpad_dim]
padding = tuple(padding)
# Prepare cache for next slice (this dim).
next_cache = None
cache_len = cache.size(split_dim) if cache is not None else 0
next_catch_size = get_cache_size(
conv_module=self,
input_len=x[idx].size(split_dim) + cache_len,
pad_len=pad_len,
dim=split_dim - 2,
)
if next_catch_size != 0:
assert next_catch_size <= x[idx].size(split_dim)
next_cache = (
x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim)
)
# Recursive.
x[idx] = self.memory_limit_conv(
x[idx],
split_dim=split_dim + 1,
padding=padding,
prev_cache=cache,
)
# Update cache.
cache = next_cache
logger.debug(f"Conv outputs, concat(dim={split_dim}): {[d.size() for d in x]}")
return torch.cat(x, split_dim)
def forward(
self,
input: Union[Tensor, List[Tensor]],
memory_state: MemoryState = MemoryState.UNSET,
) -> Tensor:
assert memory_state != MemoryState.UNSET
if memory_state != MemoryState.ACTIVE:
self.memory = None
if (
math.isinf(self.memory_limit)
and torch.is_tensor(input)
and get_sequence_parallel_group() is None
):
return self.basic_forward(input, memory_state)
return self.slicing_forward(input, memory_state)
def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET):
mem_size = self.stride[0] - self.kernel_size[0]
if (self.memory is not None) and (memory_state == MemoryState.ACTIVE):
input = extend_head(input, memory=self.memory, times=-1)
else:
input = extend_head(input, times=self.temporal_padding * 2)
memory = (
input[:, :, mem_size:].detach()
if (mem_size != 0 and memory_state != MemoryState.DISABLED)
else None
)
if (
memory_state != MemoryState.DISABLED
and not self.training
and (self.memory_device is not None)
):
self.memory = memory
if self.memory_device == "cpu" and self.memory is not None:
self.memory = self.memory.to("cpu")
return super().forward(input)
def slicing_forward(
self,
input: Union[Tensor, List[Tensor]],
memory_state: MemoryState = MemoryState.UNSET,
) -> Tensor:
squeeze_out = False
if torch.is_tensor(input):
input = [input]
squeeze_out = True
cache_size = self.kernel_size[0] - self.stride[0]
cache = cache_send_recv(
input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2
)
# For slice=4 and sp=2, and 17 frames in total
# sp0 sp1
# slice 0: [`0 0` 0 1 2 {3 4}] [{3 4} 5 6 (7 8)] extend=`0 0` cache={3 4} memory=(7 8)
# slice 1: [(7 8) 9 10 {11 12}] [{11 12} 13 14 15 16]
sp_rank = get_sequence_parallel_rank()
sp_size = get_sequence_parallel_world_size()
sp_group = get_sequence_parallel_group()
send_dst = get_next_sequence_parallel_rank()
recv_src = get_prev_sequence_parallel_rank()
if (
memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing
and not self.training
and (self.memory_device is not None)
and sp_rank in [0, sp_size - 1]
and cache_size != 0
):
if cache_size > input[-1].size(2) and cache is not None and len(input) == 1:
input[0] = torch.cat([cache, input[0]], dim=2)
cache = None
assert cache_size <= input[-1].size(2)
if sp_size == 1:
self.memory = input[-1][:, :, -cache_size:].detach().contiguous()
else:
if sp_rank == sp_size - 1:
dist.send(
input[-1][:, :, -cache_size:].detach().contiguous(),
send_dst,
group=sp_group,
)
if sp_rank == 0:
shape = list(input[0].size())
shape[2] = cache_size
self.memory = torch.empty(
*shape, device=input[0].device, dtype=input[0].dtype
).contiguous()
dist.recv(self.memory, recv_src, group=sp_group)
if self.memory_device == "cpu" and self.memory is not None:
self.memory = self.memory.to("cpu")
padding = tuple(x for x in reversed(self.padding) for _ in range(2))
for i in range(len(input)):
# Prepare cache for next input slice.
next_cache = None
cache_size = 0
if i < len(input) - 1:
cache_len = cache.size(2) if cache is not None else 0
cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0)
if cache_size != 0:
if cache_size > input[i].size(2) and cache is not None:
input[i] = torch.cat([cache, input[i]], dim=2)
cache = None
assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}"
next_cache = input[i][:, :, -cache_size:]
# Conv forward for this input slice.
input[i] = self.memory_limit_conv(
input[i],
padding=padding,
prev_cache=cache,
)
# Update cache.
cache = next_cache
return input[0] if squeeze_out else input
def tflops(self, args, kwargs, output) -> float:
if torch.is_tensor(output):
output_numel = output.numel()
elif isinstance(output, list):
output_numel = sum(o.numel() for o in output)
else:
raise NotImplementedError
return (2 * math.prod(self.kernel_size) * self.in_channels * (output_numel / 1e6)) / 1e6
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if self.inflation_mode != "none":
state_dict = modify_state_dict(
self,
state_dict,
prefix,
inflate_weight_fn=inflate_weight,
inflate_bias_fn=inflate_bias,
)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
(strict and self.inflation_mode == "none"),
missing_keys,
unexpected_keys,
error_msgs,
)
def init_causal_conv3d(
*args,
inflation_mode: _inflation_mode_t,
**kwargs,
):
"""
Initialize a Causal-3D convolution layer.
Parameters:
inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have.
- none: No inflation will be conducted.
The loading logic of state dict will fall back to default.
- tail / replicate: Refer to the definition of `InflatedCausalConv3d`.
"""
return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs)
def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype
if isinstance(norm_layer, (nn.LayerNorm, RMSNorm)):
if x.ndim == 4:
x = rearrange(x, "b c h w -> b h w c")
x = norm_layer(x)
x = rearrange(x, "b h w c -> b c h w")
return x.to(input_dtype)
if x.ndim == 5:
x = rearrange(x, "b c t h w -> b t h w c")
x = norm_layer(x)
x = rearrange(x, "b t h w c -> b c t h w")
return x.to(input_dtype)
if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
if x.ndim <= 4:
return norm_layer(x).to(input_dtype)
if x.ndim == 5:
t = x.size(2)
x = rearrange(x, "b c t h w -> (b t) c h w")
memory_occupy = x.numel() * x.element_size() / 1024**3
if isinstance(norm_layer, nn.GroupNorm) and memory_occupy > get_norm_limit():
num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups)
logger.debug(f"large tensor {x.shape}, norm in {num_chunks} chunks")
assert norm_layer.num_groups % num_chunks == 0
num_groups_per_chunk = norm_layer.num_groups // num_chunks
x = list(x.chunk(num_chunks, dim=1))
weights = norm_layer.weight.chunk(num_chunks, dim=0)
biases = norm_layer.bias.chunk(num_chunks, dim=0)
for i, (w, b) in enumerate(zip(weights, biases)):
x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps)
x[i] = x[i].to(input_dtype)
x = torch.cat(x, dim=1)
else:
x = norm_layer(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x.to(input_dtype)
raise NotImplementedError
def remove_head(tensor: Tensor, times: int = 1) -> Tensor:
"""
Remove duplicated first frame features in the up-sampling process.
"""
sp_rank = get_sequence_parallel_rank()
if times == 0 or sp_rank > 0:
return tensor
return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2)
def extend_head(tensor: Tensor, times: int = 2, memory: Optional[Tensor] = None) -> Tensor:
"""
When memory is None:
- Duplicate first frame features in the down-sampling process.
When memory is not None:
- Concatenate memory features with the input features to keep temporal consistency.
"""
if memory is not None:
return torch.cat((memory.to(tensor), tensor), dim=2)
assert times >= 0, "Invalid input for function 'extend_head'!"
if times == 0:
return tensor
else:
tile_repeat = [1] * tensor.ndim
tile_repeat[2] = times
return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2)
def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str):
"""
Inflate a 2D convolution weight matrix to a 3D one.
Parameters:
weight_2d: The weight matrix of 2D conv to be inflated.
weight_3d: The weight matrix of 3D conv to be initialized.
inflation_mode: the mode of inflation
"""
assert inflation_mode in ["tail", "replicate"]
assert weight_3d.shape[:2] == weight_2d.shape[:2]
with torch.no_grad():
if inflation_mode == "replicate":
depth = weight_3d.size(2)
weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth)
else:
weight_3d.fill_(0.0)
weight_3d[:, :, -1].copy_(weight_2d)
return weight_3d
def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str):
"""
Inflate a 2D convolution bias tensor to a 3D one
Parameters:
bias_2d: The bias tensor of 2D conv to be inflated.
bias_3d: The bias tensor of 3D conv to be initialized.
inflation_mode: Placeholder to align `inflate_weight`.
"""
assert bias_3d.shape == bias_2d.shape
with torch.no_grad():
bias_3d.copy_(bias_2d)
return bias_3d
def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn):
"""
the main function to inflated 2D parameters to 3D.
"""
weight_name = prefix + "weight"
bias_name = prefix + "bias"
if weight_name in state_dict:
weight_2d = state_dict[weight_name]
if weight_2d.dim() == 4:
# Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w)
weight_3d = inflate_weight_fn(
weight_2d=weight_2d,
weight_3d=layer.weight,
inflation_mode=layer.inflation_mode,
)
state_dict[weight_name] = weight_3d
else:
return state_dict
# It's a 3d state dict, should not do inflation on both bias and weight.
if bias_name in state_dict:
bias_2d = state_dict[bias_name]
if bias_2d.dim() == 1:
# Assuming the 2D biases are 1D tensors (out_channels,)
bias_3d = inflate_bias_fn(
bias_2d=bias_2d,
bias_3d=layer.bias,
inflation_mode=layer.inflation_mode,
)
state_dict[bias_name] = bias_3d
return state_dict