File size: 14,495 Bytes
62a2f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import torch
from torch import nn
import torch.nn.functional as F
import os

__all__ = ['HRNet', 'hrnetv2_48', 'hrnetv2_32']

# Checkpoint path of pre-trained backbone (edit to your path). Download backbone pretrained model hrnetv2-32 @
# https://drive.google.com/file/d/1NxCK7Zgn5PmeS7W1jYLt5J9E0RRZ2oyF/view?usp=sharing .Personally, I added the backbone
# weights to the folder /checkpoints

model_urls = {
    'hrnetv2_32': './checkpoints/model_best_epoch96_edit.pth',
    'hrnetv2_48': None
}


def check_pth(arch):
    CKPT_PATH = model_urls[arch]
    if os.path.exists(CKPT_PATH):
        print(f"Backbone HRNet Pretrained weights at: {CKPT_PATH}, only usable for HRNetv2-32")
    else:
        print("No backbone checkpoint found for HRNetv2, please set pretrained=False when calling model")
    return CKPT_PATH
    # HRNetv2-48 not available yet, but you can train the whole model from scratch.


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class StageModule(nn.Module):
    def __init__(self, stage, output_branches, c):
        super(StageModule, self).__init__()

        self.number_of_branches = stage  # number of branches is equivalent to the stage configuration.
        self.output_branches = output_branches

        self.branches = nn.ModuleList()

        # Note: Resolution + Number of channels maintains the same throughout respective branch.
        for i in range(self.number_of_branches):  # Stage scales with the number of branches. Ex: Stage 2 -> 2 branch
            channels = c * (2 ** i)  # Scale channels by 2x for branch with lower resolution,

            # Paper does x4 basic block for each forward sequence in each branch (x4 basic block considered as a block)
            branch = nn.Sequential(*[BasicBlock(channels, channels) for _ in range(4)])

            self.branches.append(branch)  # list containing all forward sequence of individual branches.

        # For each branch requires repeated fusion with all other branches after passing through x4 basic blocks.
        self.fuse_layers = nn.ModuleList()

        for branch_output_number in range(self.output_branches):

            self.fuse_layers.append(nn.ModuleList())

            for branch_number in range(self.number_of_branches):
                if branch_number == branch_output_number:
                    self.fuse_layers[-1].append(nn.Sequential())  # Used in place of "None" because it is callable
                elif branch_number > branch_output_number:
                    self.fuse_layers[-1].append(nn.Sequential(
                        nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_output_number), kernel_size=1, stride=1,
                                  bias=False),
                        nn.BatchNorm2d(c * (2 ** branch_output_number), eps=1e-05, momentum=0.1, affine=True,
                                       track_running_stats=True),
                        nn.Upsample(scale_factor=(2.0 ** (branch_number - branch_output_number)), mode='nearest'),
                    ))
                elif branch_number < branch_output_number:
                    downsampling_fusion = []
                    for _ in range(branch_output_number - branch_number - 1):
                        downsampling_fusion.append(nn.Sequential(
                            nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_number), kernel_size=3, stride=2,
                                      padding=1,
                                      bias=False),
                            nn.BatchNorm2d(c * (2 ** branch_number), eps=1e-05, momentum=0.1, affine=True,
                                           track_running_stats=True),
                            nn.ReLU(inplace=True),
                        ))
                    downsampling_fusion.append(nn.Sequential(
                        nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_output_number), kernel_size=3,
                                  stride=2, padding=1,
                                  bias=False),
                        nn.BatchNorm2d(c * (2 ** branch_output_number), eps=1e-05, momentum=0.1, affine=True,
                                       track_running_stats=True),
                    ))
                    self.fuse_layers[-1].append(nn.Sequential(*downsampling_fusion))

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):

        # input to each stage is a list of inputs for each branch
        x = [branch(branch_input) for branch, branch_input in zip(self.branches, x)]

        x_fused = []
        for branch_output_index in range(
                self.output_branches):  # Amount of output branches == total length of fusion layers
            for input_index in range(self.number_of_branches):  # The inputs of other branches to be fused.
                if input_index == 0:
                    x_fused.append(self.fuse_layers[branch_output_index][input_index](x[input_index]))
                else:
                    x_fused[branch_output_index] = x_fused[branch_output_index] + self.fuse_layers[branch_output_index][
                        input_index](x[input_index])

        # After fusing all streams together, you will need to pass the fused layers
        for i in range(self.output_branches):
            x_fused[i] = self.relu(x_fused[i])

        return x_fused  # returning a list of fused outputs


