selfmask / networks /maskformer /maskformer.py
NoelShin
Add application file
35188e4
from typing import Dict, List
from math import sqrt, log
import torch
import torch.nn as nn
import torch.nn.functional as F
from networks.maskformer.transformer_decoder import TransformerDecoderLayer, TransformerDecoder
from utils import get_model
class MaskFormer(nn.Module):
def __init__(
self,
n_queries: int = 100,
arch: str = "vit_small",
patch_size: int = 8,
training_method: str = "dino",
n_decoder_layers: int = 6,
normalize_before: bool = False,
return_intermediate: bool = False,
learnable_pixel_decoder: bool = False,
lateral_connection: bool = False,
scale_factor: int = 2,
abs_2d_pe_init: bool = False,
use_binary_classifier: bool = False
):
"""Define a encoder and decoder along with queries to be learned through the decoder."""
super(MaskFormer, self).__init__()
if arch == "vit_small":
self.encoder = get_model(arch=arch, patch_size=patch_size, training_method=training_method)
n_dims: int = self.encoder.n_embs
n_heads: int = self.encoder.n_heads
mlp_ratio: int = self.encoder.mlp_ratio
else:
self.encoder = get_model(arch=arch, training_method=training_method)
n_dims_resnet: int = self.encoder.n_embs
n_dims: int = 384
n_heads: int = 6
mlp_ratio: int = 4
self.linear_layer = nn.Conv2d(n_dims_resnet, n_dims, kernel_size=1)
decoder_layer = TransformerDecoderLayer(
n_dims, n_heads, n_dims * mlp_ratio, 0., activation="relu", normalize_before=normalize_before
)
self.decoder = TransformerDecoder(
decoder_layer,
n_decoder_layers,
norm=nn.LayerNorm(n_dims),
return_intermediate=return_intermediate
)
self.query_embed = nn.Embedding(n_queries, n_dims).weight # initialized with gaussian(0, 1)
if use_binary_classifier:
# self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
# self.linear_classifier = nn.Linear(n_dims, 1)
self.ffn = MLP(n_dims, n_dims, 1, num_layers=3)
# self.norm = nn.LayerNorm(n_dims)
else:
# self.ffn = None
# self.linear_classifier = None
# self.norm = None
self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
self.linear_classifier = nn.Linear(n_dims, 2)
self.norm = nn.LayerNorm(n_dims)
self.arch = arch
self.use_binary_classifier = use_binary_classifier
self.lateral_connection = lateral_connection
self.learnable_pixel_decoder = learnable_pixel_decoder
self.scale_factor = scale_factor
# copy-pasted from https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py
@staticmethod
def positional_encoding_2d(n_dims: int, height: int, width: int):
"""
:param n_dims: dimension of the model
:param height: height of the positions
:param width: width of the positions
:return: d_model*height*width position matrix
"""
if n_dims % 4 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dimension (got dim={:d})".format(n_dims))
pe = torch.zeros(n_dims, height, width)
# Each dimension use half of d_model
d_model = int(n_dims / 2)
div_term = torch.exp(torch.arange(0., d_model, 2) * -(log(10000.0) / d_model))
pos_w = torch.arange(0., width).unsqueeze(1)
pos_h = torch.arange(0., height).unsqueeze(1)
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
return pe
def forward_encoder(self, x: torch.Tensor):
"""
:param x: b x c x h x w
:return patch_tokens: b x depth x hw x n_dims
"""
if self.arch == "vit_small":
encoder_outputs: Dict[str, torch.Tensor] = self.encoder(x) # [:, 1:, :]
all_patch_tokens: List[torch.Tensor] = list()
for layer_name in [f"layer{num_layer}" for num_layer in range(1, self.encoder.depth + 1)]:
patch_tokens: torch.Tensor = encoder_outputs[layer_name][:, 1:, :] # b x hw x n_dims
all_patch_tokens.append(patch_tokens)
all_patch_tokens: torch.Tensor = torch.stack(all_patch_tokens, dim=0) # depth x b x hw x n_dims
all_patch_tokens = all_patch_tokens.permute(1, 0, 3, 2) # b x depth x n_dims x hw
return all_patch_tokens
else:
encoder_outputs = self.linear_layer(self.encoder(x)[-1]) # b x n_dims x h x w
return encoder_outputs
def forward_transformer_decoder(self, patch_tokens: torch.Tensor, skip_decoder: bool = False) -> torch.Tensor:
"""Forward transformer decoder given patch tokens from the encoder's last layer.
:param patch_tokens: b x n_dims x hw -> hw x b x n_dims
:param skip_decoder: if True, skip the decoder and produce mask predictions directly by matrix multiplication
between learnable queries and encoder features (i.e., patch tokens). This is for the purpose of an overfitting
experiment.
:return queries: n_queries x b x n_dims -> b x n_queries x n_dims or b x n_layers x n_queries x n_dims
"""
b = patch_tokens.shape[0]
patch_tokens = patch_tokens.permute(2, 0, 1) # b x n_dims x hw -> hw x b x n_dims
# n_queries x n_dims -> n_queries x b x n_dims
queries: torch.Tensor = self.query_embed.unsqueeze(1).repeat(1, b, 1)
queries: torch.Tensor = self.decoder.forward(
tgt=torch.zeros_like(queries),
memory=patch_tokens,
query_pos=queries
).squeeze(dim=0)
if len(queries.shape) == 3:
queries: torch.Tensor = queries.permute(1, 0, 2) # n_queries x b x n_dims -> b x n_queries x n_dims
elif len(queries.shape) == 4:
# n_layers x n_queries x b x n_dims -> b x n_layers x n_queries x n_dims
queries: torch.Tensor = queries.permute(2, 0, 1, 3)
return queries
def forward_pixel_decoder(self, patch_tokens: torch.Tensor, input_size=None):
""" Upsample patch tokens by self.scale_factor and produce mask predictions
:param patch_tokens: b (x depth) x n_dims x hw -> b (x depth) x n_dims x h x w
:param queries: b x n_queries x n_dims
:return mask_predictions: b x n_queries x h x w
"""
if input_size is None:
# assume square shape features
hw = patch_tokens.shape[-1]
h = w = int(sqrt(hw))
else:
# arbitrary shape features
h, w = input_size
patch_tokens = patch_tokens.view(*patch_tokens.shape[:-1], h, w)
assert len(patch_tokens.shape) == 4
patch_tokens = F.interpolate(patch_tokens, scale_factor=self.scale_factor, mode="bilinear")
return patch_tokens
def forward(self, x, encoder_only=False, skip_decoder: bool = False):
"""
x: b x c x h x w
patch_tokens: b x n_patches x n_dims -> n_patches x b x n_dims
query_emb: n_queries x n_dims -> n_queries x b x n_dims
"""
dict_outputs: dict = dict()
# b x depth x n_dims x hw (vit) or b x n_dims x h x w (resnet50)
features: torch.Tensor = self.forward_encoder(x)
if self.arch == "vit_small":
# extract the last layer for decoder input
last_layer_features: torch.Tensor = features[:, -1, ...] # b x n_dims x hw
else:
# transform the shape of the features to the one compatible with transformer decoder
b, n_dims, h, w = features.shape
last_layer_features: torch.Tensor = features.view(b, n_dims, h * w) # b x n_dims x hw
if encoder_only:
_h, _w = self.encoder.make_input_divisible(x).shape[-2:]
_h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size
b, n_dims, hw = last_layer_features.shape
dict_outputs.update({"patch_tokens": last_layer_features.view(b, _h, _w, n_dims)})
return dict_outputs
# transformer decoder forward
queries: torch.Tensor = self.forward_transformer_decoder(
last_layer_features,
skip_decoder=skip_decoder
) # b x n_queries x n_dims or b x n_layers x n_queries x n_dims
# pixel decoder forward (upsampling the patch tokens by self.scale_factor)
if self.arch == "vit_small":
_h, _w = self.encoder.make_input_divisible(x).shape[-2:]
_h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size
else:
_h, _w = h, w
features: torch.Tensor = self.forward_pixel_decoder(
patch_tokens=features if self.lateral_connection else last_layer_features,
input_size=(_h, _w)
) # b x n_dims x h x w
# queries: b x n_queries x n_dims or b x n_layers x n_queries x n_dims
# features: b x n_dims x h x w
# mask_pred: b x n_queries x h x w or b x n_layers x n_queries x h x w
if len(queries.shape) == 3:
mask_pred = torch.einsum("bqn,bnhw->bqhw", queries, features)
else:
if self.use_binary_classifier:
mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", queries, features))
else:
mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", self.ffn(queries), features))
if self.use_binary_classifier:
# queries: b x n_layers x n_queries x n_dims -> n_layers x b x n_queries x n_dims
queries = queries.permute(1, 0, 2, 3)
objectness: List[torch.Tensor] = list()
for n_layer, queries_per_layer in enumerate(queries): # queries_per_layer: b x n_queries x n_dims
# objectness_per_layer = self.linear_classifier(
# self.ffn(self.norm(queries_per_layer))
# ) # b x n_queries x 1
objectness_per_layer = self.ffn(queries_per_layer) # b x n_queries x 1
objectness.append(objectness_per_layer)
# n_layers x b x n_queries x 1 -> # b x n_layers x n_queries x 1
objectness: torch.Tensor = torch.stack(objectness).permute(1, 0, 2, 3)
dict_outputs.update({
"objectness": torch.sigmoid(objectness),
"mask_pred": mask_pred
})
return dict_outputs
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class UpsampleBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, n_groups=32, scale_factor=2):
super(UpsampleBlock, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
nn.GroupNorm(n_groups, out_channels),
nn.ReLU()
)
self.scale_factor = scale_factor
def forward(self, x):
return F.interpolate(self.block(x), scale_factor=self.scale_factor, mode="bilinear")