Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2022 Google LLC. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Main file for the MAXIM model.""" | |
import functools | |
from typing import Any, Sequence, Tuple | |
import einops | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
Conv3x3 = functools.partial(nn.Conv, kernel_size=(3, 3)) | |
Conv1x1 = functools.partial(nn.Conv, kernel_size=(1, 1)) | |
ConvT_up = functools.partial(nn.ConvTranspose, | |
kernel_size=(2, 2), | |
strides=(2, 2)) | |
Conv_down = functools.partial(nn.Conv, | |
kernel_size=(4, 4), | |
strides=(2, 2)) | |
weight_initializer = nn.initializers.normal(stddev=2e-2) | |
class MlpBlock(nn.Module): | |
"""A 1-hidden-layer MLP block, applied over the last dimension.""" | |
mlp_dim: int | |
dropout_rate: float = 0.0 | |
use_bias: bool = True | |
def __call__(self, x, deterministic=True): | |
n, h, w, d = x.shape | |
x = nn.Dense(self.mlp_dim, use_bias=self.use_bias, | |
kernel_init=weight_initializer)(x) | |
x = nn.gelu(x) | |
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic) | |
x = nn.Dense(d, use_bias=self.use_bias, | |
kernel_init=weight_initializer)(x) | |
return x | |
def block_images_einops(x, patch_size): | |
"""Image to patches.""" | |
batch, height, width, channels = x.shape | |
grid_height = height // patch_size[0] | |
grid_width = width // patch_size[1] | |
x = einops.rearrange( | |
x, "n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c", | |
gh=grid_height, gw=grid_width, fh=patch_size[0], fw=patch_size[1]) | |
return x | |
def unblock_images_einops(x, grid_size, patch_size): | |
"""patches to images.""" | |
x = einops.rearrange( | |
x, "n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c", | |
gh=grid_size[0], gw=grid_size[1], fh=patch_size[0], fw=patch_size[1]) | |
return x | |
class UpSampleRatio(nn.Module): | |
"""Upsample features given a ratio > 0.""" | |
features: int | |
ratio: float | |
use_bias: bool = True | |
def __call__(self, x): | |
n, h, w, c = x.shape | |
x = jax.image.resize( | |
x, | |
shape=(n, int(h * self.ratio), int(w * self.ratio), c), | |
method="bilinear") | |
x = Conv1x1(features=self.features, use_bias=self.use_bias)(x) | |
return x | |
class CALayer(nn.Module): | |
"""Squeeze-and-excitation block for channel attention. | |
ref: https://arxiv.org/abs/1709.01507 | |
""" | |
features: int | |
reduction: int = 4 | |
use_bias: bool = True | |
def __call__(self, x): | |
# 2D global average pooling | |
y = jnp.mean(x, axis=[1, 2], keepdims=True) | |
# Squeeze (in Squeeze-Excitation) | |
y = Conv1x1(self.features // self.reduction, use_bias=self.use_bias)(y) | |
y = nn.relu(y) | |
# Excitation (in Squeeze-Excitation) | |
y = Conv1x1(self.features, use_bias=self.use_bias)(y) | |
y = nn.sigmoid(y) | |
return x * y | |
class RCAB(nn.Module): | |
"""Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer.""" | |
features: int | |
reduction: int = 4 | |
lrelu_slope: float = 0.2 | |
use_bias: bool = True | |
def __call__(self, x): | |
shortcut = x | |
x = nn.LayerNorm(name="LayerNorm")(x) | |
x = Conv3x3(features=self.features, use_bias=self.use_bias, name="conv1")(x) | |
x = nn.leaky_relu(x, negative_slope=self.lrelu_slope) | |
x = Conv3x3(features=self.features, use_bias=self.use_bias, name="conv2")(x) | |
x = CALayer(features=self.features, reduction=self.reduction, | |
use_bias=self.use_bias, name="channel_attention")(x) | |
return x + shortcut | |
class GridGatingUnit(nn.Module): | |
"""A SpatialGatingUnit as defined in the gMLP paper. | |
The 'spatial' dim is defined as the second last. | |
If applied on other dims, you should swapaxes first. | |
""" | |
use_bias: bool = True | |
def __call__(self, x): | |
u, v = jnp.split(x, 2, axis=-1) | |
v = nn.LayerNorm(name="intermediate_layernorm")(v) | |
n = x.shape[-3] # get spatial dim | |
v = jnp.swapaxes(v, -1, -3) | |
v = nn.Dense(n, use_bias=self.use_bias, kernel_init=weight_initializer)(v) | |
v = jnp.swapaxes(v, -1, -3) | |
return u * (v + 1.) | |
class GridGmlpLayer(nn.Module): | |
"""Grid gMLP layer that performs global mixing of tokens.""" | |
grid_size: Sequence[int] | |
use_bias: bool = True | |
factor: int = 2 | |
dropout_rate: float = 0.0 | |
def __call__(self, x, deterministic=True): | |
n, h, w, num_channels = x.shape | |
gh, gw = self.grid_size | |
fh, fw = h // gh, w // gw | |
x = block_images_einops(x, patch_size=(fh, fw)) | |
# gMLP1: Global (grid) mixing part, provides global grid communication. | |
y = nn.LayerNorm(name="LayerNorm")(x) | |
y = nn.Dense(num_channels * self.factor, use_bias=self.use_bias, | |
kernel_init=weight_initializer, name="in_project")(y) | |
y = nn.gelu(y) | |
y = GridGatingUnit(use_bias=self.use_bias, name="GridGatingUnit")(y) | |
y = nn.Dense(num_channels, use_bias=self.use_bias, | |
kernel_init=weight_initializer, name="out_project")(y) | |
y = nn.Dropout(self.dropout_rate)(y, deterministic) | |
x = x + y | |
x = unblock_images_einops(x, grid_size=(gh, gw), patch_size=(fh, fw)) | |
return x | |
class BlockGatingUnit(nn.Module): | |
"""A SpatialGatingUnit as defined in the gMLP paper. | |
The 'spatial' dim is defined as the **second last**. | |
If applied on other dims, you should swapaxes first. | |
""" | |
use_bias: bool = True | |
def __call__(self, x): | |
u, v = jnp.split(x, 2, axis=-1) | |
v = nn.LayerNorm(name="intermediate_layernorm")(v) | |
n = x.shape[-2] # get spatial dim | |
v = jnp.swapaxes(v, -1, -2) | |
v = nn.Dense(n, use_bias=self.use_bias, kernel_init=weight_initializer)(v) | |
v = jnp.swapaxes(v, -1, -2) | |
return u * (v + 1.) | |
class BlockGmlpLayer(nn.Module): | |
"""Block gMLP layer that performs local mixing of tokens.""" | |
block_size: Sequence[int] | |
use_bias: bool = True | |
factor: int = 2 | |
dropout_rate: float = 0.0 | |
def __call__(self, x, deterministic=True): | |
n, h, w, num_channels = x.shape | |
fh, fw = self.block_size | |
gh, gw = h // fh, w // fw | |
x = block_images_einops(x, patch_size=(fh, fw)) | |
# MLP2: Local (block) mixing part, provides within-block communication. | |
y = nn.LayerNorm(name="LayerNorm")(x) | |
y = nn.Dense(num_channels * self.factor, use_bias=self.use_bias, | |
kernel_init=weight_initializer, name="in_project")(y) | |
y = nn.gelu(y) | |
y = BlockGatingUnit(use_bias=self.use_bias, name="BlockGatingUnit")(y) | |
y = nn.Dense(num_channels, use_bias=self.use_bias, | |
kernel_init=weight_initializer, name="out_project")(y) | |
y = nn.Dropout(self.dropout_rate)(y, deterministic) | |
x = x + y | |
x = unblock_images_einops(x, grid_size=(gh, gw), patch_size=(fh, fw)) | |
return x | |
class ResidualSplitHeadMultiAxisGmlpLayer(nn.Module): | |
"""The multi-axis gated MLP block.""" | |
block_size: Sequence[int] | |
grid_size: Sequence[int] | |
block_gmlp_factor: int = 2 | |
grid_gmlp_factor: int = 2 | |
input_proj_factor: int = 2 | |
use_bias: bool = True | |
dropout_rate: float = 0.0 | |
def __call__(self, x, deterministic=True): | |
shortcut = x | |
n, h, w, num_channels = x.shape | |
x = nn.LayerNorm(name="LayerNorm_in")(x) | |
x = nn.Dense(num_channels * self.input_proj_factor, use_bias=self.use_bias, | |
kernel_init=weight_initializer, name="in_project")(x) | |
x = nn.gelu(x) | |
u, v = jnp.split(x, 2, axis=-1) | |
# GridGMLPLayer | |
u = GridGmlpLayer( | |
grid_size=self.grid_size, | |
factor=self.grid_gmlp_factor, | |
use_bias=self.use_bias, | |
dropout_rate=self.dropout_rate, | |
name="GridGmlpLayer")(u, deterministic) | |
# BlockGMLPLayer | |
v = BlockGmlpLayer( | |
block_size=self.block_size, | |
factor=self.block_gmlp_factor, | |
use_bias=self.use_bias, | |
dropout_rate=self.dropout_rate, | |
name="BlockGmlpLayer")(v, deterministic) | |
x = jnp.concatenate([u, v], axis=-1) | |
x = nn.Dense(num_channels, use_bias=self.use_bias, | |
kernel_init=weight_initializer, name="out_project")(x) | |
x = nn.Dropout(self.dropout_rate)(x, deterministic) | |
x = x + shortcut | |
return x | |
class RDCAB(nn.Module): | |
"""Residual dense channel attention block. Used in Bottlenecks.""" | |
features: int | |
reduction: int = 16 | |
use_bias: bool = True | |
dropout_rate: float = 0.0 | |
def __call__(self, x, deterministic=True): | |
y = nn.LayerNorm(name="LayerNorm")(x) | |
y = MlpBlock( | |
mlp_dim=self.features, | |
dropout_rate=self.dropout_rate, | |
use_bias=self.use_bias, | |
name="channel_mixing")( | |
y, deterministic=deterministic) | |
y = CALayer( | |
features=self.features, | |
reduction=self.reduction, | |
use_bias=self.use_bias, | |
name="channel_attention")( | |
y) | |
x = x + y | |
return x | |
class BottleneckBlock(nn.Module): | |
"""The bottleneck block consisting of multi-axis gMLP block and RDCAB.""" | |
features: int | |
block_size: Sequence[int] | |
grid_size: Sequence[int] | |
num_groups: int = 1 | |
block_gmlp_factor: int = 2 | |
grid_gmlp_factor: int = 2 | |
input_proj_factor: int = 2 | |
channels_reduction: int = 4 | |
dropout_rate: float = 0.0 | |
use_bias: bool = True | |
def __call__(self, x, deterministic): | |
"""Applies the Mixer block to inputs.""" | |
assert x.ndim == 4 # Input has shape [batch, h, w, c] | |
n, h, w, num_channels = x.shape | |
# input projection | |
x = Conv1x1(self.features, use_bias=self.use_bias, name="input_proj")(x) | |
shortcut_long = x | |
for i in range(self.num_groups): | |
x = ResidualSplitHeadMultiAxisGmlpLayer( | |
grid_size=self.grid_size, | |
block_size=self.block_size, | |
grid_gmlp_factor=self.grid_gmlp_factor, | |
block_gmlp_factor=self.block_gmlp_factor, | |
input_proj_factor=self.input_proj_factor, | |
use_bias=self.use_bias, | |
dropout_rate=self.dropout_rate, | |
name=f"SplitHeadMultiAxisGmlpLayer_{i}")(x, deterministic) | |
# Channel-mixing part, which provides within-patch communication. | |
x = RDCAB( | |
features=self.features, | |
reduction=self.channels_reduction, | |
use_bias=self.use_bias, | |
name=f"channel_attention_block_1_{i}")( | |
x) | |
# long skip-connect | |
x = x + shortcut_long | |
return x | |
class UNetEncoderBlock(nn.Module): | |
"""Encoder block in MAXIM.""" | |
features: int | |
block_size: Sequence[int] | |
grid_size: Sequence[int] | |
num_groups: int = 1 | |
lrelu_slope: float = 0.2 | |
block_gmlp_factor: int = 2 | |
grid_gmlp_factor: int = 2 | |
input_proj_factor: int = 2 | |
channels_reduction: int = 4 | |
dropout_rate: float = 0.0 | |
downsample: bool = True | |
use_global_mlp: bool = True | |
use_bias: bool = True | |
use_cross_gating: bool = False | |
def __call__(self, x: jnp.ndarray, skip: jnp.ndarray = None, | |
enc: jnp.ndarray = None, dec: jnp.ndarray = None, *, | |
deterministic: bool = True) -> jnp.ndarray: | |
if skip is not None: | |
x = jnp.concatenate([x, skip], axis=-1) | |
# convolution-in | |
x = Conv1x1(self.features, use_bias=self.use_bias)(x) | |
shortcut_long = x | |
for i in range(self.num_groups): | |
if self.use_global_mlp: | |
x = ResidualSplitHeadMultiAxisGmlpLayer( | |
grid_size=self.grid_size, | |
block_size=self.block_size, | |
grid_gmlp_factor=self.grid_gmlp_factor, | |
block_gmlp_factor=self.block_gmlp_factor, | |
input_proj_factor=self.input_proj_factor, | |
use_bias=self.use_bias, | |
dropout_rate=self.dropout_rate, | |
name=f"SplitHeadMultiAxisGmlpLayer_{i}")(x, deterministic) | |
x = RCAB( | |
features=self.features, | |
reduction=self.channels_reduction, | |
use_bias=self.use_bias, | |
name=f"channel_attention_block_1{i}")(x) | |
x = x + shortcut_long | |
if enc is not None and dec is not None: | |
assert self.use_cross_gating | |
x, _ = CrossGatingBlock( | |
features=self.features, | |
block_size=self.block_size, | |
grid_size=self.grid_size, | |
dropout_rate=self.dropout_rate, | |
input_proj_factor=self.input_proj_factor, | |
upsample_y=False, | |
use_bias=self.use_bias, | |
name="cross_gating_block")( | |
x, enc + dec, deterministic=deterministic) | |
if self.downsample: | |
x_down = Conv_down(self.features, use_bias=self.use_bias)(x) | |
return x_down, x | |
else: | |
return x | |
class UNetDecoderBlock(nn.Module): | |
"""Decoder block in MAXIM.""" | |
features: int | |
block_size: Sequence[int] | |
grid_size: Sequence[int] | |
num_groups: int = 1 | |
lrelu_slope: float = 0.2 | |
block_gmlp_factor: int = 2 | |
grid_gmlp_factor: int = 2 | |
input_proj_factor: int = 2 | |
channels_reduction: int = 4 | |
dropout_rate: float = 0.0 | |
downsample: bool = True | |
use_global_mlp: bool = True | |
use_bias: bool = True | |
def __call__(self, x: jnp.ndarray, bridge: jnp.ndarray = None, | |
deterministic: bool = True) -> jnp.ndarray: | |
x = ConvT_up(self.features, use_bias=self.use_bias)(x) | |
x = UNetEncoderBlock( | |
self.features, | |
num_groups=self.num_groups, | |
lrelu_slope=self.lrelu_slope, | |
block_size=self.block_size, | |
grid_size=self.grid_size, | |
block_gmlp_factor=self.block_gmlp_factor, | |
grid_gmlp_factor=self.grid_gmlp_factor, | |
channels_reduction=self.channels_reduction, | |
use_global_mlp=self.use_global_mlp, | |
dropout_rate=self.dropout_rate, | |
downsample=False, | |
use_bias=self.use_bias)(x, skip=bridge, deterministic=deterministic) | |
return x | |
class GetSpatialGatingWeights(nn.Module): | |
"""Get gating weights for cross-gating MLP block.""" | |
features: int | |
block_size: Sequence[int] | |
grid_size: Sequence[int] | |
input_proj_factor: int = 2 | |
dropout_rate: float = 0.0 | |
use_bias: bool = True | |
def __call__(self, x, deterministic): | |
n, h, w, num_channels = x.shape | |
# input projection | |
x = nn.LayerNorm(name="LayerNorm_in")(x) | |
x = nn.Dense( | |
num_channels * self.input_proj_factor, | |
use_bias=self.use_bias, | |
name="in_project")( | |
x) | |
x = nn.gelu(x) | |
u, v = jnp.split(x, 2, axis=-1) | |
# Get grid MLP weights | |
gh, gw = self.grid_size | |
fh, fw = h // gh, w // gw | |
u = block_images_einops(u, patch_size=(fh, fw)) | |
dim_u = u.shape[-3] | |
u = jnp.swapaxes(u, -1, -3) | |
u = nn.Dense( | |
dim_u, use_bias=self.use_bias, kernel_init=nn.initializers.normal(2e-2), | |
bias_init=nn.initializers.ones)(u) | |
u = jnp.swapaxes(u, -1, -3) | |
u = unblock_images_einops(u, grid_size=(gh, gw), patch_size=(fh, fw)) | |
# Get Block MLP weights | |
fh, fw = self.block_size | |
gh, gw = h // fh, w // fw | |
v = block_images_einops(v, patch_size=(fh, fw)) | |
dim_v = v.shape[-2] | |
v = jnp.swapaxes(v, -1, -2) | |
v = nn.Dense( | |
dim_v, use_bias=self.use_bias, kernel_init=nn.initializers.normal(2e-2), | |
bias_init=nn.initializers.ones)(v) | |
v = jnp.swapaxes(v, -1, -2) | |
v = unblock_images_einops(v, grid_size=(gh, gw), patch_size=(fh, fw)) | |
x = jnp.concatenate([u, v], axis=-1) | |
x = nn.Dense(num_channels, use_bias=self.use_bias, name="out_project")(x) | |
x = nn.Dropout(self.dropout_rate)(x, deterministic) | |
return x | |
class CrossGatingBlock(nn.Module): | |
"""Cross-gating MLP block.""" | |
features: int | |
block_size: Sequence[int] | |
grid_size: Sequence[int] | |
dropout_rate: float = 0.0 | |
input_proj_factor: int = 2 | |
upsample_y: bool = True | |
use_bias: bool = True | |
def __call__(self, x, y, deterministic=True): | |
# Upscale Y signal, y is the gating signal. | |
if self.upsample_y: | |
y = ConvT_up(self.features, use_bias=self.use_bias)(y) | |
x = Conv1x1(self.features, use_bias=self.use_bias)(x) | |
n, h, w, num_channels = x.shape | |
y = Conv1x1(num_channels, use_bias=self.use_bias)(y) | |
assert y.shape == x.shape | |
shortcut_x = x | |
shortcut_y = y | |
# Get gating weights from X | |
x = nn.LayerNorm(name="LayerNorm_x")(x) | |
x = nn.Dense(num_channels, use_bias=self.use_bias, name="in_project_x")(x) | |
x = nn.gelu(x) | |
gx = GetSpatialGatingWeights( | |
features=num_channels, | |
block_size=self.block_size, | |
grid_size=self.grid_size, | |
dropout_rate=self.dropout_rate, | |
use_bias=self.use_bias, | |
name="SplitHeadMultiAxisGating_x")( | |
x, deterministic=deterministic) | |
# Get gating weights from Y | |
y = nn.LayerNorm(name="LayerNorm_y")(y) | |
y = nn.Dense(num_channels, use_bias=self.use_bias, name="in_project_y")(y) | |
y = nn.gelu(y) | |
gy = GetSpatialGatingWeights( | |
features=num_channels, | |
block_size=self.block_size, | |
grid_size=self.grid_size, | |
dropout_rate=self.dropout_rate, | |
use_bias=self.use_bias, | |
name="SplitHeadMultiAxisGating_y")( | |
y, deterministic=deterministic) | |
# Apply cross gating: X = X * GY, Y = Y * GX | |
y = y * gx | |
y = nn.Dense(num_channels, use_bias=self.use_bias, name="out_project_y")(y) | |
y = nn.Dropout(self.dropout_rate)(y, deterministic=deterministic) | |
y = y + shortcut_y | |
x = x * gy # gating x using y | |
x = nn.Dense(num_channels, use_bias=self.use_bias, name="out_project_x")(x) | |
x = nn.Dropout(self.dropout_rate)(x, deterministic=deterministic) | |
x = x + y + shortcut_x # get all aggregated signals | |
return x, y | |
class SAM(nn.Module): | |
"""Supervised attention module for multi-stage training. | |
Introduced by MPRNet [CVPR2021]: https://github.com/swz30/MPRNet | |
""" | |
features: int | |
output_channels: int = 3 | |
use_bias: bool = True | |
def __call__(self, x: jnp.ndarray, x_image: jnp.ndarray, *, | |
train: bool) -> Tuple[jnp.ndarray, jnp.ndarray]: | |
"""Apply the SAM module to the input and features. | |
Args: | |
x: the output features from UNet decoder with shape (h, w, c) | |
x_image: the input image with shape (h, w, 3) | |
train: Whether it is training | |
Returns: | |
A tuple of tensors (x1, image) where (x1) is the sam features used for the | |
next stage, and (image) is the output restored image at current stage. | |
""" | |
# Get features | |
x1 = Conv3x3(self.features, use_bias=self.use_bias)(x) | |
# Output restored image X_s | |
if self.output_channels == 3: | |
image = Conv3x3(self.output_channels, use_bias=self.use_bias)(x) + x_image | |
else: | |
image = Conv3x3(self.output_channels, use_bias=self.use_bias)(x) | |
# Get attention maps for features | |
x2 = nn.sigmoid(Conv3x3(self.features, use_bias=self.use_bias)(image)) | |
# Get attended feature maps | |
x1 = x1 * x2 | |
# Residual connection | |
x1 = x1 + x | |
return x1, image | |
class MAXIM(nn.Module): | |
"""The MAXIM model function with multi-stage and multi-scale supervision. | |
For more model details, please check the CVPR paper: | |
MAXIM: MUlti-Axis MLP for Image Processing (https://arxiv.org/abs/2201.02973) | |
Attributes: | |
features: initial hidden dimension for the input resolution. | |
depth: the number of downsampling depth for the model. | |
num_stages: how many stages to use. It will also affects the output list. | |
num_groups: how many blocks each stage contains. | |
use_bias: whether to use bias in all the conv/mlp layers. | |
num_supervision_scales: the number of desired supervision scales. | |
lrelu_slope: the negative slope parameter in leaky_relu layers. | |
use_global_mlp: whether to use the multi-axis gated MLP block (MAB) in each | |
layer. | |
use_cross_gating: whether to use the cross-gating MLP block (CGB) in the | |
skip connections and multi-stage feature fusion layers. | |
high_res_stages: how many stages are specificied as high-res stages. The | |
rest (depth - high_res_stages) are called low_res_stages. | |
block_size_hr: the block_size parameter for high-res stages. | |
block_size_lr: the block_size parameter for low-res stages. | |
grid_size_hr: the grid_size parameter for high-res stages. | |
grid_size_lr: the grid_size parameter for low-res stages. | |
num_bottleneck_blocks: how many bottleneck blocks. | |
block_gmlp_factor: the input projection factor for block_gMLP layers. | |
grid_gmlp_factor: the input projection factor for grid_gMLP layers. | |
input_proj_factor: the input projection factor for the MAB block. | |
channels_reduction: the channel reduction factor for SE layer. | |
num_outputs: the output channels. | |
dropout_rate: Dropout rate. | |
Returns: | |
The output contains a list of arrays consisting of multi-stage multi-scale | |
outputs. For example, if num_stages = num_supervision_scales = 3 (the | |
model used in the paper), the output specs are: outputs = | |
[[output_stage1_scale1, output_stage1_scale2, output_stage1_scale3], | |
[output_stage2_scale1, output_stage2_scale2, output_stage2_scale3], | |
[output_stage3_scale1, output_stage3_scale2, output_stage3_scale3],] | |
The final output can be retrieved by outputs[-1][-1]. | |
""" | |
features: int = 64 | |
depth: int = 3 | |
num_stages: int = 2 | |
num_groups: int = 1 | |
use_bias: bool = True | |
num_supervision_scales: int = 1 | |
lrelu_slope: float = 0.2 | |
use_global_mlp: bool = True | |
use_cross_gating: bool = True | |
high_res_stages: int = 2 | |
block_size_hr: Sequence[int] = (16, 16) | |
block_size_lr: Sequence[int] = (8, 8) | |
grid_size_hr: Sequence[int] = (16, 16) | |
grid_size_lr: Sequence[int] = (8, 8) | |
num_bottleneck_blocks: int = 1 | |
block_gmlp_factor: int = 2 | |
grid_gmlp_factor: int = 2 | |
input_proj_factor: int = 2 | |
channels_reduction: int = 4 | |
num_outputs: int = 3 | |
dropout_rate: float = 0.0 | |
def __call__(self, x: jnp.ndarray, *, train: bool = False) -> Any: | |
n, h, w, c = x.shape # input image shape | |
shortcuts = [] | |
shortcuts.append(x) | |
# Get multi-scale input images | |
for i in range(1, self.num_supervision_scales): | |
shortcuts.append(jax.image.resize( | |
x, shape=(n, h // (2**i), w // (2**i), c), method="nearest")) | |
# store outputs from all stages and all scales | |
# Eg, [[(64, 64, 3), (128, 128, 3), (256, 256, 3)], # Stage-1 outputs | |
# [(64, 64, 3), (128, 128, 3), (256, 256, 3)],] # Stage-2 outputs | |
outputs_all = [] | |
sam_features, encs_prev, decs_prev = [], [], [] | |
for idx_stage in range(self.num_stages): | |
# Input convolution, get multi-scale input features | |
x_scales = [] | |
for i in range(self.num_supervision_scales): | |
x_scale = Conv3x3( | |
(2**i) * self.features, | |
use_bias=self.use_bias, | |
name=f"stage_{idx_stage}_input_conv_{i}")( | |
shortcuts[i]) | |
# If later stages, fuse input features with SAM features from prev stage | |
if idx_stage > 0: | |
# use larger blocksize at high-res stages | |
if self.use_cross_gating: | |
block_size = self.block_size_hr if i < self.high_res_stages else self.block_size_lr | |
grid_size = self.grid_size_hr if i < self.high_res_stages else self.block_size_lr | |
x_scale, _ = CrossGatingBlock( | |
features=(2**i) * self.features, | |
block_size=block_size, | |
grid_size=grid_size, | |
dropout_rate=self.dropout_rate, | |
input_proj_factor=self.input_proj_factor, | |
upsample_y=False, | |
use_bias=self.use_bias, | |
name=f"stage_{idx_stage}_input_fuse_sam_{i}")( | |
x_scale, sam_features.pop(), deterministic=not train) | |
else: | |
x_scale = Conv1x1( | |
(2**i) * self.features, | |
use_bias=self.use_bias, | |
name=f"stage_{idx_stage}_input_catconv_{i}")( | |
jnp.concatenate( | |
[x_scale, sam_features.pop()], axis=-1)) | |
x_scales.append(x_scale) | |
# start encoder blocks | |
encs = [] | |
x = x_scales[0] # First full-scale input feature | |
for i in range(self.depth): # 0, 1, 2 | |
# use larger blocksize at high-res stages, vice versa. | |
block_size = self.block_size_hr if i < self.high_res_stages else self.block_size_lr | |
grid_size = self.grid_size_hr if i < self.high_res_stages else self.block_size_lr | |
use_cross_gating_layer = True if idx_stage > 0 else False | |
# Multi-scale input if multi-scale supervision | |
x_scale = x_scales[i] if i < self.num_supervision_scales else None | |
# UNet Encoder block | |
enc_prev = encs_prev.pop() if idx_stage > 0 else None | |
dec_prev = decs_prev.pop() if idx_stage > 0 else None | |
x, bridge = UNetEncoderBlock( | |
features=(2**i) * self.features, | |
num_groups=self.num_groups, | |
downsample=True, | |
lrelu_slope=self.lrelu_slope, | |
block_size=block_size, | |
grid_size=grid_size, | |
block_gmlp_factor=self.block_gmlp_factor, | |
grid_gmlp_factor=self.grid_gmlp_factor, | |
input_proj_factor=self.input_proj_factor, | |
channels_reduction=self.channels_reduction, | |
use_global_mlp=self.use_global_mlp, | |
dropout_rate=self.dropout_rate, | |
use_bias=self.use_bias, | |
use_cross_gating=use_cross_gating_layer, | |
name=f"stage_{idx_stage}_encoder_block_{i}")( | |
x, | |
skip=x_scale, | |
enc=enc_prev, | |
dec=dec_prev, | |
deterministic=not train) | |
# Cache skip signals | |
encs.append(bridge) | |
# Global MLP bottleneck blocks | |
for i in range(self.num_bottleneck_blocks): | |
x = BottleneckBlock( | |
block_size=self.block_size_lr, | |
grid_size=self.block_size_lr, | |
features=(2**(self.depth - 1)) * self.features, | |
num_groups=self.num_groups, | |
block_gmlp_factor=self.block_gmlp_factor, | |
grid_gmlp_factor=self.grid_gmlp_factor, | |
input_proj_factor=self.input_proj_factor, | |
dropout_rate=self.dropout_rate, | |
use_bias=self.use_bias, | |
channels_reduction=self.channels_reduction, | |
name=f"stage_{idx_stage}_global_block_{i}")( | |
x, deterministic=not train) | |
# cache global feature for cross-gating | |
global_feature = x | |
# start cross gating. Use multi-scale feature fusion | |
skip_features = [] | |
for i in reversed(range(self.depth)): # 2, 1, 0 | |
# use larger blocksize at high-res stages | |
block_size = self.block_size_hr if i < self.high_res_stages else self.block_size_lr | |
grid_size = self.grid_size_hr if i < self.high_res_stages else self.block_size_lr | |
# get additional multi-scale signals | |
signal = jnp.concatenate([ | |
UpSampleRatio( | |
(2**i) * self.features, | |
ratio=2**(j - i), | |
use_bias=self.use_bias)(enc) for j, enc in enumerate(encs) | |
], | |
axis=-1) | |
# Use cross-gating to cross modulate features | |
if self.use_cross_gating: | |
skips, global_feature = CrossGatingBlock( | |
features=(2**i) * self.features, | |
block_size=block_size, | |
grid_size=grid_size, | |
input_proj_factor=self.input_proj_factor, | |
dropout_rate=self.dropout_rate, | |
upsample_y=True, | |
use_bias=self.use_bias, | |
name=f"stage_{idx_stage}_cross_gating_block_{i}")( | |
signal, global_feature, deterministic=not train) | |
else: | |
skips = Conv1x1( | |
(2**i) * self.features, use_bias=self.use_bias)( | |
signal) | |
skips = Conv3x3((2**i) * self.features, use_bias=self.use_bias)(skips) | |
skip_features.append(skips) | |
# start decoder. Multi-scale feature fusion of cross-gated features | |
outputs, decs, sam_features = [], [], [] | |
for i in reversed(range(self.depth)): | |
# use larger blocksize at high-res stages | |
block_size = self.block_size_hr if i < self.high_res_stages else self.block_size_lr | |
grid_size = self.grid_size_hr if i < self.high_res_stages else self.block_size_lr | |
# get multi-scale skip signals from cross-gating block | |
signal = jnp.concatenate([ | |
UpSampleRatio( | |
(2**i) * self.features, | |
ratio=2**(self.depth - j - 1 - i), | |
use_bias=self.use_bias)(skip) | |
for j, skip in enumerate(skip_features) | |
], | |
axis=-1) | |
# Decoder block | |
x = UNetDecoderBlock( | |
features=(2**i) * self.features, | |
num_groups=self.num_groups, | |
lrelu_slope=self.lrelu_slope, | |
block_size=block_size, | |
grid_size=grid_size, | |
block_gmlp_factor=self.block_gmlp_factor, | |
grid_gmlp_factor=self.grid_gmlp_factor, | |
input_proj_factor=self.input_proj_factor, | |
channels_reduction=self.channels_reduction, | |
use_global_mlp=self.use_global_mlp, | |
dropout_rate=self.dropout_rate, | |
use_bias=self.use_bias, | |
name=f"stage_{idx_stage}_decoder_block_{i}")( | |
x, bridge=signal, deterministic=not train) | |
# Cache decoder features for later-stage's usage | |
decs.append(x) | |
# output conv, if not final stage, use supervised-attention-block. | |
if i < self.num_supervision_scales: | |
if idx_stage < self.num_stages - 1: # not last stage, apply SAM | |
sam, output = SAM( | |
(2**i) * self.features, | |
output_channels=self.num_outputs, | |
use_bias=self.use_bias, | |
name=f"stage_{idx_stage}_supervised_attention_module_{i}")( | |
x, shortcuts[i], train=train) | |
outputs.append(output) | |
sam_features.append(sam) | |
else: # Last stage, apply output convolutions | |
output = Conv3x3(self.num_outputs, | |
use_bias=self.use_bias, | |
name=f"stage_{idx_stage}_output_conv_{i}")(x) | |
output = output + shortcuts[i] | |
outputs.append(output) | |
# Cache encoder and decoder features for later-stage's usage | |
encs_prev = encs[::-1] | |
decs_prev = decs | |
# Store outputs | |
outputs_all.append(outputs) | |
return outputs_all | |
def Model(*, variant=None, **kw): | |
"""Factory function to easily create a Model variant like "S". | |
Every model file should have this Model() function that returns the flax | |
model function. The function name should be fixed. | |
Args: | |
variant: UNet model variants. Options: 'S-1' | 'S-2' | 'S-3' | |
| 'M-1' | 'M-2' | 'M-3' | |
**kw: Other UNet config dicts. | |
Returns: | |
The MAXIM() model function | |
""" | |
if variant is not None: | |
config = { | |
# params: 6.108515000000001 M, GFLOPS: 93.163716608 | |
"S-1": { | |
"features": 32, | |
"depth": 3, | |
"num_stages": 1, | |
"num_groups": 2, | |
"num_bottleneck_blocks": 2, | |
"block_gmlp_factor": 2, | |
"grid_gmlp_factor": 2, | |
"input_proj_factor": 2, | |
"channels_reduction": 4, | |
}, | |
# params: 13.35383 M, GFLOPS: 206.743273472 | |
"S-2": { | |
"features": 32, | |
"depth": 3, | |
"num_stages": 2, | |
"num_groups": 2, | |
"num_bottleneck_blocks": 2, | |
"block_gmlp_factor": 2, | |
"grid_gmlp_factor": 2, | |
"input_proj_factor": 2, | |
"channels_reduction": 4, | |
}, | |
# params: 20.599145 M, GFLOPS: 320.32194560000005 | |
"S-3": { | |
"features": 32, | |
"depth": 3, | |
"num_stages": 3, | |
"num_groups": 2, | |
"num_bottleneck_blocks": 2, | |
"block_gmlp_factor": 2, | |
"grid_gmlp_factor": 2, | |
"input_proj_factor": 2, | |
"channels_reduction": 4, | |
}, | |
# params: 19.361219000000002 M, 308.495712256 GFLOPs | |
"M-1": { | |
"features": 64, | |
"depth": 3, | |
"num_stages": 1, | |
"num_groups": 2, | |
"num_bottleneck_blocks": 2, | |
"block_gmlp_factor": 2, | |
"grid_gmlp_factor": 2, | |
"input_proj_factor": 2, | |
"channels_reduction": 4, | |
}, | |
# params: 40.83911 M, 675.25541888 GFLOPs | |
"M-2": { | |
"features": 64, | |
"depth": 3, | |
"num_stages": 2, | |
"num_groups": 2, | |
"num_bottleneck_blocks": 2, | |
"block_gmlp_factor": 2, | |
"grid_gmlp_factor": 2, | |
"input_proj_factor": 2, | |
"channels_reduction": 4, | |
}, | |
# params: 62.317001 M, 1042.014666752 GFLOPs | |
"M-3": { | |
"features": 64, | |
"depth": 3, | |
"num_stages": 3, | |
"num_groups": 2, | |
"num_bottleneck_blocks": 2, | |
"block_gmlp_factor": 2, | |
"grid_gmlp_factor": 2, | |
"input_proj_factor": 2, | |
"channels_reduction": 4, | |
}, | |
}[variant] | |
for k, v in config.items(): | |
kw.setdefault(k, v) | |
return MAXIM(**kw) | |