Spaces:
Running
Running
File size: 5,563 Bytes
2e237ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
from torch import nn
class FeatureMerge(nn.Module):
"""Multimodal feature fusion used in VSR."""
def __init__(
self,
feature_names,
visual_dim,
semantic_dim,
merge_type="Sum",
dropout_ratio=0.1,
with_extra_fc=True,
shortcut=False,
):
"""Multimodal feature merge used in VSR.
Args:
visual_dim (list): the dim of visual features, e.g. [256]
semantic_dim (list): the dim of semantic features, e.g. [256]
merge_type (str): fusion type, e.g. 'Sum', 'Concat', 'Weighted'
dropout_ratio (float): dropout ratio of fusion features
with_extra_fc (bool): whether add extra fc layers for adaptation
shortcut (bool): whether add shortcut connection
"""
super().__init__()
# merge param
self.feature_names = feature_names
self.merge_type = merge_type
self.visual_dim = visual_dim
self.textual_dim = semantic_dim
self.with_extra_fc = with_extra_fc
self.shortcut = shortcut
self.relu = nn.ReLU(inplace=True)
if self.merge_type == "Sum":
assert len(self.visual_dim) == len(self.textual_dim)
elif self.merge_type == "Concat":
assert len(self.visual_dim) == len(self.textual_dim)
# self.concat_proj = nn.ModuleList()
self.vis_proj = nn.ModuleList()
self.text_proj = nn.ModuleList()
self.alpha_proj = nn.ModuleList()
for idx in range(len(self.visual_dim)):
# self.concat_proj.append(nn.Conv2d(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx], kernel_size = (1,1), stride=1))
if self.with_extra_fc:
self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx]))
self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx]))
self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx]))
elif self.merge_type == "Weighted":
assert len(self.visual_dim) == len(self.textual_dim)
self.total_num = len(self.visual_dim)
# vis projection
self.vis_proj = nn.ModuleList()
self.vis_proj_relu = nn.ModuleList()
# text projection
self.text_proj = nn.ModuleList()
self.text_proj_relu = nn.ModuleList()
self.alpha_proj = nn.ModuleList()
for idx in range(self.total_num):
if self.with_extra_fc:
self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx]))
self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx]))
self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx]))
else:
raise "Unknown merge type {}".format(self.merge_type)
self.dropout = nn.Dropout(dropout_ratio)
# visual context
# self.visual_ap = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, visual_feat=None, textual_feat=None):
"""Forward computation
Args:
visual_feat (list(Tensor)): visual feature maps, in shape of [L x C x H x W] x B
textual_feat (Tensor): textual feature maps, in shape of B x L x C
Returns:
Tensor: fused feature maps, in shape of [B x L x C]
"""
assert len(visual_feat) == len(textual_feat)
# feature merge
merged_feat = {}
if self.merge_type == "Sum":
for name in self.feature_names:
merged_feat[name] = visual_feat[name] + textual_feat[name]
elif self.merge_type == "Concat":
for idx, name in enumerate(self.feature_names):
# merged_feat[name] = self.concat_proj[idx](torch.cat((visual_feat[name],textual_feat[name]),1))
per_vis = visual_feat[name].permute(0, 2, 3, 1)
per_text = textual_feat[name].permute(0, 2, 3, 1)
if self.with_extra_fc:
per_vis = self.relu(self.vis_proj[idx](per_vis))
per_text = self.relu(self.text_proj[idx](per_text))
x_sentence = self.alpha_proj[idx](torch.cat((per_vis, per_text), -1))
x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous()
merged_feat[name] = x_sentence
else:
assert self.total_num == len(visual_feat) or self.total_num == 1
# for per_vis, per_text in zip(visual_feat, textual_feat):
for idx, name in enumerate(self.feature_names):
per_vis = visual_feat[name].permute(0, 2, 3, 1)
per_text = textual_feat[name].permute(0, 2, 3, 1)
if self.with_extra_fc:
per_vis = self.relu(self.vis_proj[idx](per_vis))
per_text = self.relu(self.text_proj[idx](per_text))
alpha = torch.sigmoid(self.alpha_proj[idx](torch.cat((per_vis, per_text), -1)))
if self.shortcut:
# shortcut
x_sentence = per_vis + alpha * per_text
else:
# selection
x_sentence = alpha * per_vis + (1 - alpha) * per_text
x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous()
merged_feat[name] = x_sentence
return merged_feat
|