Spaces:
Running
on
A100
Running
on
A100
from dataclasses import dataclass | |
from typing import Optional, Tuple | |
import numpy as np | |
import torch | |
from diffusers.configuration_utils import ConfigMixin | |
from diffusers.configuration_utils import register_to_config | |
from diffusers.models.modeling_outputs import AutoencoderKLOutput | |
from diffusers.models.modeling_utils import ModelMixin | |
from diffusers.utils import BaseOutput | |
from diffusers.utils.torch_utils import randn_tensor | |
from einops import rearrange | |
from torch import Tensor, nn | |
from torch.nn import Conv2d | |
class DiagonalGaussianDistribution: | |
def __init__(self, parameters: torch.Tensor, deterministic: bool = False): | |
if parameters.ndim == 3: | |
dim = 2 # (B, L, C) | |
elif parameters.ndim == 5 or parameters.ndim == 4: | |
dim = 1 # (B, C, T, H, W) / (B, C, H, W) | |
else: | |
raise NotImplementedError | |
self.parameters = parameters | |
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim) | |
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | |
self.deterministic = deterministic | |
self.std = torch.exp(0.5 * self.logvar) | |
self.var = torch.exp(self.logvar) | |
if self.deterministic: | |
zero_tensor = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype) | |
self.var = zero_tensor | |
self.std = zero_tensor | |
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: | |
sample = randn_tensor( | |
self.mean.shape, | |
generator=generator, | |
device=self.parameters.device, | |
dtype=self.parameters.dtype, | |
) | |
return self.mean + self.std * sample | |
def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor: | |
if self.deterministic: | |
return torch.tensor([0.0], device=self.parameters.device, dtype=self.parameters.dtype) | |
reduce_dim = list(range(1, self.mean.ndim)) | |
if other is None: | |
return 0.5 * torch.sum( | |
self.mean.pow(2) + self.var - 1.0 - self.logvar, | |
dim=reduce_dim, | |
) | |
else: | |
return 0.5 * torch.sum( | |
(self.mean - other.mean).pow(2) / other.var | |
+ self.var / other.var | |
- 1.0 | |
- self.logvar | |
+ other.logvar, | |
dim=reduce_dim, | |
) | |
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = (1, 2, 3)) -> torch.Tensor: | |
if self.deterministic: | |
return torch.tensor([0.0], device=self.parameters.device, dtype=self.parameters.dtype) | |
logtwopi = np.log(2.0 * np.pi) | |
return 0.5 * torch.sum( | |
logtwopi + self.logvar + (sample - self.mean).pow(2) / self.var, | |
dim=dims, | |
) | |
def mode(self) -> torch.Tensor: | |
return self.mean | |
class DecoderOutput(BaseOutput): | |
"""Output of the decoder with sample and optional posterior distribution.""" | |
sample: torch.FloatTensor | |
posterior: Optional[DiagonalGaussianDistribution] = None | |
def swish(x: Tensor) -> Tensor: | |
"""Swish activation function: x * sigmoid(x).""" | |
return x * torch.sigmoid(x) | |
def forward_with_checkpointing(module, *inputs, use_checkpointing=False): | |
""" | |
Forward pass with optional gradient checkpointing for memory efficiency. | |
Parameters | |
---------- | |
module : nn.Module | |
The module to run. | |
*inputs : Tensor | |
Inputs to the module. | |
use_checkpointing : bool | |
Whether to use gradient checkpointing. | |
""" | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
if use_checkpointing: | |
return torch.utils.checkpoint.checkpoint(create_custom_forward(module), *inputs, use_reentrant=False) | |
else: | |
return module(*inputs) | |
class AttnBlock(nn.Module): | |
"""Self-attention block for 3D tensors.""" | |
def __init__(self, in_channels: int): | |
super().__init__() | |
self.in_channels = in_channels | |
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
self.q = Conv2d(in_channels, in_channels, kernel_size=1) | |
self.k = Conv2d(in_channels, in_channels, kernel_size=1) | |
self.v = Conv2d(in_channels, in_channels, kernel_size=1) | |
self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1) | |
def attention(self, x: Tensor) -> Tensor: | |
x = self.norm(x) | |
q = self.q(x) | |
k = self.k(x) | |
v = self.v(x) | |
b, c, h, w = q.shape | |
q = rearrange(q, "b c h w -> b (h w) c").contiguous() | |
k = rearrange(k, "b c h w -> b (h w) c").contiguous() | |
v = rearrange(v, "b c h w -> b (h w) c").contiguous() | |
x = nn.functional.scaled_dot_product_attention(q, k, v) | |
return rearrange(x, "b (h w) c -> b c h w", h=h, w=w, c=c, b=b) | |
def forward(self, x: Tensor) -> Tensor: | |
return x + self.proj_out(self.attention(x)) | |
class ResnetBlock(nn.Module): | |
""" | |
Residual block with two convolutions and optional channel change. | |
Parameters | |
---------- | |
in_channels : int | |
Number of input channels. | |
out_channels : int | |
Number of output channels. | |
""" | |
def __init__(self, in_channels: int, out_channels: int): | |
super().__init__() | |
self.in_channels = in_channels | |
out_channels = in_channels if out_channels is None else out_channels | |
self.out_channels = out_channels | |
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) | |
self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
if self.in_channels != self.out_channels: | |
self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) | |
def forward(self, x: Tensor) -> Tensor: | |
h = x | |
h = self.norm1(h) | |
h = swish(h) | |
h = self.conv1(h) | |
h = self.norm2(h) | |
h = swish(h) | |
h = self.conv2(h) | |
if self.in_channels != self.out_channels: | |
x = self.nin_shortcut(x) | |
return x + h | |
class Downsample(nn.Module): | |
""" | |
Downsampling block for spatial reduction. | |
Parameters | |
---------- | |
in_channels : int | |
Number of input channels. | |
out_channels : int | |
Number of output channels. | |
""" | |
def __init__(self, in_channels: int, out_channels: int): | |
super().__init__() | |
factor = 4 | |
assert out_channels % factor == 0 | |
self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1) | |
self.group_size = factor * in_channels // out_channels | |
def forward(self, x: Tensor) -> Tensor: | |
h = self.conv(x) | |
h = rearrange(h, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2) | |
shortcut = rearrange(x, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2) | |
B, C, H, W = shortcut.shape | |
shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2) | |
return h + shortcut | |
class Upsample(nn.Module): | |
""" | |
Upsampling block for spatial expansion. | |
Parameters | |
---------- | |
in_channels : int | |
Number of input channels. | |
out_channels : int | |
Number of output channels. | |
""" | |
def __init__(self, in_channels: int, out_channels: int): | |
super().__init__() | |
factor = 4 | |
self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1) | |
self.repeats = factor * out_channels // in_channels | |
def forward(self, x: Tensor) -> Tensor: | |
h = self.conv(x) | |
h = rearrange(h, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2) | |
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1) | |
shortcut = rearrange(shortcut, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2) | |
return h + shortcut | |
class Encoder(nn.Module): | |
""" | |
Encoder network that compresses input to latent representation. | |
Parameters | |
---------- | |
in_channels : int | |
Number of input channels. | |
z_channels : int | |
Number of latent channels. | |
block_out_channels : Tuple[int, ...] | |
Output channels for each block. | |
num_res_blocks : int | |
Number of residual blocks per block. | |
ffactor_spatial : int | |
Spatial downsampling factor. | |
downsample_match_channel : bool | |
Whether to match channels during downsampling. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
z_channels: int, | |
block_out_channels: Tuple[int, ...], | |
num_res_blocks: int, | |
ffactor_spatial: int, | |
downsample_match_channel: bool = True, | |
): | |
super().__init__() | |
assert block_out_channels[-1] % (2 * z_channels) == 0 | |
self.z_channels = z_channels | |
self.block_out_channels = block_out_channels | |
self.num_res_blocks = num_res_blocks | |
self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) | |
self.down = nn.ModuleList() | |
block_in = block_out_channels[0] | |
for i_level, ch in enumerate(block_out_channels): | |
block = nn.ModuleList() | |
block_out = ch | |
for _ in range(self.num_res_blocks): | |
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) | |
block_in = block_out | |
down = nn.Module() | |
down.block = block | |
add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial)) | |
if add_spatial_downsample: | |
assert i_level < len(block_out_channels) - 1 | |
block_out = block_out_channels[i_level + 1] if downsample_match_channel else block_in | |
down.downsample = Downsample(block_in, block_out) | |
block_in = block_out | |
self.down.append(down) | |
# Middle blocks with attention | |
self.mid = nn.Module() | |
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) | |
self.mid.attn_1 = AttnBlock(block_in) | |
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) | |
# Output layers | |
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) | |
self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) | |
self.gradient_checkpointing = False | |
def forward(self, x: Tensor) -> Tensor: | |
use_checkpointing = bool(self.training and self.gradient_checkpointing) | |
# Downsampling | |
h = self.conv_in(x) | |
for i_level in range(len(self.block_out_channels)): | |
for i_block in range(self.num_res_blocks): | |
h = forward_with_checkpointing( | |
self.down[i_level].block[i_block], h, use_checkpointing=use_checkpointing | |
) | |
if hasattr(self.down[i_level], "downsample"): | |
h = forward_with_checkpointing(self.down[i_level].downsample, h, use_checkpointing=use_checkpointing) | |
# Middle processing | |
h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing) | |
h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing) | |
h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing) | |
# Output with shortcut connection | |
group_size = self.block_out_channels[-1] // (2 * self.z_channels) | |
shortcut = rearrange(h, "b (c r) h w -> b c r h w", r=group_size).mean(dim=2) | |
h = self.norm_out(h) | |
h = swish(h) | |
h = self.conv_out(h) | |
h += shortcut | |
return h | |
class Decoder(nn.Module): | |
""" | |
Decoder network that reconstructs output from latent representation. | |
Parameters | |
---------- | |
z_channels : int | |
Number of latent channels. | |
out_channels : int | |
Number of output channels. | |
block_out_channels : Tuple[int, ...] | |
Output channels for each block. | |
num_res_blocks : int | |
Number of residual blocks per block. | |
ffactor_spatial : int | |
Spatial upsampling factor. | |
upsample_match_channel : bool | |
Whether to match channels during upsampling. | |
""" | |
def __init__( | |
self, | |
z_channels: int, | |
out_channels: int, | |
block_out_channels: Tuple[int, ...], | |
num_res_blocks: int, | |
ffactor_spatial: int, | |
upsample_match_channel: bool = True, | |
): | |
super().__init__() | |
assert block_out_channels[0] % z_channels == 0 | |
self.z_channels = z_channels | |
self.block_out_channels = block_out_channels | |
self.num_res_blocks = num_res_blocks | |
block_in = block_out_channels[0] | |
self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) | |
# Middle blocks with attention | |
self.mid = nn.Module() | |
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) | |
self.mid.attn_1 = AttnBlock(block_in) | |
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) | |
# Upsampling blocks | |
self.up = nn.ModuleList() | |
for i_level, ch in enumerate(block_out_channels): | |
block = nn.ModuleList() | |
block_out = ch | |
for _ in range(self.num_res_blocks + 1): | |
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) | |
block_in = block_out | |
up = nn.Module() | |
up.block = block | |
# Determine upsampling strategy | |
add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial)) | |
if add_spatial_upsample: | |
assert i_level < len(block_out_channels) - 1 | |
block_out = block_out_channels[i_level + 1] if upsample_match_channel else block_in | |
up.upsample = Upsample(block_in, block_out) | |
block_in = block_out | |
self.up.append(up) | |
# Output layers | |
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) | |
self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) | |
self.gradient_checkpointing = False | |
def forward(self, z: Tensor) -> Tensor: | |
use_checkpointing = bool(self.training and self.gradient_checkpointing) | |
repeats = self.block_out_channels[0] // self.z_channels | |
h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1) | |
h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing) | |
h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing) | |
h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing) | |
for i_level in range(len(self.block_out_channels)): | |
for i_block in range(self.num_res_blocks + 1): | |
h = forward_with_checkpointing(self.up[i_level].block[i_block], h, use_checkpointing=use_checkpointing) | |
if hasattr(self.up[i_level], "upsample"): | |
h = forward_with_checkpointing(self.up[i_level].upsample, h, use_checkpointing=use_checkpointing) | |
h = self.norm_out(h) | |
h = swish(h) | |
h = self.conv_out(h) | |
return h | |
class HunyuanVAE2D(ModelMixin, ConfigMixin): | |
""" | |
HunyuanVAE2D: A 2D image VAE model with spatial tiling support. | |
This model implements a variational autoencoder specifically designed for image data, | |
with support for memory-efficient processing through tiling strategies. | |
""" | |
_supports_gradient_checkpointing = True | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
latent_channels: int, | |
block_out_channels: Tuple[int, ...], | |
layers_per_block: int, | |
ffactor_spatial: int, | |
sample_size: int, | |
sample_tsize: int, | |
scaling_factor: float = None, | |
shift_factor: Optional[float] = None, | |
downsample_match_channel: bool = True, | |
upsample_match_channel: bool = True, | |
**kwargs, | |
): | |
super().__init__() | |
self.ffactor_spatial = ffactor_spatial | |
self.scaling_factor = scaling_factor | |
self.shift_factor = shift_factor | |
self.encoder = Encoder( | |
in_channels=in_channels, | |
z_channels=latent_channels, | |
block_out_channels=block_out_channels, | |
num_res_blocks=layers_per_block, | |
ffactor_spatial=ffactor_spatial, | |
downsample_match_channel=downsample_match_channel, | |
) | |
self.decoder = Decoder( | |
z_channels=latent_channels, | |
out_channels=out_channels, | |
block_out_channels=list(reversed(block_out_channels)), | |
num_res_blocks=layers_per_block, | |
ffactor_spatial=ffactor_spatial, | |
upsample_match_channel=upsample_match_channel, | |
) | |
# Tiling and slicing configuration | |
self.use_slicing = False | |
self.use_spatial_tiling = False | |
self.use_tiling_during_training = False | |
# Tiling parameters | |
self.tile_sample_min_size = sample_size | |
self.tile_latent_min_size = sample_size // ffactor_spatial | |
self.tile_overlap_factor = 0.25 | |
def _set_gradient_checkpointing(self, module, value=False): | |
""" | |
Enable or disable gradient checkpointing for memory efficiency. | |
Parameters | |
---------- | |
module : nn.Module | |
The module to set. | |
value : bool | |
Whether to enable gradient checkpointing. | |
""" | |
if isinstance(module, (Encoder, Decoder)): | |
module.gradient_checkpointing = value | |
def enable_spatial_tiling(self, use_tiling: bool = True): | |
"""Enable or disable spatial tiling.""" | |
self.use_spatial_tiling = use_tiling | |
def disable_spatial_tiling(self): | |
"""Disable spatial tiling.""" | |
self.use_spatial_tiling = False | |
def enable_tiling(self, use_tiling: bool = True): | |
"""Enable or disable spatial tiling (alias for enable_spatial_tiling).""" | |
self.enable_spatial_tiling(use_tiling) | |
def disable_tiling(self): | |
"""Disable spatial tiling (alias for disable_spatial_tiling).""" | |
self.disable_spatial_tiling() | |
def enable_slicing(self): | |
"""Enable slicing for batch processing.""" | |
self.use_slicing = True | |
def disable_slicing(self): | |
"""Disable slicing for batch processing.""" | |
self.use_slicing = False | |
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | |
""" | |
Blend two tensors horizontally with smooth transition. | |
Parameters | |
---------- | |
a : torch.Tensor | |
Left tensor. | |
b : torch.Tensor | |
Right tensor. | |
blend_extent : int | |
Number of columns to blend. | |
""" | |
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) | |
for x in range(blend_extent): | |
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( | |
x / blend_extent | |
) | |
return b | |
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | |
""" | |
Blend two tensors vertically with smooth transition. | |
Parameters | |
---------- | |
a : torch.Tensor | |
Top tensor. | |
b : torch.Tensor | |
Bottom tensor. | |
blend_extent : int | |
Number of rows to blend. | |
""" | |
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) | |
for y in range(blend_extent): | |
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( | |
y / blend_extent | |
) | |
return b | |
def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Encode input using spatial tiling strategy. | |
Parameters | |
---------- | |
x : torch.Tensor | |
Input tensor of shape (B, C, T, H, W). | |
""" | |
B, C, T, H, W = x.shape | |
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) | |
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) | |
row_limit = self.tile_latent_min_size - blend_extent | |
rows = [] | |
for i in range(0, H, overlap_size): | |
row = [] | |
for j in range(0, W, overlap_size): | |
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] | |
tile = self.encoder(tile) | |
row.append(tile) | |
rows.append(row) | |
result_rows = [] | |
for i, row in enumerate(rows): | |
result_row = [] | |
for j, tile in enumerate(row): | |
if i > 0: | |
tile = self.blend_v(rows[i - 1][j], tile, blend_extent) | |
if j > 0: | |
tile = self.blend_h(row[j - 1], tile, blend_extent) | |
result_row.append(tile[:, :, :, :row_limit, :row_limit]) | |
result_rows.append(torch.cat(result_row, dim=-1)) | |
moments = torch.cat(result_rows, dim=-2) | |
return moments | |
def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor: | |
""" | |
Decode latent using spatial tiling strategy. | |
Parameters | |
---------- | |
z : torch.Tensor | |
Latent tensor of shape (B, C, H, W). | |
""" | |
B, C, H, W = z.shape | |
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) | |
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) | |
row_limit = self.tile_sample_min_size - blend_extent | |
rows = [] | |
for i in range(0, H, overlap_size): | |
row = [] | |
for j in range(0, W, overlap_size): | |
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] | |
decoded = self.decoder(tile) | |
row.append(decoded) | |
rows.append(row) | |
result_rows = [] | |
for i, row in enumerate(rows): | |
result_row = [] | |
for j, tile in enumerate(row): | |
if i > 0: | |
tile = self.blend_v(rows[i - 1][j], tile, blend_extent) | |
if j > 0: | |
tile = self.blend_h(row[j - 1], tile, blend_extent) | |
result_row.append(tile[:, :, :, :row_limit, :row_limit]) | |
result_rows.append(torch.cat(result_row, dim=-1)) | |
dec = torch.cat(result_rows, dim=-2) | |
return dec | |
def encode(self, x: Tensor, return_dict: bool = True): | |
""" | |
Encode input tensor to latent representation. | |
Parameters | |
---------- | |
x : Tensor | |
Input tensor. | |
return_dict : bool | |
Whether to return a dict. | |
""" | |
original_ndim = x.ndim | |
if original_ndim == 5: | |
x = x.squeeze(2) | |
def _encode(x): | |
if self.use_spatial_tiling and ( | |
x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size | |
): | |
return self.spatial_tiled_encode(x) | |
return self.encoder(x) | |
# Process with or without slicing | |
if self.use_slicing and x.shape[0] > 1: | |
encoded_slices = [_encode(x_slice) for x_slice in x.split(1)] | |
h = torch.cat(encoded_slices) | |
else: | |
h = _encode(x) | |
if original_ndim == 5: | |
h = h.unsqueeze(2) | |
posterior = DiagonalGaussianDistribution(h) | |
if not return_dict: | |
return (posterior,) | |
return AutoencoderKLOutput(latent_dist=posterior) | |
def decode(self, z: Tensor, return_dict: bool = True, generator=None): | |
""" | |
Decode latent representation to output tensor. | |
Parameters | |
---------- | |
z : Tensor | |
Latent tensor. | |
return_dict : bool | |
Whether to return a dict. | |
generator : unused | |
For compatibility. | |
""" | |
original_ndim = z.ndim | |
if original_ndim == 5: | |
z = z.squeeze(2) | |
def _decode(z): | |
if self.use_spatial_tiling and ( | |
z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size | |
): | |
return self.spatial_tiled_decode(z) | |
return self.decoder(z) | |
if self.use_slicing and z.shape[0] > 1: | |
decoded_slices = [_decode(z_slice) for z_slice in z.split(1)] | |
decoded = torch.cat(decoded_slices) | |
else: | |
decoded = _decode(z) | |
if original_ndim == 5: | |
decoded = decoded.unsqueeze(2) | |
if not return_dict: | |
return (decoded,) | |
return DecoderOutput(sample=decoded) | |
def forward( | |
self, | |
sample: torch.Tensor, | |
sample_posterior: bool = False, | |
return_posterior: bool = True, | |
return_dict: bool = True, | |
): | |
""" | |
Forward pass through the VAE (Encode and Decode). | |
Parameters | |
---------- | |
sample : torch.Tensor | |
Input tensor. | |
sample_posterior : bool | |
Whether to sample from the posterior. | |
return_posterior : bool | |
Whether to return the posterior. | |
return_dict : bool | |
Whether to return a dict. | |
""" | |
posterior = self.encode(sample).latent_dist | |
z = posterior.sample() if sample_posterior else posterior.mode() | |
dec = self.decode(z).sample | |
if return_dict: | |
return DecoderOutput(sample=dec, posterior=posterior) | |
else: | |
return (dec, posterior) | |
def load_state_dict(self, state_dict, strict=True): | |
""" | |
Load state dict, handling possible 5D weight tensors. | |
Parameters | |
---------- | |
state_dict : dict | |
State dictionary. | |
strict : bool | |
Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict function. | |
""" | |
converted_state_dict = {} | |
for key, value in state_dict.items(): | |
if 'weight' in key: | |
if len(value.shape) == 5 and value.shape[2] == 1: | |
converted_state_dict[key] = value.squeeze(2) | |
else: | |
converted_state_dict[key] = value | |
else: | |
converted_state_dict[key] = value | |
return super().load_state_dict(converted_state_dict, strict=strict) |