Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
from itertools import zip_longest | |
from typing import Tuple, Union | |
import torch | |
from torch import Tensor | |
from mmpose.models.utils import check_and_update_config | |
from mmpose.models.utils.tta import flip_coordinates | |
from mmpose.registry import MODELS | |
from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType, | |
Optional, OptMultiConfig, OptSampleList, | |
PixelDataList, SampleList) | |
from .base import BasePoseEstimator | |
class PoseLifter(BasePoseEstimator): | |
"""Base class for pose lifter. | |
Args: | |
backbone (dict): The backbone config | |
neck (dict, optional): The neck config. Defaults to ``None`` | |
head (dict, optional): The head config. Defaults to ``None`` | |
traj_backbone (dict, optional): The backbone config for trajectory | |
model. Defaults to ``None`` | |
traj_neck (dict, optional): The neck config for trajectory model. | |
Defaults to ``None`` | |
traj_head (dict, optional): The head config for trajectory model. | |
Defaults to ``None`` | |
semi_loss (dict, optional): The semi-supervised loss config. | |
Defaults to ``None`` | |
train_cfg (dict, optional): The runtime config for training process. | |
Defaults to ``None`` | |
test_cfg (dict, optional): The runtime config for testing process. | |
Defaults to ``None`` | |
data_preprocessor (dict, optional): The data preprocessing config to | |
build the instance of :class:`BaseDataPreprocessor`. Defaults to | |
``None`` | |
init_cfg (dict, optional): The config to control the initialization. | |
Defaults to ``None`` | |
metainfo (dict): Meta information for dataset, such as keypoints | |
definition and properties. If set, the metainfo of the input data | |
batch will be overridden. For more details, please refer to | |
https://mmpose.readthedocs.io/en/latest/user_guides/ | |
prepare_datasets.html#create-a-custom-dataset-info- | |
config-file-for-the-dataset. Defaults to ``None`` | |
""" | |
def __init__(self, | |
backbone: ConfigType, | |
neck: OptConfigType = None, | |
head: OptConfigType = None, | |
traj_backbone: OptConfigType = None, | |
traj_neck: OptConfigType = None, | |
traj_head: OptConfigType = None, | |
semi_loss: OptConfigType = None, | |
train_cfg: OptConfigType = None, | |
test_cfg: OptConfigType = None, | |
data_preprocessor: OptConfigType = None, | |
init_cfg: OptMultiConfig = None, | |
metainfo: Optional[dict] = None): | |
super().__init__( | |
backbone=backbone, | |
neck=neck, | |
head=head, | |
train_cfg=train_cfg, | |
test_cfg=test_cfg, | |
data_preprocessor=data_preprocessor, | |
init_cfg=init_cfg, | |
metainfo=metainfo) | |
# trajectory model | |
self.share_backbone = False | |
if traj_head is not None: | |
if traj_backbone is not None: | |
self.traj_backbone = MODELS.build(traj_backbone) | |
else: | |
self.share_backbone = True | |
# the PR #2108 and #2126 modified the interface of neck and head. | |
# The following function automatically detects outdated | |
# configurations and updates them accordingly, while also providing | |
# clear and concise information on the changes made. | |
traj_neck, traj_head = check_and_update_config( | |
traj_neck, traj_head) | |
if traj_neck is not None: | |
self.traj_neck = MODELS.build(traj_neck) | |
self.traj_head = MODELS.build(traj_head) | |
# semi-supervised loss | |
self.semi_supervised = semi_loss is not None | |
if self.semi_supervised: | |
assert any([head, traj_head]) | |
self.semi_loss = MODELS.build(semi_loss) | |
def with_traj_backbone(self): | |
"""bool: Whether the pose lifter has trajectory backbone.""" | |
return hasattr(self, 'traj_backbone') and \ | |
self.traj_backbone is not None | |
def with_traj_neck(self): | |
"""bool: Whether the pose lifter has trajectory neck.""" | |
return hasattr(self, 'traj_neck') and self.traj_neck is not None | |
def with_traj(self): | |
"""bool: Whether the pose lifter has trajectory head.""" | |
return hasattr(self, 'traj_head') | |
def causal(self): | |
"""bool: Whether the pose lifter is causal.""" | |
if hasattr(self.backbone, 'causal'): | |
return self.backbone.causal | |
else: | |
raise AttributeError('A PoseLifter\'s backbone should have ' | |
'the bool attribute "causal" to indicate if' | |
'it performs causal inference.') | |
def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]: | |
"""Extract features. | |
Args: | |
inputs (Tensor): Image tensor with shape (N, K, C, T). | |
Returns: | |
tuple[Tensor]: Multi-level features that may have various | |
resolutions. | |
""" | |
# supervised learning | |
# pose model | |
feats = self.backbone(inputs) | |
if self.with_neck: | |
feats = self.neck(feats) | |
# trajectory model | |
if self.with_traj: | |
if self.share_backbone: | |
traj_x = feats | |
else: | |
traj_x = self.traj_backbone(inputs) | |
if self.with_traj_neck: | |
traj_x = self.traj_neck(traj_x) | |
return feats, traj_x | |
else: | |
return feats | |
def _forward(self, | |
inputs: Tensor, | |
data_samples: OptSampleList = None | |
) -> Union[Tensor, Tuple[Tensor]]: | |
"""Network forward process. Usually includes backbone, neck and head | |
forward without any post-processing. | |
Args: | |
inputs (Tensor): Inputs with shape (N, K, C, T). | |
Returns: | |
Union[Tensor | Tuple[Tensor]]: forward output of the network. | |
""" | |
feats = self.extract_feat(inputs) | |
if self.with_traj: | |
# forward with trajectory model | |
x, traj_x = feats | |
if self.with_head: | |
x = self.head.forward(x) | |
traj_x = self.traj_head.forward(traj_x) | |
return x, traj_x | |
else: | |
# forward without trajectory model | |
x = feats | |
if self.with_head: | |
x = self.head.forward(x) | |
return x | |
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: | |
"""Calculate losses from a batch of inputs and data samples. | |
Args: | |
inputs (Tensor): Inputs with shape (N, K, C, T). | |
data_samples (List[:obj:`PoseDataSample`]): The batch | |
data samples. | |
Returns: | |
dict: A dictionary of losses. | |
""" | |
feats = self.extract_feat(inputs) | |
losses = {} | |
if self.with_traj: | |
x, traj_x = feats | |
# loss of trajectory model | |
losses.update( | |
self.traj_head.loss( | |
traj_x, data_samples, train_cfg=self.train_cfg)) | |
else: | |
x = feats | |
if self.with_head: | |
# loss of pose model | |
losses.update( | |
self.head.loss(x, data_samples, train_cfg=self.train_cfg)) | |
# TODO: support semi-supervised learning | |
if self.semi_supervised: | |
losses.update(semi_loss=self.semi_loss(inputs, data_samples)) | |
return losses | |
def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList: | |
"""Predict results from a batch of inputs and data samples with post- | |
processing. | |
Note: | |
- batch_size: B | |
- num_input_keypoints: K | |
- input_keypoint_dim: C | |
- input_sequence_len: T | |
Args: | |
inputs (Tensor): Inputs with shape like (B, K, C, T). | |
data_samples (List[:obj:`PoseDataSample`]): The batch | |
data samples | |
Returns: | |
list[:obj:`PoseDataSample`]: The pose estimation results of the | |
input images. The return value is `PoseDataSample` instances with | |
``pred_instances`` and ``pred_fields``(optional) field , and | |
``pred_instances`` usually contains the following keys: | |
- keypoints (Tensor): predicted keypoint coordinates in shape | |
(num_instances, K, D) where K is the keypoint number and D | |
is the keypoint dimension | |
- keypoint_scores (Tensor): predicted keypoint scores in shape | |
(num_instances, K) | |
""" | |
assert self.with_head, ( | |
'The model must have head to perform prediction.') | |
if self.test_cfg.get('flip_test', False): | |
flip_indices = data_samples[0].metainfo['flip_indices'] | |
_feats = self.extract_feat(inputs) | |
_feats_flip = self.extract_feat( | |
torch.stack([ | |
flip_coordinates( | |
_input, | |
flip_indices=flip_indices, | |
shift_coords=self.test_cfg.get('shift_coords', True), | |
input_size=(1, 1)) for _input in inputs | |
], | |
dim=0)) | |
feats = [_feats, _feats_flip] | |
else: | |
feats = self.extract_feat(inputs) | |
pose_preds, batch_pred_instances, batch_pred_fields = None, None, None | |
traj_preds, batch_traj_instances, batch_traj_fields = None, None, None | |
if self.with_traj: | |
x, traj_x = feats | |
traj_preds = self.traj_head.predict( | |
traj_x, data_samples, test_cfg=self.test_cfg) | |
else: | |
x = feats | |
if self.with_head: | |
pose_preds = self.head.predict( | |
x, data_samples, test_cfg=self.test_cfg) | |
if isinstance(pose_preds, tuple): | |
batch_pred_instances, batch_pred_fields = pose_preds | |
else: | |
batch_pred_instances = pose_preds | |
if isinstance(traj_preds, tuple): | |
batch_traj_instances, batch_traj_fields = traj_preds | |
else: | |
batch_traj_instances = traj_preds | |
results = self.add_pred_to_datasample(batch_pred_instances, | |
batch_pred_fields, | |
batch_traj_instances, | |
batch_traj_fields, data_samples) | |
return results | |
def add_pred_to_datasample( | |
self, | |
batch_pred_instances: InstanceList, | |
batch_pred_fields: Optional[PixelDataList], | |
batch_traj_instances: InstanceList, | |
batch_traj_fields: Optional[PixelDataList], | |
batch_data_samples: SampleList, | |
) -> SampleList: | |
"""Add predictions into data samples. | |
Args: | |
batch_pred_instances (List[InstanceData]): The predicted instances | |
of the input data batch | |
batch_pred_fields (List[PixelData], optional): The predicted | |
fields (e.g. heatmaps) of the input batch | |
batch_traj_instances (List[InstanceData]): The predicted instances | |
of the input data batch | |
batch_traj_fields (List[PixelData], optional): The predicted | |
fields (e.g. heatmaps) of the input batch | |
batch_data_samples (List[PoseDataSample]): The input data batch | |
Returns: | |
List[PoseDataSample]: A list of data samples where the predictions | |
are stored in the ``pred_instances`` field of each data sample. | |
""" | |
assert len(batch_pred_instances) == len(batch_data_samples) | |
if batch_pred_fields is None: | |
batch_pred_fields, batch_traj_fields = [], [] | |
if batch_traj_instances is None: | |
batch_traj_instances = [] | |
output_keypoint_indices = self.test_cfg.get('output_keypoint_indices', | |
None) | |
for (pred_instances, pred_fields, traj_instances, traj_fields, | |
data_sample) in zip_longest(batch_pred_instances, | |
batch_pred_fields, | |
batch_traj_instances, | |
batch_traj_fields, | |
batch_data_samples): | |
if output_keypoint_indices is not None: | |
# select output keypoints with given indices | |
num_keypoints = pred_instances.keypoints.shape[1] | |
for key, value in pred_instances.all_items(): | |
if key.startswith('keypoint'): | |
pred_instances.set_field( | |
value[:, output_keypoint_indices], key) | |
data_sample.pred_instances = pred_instances | |
if pred_fields is not None: | |
if output_keypoint_indices is not None: | |
# select output heatmap channels with keypoint indices | |
# when the number of heatmap channel matches num_keypoints | |
for key, value in pred_fields.all_items(): | |
if value.shape[0] != num_keypoints: | |
continue | |
pred_fields.set_field(value[output_keypoint_indices], | |
key) | |
data_sample.pred_fields = pred_fields | |
return batch_data_samples | |