|
|
|
import torch |
|
import torch.nn as nn |
|
from typing import Dict, Optional, Tuple, Union |
|
from diffusers import AutoencoderKL |
|
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, Decoder |
|
from diffusers.models.attention_processor import Attention, AttentionProcessor |
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput |
|
from diffusers.models.unets.unet_2d_blocks import ( |
|
AutoencoderTinyBlock, |
|
UNetMidBlock2D, |
|
get_down_block, |
|
get_up_block, |
|
) |
|
from diffusers.utils.accelerate_utils import apply_forward_hook |
|
|
|
class ZeroConv2d(nn.Module): |
|
""" |
|
Zero Convolution layer, similar to the one used in ControlNet. |
|
""" |
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) |
|
self.conv.weight.data.zero_() |
|
self.conv.bias.data.zero_() |
|
|
|
def forward(self, x): |
|
return self.conv(x) |
|
|
|
class CustomAutoencoderKL(AutoencoderKL): |
|
def __init__( |
|
self, |
|
in_channels: int = 3, |
|
out_channels: int = 3, |
|
down_block_types: Tuple[str] = ("DownEncoderBlock2D",), |
|
up_block_types: Tuple[str] = ("UpDecoderBlock2D",), |
|
block_out_channels: Tuple[int] = (64,), |
|
layers_per_block: int = 1, |
|
act_fn: str = "silu", |
|
latent_channels: int = 4, |
|
norm_num_groups: int = 32, |
|
sample_size: int = 32, |
|
scaling_factor: float = 0.18215, |
|
force_upcast: float = True, |
|
use_quant_conv: bool = True, |
|
use_post_quant_conv: bool = True, |
|
mid_block_add_attention: bool = True, |
|
): |
|
super().__init__( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
down_block_types=down_block_types, |
|
up_block_types=up_block_types, |
|
block_out_channels=block_out_channels, |
|
layers_per_block=layers_per_block, |
|
act_fn=act_fn, |
|
latent_channels=latent_channels, |
|
norm_num_groups=norm_num_groups, |
|
sample_size=sample_size, |
|
scaling_factor=scaling_factor, |
|
force_upcast=force_upcast, |
|
use_quant_conv=use_quant_conv, |
|
use_post_quant_conv=use_post_quant_conv, |
|
mid_block_add_attention=mid_block_add_attention, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.decoder = CustomDecoder( |
|
in_channels=latent_channels, |
|
out_channels=out_channels, |
|
up_block_types=up_block_types, |
|
block_out_channels=block_out_channels, |
|
layers_per_block=layers_per_block, |
|
norm_num_groups=norm_num_groups, |
|
act_fn=act_fn, |
|
mid_block_add_attention=mid_block_add_attention, |
|
) |
|
self.encoder = CustomEncoder( |
|
in_channels=in_channels, |
|
out_channels=latent_channels, |
|
down_block_types=down_block_types, |
|
block_out_channels=block_out_channels, |
|
layers_per_block=layers_per_block, |
|
norm_num_groups=norm_num_groups, |
|
act_fn=act_fn, |
|
mid_block_add_attention=mid_block_add_attention, |
|
) |
|
|
|
def encode(self, x: torch.Tensor, return_dict: bool = True): |
|
|
|
_, skip_connections = self.encoder(x) |
|
|
|
return skip_connections |
|
|
|
def decode(self, z: torch.Tensor, skip_connections: list, return_dict: bool = True): |
|
if self.post_quant_conv is not None: |
|
z = self.post_quant_conv(z) |
|
|
|
dec = self.decoder(z, skip_connections) |
|
|
|
if not return_dict: |
|
return (dec,) |
|
|
|
return DecoderOutput(sample=dec) |
|
|
|
def forward( |
|
self, |
|
sample: torch.Tensor, |
|
sample_posterior: bool = False, |
|
return_dict: bool = True, |
|
generator: Optional[torch.Generator] = None, |
|
): |
|
|
|
posterior, skip_connections = self.encode(sample, return_dict=True) |
|
|
|
|
|
if sample_posterior: |
|
z = posterior.sample(generator=generator) |
|
else: |
|
z = posterior.mode() |
|
|
|
|
|
dec = self.decode(z, skip_connections, return_dict=return_dict) |
|
|
|
if not return_dict: |
|
return (dec,) |
|
|
|
return DecoderOutput(sample=dec) |
|
|
|
|
|
class CustomDecoder(Decoder): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
up_block_types: Tuple[str, ...], |
|
block_out_channels: Tuple[int, ...], |
|
layers_per_block: int, |
|
norm_num_groups: int, |
|
act_fn: str, |
|
mid_block_add_attention: bool, |
|
): |
|
super().__init__( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
up_block_types=up_block_types, |
|
block_out_channels=block_out_channels, |
|
layers_per_block=layers_per_block, |
|
norm_num_groups=norm_num_groups, |
|
act_fn=act_fn, |
|
mid_block_add_attention=mid_block_add_attention, |
|
) |
|
|
|
def forward( |
|
self, |
|
sample: torch.Tensor, |
|
skip_connections: list, |
|
latent_embeds: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
r"""The forward method of the `Decoder` class.""" |
|
|
|
sample = self.conv_in(sample) |
|
|
|
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype |
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
if is_torch_version(">=", "1.11.0"): |
|
|
|
sample = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(self.mid_block), |
|
sample, |
|
latent_embeds, |
|
use_reentrant=False, |
|
) |
|
sample = sample.to(upscale_dtype) |
|
|
|
|
|
for up_block in self.up_blocks: |
|
sample = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(up_block), |
|
sample, |
|
latent_embeds, |
|
use_reentrant=False, |
|
) |
|
else: |
|
|
|
sample = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(self.mid_block), sample, latent_embeds |
|
) |
|
sample = sample.to(upscale_dtype) |
|
|
|
|
|
for up_block in self.up_blocks: |
|
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) |
|
else: |
|
|
|
sample = self.mid_block(sample, latent_embeds) |
|
sample = sample.to(upscale_dtype) |
|
|
|
|
|
|
|
|
|
for i, up_block in enumerate(self.up_blocks): |
|
|
|
if i < len(skip_connections): |
|
skip_connection = skip_connections[-(i + 1)] |
|
|
|
sample = sample + skip_connection |
|
|
|
sample = up_block(sample) |
|
|
|
|
|
if latent_embeds is None: |
|
sample = self.conv_norm_out(sample) |
|
else: |
|
sample = self.conv_norm_out(sample, latent_embeds) |
|
sample = self.conv_act(sample) |
|
sample = self.conv_out(sample) |
|
|
|
return sample |
|
|
|
class CustomEncoder(Encoder): |
|
r""" |
|
Custom Encoder that adds Zero Convolution layers to each block's output |
|
to generate skip connections. |
|
""" |
|
def __init__( |
|
self, |
|
in_channels: int = 3, |
|
out_channels: int = 3, |
|
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), |
|
block_out_channels: Tuple[int, ...] = (64,), |
|
layers_per_block: int = 2, |
|
norm_num_groups: int = 32, |
|
act_fn: str = "silu", |
|
double_z: bool = True, |
|
mid_block_add_attention: bool = True, |
|
): |
|
super().__init__( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
down_block_types=down_block_types, |
|
block_out_channels=block_out_channels, |
|
layers_per_block=layers_per_block, |
|
norm_num_groups=norm_num_groups, |
|
act_fn=act_fn, |
|
double_z=double_z, |
|
mid_block_add_attention=mid_block_add_attention, |
|
) |
|
|
|
|
|
self.zero_convs = nn.ModuleList() |
|
for i, out_channels in enumerate(block_out_channels): |
|
if i < 2: |
|
self.zero_convs.append(ZeroConv2d(out_channels, out_channels * 2)) |
|
else: |
|
self.zero_convs.append(ZeroConv2d(out_channels, out_channels)) |
|
|
|
def forward(self, sample: torch.Tensor) -> list[torch.Tensor]: |
|
r""" |
|
Forward pass of the CustomEncoder. |
|
|
|
Args: |
|
sample (`torch.Tensor`): Input tensor. |
|
|
|
Returns: |
|
`Tuple[torch.Tensor, List[torch.Tensor]]`: |
|
- The final latent representation. |
|
- A list of skip connections from each block. |
|
""" |
|
skip_connections = [] |
|
|
|
|
|
sample = self.conv_in(sample) |
|
|
|
|
|
for i, (down_block, zero_conv) in enumerate(zip(self.down_blocks, self.zero_convs)): |
|
|
|
sample = down_block(sample) |
|
if i != len(self.down_blocks) - 1: |
|
sample_out = nn.functional.interpolate(zero_conv(sample), scale_factor=2, mode='bilinear', align_corners=False) |
|
else: |
|
sample_out = zero_conv(sample) |
|
skip_connections.append(sample_out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return sample, skip_connections |