Miroslav Purkrabek
add code
a249588
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Tuple
import numpy as np
import torch
from torch import Tensor, nn
from mmpose.evaluation.functional import keypoint_mpjpe
from mmpose.models.utils.tta import flip_coordinates
from mmpose.registry import KEYPOINT_CODECS, MODELS
from mmpose.utils.tensor_utils import to_numpy
from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList,
Predictions)
from ..base_head import BaseHead
@MODELS.register_module()
class MotionRegressionHead(BaseHead):
"""Regression head of `MotionBERT`_ by Zhu et al (2022).
Args:
in_channels (int): Number of input channels. Default: 256.
out_channels (int): Number of output channels. Default: 3.
embedding_size (int): Number of embedding channels. Default: 512.
loss (Config): Config for keypoint loss. Defaults to use
:class:`MSELoss`
decoder (Config, optional): The decoder config that controls decoding
keypoint coordinates from the network output. Defaults to ``None``
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
.. _`MotionBERT`: https://arxiv.org/abs/2210.06551
"""
_version = 2
def __init__(self,
in_channels: int = 256,
out_channels: int = 3,
embedding_size: int = 512,
loss: ConfigType = dict(
type='MSELoss', use_target_weight=True),
decoder: OptConfigType = None,
init_cfg: OptConfigType = None):
if init_cfg is None:
init_cfg = self.default_init_cfg
super().__init__(init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.loss_module = MODELS.build(loss)
if decoder is not None:
self.decoder = KEYPOINT_CODECS.build(decoder)
else:
self.decoder = None
# Define fully-connected layers
self.pre_logits = nn.Sequential(
OrderedDict([('fc', nn.Linear(in_channels, embedding_size)),
('act', nn.Tanh())]))
self.fc = nn.Linear(
embedding_size,
out_channels) if embedding_size > 0 else nn.Identity()
def forward(self, feats: Tuple[Tensor]) -> Tensor:
"""Forward the network. The input is multi scale feature maps and the
output is the coordinates.
Args:
feats (Tuple[Tensor]): Multi scale feature maps.
Returns:
Tensor: Output coordinates (and sigmas[optional]).
"""
x = feats # (B, F, K, in_channels)
x = self.pre_logits(x) # (B, F, K, embedding_size)
x = self.fc(x) # (B, F, K, out_channels)
return x
def predict(self,
feats: Tuple[Tensor],
batch_data_samples: OptSampleList,
test_cfg: ConfigType = {}) -> Predictions:
"""Predict results from outputs.
Returns:
preds (sequence[InstanceData]): Prediction results.
Each contains the following fields:
- keypoints: Predicted keypoints of shape (B, N, K, D).
- keypoint_scores: Scores of predicted keypoints of shape
(B, N, K).
"""
if test_cfg.get('flip_test', False):
# TTA: flip test -> feats = [orig, flipped]
assert isinstance(feats, list) and len(feats) == 2
flip_indices = batch_data_samples[0].metainfo['flip_indices']
_feats, _feats_flip = feats
_batch_coords = self.forward(_feats)
_batch_coords_flip = torch.stack([
flip_coordinates(
_batch_coord_flip,
flip_indices=flip_indices,
shift_coords=test_cfg.get('shift_coords', True),
input_size=(1, 1))
for _batch_coord_flip in self.forward(_feats_flip)
],
dim=0)
batch_coords = (_batch_coords + _batch_coords_flip) * 0.5
else:
batch_coords = self.forward(feats)
# Restore global position with camera_param and factor
camera_param = batch_data_samples[0].metainfo.get('camera_param', None)
if camera_param is not None:
w = torch.stack([
torch.from_numpy(np.array([b.metainfo['camera_param']['w']]))
for b in batch_data_samples
])
h = torch.stack([
torch.from_numpy(np.array([b.metainfo['camera_param']['h']]))
for b in batch_data_samples
])
else:
w = torch.stack([
torch.empty((0), dtype=torch.float32)
for _ in batch_data_samples
])
h = torch.stack([
torch.empty((0), dtype=torch.float32)
for _ in batch_data_samples
])
factor = batch_data_samples[0].metainfo.get('factor', None)
if factor is not None:
factor = torch.stack([
torch.from_numpy(b.metainfo['factor'])
for b in batch_data_samples
])
else:
factor = torch.stack([
torch.empty((0), dtype=torch.float32)
for _ in batch_data_samples
])
preds = self.decode((batch_coords, w, h, factor))
return preds
def loss(self,
inputs: Tuple[Tensor],
batch_data_samples: OptSampleList,
train_cfg: ConfigType = {}) -> dict:
"""Calculate losses from a batch of inputs and data samples."""
pred_outputs = self.forward(inputs)
lifting_target_label = torch.stack([
d.gt_instance_labels.lifting_target_label
for d in batch_data_samples
])
lifting_target_weight = torch.stack([
d.gt_instance_labels.lifting_target_weight
for d in batch_data_samples
])
# calculate losses
losses = dict()
loss = self.loss_module(pred_outputs, lifting_target_label,
lifting_target_weight.unsqueeze(-1))
losses.update(loss_pose3d=loss)
# calculate accuracy
mpjpe_err = keypoint_mpjpe(
pred=to_numpy(pred_outputs),
gt=to_numpy(lifting_target_label),
mask=to_numpy(lifting_target_weight) > 0)
mpjpe_pose = torch.tensor(
mpjpe_err, device=lifting_target_label.device)
losses.update(mpjpe=mpjpe_pose)
return losses
@property
def default_init_cfg(self):
init_cfg = [dict(type='TruncNormal', layer=['Linear'], std=0.02)]
return init_cfg