InPeerReview's picture
Upload 161 files
226675b verified
import torch.nn as nn
"""
This code refers to "Pyramid attention network for semantic segmentation", that is
"https://github.com/JaveyWang/Pyramid-Attention-Networks-pytorch/blob/f719365c1780f062058dd0c94550c6c4766cd937/networks.py#L41"
"""
class FPM(nn.Module):
def __init__(self, channels=1024):
"""
Feature Pyramid Attention
:type channels: int
"""
super(FPM, self).__init__()
channels_mid = int(channels/4)
self.channels_cond = channels
self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
self.bn_master = nn.BatchNorm2d(channels)
# Feature Pyramid
self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False)
self.bn1_1 = nn.BatchNorm2d(channels_mid)
self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False)
self.bn2_1 = nn.BatchNorm2d(channels_mid)
self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False)
self.bn3_1 = nn.BatchNorm2d(channels_mid)
self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False)
self.bn1_2 = nn.BatchNorm2d(channels_mid)
self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False)
self.bn2_2 = nn.BatchNorm2d(channels_mid)
self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False)
self.bn3_2 = nn.BatchNorm2d(channels_mid)
# Upsample
self.conv_upsample_3 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False)
self.bn_upsample_3 = nn.BatchNorm2d(channels_mid)
self.conv_upsample_2 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False)
self.bn_upsample_2 = nn.BatchNorm2d(channels_mid)
self.conv_upsample_1 = nn.ConvTranspose2d(channels_mid, channels, kernel_size=4, stride=2, padding=1, bias=False)
self.bn_upsample_1 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
"""
:param x: Shape: [b, 2048, h, w]
:return: out: Feature maps. Shape: [b, 2048, h, w]
"""
x_master = self.conv_master(x)
x_master = self.bn_master(x_master)
# Branch 1
x1_1 = self.conv7x7_1(x)
x1_1 = self.bn1_1(x1_1)
x1_1 = self.relu(x1_1)
x1_2 = self.conv7x7_2(x1_1)
x1_2 = self.bn1_2(x1_2)
# Branch 2
x2_1 = self.conv5x5_1(x1_1)
x2_1 = self.bn2_1(x2_1)
x2_1 = self.relu(x2_1)
x2_2 = self.conv5x5_2(x2_1)
x2_2 = self.bn2_2(x2_2)
# Branch 3
x3_1 = self.conv3x3_1(x2_1)
x3_1 = self.bn3_1(x3_1)
x3_1 = self.relu(x3_1)
x3_2 = self.conv3x3_2(x3_1)
x3_2 = self.bn3_2(x3_2)
# Merge branch 1 and 2
x3_upsample = self.relu(self.bn_upsample_3(self.conv_upsample_3(x3_2)))
x2_merge = self.relu(x2_2 + x3_upsample)
x2_upsample = self.relu(self.bn_upsample_2(self.conv_upsample_2(x2_merge)))
x1_merge = self.relu(x1_2 + x2_upsample)
x_master = x_master * self.relu(self.bn_upsample_1(self.conv_upsample_1(x1_merge)))
out = self.relu(x_master)
return out