File size: 3,767 Bytes
e6ac593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# adapted from: https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/encoder.py and https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/decoder.py

import torch.nn as nn
import torch.nn.functional as F

from .backbone_base import BackboneBase
from .vgg_utils import VGG19, ConvRefiner, Decoder


class VGG(BackboneBase):
    def __init__(self, nchannels=3, pretrained=True, use_instance_norm=True, mode="dect"):
        super().__init__(nchannels=nchannels, use_instance_norm=use_instance_norm)

        self.nchannels = nchannels
        self.mode = mode

        if self.mode not in ["dect", "desc", "dect+desc"]:
            raise ValueError("mode should be 'dect', 'desc' or 'dect+desc'")

        NUM_OUTPUT_CHANNELS, hidden_blocks = self._get_mode_params(mode)
        conv_refiner = self._create_conv_refiner(NUM_OUTPUT_CHANNELS, hidden_blocks)

        self.encoder = VGG19(pretrained=pretrained, num_input_channels=nchannels)
        self.decoder = Decoder(conv_refiner, num_prototypes=NUM_OUTPUT_CHANNELS)

    def _get_mode_params(self, mode):
        """Get the number of output channels and the number of hidden blocks for the ConvRefiner.

        Depending on the mode, the ConvRefiner will have a different number of output channels.
        """

        if mode == "dect":
            return 1, 8
        elif mode == "desc":
            return 256, 5
        elif mode == "dect+desc":
            return 256 + 1, 8

    def _create_conv_refiner(self, num_output_channels, hidden_blocks):
        return nn.ModuleDict(
            {
                "8": ConvRefiner(
                    512,
                    512,
                    256 + num_output_channels,
                    hidden_blocks=hidden_blocks,
                    residual=True,
                ),
                "4": ConvRefiner(
                    256 + 256,
                    256,
                    128 + num_output_channels,
                    hidden_blocks=hidden_blocks,
                    residual=True,
                ),
                "2": ConvRefiner(
                    128 + 128,
                    128,
                    64 + num_output_channels,
                    hidden_blocks=hidden_blocks,
                    residual=True,
                ),
                "1": ConvRefiner(
                    64 + 64,
                    64,
                    1 + num_output_channels,
                    hidden_blocks=hidden_blocks,
                    residual=True,
                ),
            }
        )

    def get_dim_layers_encoder(self):
        return self.encoder.get_dim_layers()

    def _forward(self, x):
        features, sizes = self.encoder(x)
        output = 0
        context = None
        scales = self.decoder.scales
        for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
            delta_descriptor, context = self.decoder(feature_map, scale=scale, context=context)
            output = output + delta_descriptor
            if idx < len(scales) - 1:
                size = sizes[-(idx + 2)]
                output = F.interpolate(output, size=size, mode="bilinear", align_corners=False)
                context = F.interpolate(context, size=size, mode="bilinear", align_corners=False)

        if self.mode == "dect":
            return {"heatmap": output, "coarse_descs": features}
        elif self.mode == "desc":
            return {"fine_descs": output, "coarse_descs": features}
        elif self.mode == "dect+desc":
            logits = output[:, :1].contiguous()
            descs = output[:, 1:].contiguous()

            return {"heatmap": logits, "fine_descs": descs, "coarse_descs": features}
        else:
            raise ValueError("mode should be 'dect', 'desc' or 'dect+desc'")