import torch.nn as nn """ This code refers to "Pyramid attention network for semantic segmentation", that is "https://github.com/JaveyWang/Pyramid-Attention-Networks-pytorch/blob/f719365c1780f062058dd0c94550c6c4766cd937/networks.py#L41" """ class FPM(nn.Module): def __init__(self, channels=1024): """ Feature Pyramid Attention :type channels: int """ super(FPM, self).__init__() channels_mid = int(channels/4) self.channels_cond = channels self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False) self.bn_master = nn.BatchNorm2d(channels) # Feature Pyramid self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False) self.bn1_1 = nn.BatchNorm2d(channels_mid) self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False) self.bn2_1 = nn.BatchNorm2d(channels_mid) self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False) self.bn3_1 = nn.BatchNorm2d(channels_mid) self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False) self.bn1_2 = nn.BatchNorm2d(channels_mid) self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False) self.bn2_2 = nn.BatchNorm2d(channels_mid) self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False) self.bn3_2 = nn.BatchNorm2d(channels_mid) # Upsample self.conv_upsample_3 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False) self.bn_upsample_3 = nn.BatchNorm2d(channels_mid) self.conv_upsample_2 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False) self.bn_upsample_2 = nn.BatchNorm2d(channels_mid) self.conv_upsample_1 = nn.ConvTranspose2d(channels_mid, channels, kernel_size=4, stride=2, padding=1, bias=False) self.bn_upsample_1 = nn.BatchNorm2d(channels) self.relu = nn.ReLU(inplace=False) def forward(self, x): """ :param x: Shape: [b, 2048, h, w] :return: out: Feature maps. Shape: [b, 2048, h, w] """ x_master = self.conv_master(x) x_master = self.bn_master(x_master) # Branch 1 x1_1 = self.conv7x7_1(x) x1_1 = self.bn1_1(x1_1) x1_1 = self.relu(x1_1) x1_2 = self.conv7x7_2(x1_1) x1_2 = self.bn1_2(x1_2) # Branch 2 x2_1 = self.conv5x5_1(x1_1) x2_1 = self.bn2_1(x2_1) x2_1 = self.relu(x2_1) x2_2 = self.conv5x5_2(x2_1) x2_2 = self.bn2_2(x2_2) # Branch 3 x3_1 = self.conv3x3_1(x2_1) x3_1 = self.bn3_1(x3_1) x3_1 = self.relu(x3_1) x3_2 = self.conv3x3_2(x3_1) x3_2 = self.bn3_2(x3_2) # Merge branch 1 and 2 x3_upsample = self.relu(self.bn_upsample_3(self.conv_upsample_3(x3_2))) x2_merge = self.relu(x2_2 + x3_upsample) x2_upsample = self.relu(self.bn_upsample_2(self.conv_upsample_2(x2_merge))) x1_merge = self.relu(x1_2 + x2_upsample) x_master = x_master * self.relu(self.bn_upsample_1(self.conv_upsample_1(x1_merge))) out = self.relu(x_master) return out