Miroslav Purkrabek
add code
a249588
# Copyright (c) OpenMMLab. All rights reserved.
from itertools import zip_longest
from typing import Tuple, Union
import torch
from torch import Tensor
from mmpose.models.utils import check_and_update_config
from mmpose.models.utils.tta import flip_coordinates
from mmpose.registry import MODELS
from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType,
Optional, OptMultiConfig, OptSampleList,
PixelDataList, SampleList)
from .base import BasePoseEstimator
@MODELS.register_module()
class PoseLifter(BasePoseEstimator):
"""Base class for pose lifter.
Args:
backbone (dict): The backbone config
neck (dict, optional): The neck config. Defaults to ``None``
head (dict, optional): The head config. Defaults to ``None``
traj_backbone (dict, optional): The backbone config for trajectory
model. Defaults to ``None``
traj_neck (dict, optional): The neck config for trajectory model.
Defaults to ``None``
traj_head (dict, optional): The head config for trajectory model.
Defaults to ``None``
semi_loss (dict, optional): The semi-supervised loss config.
Defaults to ``None``
train_cfg (dict, optional): The runtime config for training process.
Defaults to ``None``
test_cfg (dict, optional): The runtime config for testing process.
Defaults to ``None``
data_preprocessor (dict, optional): The data preprocessing config to
build the instance of :class:`BaseDataPreprocessor`. Defaults to
``None``
init_cfg (dict, optional): The config to control the initialization.
Defaults to ``None``
metainfo (dict): Meta information for dataset, such as keypoints
definition and properties. If set, the metainfo of the input data
batch will be overridden. For more details, please refer to
https://mmpose.readthedocs.io/en/latest/user_guides/
prepare_datasets.html#create-a-custom-dataset-info-
config-file-for-the-dataset. Defaults to ``None``
"""
def __init__(self,
backbone: ConfigType,
neck: OptConfigType = None,
head: OptConfigType = None,
traj_backbone: OptConfigType = None,
traj_neck: OptConfigType = None,
traj_head: OptConfigType = None,
semi_loss: OptConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None,
metainfo: Optional[dict] = None):
super().__init__(
backbone=backbone,
neck=neck,
head=head,
train_cfg=train_cfg,
test_cfg=test_cfg,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg,
metainfo=metainfo)
# trajectory model
self.share_backbone = False
if traj_head is not None:
if traj_backbone is not None:
self.traj_backbone = MODELS.build(traj_backbone)
else:
self.share_backbone = True
# the PR #2108 and #2126 modified the interface of neck and head.
# The following function automatically detects outdated
# configurations and updates them accordingly, while also providing
# clear and concise information on the changes made.
traj_neck, traj_head = check_and_update_config(
traj_neck, traj_head)
if traj_neck is not None:
self.traj_neck = MODELS.build(traj_neck)
self.traj_head = MODELS.build(traj_head)
# semi-supervised loss
self.semi_supervised = semi_loss is not None
if self.semi_supervised:
assert any([head, traj_head])
self.semi_loss = MODELS.build(semi_loss)
@property
def with_traj_backbone(self):
"""bool: Whether the pose lifter has trajectory backbone."""
return hasattr(self, 'traj_backbone') and \
self.traj_backbone is not None
@property
def with_traj_neck(self):
"""bool: Whether the pose lifter has trajectory neck."""
return hasattr(self, 'traj_neck') and self.traj_neck is not None
@property
def with_traj(self):
"""bool: Whether the pose lifter has trajectory head."""
return hasattr(self, 'traj_head')
@property
def causal(self):
"""bool: Whether the pose lifter is causal."""
if hasattr(self.backbone, 'causal'):
return self.backbone.causal
else:
raise AttributeError('A PoseLifter\'s backbone should have '
'the bool attribute "causal" to indicate if'
'it performs causal inference.')
def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]:
"""Extract features.
Args:
inputs (Tensor): Image tensor with shape (N, K, C, T).
Returns:
tuple[Tensor]: Multi-level features that may have various
resolutions.
"""
# supervised learning
# pose model
feats = self.backbone(inputs)
if self.with_neck:
feats = self.neck(feats)
# trajectory model
if self.with_traj:
if self.share_backbone:
traj_x = feats
else:
traj_x = self.traj_backbone(inputs)
if self.with_traj_neck:
traj_x = self.traj_neck(traj_x)
return feats, traj_x
else:
return feats
def _forward(self,
inputs: Tensor,
data_samples: OptSampleList = None
) -> Union[Tensor, Tuple[Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
inputs (Tensor): Inputs with shape (N, K, C, T).
Returns:
Union[Tensor | Tuple[Tensor]]: forward output of the network.
"""
feats = self.extract_feat(inputs)
if self.with_traj:
# forward with trajectory model
x, traj_x = feats
if self.with_head:
x = self.head.forward(x)
traj_x = self.traj_head.forward(traj_x)
return x, traj_x
else:
# forward without trajectory model
x = feats
if self.with_head:
x = self.head.forward(x)
return x
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (Tensor): Inputs with shape (N, K, C, T).
data_samples (List[:obj:`PoseDataSample`]): The batch
data samples.
Returns:
dict: A dictionary of losses.
"""
feats = self.extract_feat(inputs)
losses = {}
if self.with_traj:
x, traj_x = feats
# loss of trajectory model
losses.update(
self.traj_head.loss(
traj_x, data_samples, train_cfg=self.train_cfg))
else:
x = feats
if self.with_head:
# loss of pose model
losses.update(
self.head.loss(x, data_samples, train_cfg=self.train_cfg))
# TODO: support semi-supervised learning
if self.semi_supervised:
losses.update(semi_loss=self.semi_loss(inputs, data_samples))
return losses
def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Note:
- batch_size: B
- num_input_keypoints: K
- input_keypoint_dim: C
- input_sequence_len: T
Args:
inputs (Tensor): Inputs with shape like (B, K, C, T).
data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
Returns:
list[:obj:`PoseDataSample`]: The pose estimation results of the
input images. The return value is `PoseDataSample` instances with
``pred_instances`` and ``pred_fields``(optional) field , and
``pred_instances`` usually contains the following keys:
- keypoints (Tensor): predicted keypoint coordinates in shape
(num_instances, K, D) where K is the keypoint number and D
is the keypoint dimension
- keypoint_scores (Tensor): predicted keypoint scores in shape
(num_instances, K)
"""
assert self.with_head, (
'The model must have head to perform prediction.')
if self.test_cfg.get('flip_test', False):
flip_indices = data_samples[0].metainfo['flip_indices']
_feats = self.extract_feat(inputs)
_feats_flip = self.extract_feat(
torch.stack([
flip_coordinates(
_input,
flip_indices=flip_indices,
shift_coords=self.test_cfg.get('shift_coords', True),
input_size=(1, 1)) for _input in inputs
],
dim=0))
feats = [_feats, _feats_flip]
else:
feats = self.extract_feat(inputs)
pose_preds, batch_pred_instances, batch_pred_fields = None, None, None
traj_preds, batch_traj_instances, batch_traj_fields = None, None, None
if self.with_traj:
x, traj_x = feats
traj_preds = self.traj_head.predict(
traj_x, data_samples, test_cfg=self.test_cfg)
else:
x = feats
if self.with_head:
pose_preds = self.head.predict(
x, data_samples, test_cfg=self.test_cfg)
if isinstance(pose_preds, tuple):
batch_pred_instances, batch_pred_fields = pose_preds
else:
batch_pred_instances = pose_preds
if isinstance(traj_preds, tuple):
batch_traj_instances, batch_traj_fields = traj_preds
else:
batch_traj_instances = traj_preds
results = self.add_pred_to_datasample(batch_pred_instances,
batch_pred_fields,
batch_traj_instances,
batch_traj_fields, data_samples)
return results
def add_pred_to_datasample(
self,
batch_pred_instances: InstanceList,
batch_pred_fields: Optional[PixelDataList],
batch_traj_instances: InstanceList,
batch_traj_fields: Optional[PixelDataList],
batch_data_samples: SampleList,
) -> SampleList:
"""Add predictions into data samples.
Args:
batch_pred_instances (List[InstanceData]): The predicted instances
of the input data batch
batch_pred_fields (List[PixelData], optional): The predicted
fields (e.g. heatmaps) of the input batch
batch_traj_instances (List[InstanceData]): The predicted instances
of the input data batch
batch_traj_fields (List[PixelData], optional): The predicted
fields (e.g. heatmaps) of the input batch
batch_data_samples (List[PoseDataSample]): The input data batch
Returns:
List[PoseDataSample]: A list of data samples where the predictions
are stored in the ``pred_instances`` field of each data sample.
"""
assert len(batch_pred_instances) == len(batch_data_samples)
if batch_pred_fields is None:
batch_pred_fields, batch_traj_fields = [], []
if batch_traj_instances is None:
batch_traj_instances = []
output_keypoint_indices = self.test_cfg.get('output_keypoint_indices',
None)
for (pred_instances, pred_fields, traj_instances, traj_fields,
data_sample) in zip_longest(batch_pred_instances,
batch_pred_fields,
batch_traj_instances,
batch_traj_fields,
batch_data_samples):
if output_keypoint_indices is not None:
# select output keypoints with given indices
num_keypoints = pred_instances.keypoints.shape[1]
for key, value in pred_instances.all_items():
if key.startswith('keypoint'):
pred_instances.set_field(
value[:, output_keypoint_indices], key)
data_sample.pred_instances = pred_instances
if pred_fields is not None:
if output_keypoint_indices is not None:
# select output heatmap channels with keypoint indices
# when the number of heatmap channel matches num_keypoints
for key, value in pred_fields.all_items():
if value.shape[0] != num_keypoints:
continue
pred_fields.set_field(value[output_keypoint_indices],
key)
data_sample.pred_fields = pred_fields
return batch_data_samples