|
from typing import Tuple |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
def zero_module(module): |
|
for p in module.parameters(): |
|
nn.init.zeros_(p) |
|
return module |
|
|
|
class ControlNetConditioningEmbedding(nn.Module): |
|
""" |
|
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN |
|
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized |
|
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the |
|
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides |
|
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full |
|
model) to encode image-space conditions ... into feature maps ..." |
|
""" |
|
|
|
def __init__( |
|
self, |
|
conditioning_embedding_channels: int, |
|
conditioning_channels: int = 3, |
|
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), |
|
): |
|
super().__init__() |
|
|
|
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) |
|
|
|
self.blocks = nn.ModuleList([]) |
|
|
|
for i in range(len(block_out_channels) - 1): |
|
channel_in = block_out_channels[i] |
|
channel_out = block_out_channels[i + 1] |
|
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) |
|
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) |
|
|
|
self.conv_out = zero_module( |
|
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) |
|
) |
|
|
|
def forward(self, conditioning): |
|
embedding = self.conv_in(conditioning) |
|
embedding = F.silu(embedding) |
|
|
|
for block in self.blocks: |
|
embedding = block(embedding) |
|
embedding = F.silu(embedding) |
|
|
|
embedding = self.conv_out(embedding) |
|
|
|
return embedding |