Miroslav Purkrabek
add code
a249588
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple, Union
import torch
from mmcv.cnn import build_conv_layer, build_upsample_layer
from mmengine.structures import PixelData
from torch import Tensor, nn
from mmpose.evaluation.functional import pose_pck_accuracy
from mmpose.models.utils.tta import flip_heatmaps
from mmpose.registry import KEYPOINT_CODECS, MODELS
from mmpose.utils.tensor_utils import to_numpy
from mmpose.utils.typing import (ConfigType, Features, OptConfigType,
OptSampleList, Predictions)
from ..base_head import BaseHead
import numpy as np
from sparsemax import Sparsemax
import os
import shutil
import cv2
from mmpose.structures.keypoint import fix_bbox_aspect_ratio
OptIntSeq = Optional[Sequence[int]]
@MODELS.register_module()
class CalibrationHead(BaseHead):
"""Multi-variate head predicting all information about keypoints. Apart
from the heatmap, it also predicts:
1) Heatmap for each keypoint
2) Probability of keypoint being in the heatmap
3) Visibility of each keypoint
4) Predicted OKS per keypoint
5) Predictd euclidean error per keypoint
The heatmap predicting part is the same as HeatmapHead introduced in
in `Simple Baselines`_ by Xiao et al (2018).
Args:
in_channels (int | Sequence[int]): Number of channels in the input
feature map
out_channels (int): Number of channels in the output heatmap
deconv_out_channels (Sequence[int], optional): The output channel
number of each deconv layer. Defaults to ``(256, 256, 256)``
deconv_kernel_sizes (Sequence[int | tuple], optional): The kernel size
of each deconv layer. Each element should be either an integer for
both height and width dimensions, or a tuple of two integers for
the height and the width dimension respectively.Defaults to
``(4, 4, 4)``
conv_out_channels (Sequence[int], optional): The output channel number
of each intermediate conv layer. ``None`` means no intermediate
conv layer between deconv layers and the final conv layer.
Defaults to ``None``
conv_kernel_sizes (Sequence[int | tuple], optional): The kernel size
of each intermediate conv layer. Defaults to ``None``
final_layer_dict (dict): Arguments of the final Conv2d layer.
Defaults to ``dict(kernel_size=1)``
keypoint_loss (Config): Config of the keypoint loss. Defaults to use
:class:`KeypointMSELoss`
probability_loss (Config): Config of the probability loss. Defaults to use
:class:`BCELoss`
visibility_loss (Config): Config of the visibility loss. Defaults to use
:class:`BCELoss`
oks_loss (Config): Config of the oks loss. Defaults to use
:class:`MSELoss`
error_loss (Config): Config of the error loss. Defaults to use
:class:`L1LogLoss`
normalize (bool): Whether to normalize values in the heatmaps between
0 and 1 with sigmoid. Defaults to ``False``
detach_probability (bool): Whether to detach the probability
from gradient computation. Defaults to ``True``
detach_visibility (bool): Whether to detach the visibility
from gradient computation. Defaults to ``True``
learn_heatmaps_from_zeros (bool): Whether to learn the
heatmaps from zeros. Defaults to ``False``
freeze_heatmaps (bool): Whether to freeze the heatmaps prediction.
Defaults to ``False``
freeze_probability (bool): Whether to freeze the probability prediction.
Defaults to ``False``
freeze_visibility (bool): Whether to freeze the visibility prediction.
Defaults to ``False``
freeze_oks (bool): Whether to freeze the oks prediction.
Defaults to ``False``
freeze_error (bool): Whether to freeze the error prediction.
Defaults to ``False``
decoder (Config, optional): The decoder config that controls decoding
keypoint coordinates from the network output. Defaults to ``None``
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
.. _`Simple Baselines`: https://arxiv.org/abs/1804.06208
"""
_version = 2
def __init__(self,
in_channels: Union[int, Sequence[int]],
out_channels: int,
deconv_out_channels: OptIntSeq = (256, 256, 256),
deconv_kernel_sizes: OptIntSeq = (4, 4, 4),
conv_out_channels: OptIntSeq = None,
conv_kernel_sizes: OptIntSeq = None,
final_layer_dict: dict = dict(kernel_size=1),
keypoint_loss: ConfigType = dict(
type='KeypointMSELoss', use_target_weight=True),
probability_loss: ConfigType = dict(
type='BCELoss', use_target_weight=True),
visibility_loss: ConfigType = dict(
type='BCELoss', use_target_weight=True),
oks_loss: ConfigType = dict(
type='MSELoss', use_target_weight=True),
error_loss: ConfigType = dict(
type='L1LogLoss', use_target_weight=True),
normalize: float = None,
detach_probability: bool = True,
detach_visibility: bool = True,
learn_heatmaps_from_zeros: bool = False,
freeze_heatmaps: bool = False,
freeze_probability: bool = False,
freeze_visibility: bool = False,
freeze_oks: bool = False,
freeze_error: bool = False,
decoder: OptConfigType = dict(
type='UDPHeatmap', input_size=(192, 256),
heatmap_size=(48, 64), sigma=2),
init_cfg: OptConfigType = None,
):
if init_cfg is None:
init_cfg = self.default_init_cfg
super().__init__(init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.keypoint_loss_module = MODELS.build(keypoint_loss)
self.probability_loss_module = MODELS.build(probability_loss)
self.visibility_loss_module = MODELS.build(visibility_loss)
self.oks_loss_module = MODELS.build(oks_loss)
self.error_loss_module = MODELS.build(error_loss)
self.temperature = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
self.gauss_sigma = 2.0
self.gauss_kernel_size = int(2.0 * 3.0 * self.gauss_sigma + 1.0)
ts = torch.linspace(
- self.gauss_kernel_size // 2,
self.gauss_kernel_size // 2,
self.gauss_kernel_size
)
gauss = torch.exp(-(ts / self.gauss_sigma)**2 / 2)
gauss = gauss / gauss.sum()
self.gauss_kernel = gauss.unsqueeze(0) * gauss.unsqueeze(1)
self.decoder = KEYPOINT_CODECS.build(decoder)
self.nonlinearity = nn.ReLU(inplace=True)
self.learn_heatmaps_from_zeros = learn_heatmaps_from_zeros
self.num_iters = 0
unique_hash = np.random.randint(0, 100000)
self.loss_vis_folder = "work_dirs/loss_vis_{:05d}".format(unique_hash)
self.interval = 50
shutil.rmtree(self.loss_vis_folder, ignore_errors=True)
print("Will save heatmap visualizations to folder '{:s}'".format(self.loss_vis_folder))
self._build_heatmap_head(
in_channels=in_channels,
out_channels=out_channels,
deconv_out_channels=deconv_out_channels,
deconv_kernel_sizes=deconv_kernel_sizes,
conv_out_channels=conv_out_channels,
conv_kernel_sizes=conv_kernel_sizes,
final_layer_dict=final_layer_dict,
normalize=normalize,
freeze=freeze_heatmaps)
self.normalize = normalize
self.detach_probability = detach_probability
self._build_probability_head(
in_channels=in_channels,
out_channels=out_channels,
freeze=freeze_probability)
self.detach_visibility = detach_visibility
self._build_visibility_head(
in_channels=in_channels,
out_channels=out_channels,
freeze=freeze_visibility)
self._build_oks_head(
in_channels=in_channels,
out_channels=out_channels,
freeze=freeze_oks)
self.freeze_oks = freeze_oks
self._build_error_head(
in_channels=in_channels,
out_channels=out_channels,
freeze=freeze_error)
self.freeze_error = freeze_error
# Register the hook to automatically convert old version state dicts
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
self._freeze_all_but_temperature()
# Print all params and their gradients
print("\n", "="*20)
for name, param in self.named_parameters():
print(name, param.requires_grad)
def _freeze_all_but_temperature(self):
for param in self.parameters():
param.requires_grad = False
self.temperature.requires_grad = True
def _build_heatmap_head(self, in_channels: int, out_channels: int,
deconv_out_channels: Sequence[int],
deconv_kernel_sizes: Sequence[int],
conv_out_channels: Sequence[int],
conv_kernel_sizes: Sequence[int],
final_layer_dict: dict,
normalize: bool = False,
freeze: bool = False) -> nn.Module:
"""Build the heatmap head module."""
if deconv_out_channels:
if deconv_kernel_sizes is None or len(deconv_out_channels) != len(
deconv_kernel_sizes):
raise ValueError(
'"deconv_out_channels" and "deconv_kernel_sizes" should '
'be integer sequences with the same length. Got '
f'mismatched lengths {deconv_out_channels} and '
f'{deconv_kernel_sizes}')
self.deconv_layers = self._make_deconv_layers(
in_channels=in_channels,
layer_out_channels=deconv_out_channels,
layer_kernel_sizes=deconv_kernel_sizes,
)
in_channels = deconv_out_channels[-1]
else:
self.deconv_layers = nn.Identity()
if conv_out_channels:
if conv_kernel_sizes is None or len(conv_out_channels) != len(
conv_kernel_sizes):
raise ValueError(
'"conv_out_channels" and "conv_kernel_sizes" should '
'be integer sequences with the same length. Got '
f'mismatched lengths {conv_out_channels} and '
f'{conv_kernel_sizes}')
self.conv_layers = self._make_conv_layers(
in_channels=in_channels,
layer_out_channels=conv_out_channels,
layer_kernel_sizes=conv_kernel_sizes)
in_channels = conv_out_channels[-1]
else:
self.conv_layers = nn.Identity()
if final_layer_dict is not None:
cfg = dict(
type='Conv2d',
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1)
cfg.update(final_layer_dict)
self.final_layer = build_conv_layer(cfg)
else:
self.final_layer = nn.Identity()
# self.normalize_layer = lambda x: x / x.sum(dim=-1, keepdim=True) if normalize else nn.Identity()
# self.normalize_layer = nn.Softmax(dim=-1) if normalize else nn.Identity()
self.normalize_layer = nn.Identity() if normalize is None else Sparsemax(dim=-1)
if freeze:
for param in self.deconv_layers.parameters():
param.requires_grad = False
for param in self.conv_layers.parameters():
param.requires_grad = False
for param in self.final_layer.parameters():
param.requires_grad = False
def _build_probability_head(self, in_channels: int, out_channels: int,
freeze: bool = False) -> nn.Module:
"""Build the probability head module."""
ppb_layers = []
kernel_sizes = [(4, 3), (2, 2), (2, 2)]
for i in range(len(kernel_sizes)):
ppb_layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=1,
padding=1))
ppb_layers.append(
nn.BatchNorm2d(num_features=in_channels))
ppb_layers.append(
nn.MaxPool2d(kernel_size=kernel_sizes[i], stride=kernel_sizes[i], padding=0))
ppb_layers.append(self.nonlinearity)
ppb_layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0))
ppb_layers.append(nn.Sigmoid())
self.probability_layers = nn.Sequential(*ppb_layers)
if freeze:
for param in self.probability_layers.parameters():
param.requires_grad = False
def _build_visibility_head(self, in_channels: int, out_channels: int,
freeze: bool = False) -> nn.Module:
"""Build the visibility head module."""
vis_layers = []
kernel_sizes = [(4, 3), (2, 2), (2, 2)]
for i in range(len(kernel_sizes)):
vis_layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=1,
padding=1))
vis_layers.append(
nn.BatchNorm2d(num_features=in_channels))
vis_layers.append(
nn.MaxPool2d(kernel_size=kernel_sizes[i], stride=kernel_sizes[i], padding=0))
vis_layers.append(self.nonlinearity)
vis_layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0))
vis_layers.append(nn.Sigmoid())
self.visibility_layers = nn.Sequential(*vis_layers)
if freeze:
for param in self.visibility_layers.parameters():
param.requires_grad = False
def _build_oks_head(self, in_channels: int, out_channels: int,
freeze: bool = False) -> nn.Module:
"""Build the oks head module."""
oks_layers = []
kernel_sizes = [(4, 3), (2, 2), (2, 2)]
for i in range(len(kernel_sizes)):
oks_layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=1,
padding=1))
oks_layers.append(
nn.BatchNorm2d(num_features=in_channels))
oks_layers.append(
nn.MaxPool2d(kernel_size=kernel_sizes[i], stride=kernel_sizes[i], padding=0))
oks_layers.append(self.nonlinearity)
oks_layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0))
oks_layers.append(nn.Sigmoid())
self.oks_layers = nn.Sequential(*oks_layers)
if freeze:
for param in self.oks_layers.parameters():
param.requires_grad = False
def _build_error_head(self, in_channels: int, out_channels: int,
freeze: bool = False) -> nn.Module:
"""Build the error head module."""
error_layers = []
kernel_sizes = [(4, 3), (2, 2), (2, 2)]
for i in range(len(kernel_sizes)):
error_layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=1,
padding=1))
error_layers.append(
nn.BatchNorm2d(num_features=in_channels))
error_layers.append(
nn.MaxPool2d(kernel_size=kernel_sizes[i], stride=kernel_sizes[i], padding=0))
error_layers.append(self.nonlinearity)
error_layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0))
error_layers.append(self.nonlinearity)
self.error_layers = nn.Sequential(*error_layers)
if freeze:
for param in self.error_layers.parameters():
param.requires_grad = False
def _make_conv_layers(self, in_channels: int,
layer_out_channels: Sequence[int],
layer_kernel_sizes: Sequence[int]) -> nn.Module:
"""Create convolutional layers by given parameters."""
layers = []
for out_channels, kernel_size in zip(layer_out_channels,
layer_kernel_sizes):
padding = (kernel_size - 1) // 2
cfg = dict(
type='Conv2d',
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding)
layers.append(build_conv_layer(cfg))
layers.append(nn.BatchNorm2d(num_features=out_channels))
layers.append(self.nonlinearity)
in_channels = out_channels
return nn.Sequential(*layers)
def _make_deconv_layers(self, in_channels: int,
layer_out_channels: Sequence[int],
layer_kernel_sizes: Sequence[int]) -> nn.Module:
"""Create deconvolutional layers by given parameters."""
layers = []
for out_channels, kernel_size in zip(layer_out_channels,
layer_kernel_sizes):
if kernel_size == 4:
padding = 1
output_padding = 0
elif kernel_size == 3:
padding = 1
output_padding = 1
elif kernel_size == 2:
padding = 0
output_padding = 0
else:
raise ValueError(f'Unsupported kernel size {kernel_size} for'
'deconvlutional layers in '
f'{self.__class__.__name__}')
cfg = dict(
type='deconv',
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=2,
padding=padding,
output_padding=output_padding,
bias=False)
layers.append(build_upsample_layer(cfg))
layers.append(nn.BatchNorm2d(num_features=out_channels))
layers.append(self.nonlinearity)
in_channels = out_channels
return nn.Sequential(*layers)
def _error_from_heatmaps(self, gt_heatmaps: Tensor, dt_heatmaps: Tensor) -> Tensor:
"""Calculate the error from heatmaps.
Args:
heatmaps (Tensor): The predicted heatmaps.
Returns:
Tensor: The predicted error.
"""
# Transform to numpy
gt_heatmaps = to_numpy(gt_heatmaps)
dt_heatmaps = to_numpy(dt_heatmaps)
# Get locations from heatmaps
B, C, H, W = gt_heatmaps.shape
gt_coords = np.zeros((B, C, 2))
dt_coords = np.zeros((B, C, 2))
for i, (gt_htm, dt_htm) in enumerate(zip(gt_heatmaps, dt_heatmaps)):
coords, score = self.decoder.decode(gt_htm)
coords = coords.squeeze()
gt_coords[i, :, :] = coords
coords, score = self.decoder.decode(dt_htm)
coords = coords.squeeze()
dt_coords[i, :, :] = coords
# NaN coordinates mean empty heatmaps -> set them to -1
# as the error will be ignored by weight
gt_coords[np.isnan(gt_coords)] = -1
# Calculate the error
target_errors = np.linalg.norm(gt_coords - dt_coords, axis=2)
assert (target_errors >= 0).all(), "Euclidean distance cannot be negative"
return target_errors
def _oks_from_heatmaps(self, gt_heatmaps: Tensor, dt_heatmaps: Tensor, weight: Tensor) -> Tensor:
"""Calculate the OKS from heatmaps.
Args:
heatmaps (Tensor): The predicted heatmaps.
Returns:
Tensor: The predicted OKS.
"""
C = dt_heatmaps.shape[1]
# Transform to numpy
gt_heatmaps = to_numpy(gt_heatmaps)
dt_heatmaps = to_numpy(dt_heatmaps)
B, C, H, W = gt_heatmaps.shape
weight = to_numpy(weight).squeeze().reshape((B, C, 1))
# Get locations from heatmaps
gt_coords = np.zeros((B, C, 2))
dt_coords = np.zeros((B, C, 2))
for i, (gt_htm, dt_htm) in enumerate(zip(gt_heatmaps, dt_heatmaps)):
coords, score = self.decoder.decode(gt_htm)
coords = coords.squeeze()
gt_coords[i, :, :] = coords
coords, score = self.decoder.decode(dt_htm)
coords = coords.squeeze()
dt_coords[i, :, :] = coords
# NaN coordinates mean empty heatmaps -> set them to 0
gt_coords[np.isnan(gt_coords)] = 0
# Add probability as visibility
gt_coords = gt_coords * weight
dt_coords = dt_coords * weight
gt_coords = np.concatenate((gt_coords, weight*2), axis=2)
dt_coords = np.concatenate((dt_coords, weight*2), axis=2)
# Calculate the oks
target_oks = []
oks_weights = []
for i in range(len(gt_coords)):
gt_kpts = gt_coords[i]
dt_kpts = dt_coords[i]
valid_gt_kpts = gt_kpts[:, 2] > 0
if not valid_gt_kpts.any():
# Changed for per-keypoint OKS
target_oks.append(np.zeros(C))
oks_weights.append(0)
continue
gt_bbox = np.array([
0, 0,
64, 48,
])
gt = {
'keypoints': gt_kpts,
'bbox': gt_bbox,
'area': gt_bbox[2] * gt_bbox[3],
}
dt = {
'keypoints': dt_kpts,
'bbox': gt_bbox,
'area': gt_bbox[2] * gt_bbox[3],
}
# Changed for per-keypoint OKS
oks = compute_oks(gt, dt, use_area=False, per_kpt=True)
target_oks.append(oks)
oks_weights.append(1)
target_oks = np.array(target_oks)
target_oks = torch.from_numpy(target_oks).float()
oks_weights = np.array(oks_weights)
oks_weights = torch.from_numpy(oks_weights).float()
return target_oks, oks_weights
@property
def default_init_cfg(self):
init_cfg = [
dict(
type='Normal', layer=['Conv2d', 'ConvTranspose2d'], std=0.001),
dict(type='Constant', layer='BatchNorm2d', val=1)
]
return init_cfg
def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Forward the network. The input is multi scale feature maps and the
output is (1) the heatmap, (2) probability, (3) visibility, (4) oks and (5) error.
Args:
feats (Tensor): Multi scale feature maps.
Returns:
Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: outputs.
"""
x = feats[-1].detach()
heatmaps = self.forward_heatmap(x)
probabilities = self.forward_probability(x)
visibilities = self.forward_visibility(x)
oks = self.forward_oks(x)
errors = self.forward_error(x)
return heatmaps, probabilities, visibilities, oks, errors
def forward_heatmap(self, x: Tensor) -> Tensor:
"""Forward the network. The input is multi scale feature maps and the
output is the heatmap.
Args:
x (Tensor): Multi scale feature maps.
Returns:
Tensor: output heatmap.
"""
x = self.deconv_layers(x)
x = self.conv_layers(x)
x = self.final_layer(x)
B, C, H, W = x.shape
x = x.reshape((B, C, H*W))
x = self.normalize_layer(x/self.temperature)
if self.normalize is not None:
x = x * self.normalize
x = torch.clamp(x, 0, 1)
x = x.reshape((B, C, H, W))
# # Blur the heatmaps with Gaussian
# x = x.reshape((B*C, 1, H, W))
# x = nn.functional.conv2d(x, self.gauss_kernel[None, None, :, :].to(x.device), padding='same')
# x = x.reshape((B, C, H, W))
return x
def forward_probability(self, x: Tensor) -> Tensor:
"""Forward the network. The input is multi scale feature maps and the
output is the probability.
Args:
x (Tensor): Multi scale feature maps.
detach (bool): Whether to detach the probability from gradient
Returns:
Tensor: output probability.
"""
if self.detach_probability:
x = x.detach()
x = self.probability_layers(x)
return x
def forward_visibility(self, x: Tensor) -> Tensor:
"""Forward the network. The input is multi scale feature maps and the
output is the visibility.
Args:
x (Tensor): Multi scale feature maps.
detach (bool): Whether to detach the visibility from gradient
Returns:
Tensor: output visibility.
"""
if self.detach_visibility:
x = x.detach()
x = self.visibility_layers(x)
return x
def forward_oks(self, x: Tensor) -> Tensor:
"""Forward the network. The input is multi scale feature maps and the
output is the oks.
Args:
x (Tensor): Multi scale feature maps.
Returns:
Tensor: output oks.
"""
x = x.detach()
x = self.oks_layers(x)
return x
def forward_error(self, x: Tensor) -> Tensor:
"""Forward the network. The input is multi scale feature maps and the
output is the euclidean error.
Args:
x (Tensor): Multi scale feature maps.
Returns:
Tensor: output error.
"""
x = x.detach()
x = self.error_layers(x)
return x
def predict(self,
feats: Features,
batch_data_samples: OptSampleList,
test_cfg: ConfigType = {}) -> Predictions:
"""Predict results from features.
Args:
feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
features (or multiple multi-stage features in TTA)
batch_data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
test_cfg (dict): The runtime config for testing process. Defaults
to {}
Returns:
Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
``test_cfg['output_heatmap']==True``, return both pose and heatmap
prediction; otherwise only return the pose prediction.
The pose prediction is a list of ``InstanceData``, each contains
the following fields:
- keypoints (np.ndarray): predicted keypoint coordinates in
shape (num_instances, K, D) where K is the keypoint number
and D is the keypoint dimension
- keypoint_scores (np.ndarray): predicted keypoint scores in
shape (num_instances, K)
The heatmap prediction is a list of ``PixelData``, each contains
the following fields:
- heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
"""
if test_cfg.get('flip_test', False):
# TTA: flip test -> feats = [orig, flipped]
assert isinstance(feats, list) and len(feats) == 2
flip_indices = batch_data_samples[0].metainfo['flip_indices']
_feats, _feats_flip = feats
_htm, _prob, _vis, _oks, _err = self.forward(_feats)
_htm_flip, _prob_flip, _vis_flip, _oks_flip, _err_flip = self.forward(_feats_flip)
B, C, H, W = _htm.shape
# Flip back the keypoints
_htm_flip = flip_heatmaps(
_htm_flip,
flip_mode=test_cfg.get('flip_mode', 'heatmap'),
flip_indices=flip_indices,
shift_heatmap=test_cfg.get('shift_heatmap', False))
heatmaps = (_htm + _htm_flip) * 0.5
# Flip back scalars
_prob_flip = _prob_flip[:, flip_indices]
_vis_flip = _vis_flip[:, flip_indices]
_oks_flip = _oks_flip[:, flip_indices]
_err_flip = _err_flip[:, flip_indices]
probabilities = (_prob + _prob_flip) * 0.5
visibilities = (_vis + _vis_flip) * 0.5
oks = (_oks + _oks_flip) * 0.5
errors = (_err + _err_flip) * 0.5
else:
heatmaps, probabilities, visibilities, oks, errors = self.forward(feats)
B, C, H, W = heatmaps.shape
preds = self.decode(heatmaps)
probabilities = to_numpy(probabilities).reshape((B, 1, C))
visibilities = to_numpy(visibilities).reshape((B, 1, C))
oks = to_numpy(oks).reshape((B, 1, C))
errors = to_numpy(errors).reshape((B, 1, C))
# Normalize errors by dividing with the diagonal of the heatmap
htm_diagonal = np.sqrt(H**2 + W**2)
errors = errors / htm_diagonal
for pi, p in enumerate(preds):
p.set_field(p['keypoint_scores'], "keypoints_conf")
p.set_field(probabilities[pi], "keypoints_probs")
p.set_field(visibilities[pi], "keypoints_visible")
p.set_field(oks[pi], "keypoints_oks")
p.set_field(errors[pi], "keypoints_error")
# Replace the keypoint scores with OKS/errors
if not self.freeze_oks:
p.set_field(oks[pi], "keypoint_scores")
# p.set_field(1-errors[pi], "keypoint_scores")
# hm = heatmaps.detach().cpu().numpy()
# print("Heatmaps:", hm.shape, hm.min(), hm.max())
if test_cfg.get('output_heatmaps', False):
pred_fields = [
PixelData(heatmaps=hm) for hm in heatmaps.detach()
]
return preds, pred_fields
else:
return preds
def loss(self,
feats: Tuple[Tensor],
batch_data_samples: OptSampleList,
train_cfg: ConfigType = {}) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
feats (Tuple[Tensor]): The multi-stage features
batch_data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
train_cfg (dict): The runtime config for training process.
Defaults to {}
Returns:
dict: A dictionary of losses.
"""
dt_heatmaps, dt_probs, dt_vis, dt_oks, dt_errs = self.forward(feats)
device=dt_heatmaps.device
B, C, H, W = dt_heatmaps.shape
# Extract GT data
gt_heatmaps = torch.stack(
[d.gt_fields.heatmaps for d in batch_data_samples])
gt_probs = np.stack(
[d.gt_instances.in_image.astype(int) for d in batch_data_samples])
gt_annotated = np.stack(
[d.gt_instances.keypoints_visible.astype(int) for d in batch_data_samples])
gt_vis = np.stack(
[d.gt_instances.keypoints_visibility.astype(int) for d in batch_data_samples])
keypoint_weights = torch.cat([
d.gt_instance_labels.keypoint_weights for d in batch_data_samples
])
# Compute GT errors and OKS
if self.freeze_error:
gt_errs = torch.zeros((B, C, 1), device=device, dtype=dt_errs.dtype)
else:
gt_errs = self._error_from_heatmaps(gt_heatmaps, dt_heatmaps)
if self.freeze_oks:
gt_oks = torch.zeros((B, C, 1), device=device, dtype=dt_oks.dtype)
oks_weight = torch.zeros((B, C, 1), device=device, dtype=dt_oks.dtype)
else:
gt_oks, oks_weight = self._oks_from_heatmaps(
gt_heatmaps,
dt_heatmaps,
gt_probs & gt_annotated,
)
# Convert everything to tensors
gt_probs = torch.tensor(gt_probs, device=device, dtype=dt_probs.dtype)
gt_vis = torch.tensor(gt_vis, device=device, dtype=dt_vis.dtype)
gt_annotated = torch.tensor(gt_annotated, device=device)
gt_oks = gt_oks.to(device).to(dt_oks.dtype)
oks_weight = oks_weight.to(device).to(dt_oks.dtype)
gt_errs = gt_errs.to(device).to(dt_errs.dtype)
# Reshape everything to comparable shapes
gt_heatmaps = gt_heatmaps.view((B, C, H, W))
dt_heatmaps = dt_heatmaps.view((B, C, H, W))
gt_probs = gt_probs.view((B, C))
dt_probs = dt_probs.view((B, C))
gt_vis = gt_vis.view((B, C))
dt_vis = dt_vis.view((B, C))
gt_oks = gt_oks.view((B, C))
dt_oks = dt_oks.view((B, C))
gt_errs = gt_errs.view((B, C))
dt_errs = dt_errs.view((B, C))
keypoint_weights = keypoint_weights.view((B, C))
gt_annotated = gt_annotated.view((B, C))
# oks_weight = oks_weight.view((B, C))
annotated_in = gt_annotated & (gt_probs > 0.5)
# calculate losses
losses = dict()
if self.learn_heatmaps_from_zeros:
heatmap_weights = gt_annotated
else:
# heatmap_weights = keypoint_weights
heatmap_weights = annotated_in
heatmap_loss_pxl = self.keypoint_loss_module(dt_heatmaps, gt_heatmaps, annotated_in, per_pixel=True)
heatmap_loss = self.keypoint_loss_module(dt_heatmaps, gt_heatmaps, annotated_in)
# probability_loss = self.probability_loss_module(dt_probs, gt_probs, gt_annotated)
# visibility_loss = self.visibility_loss_module(dt_vis, gt_vis, annotated_in)
# oks_loss = self.oks_loss_module(dt_oks, gt_oks, annotated_in)
# error_loss = self.error_loss_module(dt_errs, gt_errs, annotated_in)
# Visualize some heatmaps
for i in range(0, B):
# continue
if self.num_iters % self.interval == 0:
self.interval = int(self.interval * 1.3)
os.makedirs(self.loss_vis_folder, exist_ok=True)
for kpt_i in np.random.choice(C, 17, replace=False):
tgt = gt_heatmaps[i, kpt_i].detach().cpu().numpy()
htm = dt_heatmaps[i, kpt_i].detach().cpu().numpy()
lss = heatmap_loss_pxl[i, kpt_i].detach().cpu().numpy()
save_img = self._visualize_heatmaps(
htm, tgt, lss, keypoint_weights[i, kpt_i], gt_probs[i, kpt_i]
)
save_path = os.path.join(
self.loss_vis_folder,
"heatmap_{:07d}-{:d}-{:d}.png".format(self.num_iters, i, kpt_i)
)
cv2.imwrite(save_path, save_img)
self.num_iters += 1
losses.update(
loss_kpt=heatmap_loss
)
# calculate accuracy
if train_cfg.get('compute_acc', True):
acc_pose = self.get_pose_accuracy(
dt_heatmaps, gt_heatmaps, keypoint_weights > 0.5
)
losses.update(acc_pose=acc_pose)
# Calculate the best binary accuracy for probability
acc_prob, _ = self.get_binary_accuracy(
dt_probs,
gt_probs,
gt_annotated > 0.5,
force_balanced=True,
)
losses.update(acc_prob=acc_prob)
# Calculate the best binary accuracy for visibility
acc_vis, _ = self.get_binary_accuracy(
dt_vis,
gt_vis,
annotated_in > 0.5,
force_balanced=True,
)
losses.update(acc_vis=acc_vis)
# Calculate the MAE for OKS
acc_oks = self.get_mae(
dt_oks,
gt_oks,
annotated_in > 0.5,
)
losses.update(mae_oks=acc_oks)
# Calculate the MAE for euclidean error
acc_err = self.get_mae(
dt_errs,
gt_errs,
annotated_in > 0.5,
)
losses.update(mae_err=acc_err)
# Calculate the MAE between Euclidean error and OKS
err_to_oks_mae = self.get_mae(
self.error_to_OKS(dt_errs, area=H*W),
gt_oks,
annotated_in > 0.5,
)
losses.update(mae_err_to_oks=err_to_oks_mae)
print(self.temperature.item())
return losses
def _visualize_heatmaps(
self,
htm,
tgt,
lss,
weight,
prob
):
tgt_range = (tgt.min(), tgt.max())
htm_range = (htm.min(), htm.max())
lss_range = (lss.min(), lss.max())
tgt[tgt < 0] = 0
htm[htm < 0] = 0
lss[lss < 0] = 0
# Normalize heatmaps between 0 and 1
tgt /= (tgt.max()+1e-10)
htm /= (htm.max()+1e-10)
lss /= (lss.max()+1e-10)
scale = 6
htm_color = cv2.cvtColor((htm*255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
htm_color = cv2.applyColorMap(htm_color, cv2.COLORMAP_JET)
htm_color = cv2.resize(htm_color, (htm.shape[1]*scale, htm.shape[0]*scale), interpolation=cv2.INTER_NEAREST)
tgt_color = cv2.cvtColor((tgt*255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
tgt_color = cv2.applyColorMap(tgt_color, cv2.COLORMAP_JET)
tgt_color = cv2.resize(tgt_color, (htm.shape[1]*scale, htm.shape[0]*scale), interpolation=cv2.INTER_NEAREST)
lss_color = cv2.cvtColor((lss*255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
lss_color = cv2.applyColorMap(lss_color, cv2.COLORMAP_JET)
lss_color = cv2.resize(lss_color, (htm.shape[1]*scale, htm.shape[0]*scale), interpolation=cv2.INTER_NEAREST)
if scale > 2:
tgt_color_text = tgt_color.copy()
cv2.putText(tgt_color_text, "tgt ({:.1f}, {:.1f})".format(tgt_range[0]*10, tgt_range[1]*10), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
tgt_color = cv2.addWeighted(tgt_color, 0.6, tgt_color_text, 0.4, 0)
htm_color_text = htm_color.copy()
cv2.putText(htm_color_text, "htm ({:.1f}, {:.1f})".format(htm_range[0]*10, htm_range[1]*10), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
htm_color = cv2.addWeighted(htm_color, 0.6, htm_color_text, 0.4, 0)
lss_color_text = lss_color.copy()
cv2.putText(lss_color_text, "lss ({:.1f}, {:.1f})".format(lss_range[0]*10, lss_range[1]*10), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
lss_color = cv2.addWeighted(lss_color, 0.6, lss_color_text, 0.4, 0)
# Get argmax of the target and draw horizontal and vertical lines
tgt_argmax = np.unravel_index(tgt.argmax(), tgt.shape)
tgt_color_line = tgt_color.copy()
cv2.line(tgt_color_line, (0, tgt_argmax[0]*scale), (tgt_color.shape[1], tgt_argmax[0]*scale), (0, 255, 255), 1)
cv2.line(tgt_color_line, (tgt_argmax[1]*scale, 0), (tgt_argmax[1]*scale, tgt_color.shape[0]), (0, 255, 255), 1)
tgt_color = cv2.addWeighted(tgt_color, 0.6, tgt_color_line, 0.4, 0)
htm_color_line = htm_color.copy()
cv2.line(htm_color_line, (0, tgt_argmax[0]*scale), (tgt_color.shape[1], tgt_argmax[0]*scale), (0, 255, 255), 1)
cv2.line(htm_color_line, (tgt_argmax[1]*scale, 0), (tgt_argmax[1]*scale, tgt_color.shape[0]), (0, 255, 255), 1)
htm_color = cv2.addWeighted(htm_color, 0.6, htm_color_line, 0.4, 0)
lss_color_line = lss_color.copy()
cv2.line(lss_color_line, (0, tgt_argmax[0]*scale), (tgt_color.shape[1], tgt_argmax[0]*scale), (0, 255, 255), 1)
cv2.line(lss_color_line, (tgt_argmax[1]*scale, 0), (tgt_argmax[1]*scale, tgt_color.shape[0]), (0, 255, 255), 1)
lss_color = cv2.addWeighted(lss_color, 0.6, lss_color_line, 0.4, 0)
white_column = np.ones((tgt_color.shape[0], 1, 3), dtype=np.uint8) * 255
save_img = np.concatenate((
tgt_color,
white_column,
htm_color,
white_column,
lss_color,
), axis=1)
if weight < 0.5:
# Draw a red X across the whole save_img
cv2.line(save_img, (0, 0), (save_img.shape[1], save_img.shape[0]), (0, 0, 255), 2)
cv2.line(save_img, (0, save_img.shape[0]), (save_img.shape[1], 0), (0, 0, 255), 2)
elif prob < 0.5:
# Draw an yellow X across the whole save_img
cv2.line(save_img, (0, 0), (save_img.shape[1], save_img.shape[0]), (0, 255, 255), 2)
cv2.line(save_img, (0, save_img.shape[0]), (save_img.shape[1], 0), (0, 255, 255), 2)
return save_img
def get_pose_accuracy(self, dt, gt, mask):
"""Calculate the accuracy of predicted pose."""
_, avg_acc, _ = pose_pck_accuracy(
output=to_numpy(dt),
target=to_numpy(gt),
mask=to_numpy(mask),
method='argmax',
)
acc_pose = torch.tensor(avg_acc, device=gt.device)
return acc_pose
def get_binary_accuracy(self, dt, gt, mask, force_balanced=False):
"""Calculate the binary accuracy."""
assert dt.shape == gt.shape
device = gt.device
dt = to_numpy(dt)
gt = to_numpy(gt)
mask = to_numpy(mask)
dt = dt[mask]
gt = gt[mask]
gt = gt.astype(bool)
if force_balanced:
# Force the number of positive and negative samples to be balanced
pos_num = np.sum(gt)
neg_num = len(gt) - pos_num
num = min(pos_num, neg_num)
if num == 0:
return torch.tensor([0.0], device=device), torch.tensor([0.0], device=device)
pos_idx = np.where(gt)[0]
neg_idx = np.where(~gt)[0]
# Randomly sample the same number of positive and negative samples
np.random.shuffle(pos_idx)
np.random.shuffle(neg_idx)
idx = np.concatenate([pos_idx[:num], neg_idx[:num]])
dt = dt[idx]
gt = gt[idx]
n_samples = len(gt)
thresholds = np.arange(0.1, 1.0, 0.05)
preds = (dt[:, None] > thresholds)
correct = preds == gt[:, None]
counts = correct.sum(axis=0)
# Find the threshold that maximizes the accuracy
best_idx = np.argmax(counts)
best_threshold = thresholds[best_idx]
best_acc = counts[best_idx] / n_samples
best_acc = torch.tensor(best_acc, device=device).float()
best_threshold = torch.tensor(best_threshold, device=device).float()
return best_acc, best_threshold
def get_mae(self, dt, gt, mask):
"""Calculate the mean absolute error."""
assert dt.shape == gt.shape
device = gt.device
dt = to_numpy(dt)
gt = to_numpy(gt)
mask = to_numpy(mask)
dt = dt[mask]
gt = gt[mask]
mae = np.abs(dt - gt).mean()
mae = torch.tensor(mae, device=device)
return mae
def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args,
**kwargs):
"""A hook function to convert old-version state dict of
:class:`TopdownHeatmapSimpleHead` (before MMPose v1.0.0) to a
compatible format of :class:`HeatmapHead`.
The hook will be automatically registered during initialization.
"""
version = local_meta.get('version', None)
if version and version >= self._version:
return
# convert old-version state dict
keys = list(state_dict.keys())
for _k in keys:
if not _k.startswith(prefix):
continue
v = state_dict.pop(_k)
k = _k[len(prefix):]
# In old version, "final_layer" includes both intermediate
# conv layers (new "conv_layers") and final conv layers (new
# "final_layer").
#
# If there is no intermediate conv layer, old "final_layer" will
# have keys like "final_layer.xxx", which should be still
# named "final_layer.xxx";
#
# If there are intermediate conv layers, old "final_layer" will
# have keys like "final_layer.n.xxx", where the weights of the last
# one should be renamed "final_layer.xxx", and others should be
# renamed "conv_layers.n.xxx"
k_parts = k.split('.')
if k_parts[0] == 'final_layer':
if len(k_parts) == 3:
assert isinstance(self.conv_layers, nn.Sequential)
idx = int(k_parts[1])
if idx < len(self.conv_layers):
# final_layer.n.xxx -> conv_layers.n.xxx
k_new = 'conv_layers.' + '.'.join(k_parts[1:])
else:
# final_layer.n.xxx -> final_layer.xxx
k_new = 'final_layer.' + k_parts[2]
else:
# final_layer.xxx remains final_layer.xxx
k_new = k
else:
k_new = k
state_dict[prefix + k_new] = v
def error_to_OKS(self, error, area=1.0):
"""Convert the error to OKS."""
sigmas = np.array(
[.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89])/10.0
if isinstance(error, torch.Tensor):
sigmas = torch.tensor(sigmas, device=error.device)
vars = (sigmas * 2)**2
norm_error = error**2 / vars / area / 2.0
return torch.exp(-norm_error)
def compute_oks(gt, dt, use_area=True, per_kpt=False):
sigmas = np.array(
[.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89])/10.0
vars = (sigmas * 2)**2
k = len(sigmas)
visibility_condition = lambda x: x > 0
g = np.array(gt['keypoints']).reshape(k, 3)
xg = g[:, 0]; yg = g[:, 1]; vg = g[:, 2]
k1 = np.count_nonzero(visibility_condition(vg))
bb = gt['bbox']
x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2
y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2
d = np.array(dt['keypoints']).reshape((k, 3))
xd = d[:, 0]; yd = d[:, 1]
if k1>0:
# measure the per-keypoint distance if keypoints visible
dx = xd - xg
dy = yd - yg
else:
# measure minimum distance to keypoints in (x0,y0) & (x1,y1)
z = np.zeros((k))
dx = np.max((z, x0-xd),axis=0)+np.max((z, xd-x1),axis=0)
dy = np.max((z, y0-yd),axis=0)+np.max((z, yd-y1),axis=0)
if use_area:
e = (dx**2 + dy**2) / vars / (gt['area']+np.spacing(1)) / 2
else:
tmparea = gt['bbox'][3] * gt['bbox'][2] * 0.53
e = (dx**2 + dy**2) / vars / (tmparea+np.spacing(1)) / 2
if per_kpt:
oks = np.exp(-e)
if k1 > 0:
oks[~visibility_condition(vg)] = 0
else:
if k1 > 0:
e=e[visibility_condition(vg)]
oks = np.sum(np.exp(-e)) / e.shape[0]
return oks