Spaces:
Runtime error
Runtime error
| # Implementation for paper 'Attention on Attention for Image Captioning' | |
| # https://arxiv.org/abs/1908.06954 | |
| # RT: Code from original author's repo: https://github.com/husthuaan/AoANet/ | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .AttModel import pack_wrapper, AttModel, Attention | |
| from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward | |
| class MultiHeadedDotAttention(nn.Module): | |
| def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3): | |
| super(MultiHeadedDotAttention, self).__init__() | |
| assert d_model * scale % h == 0 | |
| # We assume d_v always equals d_k | |
| self.d_k = d_model * scale // h | |
| self.h = h | |
| # Do we need to do linear projections on K and V? | |
| self.project_k_v = project_k_v | |
| # normalize the query? | |
| if norm_q: | |
| self.norm = LayerNorm(d_model) | |
| else: | |
| self.norm = lambda x:x | |
| self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v) | |
| # output linear layer after the multi-head attention? | |
| self.output_layer = nn.Linear(d_model * scale, d_model) | |
| # apply aoa after attention? | |
| self.use_aoa = do_aoa | |
| if self.use_aoa: | |
| self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU()) | |
| # dropout to the input of AoA layer | |
| if dropout_aoa > 0: | |
| self.dropout_aoa = nn.Dropout(p=dropout_aoa) | |
| else: | |
| self.dropout_aoa = lambda x:x | |
| if self.use_aoa or not use_output_layer: | |
| # AoA doesn't need the output linear layer | |
| del self.output_layer | |
| self.output_layer = lambda x:x | |
| self.attn = None | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward(self, query, value, key, mask=None): | |
| if mask is not None: | |
| if len(mask.size()) == 2: | |
| mask = mask.unsqueeze(-2) | |
| # Same mask applied to all h heads. | |
| mask = mask.unsqueeze(1) | |
| single_query = 0 | |
| if len(query.size()) == 2: | |
| single_query = 1 | |
| query = query.unsqueeze(1) | |
| nbatches = query.size(0) | |
| query = self.norm(query) | |
| # Do all the linear projections in batch from d_model => h x d_k | |
| if self.project_k_v == 0: | |
| query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) | |
| key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) | |
| value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) | |
| else: | |
| query_, key_, value_ = \ | |
| [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) | |
| for l, x in zip(self.linears, (query, key, value))] | |
| # Apply attention on all the projected vectors in batch. | |
| x, self.attn = attention(query_, key_, value_, mask=mask, | |
| dropout=self.dropout) | |
| # "Concat" using a view | |
| x = x.transpose(1, 2).contiguous() \ | |
| .view(nbatches, -1, self.h * self.d_k) | |
| if self.use_aoa: | |
| # Apply AoA | |
| x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1))) | |
| x = self.output_layer(x) | |
| if single_query: | |
| query = query.squeeze(1) | |
| x = x.squeeze(1) | |
| return x | |
| class AoA_Refiner_Layer(nn.Module): | |
| def __init__(self, size, self_attn, feed_forward, dropout): | |
| super(AoA_Refiner_Layer, self).__init__() | |
| self.self_attn = self_attn | |
| self.feed_forward = feed_forward | |
| self.use_ff = 0 | |
| if self.feed_forward is not None: | |
| self.use_ff = 1 | |
| self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff) | |
| self.size = size | |
| def forward(self, x, mask): | |
| x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) | |
| return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x | |
| class AoA_Refiner_Core(nn.Module): | |
| def __init__(self, opt): | |
| super(AoA_Refiner_Core, self).__init__() | |
| attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3)) | |
| layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1) | |
| self.layers = clones(layer, 6) | |
| self.norm = LayerNorm(layer.size) | |
| def forward(self, x, mask): | |
| for layer in self.layers: | |
| x = layer(x, mask) | |
| return self.norm(x) | |
| class AoA_Decoder_Core(nn.Module): | |
| def __init__(self, opt): | |
| super(AoA_Decoder_Core, self).__init__() | |
| self.drop_prob_lm = opt.drop_prob_lm | |
| self.d_model = opt.rnn_size | |
| self.use_multi_head = opt.use_multi_head | |
| self.multi_head_scale = opt.multi_head_scale | |
| self.use_ctx_drop = getattr(opt, 'ctx_drop', 0) | |
| self.out_res = getattr(opt, 'out_res', 0) | |
| self.decoder_type = getattr(opt, 'decoder_type', 'AoA') | |
| self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1 | |
| self.out_drop = nn.Dropout(self.drop_prob_lm) | |
| if self.decoder_type == 'AoA': | |
| # AoA layer | |
| self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU()) | |
| elif self.decoder_type == 'LSTM': | |
| # LSTM layer | |
| self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size) | |
| else: | |
| # Base linear layer | |
| self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU()) | |
| # if opt.use_multi_head == 1: # TODO, not implemented for now | |
| # self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale) | |
| if opt.use_multi_head == 2: | |
| self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1) | |
| else: | |
| self.attention = Attention(opt) | |
| if self.use_ctx_drop: | |
| self.ctx_drop = nn.Dropout(self.drop_prob_lm) | |
| else: | |
| self.ctx_drop = lambda x :x | |
| def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None): | |
| # state[0][1] is the context vector at the last step | |
| h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0])) | |
| if self.use_multi_head == 2: | |
| att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks) | |
| else: | |
| att = self.attention(h_att, att_feats, p_att_feats, att_masks) | |
| ctx_input = torch.cat([att, h_att], 1) | |
| if self.decoder_type == 'LSTM': | |
| output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1])) | |
| state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic))) | |
| else: | |
| output = self.att2ctx(ctx_input) | |
| # save the context vector to state[0][1] | |
| state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1]))) | |
| if self.out_res: | |
| # add residual connection | |
| output = output + h_att | |
| output = self.out_drop(output) | |
| return output, state | |
| class AoAModel(AttModel): | |
| def __init__(self, opt): | |
| super(AoAModel, self).__init__(opt) | |
| self.num_layers = 2 | |
| # mean pooling | |
| self.use_mean_feats = getattr(opt, 'mean_feats', 1) | |
| if opt.use_multi_head == 2: | |
| del self.ctx2att | |
| self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size) | |
| if self.use_mean_feats: | |
| del self.fc_embed | |
| if opt.refine: | |
| self.refiner = AoA_Refiner_Core(opt) | |
| else: | |
| self.refiner = lambda x,y : x | |
| self.core = AoA_Decoder_Core(opt) | |
| self.d_model = getattr(opt, 'd_model', opt.input_encoding_size) | |
| def _prepare_feature(self, fc_feats, att_feats, att_masks): | |
| att_feats, att_masks = self.clip_att(att_feats, att_masks) | |
| # embed att feats | |
| att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) | |
| att_feats = self.refiner(att_feats, att_masks) | |
| if self.use_mean_feats: | |
| # meaning pooling | |
| if att_masks is None: | |
| mean_feats = torch.mean(att_feats, dim=1) | |
| else: | |
| mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1)) | |
| else: | |
| mean_feats = self.fc_embed(fc_feats) | |
| # Project the attention feats first to reduce memory and computation. | |
| p_att_feats = self.ctx2att(att_feats) | |
| return mean_feats, att_feats, p_att_feats, att_masks |