from functools import partial import math from typing import Iterable # from black import diff from torch import nn, einsum import numpy as np import torch as th import torch.nn as nn import functools import torch.nn.functional as F from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, constant_init, normal_init) from typing import Optional, Union, Tuple, List, Callable, Dict import math from einops import rearrange, repeat import torch import torch.nn.functional as F from torch import nn, Tensor from einops import rearrange import copy from torchvision import transforms from torchvision.transforms import InterpolationMode from .position_encoding import PositionEmbeddingSine import fvcore.nn.weight_init as weight_init class SelfAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos) return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos) class CrossAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm(tgt) tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) class FFNLayer(nn.Module): def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False): super().__init__() # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm = nn.LayerNorm(d_model) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt): tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt): tgt2 = self.norm(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt): if self.normalize_before: return self.forward_pre(tgt) return self.forward_post(tgt) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x def resize_fn(img, size): return transforms.Resize(size, InterpolationMode.BICUBIC)( transforms.ToPILImage()(img)) import math def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers def forward(self, tgt, memory, pos = None, query_pos = None): output = tgt for layer in self.layers: output = layer(output, memory, pos=pos, query_pos=query_pos) return output class TransformerDecoderLayer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, no_norm = False, activation="relu"): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) if not no_norm else nn.Identity() self.norm2 = nn.LayerNorm(d_model) if not no_norm else nn.Identity() self.norm3 = nn.LayerNorm(d_model) if not no_norm else nn.Identity() self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) def with_pos_embed(self, tensor, pos): return tensor if pos is None else tensor + pos def forward(self, tgt, memory, pos = None, query_pos = None): tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) # print(tgt2.shape) tgt2 = self.self_attn(q, k, value=tgt2)[0] tgt = tgt + self.dropout1(tgt2) tgt2 = self.norm2(tgt) tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory)[0] tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt # Projection of x onto y def proj(x, y): return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) # Orthogonalize x wrt list of vectors ys def gram_schmidt(x, ys): for y in ys: x = x - proj(x, y) return x def power_iteration(W, u_, update=True, eps=1e-12): # Lists holding singular vectors and values us, vs, svs = [], [], [] for i, u in enumerate(u_): # Run one step of the power iteration with torch.no_grad(): v = torch.matmul(u, W) # Run Gram-Schmidt to subtract components of all other singular vectors v = F.normalize(gram_schmidt(v, vs), eps=eps) # Add to the list vs += [v] # Update the other singular vector u = torch.matmul(v, W.t()) # Run Gram-Schmidt to subtract components of all other singular vectors u = F.normalize(gram_schmidt(u, us), eps=eps) # Add to the list us += [u] if update: u_[i][:] = u # Compute this singular value and add it to the list svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] return svs, us, vs # Spectral normalization base class class SN(object): def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): # Number of power iterations per step self.num_itrs = num_itrs # Number of singular values self.num_svs = num_svs # Transposed? self.transpose = transpose # Epsilon value for avoiding divide-by-0 self.eps = eps # Register a singular vector for each sv for i in range(self.num_svs): self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) self.register_buffer('sv%d' % i, torch.ones(1)) # Singular vectors (u side) @property def u(self): return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] # Singular values; # note that these buffers are just for logging and are not used in training. @property def sv(self): return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] # Compute the spectrally-normalized weight def W_(self): W_mat = self.weight.view(self.weight.size(0), -1) if self.transpose: W_mat = W_mat.t() # Apply num_itrs power iterations for _ in range(self.num_itrs): svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) # Update the svs if self.training: with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! for i, sv in enumerate(svs): self.sv[i][:] = sv return self.weight / svs[0] # Linear layer with spectral norm class SNLinear(nn.Linear, SN): def __init__(self, in_features, out_features, bias=True, num_svs=1, num_itrs=1, eps=1e-12): nn.Linear.__init__(self, in_features, out_features, bias) SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) def forward(self, x): return F.linear(x, self.W_(), self.bias) # 2D Conv layer with spectral norm class SNConv2d(nn.Conv2d, SN): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, num_svs=1, num_itrs=1, eps=1e-12): nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) def forward(self, x): return F.conv2d(x, self.W_(), self.bias, self.stride, self.padding, self.dilation, self.groups) class SegBlock(nn.Module): def __init__(self, in_channels, out_channels, con_channels, which_conv=nn.Conv2d, which_linear=None, activation=None, upsample=None): super(SegBlock, self).__init__() self.in_channels, self.out_channels = in_channels, out_channels self.which_conv, self.which_linear = which_conv, which_linear self.activation = activation self.upsample = upsample self.conv1 = self.which_conv(self.in_channels, self.out_channels) self.conv2 = self.which_conv(self.out_channels, self.out_channels) self.learnable_sc = in_channels != out_channels or upsample if self.learnable_sc: self.conv_sc = self.which_conv(in_channels, out_channels, kernel_size=1, padding=0) self.register_buffer('stored_mean1', torch.zeros(in_channels)) self.register_buffer('stored_var1', torch.ones(in_channels)) self.register_buffer('stored_mean2', torch.zeros(out_channels)) self.register_buffer('stored_var2', torch.ones(out_channels)) self.upsample = upsample def forward(self, x, y=None): x = F.batch_norm(x, self.stored_mean1, self.stored_var1, None, None, self.training, 0.1, 1e-4) h = self.activation(x) if self.upsample: h = self.upsample(h) x = self.upsample(x) h = self.conv1(h) h = F.batch_norm(h, self.stored_mean2, self.stored_var2, None, None, self.training, 0.1, 1e-4) h = self.activation(h) h = self.conv2(h) if self.learnable_sc: x = self.conv_sc(x) return h + x def make_coord(shape, ranges=None, flatten=True): """ Make coordinates at grid centers. """ coord_seqs = [] for i, n in enumerate(shape): if ranges is None: v0, v1 = -1, 1 else: v0, v1 = ranges[i] r = (v1 - v0) / (2 * n) seq = v0 + r + (2 * r) * torch.arange(n).float() coord_seqs.append(seq) ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) if flatten: ret = ret.view(-1, ret.shape[-1]) return ret class Embedder: def __init__(self, **kwargs): self.kwargs = kwargs self.create_embedding_fn() def create_embedding_fn(self): embed_fns = [] d = self.kwargs['input_dims'] out_dim = 0 if self.kwargs['include_input']: embed_fns.append(lambda x : x) out_dim += d max_freq = self.kwargs['max_freq_log2'] N_freqs = self.kwargs['num_freqs'] if self.kwargs['log_sampling']: freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs).double() else: freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) for freq in freq_bands: for p_fn in self.kwargs['periodic_fns']: embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x.double() * freq)) out_dim += d self.embed_fns = embed_fns self.out_dim = out_dim def embed(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) def get_embedder(multires, i=0): if i == -1: return nn.Identity(), 3 embed_kwargs = { 'include_input' : False, 'input_dims' : 2, 'max_freq_log2' : multires-1, 'num_freqs' : multires, 'log_sampling' : True, 'periodic_fns' : [torch.sin, torch.cos], } embedder_obj = Embedder(**embed_kwargs) embed = lambda x, eo=embedder_obj : eo.embed(x) return embed, embedder_obj.out_dim def aggregate_attention(attention_store, res: int, from_where: List[str], is_cross: bool, select: int, prompts=None): out = [] attention_maps = attention_store.get_average_attention() num_pixels = res ** 2 for location in from_where: for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: if item.shape[1] == num_pixels: cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1]) # print(cross_maps.shape) out.append(cross_maps) out = torch.cat(out, dim=1) # print(out.shape) return out class seg_decorder_open_word(nn.Module): def __init__(self, embedding_dim=512, num_heads=8, num_layers=3, dropout_rate=0, num_queries=100, hidden_dim=256, num_classes=19, mask_dim= 256, dim_feedforward= 2048): super().__init__() self.num_queries = num_queries # learnable query features self.query_feat = nn.Embedding(num_queries, hidden_dim) # learnable query p.e. self.query_embed = nn.Embedding(num_queries, hidden_dim) self.query_feat_mlp = nn.Linear(hidden_dim+768, hidden_dim, bias=False) self.query_embed_mlp = nn.Linear(hidden_dim+768, hidden_dim, bias=False) self.decoder_norm = nn.LayerNorm(hidden_dim) # positional encoding N_steps = hidden_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) # level embedding (we always use 3 scales) self.num_feature_levels = 3 self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) self.input_proj = nn.ModuleList() for _ in range(self.num_feature_levels): self.input_proj.append(nn.Sequential()) # output FFNs self.mask_classification = True if self.mask_classification: self.class_embed = nn.Linear(hidden_dim, num_classes + 1) self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) # define Transformer decoder here self.num_heads = 8 self.num_layers = 10 self.transformer_self_attention_layers = nn.ModuleList() self.transformer_cross_attention_layers = nn.ModuleList() self.transformer_ffn_layers = nn.ModuleList() pre_norm = False for _ in range(self.num_layers): self.transformer_self_attention_layers.append( SelfAttentionLayer( d_model=hidden_dim, nhead=self.num_heads, dropout=0.0, normalize_before=pre_norm, ) ) self.transformer_cross_attention_layers.append( CrossAttentionLayer( d_model=hidden_dim, nhead=self.num_heads, dropout=0.0, normalize_before=pre_norm, ) ) self.transformer_ffn_layers.append( FFNLayer( d_model=hidden_dim, dim_feedforward=dim_feedforward, dropout=0.0, normalize_before=False, ) ) low_feature_channel = 256 mid_feature_channel = 256 high_feature_channel = 256 highest_feature_channel=256 self.low_feature_conv = nn.Sequential( nn.Conv2d(1280*14+8*77, low_feature_channel, kernel_size=1, bias=False), ) self.mid_feature_conv = nn.Sequential( nn.Conv2d(16640+40*77, mid_feature_channel, kernel_size=1, bias=False), ) self.mid_feature_mix_conv = SegBlock( in_channels=low_feature_channel+mid_feature_channel, out_channels=mask_dim, con_channels=128, which_conv=functools.partial(SNConv2d, kernel_size=3, padding=1, num_svs=1, num_itrs=1, eps=1e-04), which_linear=functools.partial(SNLinear, num_svs=1, num_itrs=1, eps=1e-04), activation=nn.ReLU(inplace=True), upsample=False, ) self.high_feature_conv = nn.Sequential( nn.Conv2d(9600+40*77, high_feature_channel, kernel_size=1, bias=False), ) self.high_feature_mix_conv = SegBlock( in_channels=mask_dim+high_feature_channel, out_channels=mask_dim, con_channels=128, which_conv=functools.partial(SNConv2d, kernel_size=3, padding=1, num_svs=1, num_itrs=1, eps=1e-04), which_linear=functools.partial(SNLinear, num_svs=1, num_itrs=1, eps=1e-04), activation=nn.ReLU(inplace=True), upsample=False, ) self.highest_feature_conv = nn.Sequential( nn.Conv2d((640+320*6)*2+40*77, highest_feature_channel, kernel_size=1, bias=False), ) self.highest_feature_mix_conv = SegBlock( in_channels=mask_dim+highest_feature_channel, out_channels=mask_dim, con_channels=128, which_conv=functools.partial(SNConv2d, kernel_size=3, padding=1, num_svs=1, num_itrs=1, eps=1e-04), which_linear=functools.partial(SNLinear, num_svs=1, num_itrs=1, eps=1e-04), activation=nn.ReLU(inplace=True), upsample=False, ) def forward_prediction_heads(self, output, mask_features, attn_mask_target_size): decoder_output = self.decoder_norm(output) decoder_output = decoder_output.transpose(0, 1) outputs_class = self.class_embed(decoder_output) mask_embed = self.mask_embed(decoder_output) outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) # NOTE: prediction is of higher-resolution # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW] attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) # must use bool type # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() attn_mask = attn_mask.detach() return outputs_class, outputs_mask, attn_mask def forward(self,diffusion_features,controller,prompts,tokenizer,text_embeddings): x, mask_features=self._prepare_features(diffusion_features,controller,prompts,tokenizer) b=mask_features.size()[0] src = [] pos = [] size_list = [] # x # [b, 256, 20, 20] # [b, 256, 40, 40] # [b, 256, 80, 80] for i in range(self.num_feature_levels): size_list.append(x[i].shape[-2:]) pos.append(self.pe_layer(x[i], None).flatten(2)) src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) # flatten NxCxHxW to HWxNxC pos[-1] = pos[-1].permute(2, 0, 1) src[-1] = src[-1].permute(2, 0, 1) # QxNxC # query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, b, 1) # output = self.query_feat.weight.unsqueeze(1).repeat(1, b, 1) query_embed = self.query_embed.weight output = self.query_feat.weight # B, L, D text_embeddings = text_embeddings.repeat(self.num_queries, b, 1) text_embeddings=rearrange(text_embeddings, 'b n d -> (b n) d ') query_embed = torch.cat([query_embed, text_embeddings], dim=1) output = torch.cat([output, text_embeddings], dim=1) query_embed = self.query_feat_mlp(query_embed).unsqueeze(1).repeat(1, b, 1) output = self.query_embed_mlp(output).unsqueeze(1).repeat(1, b, 1) predictions_class = [] predictions_mask = [] # prediction heads on learnable query features outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0]) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) for i in range(self.num_layers): level_index = i % self.num_feature_levels attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False # attention: cross-attention first output = self.transformer_cross_attention_layers[i]( output, src[level_index], memory_mask=attn_mask, memory_key_padding_mask=None, # here we do not apply masking on padded region pos=pos[level_index], query_pos=query_embed ) output = self.transformer_self_attention_layers[i]( output, tgt_mask=None, tgt_key_padding_mask=None, query_pos=query_embed ) # FFN output = self.transformer_ffn_layers[i]( output ) outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels]) predictions_class.append(outputs_class) predictions_mask.append(outputs_mask) assert len(predictions_class) == self.num_layers + 1 out = { 'pred_logits': predictions_class[-1], 'pred_masks': predictions_mask[-1] } return out def _prepare_features(self, features, attention_store,prompts,tokenizer, upsample='bilinear'): self.low_feature_size = 16 self.mid_feature_size = 32 self.high_feature_size = 64 self.final_high_feature_size = 160 low_features = [ F.interpolate(i, size=self.low_feature_size, mode=upsample, align_corners=False) for i in features["low"] ] low_features = torch.cat(low_features, dim=1) mid_features = [ F.interpolate(i, size=self.mid_feature_size, mode=upsample, align_corners=False) for i in features["mid"] ] mid_features = torch.cat(mid_features, dim=1) high_features = [ F.interpolate(i, size=self.high_feature_size, mode=upsample, align_corners=False) for i in features["high"] ] high_features = torch.cat(high_features, dim=1) highest_features=torch.cat(features["highest"],dim=1) ## Attention map from_where=("up", "down") select = 0 # tokens = tokenizer.encode(prompts[select]) # decoder = tokenizer.decode # category_list_check = VOC_category_list_check[classs[0]] # "up", "down" attention_maps_8s = aggregate_attention(attention_store, 8, ("up", "mid", "down"), True, select,prompts=prompts) attention_maps_16s = aggregate_attention(attention_store, 16, from_where, True, select,prompts=prompts) attention_maps_32 = aggregate_attention(attention_store, 32, from_where, True, select,prompts=prompts) attention_maps_64 = aggregate_attention(attention_store, 64, from_where, True, select,prompts=prompts) attention_maps_8s = rearrange(attention_maps_8s, 'b c h w d-> b (c d) h w') attention_maps_16s = rearrange(attention_maps_16s, 'b c h w d-> b (c d) h w') attention_maps_32 = rearrange(attention_maps_32, 'b c h w d-> b (c d) h w') attention_maps_64 = rearrange(attention_maps_64, 'b c h w d-> b (c d) h w') attention_maps_8s = F.interpolate(attention_maps_8s, size=self.low_feature_size, mode=upsample, align_corners=False) attention_maps_16s = F.interpolate(attention_maps_16s, size=self.mid_feature_size, mode=upsample, align_corners=False) attention_maps_32 = F.interpolate(attention_maps_32, size=self.high_feature_size, mode=upsample, align_corners=False) features_dict = { 'low': torch.cat([low_features, attention_maps_8s], dim=1) , 'mid': torch.cat([mid_features, attention_maps_16s], dim=1) , 'high': torch.cat([high_features, attention_maps_32], dim=1) , 'highest':torch.cat([highest_features, attention_maps_64], dim=1) , } low_feat = self.low_feature_conv(features_dict['low']) low_feat = F.interpolate(low_feat, size=self.mid_feature_size, mode='bilinear', align_corners=False) mid_feat = self.mid_feature_conv(features_dict['mid']) mid_feat = torch.cat([low_feat, mid_feat], dim=1) mid_feat = self.mid_feature_mix_conv(mid_feat, y=None) mid_feat = F.interpolate(mid_feat, size=self.high_feature_size, mode='bilinear', align_corners=False) high_feat = self.high_feature_conv(features_dict['high']) high_feat = torch.cat([mid_feat, high_feat], dim=1) high_feat = self.high_feature_mix_conv(high_feat, y=None) highest_feat=self.highest_feature_conv(features_dict['highest']) highest_feat=torch.cat([high_feat,highest_feat],dim=1) highest_feat=self.highest_feature_mix_conv(highest_feat,y=None) highest_feat = F.interpolate(highest_feat, size=self.final_high_feature_size, mode='bilinear', align_corners=False) low_feat = F.interpolate(low_feat, size=20, mode='bilinear', align_corners=False) mid_feat = F.interpolate(mid_feat, size=40, mode='bilinear', align_corners=False) high_feat = F.interpolate(high_feat, size=80, mode='bilinear', align_corners=False) return [low_feat,mid_feat,high_feat], highest_feat