RemoteSensingChangeDetection-RSCD.CTTF
/
rscd
/models
/decoderheads
/lgpnet
/PositionAttentionModule.py
| import torch | |
| from torch.nn import Module, Conv2d, Parameter, Softmax | |
| class PAM(Module): | |
| """ | |
| This code refers to "Dual attention network for scene segmentation"Position attention module". | |
| Ref from SAGAN | |
| """ | |
| def __init__(self, in_dim): | |
| super(PAM, self).__init__() | |
| self.chanel_in = in_dim | |
| self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) | |
| self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) | |
| self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) | |
| self.gamma = Parameter(torch.zeros(1)) | |
| self.softmax = Softmax(dim=-1) | |
| def forward(self, x): | |
| """ | |
| inputs : | |
| x : input feature maps( B X C X H X W) | |
| returns : | |
| out : attention value + input feature | |
| attention: B X (HxW) X (HxW) | |
| """ | |
| m_batchsize, C, height, width = x.size() | |
| proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) | |
| proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) | |
| energy = torch.bmm(proj_query, proj_key) | |
| attention = self.softmax(energy) | |
| proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) | |
| out = torch.bmm(proj_value, attention.permute(0, 2, 1)) | |
| out = out.view(m_batchsize, C, height, width) | |
| out = self.gamma*out + x | |
| return out | |