|
from typing import Dict, Optional |
|
from omegaconf import DictConfig |
|
|
|
import torch |
|
import torch.nn as nn |
|
from matanyone.model.group_modules import GConv2d |
|
from matanyone.utils.tensor_utils import aggregate |
|
from matanyone.model.transformer.positional_encoding import PositionalEncoding |
|
from matanyone.model.transformer.transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN |
|
|
|
|
|
class QueryTransformerBlock(nn.Module): |
|
def __init__(self, model_cfg: DictConfig): |
|
super().__init__() |
|
|
|
this_cfg = model_cfg.object_transformer |
|
self.embed_dim = this_cfg.embed_dim |
|
self.num_heads = this_cfg.num_heads |
|
self.num_queries = this_cfg.num_queries |
|
self.ff_dim = this_cfg.ff_dim |
|
|
|
self.read_from_pixel = CrossAttention(self.embed_dim, |
|
self.num_heads, |
|
add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv) |
|
self.self_attn = SelfAttention(self.embed_dim, |
|
self.num_heads, |
|
add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv) |
|
self.ffn = FFN(self.embed_dim, self.ff_dim) |
|
self.read_from_query = CrossAttention(self.embed_dim, |
|
self.num_heads, |
|
add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv, |
|
norm=this_cfg.read_from_query.output_norm) |
|
self.pixel_ffn = PixelFFN(self.embed_dim) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
pixel: torch.Tensor, |
|
query_pe: torch.Tensor, |
|
pixel_pe: torch.Tensor, |
|
attn_mask: torch.Tensor, |
|
need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() |
|
x, q_weights = self.read_from_pixel(x, |
|
pixel_flat, |
|
query_pe, |
|
pixel_pe, |
|
attn_mask=attn_mask, |
|
need_weights=need_weights) |
|
x = self.self_attn(x, query_pe) |
|
x = self.ffn(x) |
|
|
|
pixel_flat, p_weights = self.read_from_query(pixel_flat, |
|
x, |
|
pixel_pe, |
|
query_pe, |
|
need_weights=need_weights) |
|
pixel = self.pixel_ffn(pixel, pixel_flat) |
|
|
|
if need_weights: |
|
bs, num_objects, _, h, w = pixel.shape |
|
q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w) |
|
p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads, |
|
self.num_queries, h, w) |
|
|
|
return x, pixel, q_weights, p_weights |
|
|
|
|
|
class QueryTransformer(nn.Module): |
|
def __init__(self, model_cfg: DictConfig): |
|
super().__init__() |
|
|
|
this_cfg = model_cfg.object_transformer |
|
self.value_dim = model_cfg.value_dim |
|
self.embed_dim = this_cfg.embed_dim |
|
self.num_heads = this_cfg.num_heads |
|
self.num_queries = this_cfg.num_queries |
|
|
|
|
|
self.query_init = nn.Embedding(self.num_queries, self.embed_dim) |
|
self.query_emb = nn.Embedding(self.num_queries, self.embed_dim) |
|
|
|
|
|
self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim) |
|
self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
self.pixel_pe_scale = model_cfg.pixel_pe_scale |
|
self.pixel_pe_temperature = model_cfg.pixel_pe_temperature |
|
self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) |
|
self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) |
|
self.spatial_pe = PositionalEncoding(self.embed_dim, |
|
scale=self.pixel_pe_scale, |
|
temperature=self.pixel_pe_temperature, |
|
channel_last=False, |
|
transpose_output=True) |
|
|
|
|
|
self.num_blocks = this_cfg.num_blocks |
|
self.blocks = nn.ModuleList( |
|
QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks)) |
|
self.mask_pred = nn.ModuleList( |
|
nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1)) |
|
for _ in range(self.num_blocks + 1)) |
|
|
|
self.act = nn.ReLU(inplace=True) |
|
|
|
def forward(self, |
|
pixel: torch.Tensor, |
|
obj_summaries: torch.Tensor, |
|
selector: Optional[torch.Tensor] = None, |
|
need_weights: bool = False, |
|
seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]): |
|
|
|
|
|
T = obj_summaries.shape[2] |
|
bs, num_objects, _, H, W = pixel.shape |
|
|
|
|
|
|
|
obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries, |
|
self.embed_dim + 1) |
|
|
|
|
|
obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1) |
|
obj_area = obj_summaries[:, :, :, -1:].sum(dim=1) |
|
obj_values = obj_sums / (obj_area + 1e-4) |
|
obj_init = self.summary_to_query_init(obj_values) |
|
obj_emb = self.summary_to_query_emb(obj_values) |
|
|
|
|
|
query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init |
|
query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb |
|
|
|
|
|
pixel_init = self.pixel_init_proj(pixel) |
|
pixel_emb = self.pixel_emb_proj(pixel) |
|
pixel_pe = self.spatial_pe(pixel.flatten(0, 1)) |
|
pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() |
|
pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb |
|
|
|
pixel = pixel_init |
|
|
|
|
|
aux_features = {'logits': []} |
|
|
|
|
|
aux_logits = self.mask_pred[0](pixel).squeeze(2) |
|
attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass) |
|
aux_features['logits'].append(aux_logits) |
|
for i in range(self.num_blocks): |
|
query, pixel, q_weights, p_weights = self.blocks[i](query, |
|
pixel, |
|
query_emb, |
|
pixel_pe, |
|
attn_mask, |
|
need_weights=need_weights) |
|
|
|
if self.training or i <= self.num_blocks - 1 or need_weights: |
|
aux_logits = self.mask_pred[i + 1](pixel).squeeze(2) |
|
attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass) |
|
aux_features['logits'].append(aux_logits) |
|
|
|
aux_features['q_weights'] = q_weights |
|
aux_features['p_weights'] = p_weights |
|
|
|
if self.training: |
|
|
|
aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads, |
|
self.num_queries, H, W)[:, :, 0] |
|
|
|
return pixel, aux_features |
|
|
|
def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor, seg_pass=False) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
if selector is None: |
|
prob = logits.sigmoid() |
|
else: |
|
prob = logits.sigmoid() * selector |
|
logits = aggregate(prob, dim=1) |
|
|
|
is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0]) |
|
foreground_mask = is_foreground.bool().flatten(start_dim=2) |
|
inv_foreground_mask = ~foreground_mask |
|
inv_background_mask = foreground_mask |
|
|
|
aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat( |
|
1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) |
|
aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat( |
|
1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) |
|
|
|
aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1) |
|
|
|
aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False |
|
|
|
return aux_mask |