File size: 5,563 Bytes
2e237ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
from torch import nn


class FeatureMerge(nn.Module):
    """Multimodal feature fusion used in VSR."""

    def __init__(
        self,
        feature_names,
        visual_dim,
        semantic_dim,
        merge_type="Sum",
        dropout_ratio=0.1,
        with_extra_fc=True,
        shortcut=False,
    ):
        """Multimodal feature merge used in VSR.
        Args:
            visual_dim (list): the dim of visual features, e.g. [256]
            semantic_dim (list): the dim of semantic features, e.g. [256]
            merge_type (str): fusion type, e.g. 'Sum', 'Concat', 'Weighted'
            dropout_ratio (float): dropout ratio of fusion features
            with_extra_fc (bool): whether add extra fc layers for adaptation
            shortcut (bool): whether add shortcut connection
        """
        super().__init__()

        # merge param
        self.feature_names = feature_names
        self.merge_type = merge_type
        self.visual_dim = visual_dim
        self.textual_dim = semantic_dim
        self.with_extra_fc = with_extra_fc
        self.shortcut = shortcut
        self.relu = nn.ReLU(inplace=True)

        if self.merge_type == "Sum":
            assert len(self.visual_dim) == len(self.textual_dim)
        elif self.merge_type == "Concat":
            assert len(self.visual_dim) == len(self.textual_dim)
            # self.concat_proj = nn.ModuleList()

            self.vis_proj = nn.ModuleList()
            self.text_proj = nn.ModuleList()
            self.alpha_proj = nn.ModuleList()

            for idx in range(len(self.visual_dim)):
                # self.concat_proj.append(nn.Conv2d(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx], kernel_size = (1,1), stride=1))
                if self.with_extra_fc:
                    self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx]))
                    self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx]))
                self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx]))

        elif self.merge_type == "Weighted":
            assert len(self.visual_dim) == len(self.textual_dim)
            self.total_num = len(self.visual_dim)

            # vis projection
            self.vis_proj = nn.ModuleList()
            self.vis_proj_relu = nn.ModuleList()

            # text projection
            self.text_proj = nn.ModuleList()
            self.text_proj_relu = nn.ModuleList()

            self.alpha_proj = nn.ModuleList()
            for idx in range(self.total_num):
                if self.with_extra_fc:
                    self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx]))
                    self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx]))
                self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx]))

        else:
            raise "Unknown merge type {}".format(self.merge_type)

        self.dropout = nn.Dropout(dropout_ratio)

        # visual context
        # self.visual_ap = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, visual_feat=None, textual_feat=None):
        """Forward computation
        Args:
            visual_feat (list(Tensor)): visual feature maps, in shape of [L x C x H x W] x B
            textual_feat (Tensor): textual feature maps, in shape of B x L x C
        Returns:
            Tensor: fused feature maps, in shape of [B x L x C]
        """
        assert len(visual_feat) == len(textual_feat)

        # feature merge
        merged_feat = {}
        if self.merge_type == "Sum":
            for name in self.feature_names:
                merged_feat[name] = visual_feat[name] + textual_feat[name]
        elif self.merge_type == "Concat":
            for idx, name in enumerate(self.feature_names):
                # merged_feat[name] = self.concat_proj[idx](torch.cat((visual_feat[name],textual_feat[name]),1))
                per_vis = visual_feat[name].permute(0, 2, 3, 1)
                per_text = textual_feat[name].permute(0, 2, 3, 1)
                if self.with_extra_fc:
                    per_vis = self.relu(self.vis_proj[idx](per_vis))
                    per_text = self.relu(self.text_proj[idx](per_text))
                x_sentence = self.alpha_proj[idx](torch.cat((per_vis, per_text), -1))
                x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous()
                merged_feat[name] = x_sentence
        else:
            assert self.total_num == len(visual_feat) or self.total_num == 1
            # for per_vis, per_text in zip(visual_feat, textual_feat):
            for idx, name in enumerate(self.feature_names):
                per_vis = visual_feat[name].permute(0, 2, 3, 1)
                per_text = textual_feat[name].permute(0, 2, 3, 1)
                if self.with_extra_fc:
                    per_vis = self.relu(self.vis_proj[idx](per_vis))
                    per_text = self.relu(self.text_proj[idx](per_text))

                alpha = torch.sigmoid(self.alpha_proj[idx](torch.cat((per_vis, per_text), -1)))
                if self.shortcut:
                    # shortcut
                    x_sentence = per_vis + alpha * per_text
                else:
                    # selection
                    x_sentence = alpha * per_vis + (1 - alpha) * per_text

                x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous()
                merged_feat[name] = x_sentence

        return merged_feat