Wasim
Sync: robust vehicle parser + full project
2e237ce
import torch
from torch import nn
class FeatureMerge(nn.Module):
"""Multimodal feature fusion used in VSR."""
def __init__(
self,
feature_names,
visual_dim,
semantic_dim,
merge_type="Sum",
dropout_ratio=0.1,
with_extra_fc=True,
shortcut=False,
):
"""Multimodal feature merge used in VSR.
Args:
visual_dim (list): the dim of visual features, e.g. [256]
semantic_dim (list): the dim of semantic features, e.g. [256]
merge_type (str): fusion type, e.g. 'Sum', 'Concat', 'Weighted'
dropout_ratio (float): dropout ratio of fusion features
with_extra_fc (bool): whether add extra fc layers for adaptation
shortcut (bool): whether add shortcut connection
"""
super().__init__()
# merge param
self.feature_names = feature_names
self.merge_type = merge_type
self.visual_dim = visual_dim
self.textual_dim = semantic_dim
self.with_extra_fc = with_extra_fc
self.shortcut = shortcut
self.relu = nn.ReLU(inplace=True)
if self.merge_type == "Sum":
assert len(self.visual_dim) == len(self.textual_dim)
elif self.merge_type == "Concat":
assert len(self.visual_dim) == len(self.textual_dim)
# self.concat_proj = nn.ModuleList()
self.vis_proj = nn.ModuleList()
self.text_proj = nn.ModuleList()
self.alpha_proj = nn.ModuleList()
for idx in range(len(self.visual_dim)):
# self.concat_proj.append(nn.Conv2d(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx], kernel_size = (1,1), stride=1))
if self.with_extra_fc:
self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx]))
self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx]))
self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx]))
elif self.merge_type == "Weighted":
assert len(self.visual_dim) == len(self.textual_dim)
self.total_num = len(self.visual_dim)
# vis projection
self.vis_proj = nn.ModuleList()
self.vis_proj_relu = nn.ModuleList()
# text projection
self.text_proj = nn.ModuleList()
self.text_proj_relu = nn.ModuleList()
self.alpha_proj = nn.ModuleList()
for idx in range(self.total_num):
if self.with_extra_fc:
self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx]))
self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx]))
self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx]))
else:
raise "Unknown merge type {}".format(self.merge_type)
self.dropout = nn.Dropout(dropout_ratio)
# visual context
# self.visual_ap = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, visual_feat=None, textual_feat=None):
"""Forward computation
Args:
visual_feat (list(Tensor)): visual feature maps, in shape of [L x C x H x W] x B
textual_feat (Tensor): textual feature maps, in shape of B x L x C
Returns:
Tensor: fused feature maps, in shape of [B x L x C]
"""
assert len(visual_feat) == len(textual_feat)
# feature merge
merged_feat = {}
if self.merge_type == "Sum":
for name in self.feature_names:
merged_feat[name] = visual_feat[name] + textual_feat[name]
elif self.merge_type == "Concat":
for idx, name in enumerate(self.feature_names):
# merged_feat[name] = self.concat_proj[idx](torch.cat((visual_feat[name],textual_feat[name]),1))
per_vis = visual_feat[name].permute(0, 2, 3, 1)
per_text = textual_feat[name].permute(0, 2, 3, 1)
if self.with_extra_fc:
per_vis = self.relu(self.vis_proj[idx](per_vis))
per_text = self.relu(self.text_proj[idx](per_text))
x_sentence = self.alpha_proj[idx](torch.cat((per_vis, per_text), -1))
x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous()
merged_feat[name] = x_sentence
else:
assert self.total_num == len(visual_feat) or self.total_num == 1
# for per_vis, per_text in zip(visual_feat, textual_feat):
for idx, name in enumerate(self.feature_names):
per_vis = visual_feat[name].permute(0, 2, 3, 1)
per_text = textual_feat[name].permute(0, 2, 3, 1)
if self.with_extra_fc:
per_vis = self.relu(self.vis_proj[idx](per_vis))
per_text = self.relu(self.text_proj[idx](per_text))
alpha = torch.sigmoid(self.alpha_proj[idx](torch.cat((per_vis, per_text), -1)))
if self.shortcut:
# shortcut
x_sentence = per_vis + alpha * per_text
else:
# selection
x_sentence = alpha * per_vis + (1 - alpha) * per_text
x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous()
merged_feat[name] = x_sentence
return merged_feat