RemoteSensingChangeDetection-RSCD.CTTF
/
rscd
/models
/decoderheads
/lgpnet
/FeaturePyramidModule.py
| import torch.nn as nn | |
| """ | |
| This code refers to "Pyramid attention network for semantic segmentation", that is | |
| "https://github.com/JaveyWang/Pyramid-Attention-Networks-pytorch/blob/f719365c1780f062058dd0c94550c6c4766cd937/networks.py#L41" | |
| """ | |
| class FPM(nn.Module): | |
| def __init__(self, channels=1024): | |
| """ | |
| Feature Pyramid Attention | |
| :type channels: int | |
| """ | |
| super(FPM, self).__init__() | |
| channels_mid = int(channels/4) | |
| self.channels_cond = channels | |
| self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False) | |
| self.bn_master = nn.BatchNorm2d(channels) | |
| # Feature Pyramid | |
| self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False) | |
| self.bn1_1 = nn.BatchNorm2d(channels_mid) | |
| self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False) | |
| self.bn2_1 = nn.BatchNorm2d(channels_mid) | |
| self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False) | |
| self.bn3_1 = nn.BatchNorm2d(channels_mid) | |
| self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False) | |
| self.bn1_2 = nn.BatchNorm2d(channels_mid) | |
| self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False) | |
| self.bn2_2 = nn.BatchNorm2d(channels_mid) | |
| self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False) | |
| self.bn3_2 = nn.BatchNorm2d(channels_mid) | |
| # Upsample | |
| self.conv_upsample_3 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False) | |
| self.bn_upsample_3 = nn.BatchNorm2d(channels_mid) | |
| self.conv_upsample_2 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False) | |
| self.bn_upsample_2 = nn.BatchNorm2d(channels_mid) | |
| self.conv_upsample_1 = nn.ConvTranspose2d(channels_mid, channels, kernel_size=4, stride=2, padding=1, bias=False) | |
| self.bn_upsample_1 = nn.BatchNorm2d(channels) | |
| self.relu = nn.ReLU(inplace=False) | |
| def forward(self, x): | |
| """ | |
| :param x: Shape: [b, 2048, h, w] | |
| :return: out: Feature maps. Shape: [b, 2048, h, w] | |
| """ | |
| x_master = self.conv_master(x) | |
| x_master = self.bn_master(x_master) | |
| # Branch 1 | |
| x1_1 = self.conv7x7_1(x) | |
| x1_1 = self.bn1_1(x1_1) | |
| x1_1 = self.relu(x1_1) | |
| x1_2 = self.conv7x7_2(x1_1) | |
| x1_2 = self.bn1_2(x1_2) | |
| # Branch 2 | |
| x2_1 = self.conv5x5_1(x1_1) | |
| x2_1 = self.bn2_1(x2_1) | |
| x2_1 = self.relu(x2_1) | |
| x2_2 = self.conv5x5_2(x2_1) | |
| x2_2 = self.bn2_2(x2_2) | |
| # Branch 3 | |
| x3_1 = self.conv3x3_1(x2_1) | |
| x3_1 = self.bn3_1(x3_1) | |
| x3_1 = self.relu(x3_1) | |
| x3_2 = self.conv3x3_2(x3_1) | |
| x3_2 = self.bn3_2(x3_2) | |
| # Merge branch 1 and 2 | |
| x3_upsample = self.relu(self.bn_upsample_3(self.conv_upsample_3(x3_2))) | |
| x2_merge = self.relu(x2_2 + x3_upsample) | |
| x2_upsample = self.relu(self.bn_upsample_2(self.conv_upsample_2(x2_merge))) | |
| x1_merge = self.relu(x1_2 + x2_upsample) | |
| x_master = x_master * self.relu(self.bn_upsample_1(self.conv_upsample_1(x1_merge))) | |
| out = self.relu(x_master) | |
| return out |