KevinNg99's picture
Initial commit.
43c5292
"""
Reference code
[FLUX] https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/autoencoder.py
[DCAE] https://github.com/mit-han-lab/efficientvit/blob/master/efficientvit/models/efficientvit/dc_ae.py
"""
import math
import random
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin
from einops import rearrange
from torch import Tensor, nn
from .hunyuanimage_vae import BaseOutput, DiagonalGaussianDistribution
@dataclass
class DecoderOutput(BaseOutput):
sample: torch.FloatTensor
posterior: Optional[DiagonalGaussianDistribution] = None
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
def forward_with_checkpointing(module, *inputs, use_checkpointing=False):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if use_checkpointing:
return torch.utils.checkpoint.checkpoint(create_custom_forward(module), *inputs, use_reentrant=False)
else:
return module(*inputs)
# Optimized implementation of CogVideoXSafeConv3d
# https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py#L38
class PatchCausalConv3d(nn.Conv3d):
def find_split_indices(self, seq_len, part_num):
ideal_interval = seq_len / part_num
possible_indices = list(range(0, seq_len, self.stride[0]))
selected_indices = []
for i in range(1, part_num):
closest = min(possible_indices, key=lambda x: abs(x - round(i * ideal_interval)))
if closest not in selected_indices:
selected_indices.append(closest)
merged_indices = []
prev_idx = 0
for idx in selected_indices:
if idx - prev_idx >= self.kernel_size[0]:
merged_indices.append(idx)
prev_idx = idx
return merged_indices
def forward(self, input):
T = input.shape[2] # input: NCTHW
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
if T > self.kernel_size[0] and memory_count > 2:
kernel_size = self.kernel_size[0]
part_num = int(memory_count / 2) + 1
split_indices = self.find_split_indices(T, part_num)
input_chunks = torch.tensor_split(input, split_indices, dim=2) if len(split_indices) > 0 else [input]
if kernel_size > 1:
input_chunks = [input_chunks[0]] + [
torch.cat(
(
input_chunks[i - 1][:, :, -kernel_size + 1 :],
input_chunks[i],
),
dim=2,
)
for i in range(1, len(input_chunks))
]
output_chunks = []
for input_chunk in input_chunks:
output_chunks.append(super().forward(input_chunk))
output = torch.cat(output_chunks, dim=2)
return output
else:
return super().forward(input)
class CausalConv3d(nn.Module):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
dilation: Union[int, Tuple[int, int, int]] = 1,
pad_mode="replicate",
disable_causal=False,
enable_patch_conv=False,
**kwargs,
):
super().__init__()
self.pad_mode = pad_mode
if disable_causal:
padding = (
kernel_size // 2,
kernel_size // 2,
kernel_size // 2,
kernel_size // 2,
kernel_size // 2,
kernel_size // 2,
)
else:
padding = (
kernel_size // 2,
kernel_size // 2,
kernel_size // 2,
kernel_size // 2,
kernel_size - 1,
0,
) # W, H, T
self.time_causal_padding = padding
if enable_patch_conv:
self.conv = PatchCausalConv3d(
chan_in,
chan_out,
kernel_size,
stride=stride,
dilation=dilation,
**kwargs,
)
else:
self.conv = nn.Conv3d(
chan_in,
chan_out,
kernel_size,
stride=stride,
dilation=dilation,
**kwargs,
)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class Conv3d(nn.Conv3d):
"""Perform Conv3d on patches with numerical differences from nn.Conv3d within 1e-5. Only symmetric padding is supported."""
def forward(self, input):
B, C, T, H, W = input.shape
memory_count = (C * T * H * W) * 2 / 1024**3
if memory_count > 2:
n_split = math.ceil(memory_count / 2)
assert n_split >= 2
chunks = torch.chunk(input, chunks=n_split, dim=-3)
padded_chunks = []
for i in range(len(chunks)):
if self.padding[0] > 0:
padded_chunk = F.pad(
chunks[i],
(0, 0, 0, 0, self.padding[0], self.padding[0]),
mode="constant" if self.padding_mode == "zeros" else self.padding_mode,
value=0,
)
if i > 0:
padded_chunk[:, :, : self.padding[0]] = chunks[i - 1][:, :, -self.padding[0] :]
if i < len(chunks) - 1:
padded_chunk[:, :, -self.padding[0] :] = chunks[i + 1][:, :, : self.padding[0]]
else:
padded_chunk = chunks[i]
padded_chunks.append(padded_chunk)
padding_bak = self.padding
self.padding = (0, self.padding[1], self.padding[2])
outputs = []
for i in range(len(padded_chunks)):
outputs.append(super().forward(padded_chunks[i]))
self.padding = padding_bak
return torch.cat(outputs, dim=-3)
else:
return super().forward(input)
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
seq_len = n_frame * n_hw
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // n_hw
mask[i, : (i_frame + 1) * n_hw] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
# self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.norm = RMS_norm(in_channels, images=False)
self.q = Conv3d(in_channels, in_channels, kernel_size=1)
self.k = Conv3d(in_channels, in_channels, kernel_size=1)
self.v = Conv3d(in_channels, in_channels, kernel_size=1)
self.proj_out = Conv3d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, f, h, w = q.shape
q = rearrange(q, "b c f h w -> b 1 (f h w) c").contiguous()
k = rearrange(k, "b c f h w -> b 1 (f h w) c").contiguous()
v = rearrange(v, "b c f h w -> b 1 (f h w) c").contiguous()
attention_mask = prepare_causal_attention_mask(f, h * w, h_.dtype, h_.device, batch_size=b)
h_ = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask.unsqueeze(1))
return rearrange(h_, "b 1 (f h w) c -> b c f h w", f=f, h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
# self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
# self.conv1 = Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm1 = RMS_norm(in_channels, images=False)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3)
# self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
# self.conv2 = Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = RMS_norm(out_channels, images=False)
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3)
if self.in_channels != self.out_channels:
self.nin_shortcut = Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int, add_temporal_downsample: bool = True):
super().__init__()
self.add_temporal_downsample = add_temporal_downsample
stride = (2, 2, 2) if add_temporal_downsample else (1, 2, 2) # THW
# no asymmetric padding in torch conv, must do it ourselves
# self.conv = Conv3d(in_channels, in_channels, kernel_size=3, stride=stride, padding=0)
self.conv = CausalConv3d(in_channels, in_channels, kernel_size=3)
def forward(self, x: Tensor):
spatial_pad = (0, 1, 0, 1, 0, 0) # WHT
x = nn.functional.pad(x, spatial_pad, mode="constant", value=0)
temporal_pad = (0, 0, 0, 0, 0, 1) if self.add_temporal_downsample else (0, 0, 0, 0, 1, 1)
x = nn.functional.pad(x, temporal_pad, mode="replicate")
x = self.conv(x)
return x
class DownsampleDCAE(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
assert out_channels % factor == 0
# self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.conv = CausalConv3d(in_channels, out_channels // factor, kernel_size=3)
self.add_temporal_downsample = add_temporal_downsample
self.group_size = factor * in_channels // out_channels
def forward(self, x: Tensor):
r1 = 2 if self.add_temporal_downsample else 1
h = self.conv(x)
if self.add_temporal_downsample:
h_first = h[:, :, :1, :, :]
h_first = rearrange(h_first, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
h_first = torch.cat([h_first, h_first], dim=1)
h_next = h[:, :, 1:, :, :]
h_next = rearrange(
h_next,
"b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w",
r1=r1,
r2=2,
r3=2,
)
h = torch.cat([h_first, h_next], dim=2)
# shortcut computation
x_first = x[:, :, :1, :, :]
x_first = rearrange(x_first, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
B, C, T, H, W = x_first.shape
x_first = x_first.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
x_next = x[:, :, 1:, :, :]
x_next = rearrange(
x_next,
"b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w",
r1=r1,
r2=2,
r3=2,
)
B, C, T, H, W = x_next.shape
x_next = x_next.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
shortcut = torch.cat([x_first, x_next], dim=2)
else:
h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
B, C, T, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
return h + shortcut
class Upsample(nn.Module):
def __init__(self, in_channels: int, add_temporal_upsample: bool = True):
super().__init__()
self.add_temporal_upsample = add_temporal_upsample
self.scale_factor = (2, 2, 2) if add_temporal_upsample else (1, 2, 2) # THW
# self.conv = Conv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.conv = CausalConv3d(in_channels, in_channels, kernel_size=3)
def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
x = self.conv(x)
return x
class UpsampleDCAE(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
# self.conv = Conv3d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
self.conv = CausalConv3d(in_channels, out_channels * factor, kernel_size=3)
self.add_temporal_upsample = add_temporal_upsample
self.repeats = factor * out_channels // in_channels
def forward(self, x: Tensor):
r1 = 2 if self.add_temporal_upsample else 1
h = self.conv(x)
if self.add_temporal_upsample:
h_first = h[:, :, :1, :, :]
h_first = rearrange(h_first, "b (r2 r3 c) f h w -> b c f (h r2) (w r3)", r2=2, r3=2)
h_first = h_first[:, : h_first.shape[1] // 2]
h_next = h[:, :, 1:, :, :]
h_next = rearrange(
h_next,
"b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)",
r1=r1,
r2=2,
r3=2,
)
h = torch.cat([h_first, h_next], dim=2)
# shortcut computation
x_first = x[:, :, :1, :, :]
x_first = rearrange(x_first, "b (r2 r3 c) f h w -> b c f (h r2) (w r3)", r2=2, r3=2)
x_first = x_first.repeat_interleave(repeats=self.repeats // 2, dim=1)
x_next = x[:, :, 1:, :, :]
x_next = rearrange(
x_next,
"b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)",
r1=r1,
r2=2,
r3=2,
)
x_next = x_next.repeat_interleave(repeats=self.repeats, dim=1)
shortcut = torch.cat([x_first, x_next], dim=2)
else:
h = rearrange(h, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
shortcut = rearrange(
shortcut,
"b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)",
r1=r1,
r2=2,
r3=2,
)
return h + shortcut
class Encoder(nn.Module):
def __init__(
self,
in_channels: int,
z_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
ffactor_spatial: int,
ffactor_temporal: int,
downsample_match_channel: bool = True,
):
super().__init__()
assert block_out_channels[-1] % (2 * z_channels) == 0
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
# downsampling
# self.conv_in = Conv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
self.down = nn.ModuleList()
block_in = block_out_channels[0]
for i_level, ch in enumerate(block_out_channels):
block = nn.ModuleList()
block_out = ch
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial))
add_temporal_downsample = add_spatial_downsample and bool(
i_level >= np.log2(ffactor_spatial // ffactor_temporal)
)
if add_spatial_downsample or add_temporal_downsample:
assert i_level < len(block_out_channels) - 1
block_out = block_out_channels[i_level + 1] if downsample_match_channel else block_in
down.downsample = DownsampleDCAE(block_in, block_out, add_temporal_downsample)
block_in = block_out
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
# self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
# self.conv_out = Conv3d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
self.norm_out = RMS_norm(block_in, images=False)
self.conv_out = CausalConv3d(block_in, 2 * z_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, x: Tensor) -> Tensor:
use_checkpointing = bool(self.training and self.gradient_checkpointing)
# downsampling
h = self.conv_in(x)
for i_level in range(len(self.block_out_channels)):
for i_block in range(self.num_res_blocks):
h = forward_with_checkpointing(
self.down[i_level].block[i_block],
h,
use_checkpointing=use_checkpointing,
)
if hasattr(self.down[i_level], "downsample"):
h = forward_with_checkpointing(
self.down[i_level].downsample,
h,
use_checkpointing=use_checkpointing,
)
# middle
h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
# end
group_size = self.block_out_channels[-1] // (2 * self.z_channels)
shortcut = rearrange(h, "b (c r) f h w -> b c r f h w", r=group_size).mean(dim=2)
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
h += shortcut
return h
class Decoder(nn.Module):
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
ffactor_spatial: int,
ffactor_temporal: int,
upsample_match_channel: bool = True,
):
super().__init__()
assert block_out_channels[0] % z_channels == 0
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
# z to block_in
block_in = block_out_channels[0]
# self.conv_in = Conv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level, ch in enumerate(block_out_channels):
block = nn.ModuleList()
block_out = ch
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial))
add_temporal_upsample = bool(i_level < np.log2(ffactor_temporal))
if add_spatial_upsample or add_temporal_upsample:
assert i_level < len(block_out_channels) - 1
block_out = block_out_channels[i_level + 1] if upsample_match_channel else block_in
up.upsample = UpsampleDCAE(block_in, block_out, add_temporal_upsample)
block_in = block_out
self.up.append(up)
# end
# self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
# self.conv_out = Conv3d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
self.norm_out = RMS_norm(block_in, images=False)
self.conv_out = CausalConv3d(block_in, out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, z: Tensor) -> Tensor:
use_checkpointing = bool(self.training and self.gradient_checkpointing)
# z to block_in
repeats = self.block_out_channels[0] // (self.z_channels)
h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
# middle
h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
# upsampling
for i_level in range(len(self.block_out_channels)):
for i_block in range(self.num_res_blocks + 1):
h = forward_with_checkpointing(
self.up[i_level].block[i_block],
h,
use_checkpointing=use_checkpointing,
)
if hasattr(self.up[i_level], "upsample"):
h = forward_with_checkpointing(self.up[i_level].upsample, h, use_checkpointing=use_checkpointing)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class AutoencoderKLConv3D(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int,
out_channels: int,
latent_channels: int,
block_out_channels: Tuple[int, ...],
layers_per_block: int,
ffactor_spatial: int,
ffactor_temporal: int,
sample_size: int,
sample_tsize: int,
scaling_factor: float = None,
shift_factor: Optional[float] = None,
downsample_match_channel: bool = True,
upsample_match_channel: bool = True,
):
super().__init__()
self.ffactor_spatial = ffactor_spatial
self.ffactor_temporal = ffactor_temporal
self.scaling_factor = scaling_factor
self.shift_factor = shift_factor
self.encoder = Encoder(
in_channels=in_channels,
z_channels=latent_channels,
block_out_channels=block_out_channels,
num_res_blocks=layers_per_block,
ffactor_spatial=ffactor_spatial,
ffactor_temporal=ffactor_temporal,
downsample_match_channel=downsample_match_channel,
)
self.decoder = Decoder(
z_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
num_res_blocks=layers_per_block,
ffactor_spatial=ffactor_spatial,
ffactor_temporal=ffactor_temporal,
upsample_match_channel=upsample_match_channel,
)
self.use_slicing = False
self.use_spatial_tiling = False
self.use_temporal_tiling = False
self.use_tiling_during_training = False
# only relevant if vae tiling is enabled
self.tile_sample_min_size = sample_size
self.tile_latent_min_size = sample_size // ffactor_spatial
self.tile_sample_min_tsize = sample_tsize
self.tile_latent_min_tsize = sample_tsize // ffactor_temporal
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value
def enable_tiling_during_training(self, use_tiling: bool = True):
self.use_tiling_during_training = use_tiling
def disable_tiling_during_training(self):
self.enable_tiling_during_training(False)
def enable_temporal_tiling(self, use_tiling: bool = True):
self.use_temporal_tiling = use_tiling
def disable_temporal_tiling(self):
self.enable_temporal_tiling(False)
def enable_spatial_tiling(self, use_tiling: bool = True):
self.use_spatial_tiling = use_tiling
def disable_spatial_tiling(self):
self.enable_spatial_tiling(False)
def enable_tiling(self, use_tiling: bool = True):
self.enable_spatial_tiling(use_tiling)
def disable_tiling(self):
self.disable_spatial_tiling()
def enable_slicing(self):
self.use_slicing = True
def disable_slicing(self):
self.use_slicing = False
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b
def spatial_tiled_encode(self, x: torch.Tensor):
B, C, T, H, W = x.shape
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) # 8 * 0.25 = 2
row_limit = self.tile_latent_min_size - blend_extent # 8 - 2 = 6
rows = []
for i in range(0, H, overlap_size):
row = []
for j in range(0, W, overlap_size):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def temporal_tiled_encode(self, x: torch.Tensor):
B, C, T, H, W = x.shape
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) # 64 * (1 - 0.25) = 48
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) # 8 * 0.25 = 2
t_limit = self.tile_latent_min_tsize - blend_extent # 8 - 2 = 6
row = []
for i in range(0, T, overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
if self.use_spatial_tiling and (
tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
):
tile = self.spatial_tiled_encode(tile)
else:
tile = self.encoder(tile)
if i > 0:
tile = tile[:, :, 1:, :, :]
row.append(tile)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])
else:
result_row.append(tile[:, :, : t_limit + 1, :, :])
moments = torch.cat(result_row, dim=-3)
return moments
def spatial_tiled_decode(self, z: torch.Tensor):
B, C, T, H, W = z.shape
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) # 256 * 0.25 = 64
row_limit = self.tile_sample_min_size - blend_extent # 256 - 64 = 192
rows = []
for i in range(0, H, overlap_size):
row = []
for j in range(0, W, overlap_size):
tile = z[
:,
:,
:,
i : i + self.tile_latent_min_size,
j : j + self.tile_latent_min_size,
]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
return dec
def temporal_tiled_decode(self, z: torch.Tensor):
B, C, T, H, W = z.shape
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) # 64 * 0.25 = 16
t_limit = self.tile_sample_min_tsize - blend_extent # 64 - 16 = 48
assert 0 < overlap_size < self.tile_latent_min_tsize
row = []
for i in range(0, T, overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
if self.use_spatial_tiling and (
tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
):
decoded = self.spatial_tiled_decode(tile)
else:
decoded = self.decoder(tile)
if i > 0:
decoded = decoded[:, :, 1:, :, :]
row.append(decoded)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])
else:
result_row.append(tile[:, :, : t_limit + 1, :, :])
dec = torch.cat(result_row, dim=-3)
return dec
def encode(self, x: Tensor, return_dict: bool = True):
def _encode(x):
if self.use_temporal_tiling and x.shape[-3] > self.tile_sample_min_tsize:
return self.temporal_tiled_encode(x)
if self.use_spatial_tiling and (
x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
):
return self.spatial_tiled_encode(x)
return self.encoder(x)
assert len(x.shape) == 5 # (B, C, T, H, W)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [_encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = _encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def decode(self, z: Tensor, return_dict: bool = True, generator=None):
def _decode(z):
if self.use_temporal_tiling and z.shape[-3] > self.tile_latent_min_tsize:
return self.temporal_tiled_decode(z)
if self.use_spatial_tiling and (
z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size
):
return self.spatial_tiled_decode(z)
return self.decoder(z)
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [_decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = _decode(z)
# if z.shape[-3] == 1:
# decoded = decoded[:, :, -1:]
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_posterior: bool = True,
return_dict: bool = True,
):
posterior = self.encode(sample).latent_dist
z = posterior.sample() if sample_posterior else posterior.mode()
dec = self.decode(z).sample
return DecoderOutput(sample=dec, posterior=posterior) if return_dict else (dec, posterior)
def random_reset_tiling(self, x: torch.Tensor):
if x.shape[-3] == 1:
self.disable_spatial_tiling()
self.disable_temporal_tiling()
return
min_sample_size = int(1 / self.tile_overlap_factor) * self.ffactor_spatial
min_sample_tsize = int(1 / self.tile_overlap_factor) * self.ffactor_temporal
sample_size = random.choice([None, 1 * min_sample_size, 2 * min_sample_size, 3 * min_sample_size])
if sample_size is None:
self.disable_spatial_tiling()
else:
self.tile_sample_min_size = sample_size
self.tile_latent_min_size = sample_size // self.ffactor_spatial
self.enable_spatial_tiling()
sample_tsize = random.choice([None, 1 * min_sample_tsize, 2 * min_sample_tsize, 3 * min_sample_tsize])
if sample_tsize is None:
self.disable_temporal_tiling()
else:
self.tile_sample_min_tsize = sample_tsize
self.tile_latent_min_tsize = sample_tsize // self.ffactor_temporal
self.enable_temporal_tiling()