Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn.bricks import DropPath | |
from mmengine.model import BaseModule, constant_init | |
from mmengine.model.weight_init import trunc_normal_ | |
from mmpose.registry import MODELS | |
from .base_backbone import BaseBackbone | |
class Attention(BaseModule): | |
def __init__(self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_scale=None, | |
attn_drop=0., | |
proj_drop=0., | |
mode='spatial'): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = qk_scale or head_dim**-0.5 | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.mode = mode | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.proj_drop = nn.Dropout(proj_drop) | |
self.attn_count_s = None | |
self.attn_count_t = None | |
def forward(self, x, seq_len=1): | |
B, N, C = x.shape | |
if self.mode == 'temporal': | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // | |
self.num_heads).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv[0], qkv[1], qkv[ | |
2] # make torchscript happy (cannot use tensor as tuple) | |
x = self.forward_temporal(q, k, v, seq_len=seq_len) | |
elif self.mode == 'spatial': | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // | |
self.num_heads).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv[0], qkv[1], qkv[ | |
2] # make torchscript happy (cannot use tensor as tuple) | |
x = self.forward_spatial(q, k, v) | |
else: | |
raise NotImplementedError(self.mode) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
def forward_spatial(self, q, k, v): | |
B, _, N, C = q.shape | |
attn = (q @ k.transpose(-2, -1)) * self.scale | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = attn @ v | |
x = x.transpose(1, 2).reshape(B, N, C * self.num_heads) | |
return x | |
def forward_temporal(self, q, k, v, seq_len=8): | |
B, _, N, C = q.shape | |
qt = q.reshape(-1, seq_len, self.num_heads, N, | |
C).permute(0, 2, 3, 1, 4) # (B, H, N, T, C) | |
kt = k.reshape(-1, seq_len, self.num_heads, N, | |
C).permute(0, 2, 3, 1, 4) # (B, H, N, T, C) | |
vt = v.reshape(-1, seq_len, self.num_heads, N, | |
C).permute(0, 2, 3, 1, 4) # (B, H, N, T, C) | |
attn = (qt @ kt.transpose(-2, -1)) * self.scale | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = attn @ vt # (B, H, N, T, C) | |
x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C * self.num_heads) | |
return x | |
class AttentionBlock(BaseModule): | |
def __init__(self, | |
dim, | |
num_heads, | |
mlp_ratio=4., | |
mlp_out_ratio=1., | |
qkv_bias=True, | |
qk_scale=None, | |
drop=0., | |
attn_drop=0., | |
drop_path=0., | |
st_mode='st'): | |
super().__init__() | |
self.st_mode = st_mode | |
self.norm1_s = nn.LayerNorm(dim, eps=1e-06) | |
self.norm1_t = nn.LayerNorm(dim, eps=1e-06) | |
self.attn_s = Attention( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=drop, | |
mode='spatial') | |
self.attn_t = Attention( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=drop, | |
mode='temporal') | |
self.drop_path = DropPath( | |
drop_path) if drop_path > 0. else nn.Identity() | |
self.norm2_s = nn.LayerNorm(dim, eps=1e-06) | |
self.norm2_t = nn.LayerNorm(dim, eps=1e-06) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
mlp_out_dim = int(dim * mlp_out_ratio) | |
self.mlp_s = nn.Sequential( | |
nn.Linear(dim, mlp_hidden_dim), nn.GELU(), | |
nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(drop)) | |
self.mlp_t = nn.Sequential( | |
nn.Linear(dim, mlp_hidden_dim), nn.GELU(), | |
nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(drop)) | |
def forward(self, x, seq_len=1): | |
if self.st_mode == 'st': | |
x = x + self.drop_path(self.attn_s(self.norm1_s(x), seq_len)) | |
x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) | |
x = x + self.drop_path(self.attn_t(self.norm1_t(x), seq_len)) | |
x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) | |
elif self.st_mode == 'ts': | |
x = x + self.drop_path(self.attn_t(self.norm1_t(x), seq_len)) | |
x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) | |
x = x + self.drop_path(self.attn_s(self.norm1_s(x), seq_len)) | |
x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) | |
else: | |
raise NotImplementedError(self.st_mode) | |
return x | |
class DSTFormer(BaseBackbone): | |
"""Dual-stream Spatio-temporal Transformer Module. | |
Args: | |
in_channels (int): Number of input channels. | |
feat_size: Number of feature channels. Default: 256. | |
depth: The network depth. Default: 5. | |
num_heads: Number of heads in multi-Head self-attention blocks. | |
Default: 8. | |
mlp_ratio (int, optional): The expansion ratio of FFN. Default: 4. | |
num_keypoints: num_keypoints (int): Number of keypoints. Default: 17. | |
seq_len: The sequence length. Default: 243. | |
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. | |
Default: True. | |
qk_scale (float | None, optional): Override default qk scale of | |
head_dim ** -0.5 if set. Default: None. | |
drop_rate (float, optional): Dropout ratio of input. Default: 0. | |
attn_drop_rate (float, optional): Dropout ratio of attention weight. | |
Default: 0. | |
drop_path_rate (float, optional): Stochastic depth rate. Default: 0. | |
att_fuse: Whether to fuse the results of attention blocks. | |
Default: True. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None | |
Example: | |
>>> from mmpose.models import DSTFormer | |
>>> import torch | |
>>> self = DSTFormer(in_channels=3) | |
>>> self.eval() | |
>>> inputs = torch.rand(1, 2, 17, 3) | |
>>> level_outputs = self.forward(inputs) | |
>>> print(tuple(level_outputs.shape)) | |
(1, 2, 17, 512) | |
""" | |
def __init__(self, | |
in_channels, | |
feat_size=256, | |
depth=5, | |
num_heads=8, | |
mlp_ratio=4, | |
num_keypoints=17, | |
seq_len=243, | |
qkv_bias=True, | |
qk_scale=None, | |
drop_rate=0., | |
attn_drop_rate=0., | |
drop_path_rate=0., | |
att_fuse=True, | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.in_channels = in_channels | |
self.feat_size = feat_size | |
self.joints_embed = nn.Linear(in_channels, feat_size) | |
self.pos_drop = nn.Dropout(p=drop_rate) | |
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) | |
] # stochastic depth decay rule | |
self.blocks_st = nn.ModuleList([ | |
AttentionBlock( | |
dim=feat_size, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
drop=drop_rate, | |
attn_drop=attn_drop_rate, | |
drop_path=dpr[i], | |
st_mode='st') for i in range(depth) | |
]) | |
self.blocks_ts = nn.ModuleList([ | |
AttentionBlock( | |
dim=feat_size, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
drop=drop_rate, | |
attn_drop=attn_drop_rate, | |
drop_path=dpr[i], | |
st_mode='ts') for i in range(depth) | |
]) | |
self.norm = nn.LayerNorm(feat_size, eps=1e-06) | |
self.temp_embed = nn.Parameter(torch.zeros(1, seq_len, 1, feat_size)) | |
self.spat_embed = nn.Parameter( | |
torch.zeros(1, num_keypoints, feat_size)) | |
trunc_normal_(self.temp_embed, std=.02) | |
trunc_normal_(self.spat_embed, std=.02) | |
self.att_fuse = att_fuse | |
if self.att_fuse: | |
self.attn_regress = nn.ModuleList( | |
[nn.Linear(feat_size * 2, 2) for i in range(depth)]) | |
for i in range(depth): | |
self.attn_regress[i].weight.data.fill_(0) | |
self.attn_regress[i].bias.data.fill_(0.5) | |
def forward(self, x): | |
if len(x.shape) == 3: | |
x = x[None, :] | |
assert len(x.shape) == 4 | |
B, F, K, C = x.shape | |
x = x.reshape(-1, K, C) | |
BF = x.shape[0] | |
x = self.joints_embed(x) # (BF, K, feat_size) | |
x = x + self.spat_embed | |
_, K, C = x.shape | |
x = x.reshape(-1, F, K, C) + self.temp_embed[:, :F, :, :] | |
x = x.reshape(BF, K, C) # (BF, K, feat_size) | |
x = self.pos_drop(x) | |
for idx, (blk_st, | |
blk_ts) in enumerate(zip(self.blocks_st, self.blocks_ts)): | |
x_st = blk_st(x, F) | |
x_ts = blk_ts(x, F) | |
if self.att_fuse: | |
att = self.attn_regress[idx] | |
alpha = torch.cat([x_st, x_ts], dim=-1) | |
BF, K = alpha.shape[:2] | |
alpha = att(alpha) | |
alpha = alpha.softmax(dim=-1) | |
x = x_st * alpha[:, :, 0:1] + x_ts * alpha[:, :, 1:2] | |
else: | |
x = (x_st + x_ts) * 0.5 | |
x = self.norm(x) # (BF, K, feat_size) | |
x = x.reshape(B, F, K, -1) | |
return x | |
def init_weights(self): | |
"""Initialize the weights in backbone.""" | |
super(DSTFormer, self).init_weights() | |
if (isinstance(self.init_cfg, dict) | |
and self.init_cfg['type'] == 'Pretrained'): | |
return | |
for m in self.modules(): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
constant_init(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
constant_init(m.bias, 0) | |
constant_init(m.weight, 1.0) | |