Miroslav Purkrabek
add code
a249588
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from mmengine.model import normal_init
from mmengine.structures import InstanceData
from torch import Tensor, nn
from mmpose.evaluation.functional import multilabel_classification_accuracy
from mmpose.models.necks import GlobalAveragePooling
from mmpose.models.utils.tta import flip_heatmaps
from mmpose.registry import KEYPOINT_CODECS, MODELS
from mmpose.utils.tensor_utils import to_numpy
from mmpose.utils.typing import (ConfigType, Features, InstanceList,
OptConfigType, OptSampleList, Predictions)
from ..base_head import BaseHead
from .heatmap_head import HeatmapHead
OptIntSeq = Optional[Sequence[int]]
def make_linear_layers(feat_dims, relu_final=False):
"""Make linear layers."""
layers = []
for i in range(len(feat_dims) - 1):
layers.append(nn.Linear(feat_dims[i], feat_dims[i + 1]))
if i < len(feat_dims) - 2 or \
(i == len(feat_dims) - 2 and relu_final):
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
class Heatmap3DHead(HeatmapHead):
"""Heatmap3DHead is a sub-module of Interhand3DHead, and outputs 3D
heatmaps. Heatmap3DHead is composed of (>=0) number of deconv layers and a
simple conv2d layer.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
depth_size (int): Number of depth discretization size. Defaults to 64.
deconv_out_channels (Sequence[int], optional): The output channel
number of each deconv layer. Defaults to ``(256, 256, 256)``
deconv_kernel_sizes (Sequence[int | tuple], optional): The kernel size
of each deconv layer. Each element should be either an integer for
both height and width dimensions, or a tuple of two integers for
the height and the width dimension respectively.Defaults to
``(4, 4, 4)``.
final_layer (dict): Arguments of the final Conv2d layer.
Defaults to ``dict(kernel_size=1)``.
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings.
"""
def __init__(self,
in_channels: Union[int, Sequence[int]],
out_channels: int,
depth_size: int = 64,
deconv_out_channels: OptIntSeq = (256, 256, 256),
deconv_kernel_sizes: OptIntSeq = (4, 4, 4),
final_layer: dict = dict(kernel_size=1),
init_cfg: OptConfigType = None):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
deconv_out_channels=deconv_out_channels,
deconv_kernel_sizes=deconv_kernel_sizes,
final_layer=final_layer,
init_cfg=init_cfg)
assert out_channels % depth_size == 0
self.depth_size = depth_size
def forward(self, feats: Tensor) -> Tensor:
"""Forward the network. The input is multi scale feature maps and the
output is the heatmap.
Args:
feats (Tensor): Feature map.
Returns:
Tensor: output heatmap.
"""
x = self.deconv_layers(feats)
x = self.final_layer(x)
N, C, H, W = x.shape
# reshape the 2D heatmap to 3D heatmap
x = x.reshape(N, C // self.depth_size, self.depth_size, H, W)
return x
class Heatmap1DHead(nn.Module):
"""Heatmap1DHead is a sub-module of Interhand3DHead, and outputs 1D
heatmaps.
Args:
in_channels (int): Number of input channels. Defaults to 2048.
heatmap_size (int): Heatmap size. Defaults to 64.
hidden_dims (Sequence[int]): Number of feature dimension of FC layers.
Defaults to ``(512, )``.
"""
def __init__(self,
in_channels: int = 2048,
heatmap_size: int = 64,
hidden_dims: Sequence[int] = (512, )):
super().__init__()
self.in_channels = in_channels
self.heatmap_size = heatmap_size
feature_dims = [in_channels, *hidden_dims, heatmap_size]
self.fc = make_linear_layers(feature_dims, relu_final=False)
def soft_argmax_1d(self, heatmap1d):
heatmap1d = F.softmax(heatmap1d, 1)
accu = heatmap1d * torch.arange(
self.heatmap_size, dtype=heatmap1d.dtype,
device=heatmap1d.device)[None, :]
coord = accu.sum(dim=1)
return coord
def forward(self, feats: Tuple[Tensor]) -> Tensor:
"""Forward the network.
Args:
feats (Tuple[Tensor]): Multi scale feature maps.
Returns:
Tensor: output heatmap.
"""
x = self.fc(feats)
x = self.soft_argmax_1d(x).view(-1, 1)
return x
def init_weights(self):
"""Initialize model weights."""
for m in self.fc.modules():
if isinstance(m, nn.Linear):
normal_init(m, mean=0, std=0.01, bias=0)
class MultilabelClassificationHead(nn.Module):
"""MultilabelClassificationHead is a sub-module of Interhand3DHead, and
outputs hand type classification.
Args:
in_channels (int): Number of input channels. Defaults to 2048.
num_labels (int): Number of labels. Defaults to 2.
hidden_dims (Sequence[int]): Number of hidden dimension of FC layers.
Defaults to ``(512, )``.
"""
def __init__(self,
in_channels: int = 2048,
num_labels: int = 2,
hidden_dims: Sequence[int] = (512, )):
super().__init__()
self.in_channels = in_channels
feature_dims = [in_channels, *hidden_dims, num_labels]
self.fc = make_linear_layers(feature_dims, relu_final=False)
def init_weights(self):
for m in self.fc.modules():
if isinstance(m, nn.Linear):
normal_init(m, mean=0, std=0.01, bias=0)
def forward(self, x):
"""Forward function."""
labels = self.fc(x)
return labels
@MODELS.register_module()
class InternetHead(BaseHead):
"""Internet head introduced in `Interhand 2.6M`_ by Moon et al (2020).
Args:
keypoint_head_cfg (dict): Configs of Heatmap3DHead for hand
keypoint estimation.
root_head_cfg (dict): Configs of Heatmap1DHead for relative
hand root depth estimation.
hand_type_head_cfg (dict): Configs of ``MultilabelClassificationHead``
for hand type classification.
loss (Config): Config of the keypoint loss.
Default: :class:`KeypointMSELoss`.
loss_root_depth (dict): Config for relative root depth loss.
Default: :class:`SmoothL1Loss`.
loss_hand_type (dict): Config for hand type classification
loss. Default: :class:`BCELoss`.
decoder (Config, optional): The decoder config that controls decoding
keypoint coordinates from the network output. Default: ``None``.
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
.. _`Interhand 2.6M`: https://arxiv.org/abs/2008.09309
"""
_version = 2
def __init__(self,
keypoint_head_cfg: ConfigType,
root_head_cfg: ConfigType,
hand_type_head_cfg: ConfigType,
loss: ConfigType = dict(
type='KeypointMSELoss', use_target_weight=True),
loss_root_depth: ConfigType = dict(
type='L1Loss', use_target_weight=True),
loss_hand_type: ConfigType = dict(
type='BCELoss', use_target_weight=True),
decoder: OptConfigType = None,
init_cfg: OptConfigType = None):
super().__init__()
# build sub-module heads
self.right_hand_head = Heatmap3DHead(**keypoint_head_cfg)
self.left_hand_head = Heatmap3DHead(**keypoint_head_cfg)
self.root_head = Heatmap1DHead(**root_head_cfg)
self.hand_type_head = MultilabelClassificationHead(
**hand_type_head_cfg)
self.neck = GlobalAveragePooling()
self.loss_module = MODELS.build(loss)
self.root_loss_module = MODELS.build(loss_root_depth)
self.hand_loss_module = MODELS.build(loss_hand_type)
if decoder is not None:
self.decoder = KEYPOINT_CODECS.build(decoder)
else:
self.decoder = None
def forward(self, feats: Tuple[Tensor]) -> Tensor:
"""Forward the network. The input is multi scale feature maps and the
output is the heatmap.
Args:
feats (Tuple[Tensor]): Multi scale feature maps.
Returns:
Tuple[Tensor]: Output heatmap, root depth estimation and hand type
classification.
"""
x = feats[-1]
outputs = []
outputs.append(
torch.cat([self.right_hand_head(x),
self.left_hand_head(x)], dim=1))
x = self.neck(x)
outputs.append(self.root_head(x))
outputs.append(self.hand_type_head(x))
return outputs
def predict(self,
feats: Features,
batch_data_samples: OptSampleList,
test_cfg: ConfigType = {}) -> Predictions:
"""Predict results from features.
Args:
feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
features (or multiple multi-stage features in TTA)
batch_data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
test_cfg (dict): The runtime config for testing process. Defaults
to {}
Returns:
InstanceList: Return the pose prediction.
The pose prediction is a list of ``InstanceData``, each contains
the following fields:
- keypoints (np.ndarray): predicted keypoint coordinates in
shape (num_instances, K, D) where K is the keypoint number
and D is the keypoint dimension
- keypoint_scores (np.ndarray): predicted keypoint scores in
shape (num_instances, K)
"""
if test_cfg.get('flip_test', False):
# TTA: flip test -> feats = [orig, flipped]
assert isinstance(feats, list) and len(feats) == 2
flip_indices = batch_data_samples[0].metainfo['flip_indices']
_feats, _feats_flip = feats
_batch_outputs = self.forward(_feats)
_batch_heatmaps = _batch_outputs[0]
_batch_outputs_flip = self.forward(_feats_flip)
_batch_heatmaps_flip = flip_heatmaps(
_batch_outputs_flip[0],
flip_mode=test_cfg.get('flip_mode', 'heatmap'),
flip_indices=flip_indices,
shift_heatmap=test_cfg.get('shift_heatmap', False))
batch_heatmaps = (_batch_heatmaps + _batch_heatmaps_flip) * 0.5
# flip relative hand root depth
_batch_root = _batch_outputs[1]
_batch_root_flip = -_batch_outputs_flip[1]
batch_root = (_batch_root + _batch_root_flip) * 0.5
# flip hand type
_batch_type = _batch_outputs[2]
_batch_type_flip = torch.empty_like(_batch_outputs_flip[2])
_batch_type_flip[:, 0] = _batch_type[:, 1]
_batch_type_flip[:, 1] = _batch_type[:, 0]
batch_type = (_batch_type + _batch_type_flip) * 0.5
batch_outputs = [batch_heatmaps, batch_root, batch_type]
else:
batch_outputs = self.forward(feats)
preds = self.decode(tuple(batch_outputs))
return preds
def loss(self,
feats: Tuple[Tensor],
batch_data_samples: OptSampleList,
train_cfg: ConfigType = {}) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
feats (Tuple[Tensor]): The multi-stage features
batch_data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
train_cfg (dict): The runtime config for training process.
Defaults to {}
Returns:
dict: A dictionary of losses.
"""
pred_fields = self.forward(feats)
pred_heatmaps = pred_fields[0]
_, K, D, W, H = pred_heatmaps.shape
gt_heatmaps = torch.stack([
d.gt_fields.heatmaps.reshape(K, D, W, H)
for d in batch_data_samples
])
keypoint_weights = torch.cat([
d.gt_instance_labels.keypoint_weights for d in batch_data_samples
])
# calculate losses
losses = dict()
# hand keypoint loss
loss = self.loss_module(pred_heatmaps, gt_heatmaps, keypoint_weights)
losses.update(loss_kpt=loss)
# relative root depth loss
gt_roots = torch.stack(
[d.gt_instance_labels.root_depth for d in batch_data_samples])
root_weights = torch.stack([
d.gt_instance_labels.root_depth_weight for d in batch_data_samples
])
loss_root = self.root_loss_module(pred_fields[1], gt_roots,
root_weights)
losses.update(loss_rel_root=loss_root)
# hand type loss
gt_types = torch.stack([
d.gt_instance_labels.type.reshape(-1) for d in batch_data_samples
])
type_weights = torch.stack(
[d.gt_instance_labels.type_weight for d in batch_data_samples])
loss_type = self.hand_loss_module(pred_fields[2], gt_types,
type_weights)
losses.update(loss_hand_type=loss_type)
# calculate accuracy
if train_cfg.get('compute_acc', True):
acc = multilabel_classification_accuracy(
pred=to_numpy(pred_fields[2]),
gt=to_numpy(gt_types),
mask=to_numpy(type_weights))
acc_pose = torch.tensor(acc, device=gt_types.device)
losses.update(acc_pose=acc_pose)
return losses
def decode(self, batch_outputs: Union[Tensor,
Tuple[Tensor]]) -> InstanceList:
"""Decode keypoints from outputs.
Args:
batch_outputs (Tensor | Tuple[Tensor]): The network outputs of
a data batch
Returns:
List[InstanceData]: A list of InstanceData, each contains the
decoded pose information of the instances of one data sample.
"""
def _pack_and_call(args, func):
if not isinstance(args, tuple):
args = (args, )
return func(*args)
if self.decoder is None:
raise RuntimeError(
f'The decoder has not been set in {self.__class__.__name__}. '
'Please set the decoder configs in the init parameters to '
'enable head methods `head.predict()` and `head.decode()`')
batch_output_np = to_numpy(batch_outputs[0], unzip=True)
batch_root_np = to_numpy(batch_outputs[1], unzip=True)
batch_type_np = to_numpy(batch_outputs[2], unzip=True)
batch_keypoints = []
batch_scores = []
batch_roots = []
batch_types = []
for outputs, roots, types in zip(batch_output_np, batch_root_np,
batch_type_np):
keypoints, scores, rel_root_depth, hand_type = _pack_and_call(
tuple([outputs, roots, types]), self.decoder.decode)
batch_keypoints.append(keypoints)
batch_scores.append(scores)
batch_roots.append(rel_root_depth)
batch_types.append(hand_type)
preds = [
InstanceData(
keypoints=keypoints,
keypoint_scores=scores,
rel_root_depth=rel_root_depth,
hand_type=hand_type)
for keypoints, scores, rel_root_depth, hand_type in zip(
batch_keypoints, batch_scores, batch_roots, batch_types)
]
return preds