class HRNet(nn.Module):
    def __init__(self, c=48, num_blocks=[1, 4, 3], num_classes=1000):
        super(HRNet, self).__init__()

        # Stem:
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64, eps=1e-05, affine=True, track_running_stats=True)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(64, eps=1e-05, affine=True, track_running_stats=True)
        self.relu = nn.ReLU(inplace=True)

        # Stage 1:
        downsample = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(256, eps=1e-05, affine=True, track_running_stats=True),
        )
        # Note that bottleneck module will expand the output channels according to the output channels*block.expansion
        bn_expansion = Bottleneck.expansion  # The channel expansion is set in the bottleneck class.
        self.layer1 = nn.Sequential(
            Bottleneck(64, 64, downsample=downsample),  # Input is 64 for first module connection
            Bottleneck(bn_expansion * 64, 64),
            Bottleneck(bn_expansion * 64, 64),
            Bottleneck(bn_expansion * 64, 64),
        )

        # Transition 1 - Creation of the first two branches (one full and one half resolution)
        # Need to transition into high resolution stream and mid resolution stream
        self.transition1 = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(256, c, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(c, eps=1e-05, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
            ),
            nn.Sequential(nn.Sequential(  # Double Sequential to fit with official pretrained weights
                nn.Conv2d(256, c * 2, kernel_size=3, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(c * 2, eps=1e-05, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
            )),
        ])

        # Stage 2:
        number_blocks_stage2 = num_blocks[0]
        self.stage2 = nn.Sequential(
            *[StageModule(stage=2, output_branches=2, c=c) for _ in range(number_blocks_stage2)])

        # Transition 2  - Creation of the third branch (1/4 resolution)
        self.transition2 = self._make_transition_layers(c, transition_number=2)

        # Stage 3:
        number_blocks_stage3 = num_blocks[1]  # number blocks you want to create before fusion
        self.stage3 = nn.Sequential(
            *[StageModule(stage=3, output_branches=3, c=c) for _ in range(number_blocks_stage3)])

        # Transition  - Creation of the fourth branch (1/8 resolution)
        self.transition3 = self._make_transition_layers(c, transition_number=3)

        # Stage 4:
        number_blocks_stage4 = num_blocks[2]  # number blocks you want to create before fusion
        self.stage4 = nn.Sequential(
            *[StageModule(stage=4, output_branches=4, c=c) for _ in range(number_blocks_stage4)])

        # Classifier (extra module if want to use for classification):
        # pool, reduce dimensionality, flatten, connect to linear layer for classification:
        out_channels = sum([c * 2 ** i for i in range(len(num_blocks)+1)])  # total output channels of HRNetV2
        pool_feature_map = 8
        self.bn_classifier = nn.Sequential(
            nn.Conv2d(out_channels, out_channels // 4, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels // 4, eps=1e-05, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(pool_feature_map),
            nn.Flatten(),
            nn.Linear(pool_feature_map * pool_feature_map * (out_channels // 4), num_classes),
        )

    @staticmethod
    def _make_transition_layers(c, transition_number):
        return nn.Sequential(
            nn.Conv2d(c * (2 ** (transition_number - 1)), c * (2 ** transition_number), kernel_size=3, stride=2,
                      padding=1, bias=False),
            nn.BatchNorm2d(c * (2 ** transition_number), eps=1e-05, affine=True,
                           track_running_stats=True),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        # Stem:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        # Stage 1
        x = self.layer1(x)
        x = [trans(x) for trans in self.transition1]  # split to 2 branches, form a list.

        # Stage 2
        x = self.stage2(x)
        x.append(self.transition2(x[-1]))

        # Stage 3
        x = self.stage3(x)
        x.append(self.transition3(x[-1]))

        # Stage 4
        x = self.stage4(x)

        # HRNetV2 Example: (follow paper, upsample via bilinear interpolation and to highest resolution size)
        output_h, output_w = x[0].size(2), x[0].size(3)  # Upsample to size of highest resolution stream
        x1 = F.interpolate(x[1], size=(output_h, output_w), mode='bilinear', align_corners=False)
        x2 = F.interpolate(x[2], size=(output_h, output_w), mode='bilinear', align_corners=False)
        x3 = F.interpolate(x[3], size=(output_h, output_w), mode='bilinear', align_corners=False)

        # Upsampling all the other resolution streams and then concatenate all (rather than adding/fusing like HRNetV1)
        x = torch.cat([x[0], x1, x2, x3], dim=1)
        x = self.bn_classifier(x)
        return x


def _hrnet(arch, channels, num_blocks, pretrained, progress, **kwargs):
    model = HRNet(channels, num_blocks, **kwargs)
    if pretrained:
        CKPT_PATH = check_pth(arch)
        checkpoint = torch.load(CKPT_PATH)
        model.load_state_dict(checkpoint['state_dict'])
    return model


def hrnetv2_48(pretrained=False, progress=True, number_blocks=[1, 4, 3], **kwargs):
    w_channels = 48
    return _hrnet('hrnetv2_48', w_channels, number_blocks, pretrained, progress,
                  **kwargs)


def hrnetv2_32(pretrained=False, progress=True, number_blocks=[1, 4, 3], **kwargs):
    w_channels = 32
    return _hrnet('hrnetv2_32', w_channels, number_blocks, pretrained, progress,
                  **kwargs)


if __name__ == '__main__':

    try:
        CKPT_PATH = os.path.join(os.path.abspath("."), '../../checkpoints/hrnetv2_32_model_best_epoch96.pth')
        print("--- Running file as MAIN ---")
        print(f"Backbone HRNET Pretrained weights as __main__ at: {CKPT_PATH}")
    except:
        print("No backbone checkpoint found for HRNetv2, please set pretrained=False when calling model")

    # Models
    model = hrnetv2_32(pretrained=True)
    #model = hrnetv2_48(pretrained=False)

    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    model.to(device)
    in_ = torch.ones(1, 3, 768, 768).to(device)
    y = model(in_)
    print(y.shape)

    # Calculate total number of parameters:
    # pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # print(pytorch_total_params)