Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional, Sequence, Tuple, Union | |
import torch | |
from mmcv.cnn import build_conv_layer, build_upsample_layer | |
from mmengine.structures import PixelData | |
from torch import Tensor, nn | |
from mmpose.evaluation.functional import pose_pck_accuracy | |
from mmpose.models.utils.tta import flip_heatmaps | |
from mmpose.registry import KEYPOINT_CODECS, MODELS | |
from mmpose.utils.tensor_utils import to_numpy | |
from mmpose.utils.typing import (ConfigType, Features, OptConfigType, | |
OptSampleList, Predictions) | |
from ..base_head import BaseHead | |
import numpy as np | |
from sparsemax import Sparsemax | |
import os | |
import shutil | |
import cv2 | |
from mmpose.structures.keypoint import fix_bbox_aspect_ratio | |
OptIntSeq = Optional[Sequence[int]] | |
class CalibrationHead(BaseHead): | |
"""Multi-variate head predicting all information about keypoints. Apart | |
from the heatmap, it also predicts: | |
1) Heatmap for each keypoint | |
2) Probability of keypoint being in the heatmap | |
3) Visibility of each keypoint | |
4) Predicted OKS per keypoint | |
5) Predictd euclidean error per keypoint | |
The heatmap predicting part is the same as HeatmapHead introduced in | |
in `Simple Baselines`_ by Xiao et al (2018). | |
Args: | |
in_channels (int | Sequence[int]): Number of channels in the input | |
feature map | |
out_channels (int): Number of channels in the output heatmap | |
deconv_out_channels (Sequence[int], optional): The output channel | |
number of each deconv layer. Defaults to ``(256, 256, 256)`` | |
deconv_kernel_sizes (Sequence[int | tuple], optional): The kernel size | |
of each deconv layer. Each element should be either an integer for | |
both height and width dimensions, or a tuple of two integers for | |
the height and the width dimension respectively.Defaults to | |
``(4, 4, 4)`` | |
conv_out_channels (Sequence[int], optional): The output channel number | |
of each intermediate conv layer. ``None`` means no intermediate | |
conv layer between deconv layers and the final conv layer. | |
Defaults to ``None`` | |
conv_kernel_sizes (Sequence[int | tuple], optional): The kernel size | |
of each intermediate conv layer. Defaults to ``None`` | |
final_layer_dict (dict): Arguments of the final Conv2d layer. | |
Defaults to ``dict(kernel_size=1)`` | |
keypoint_loss (Config): Config of the keypoint loss. Defaults to use | |
:class:`KeypointMSELoss` | |
probability_loss (Config): Config of the probability loss. Defaults to use | |
:class:`BCELoss` | |
visibility_loss (Config): Config of the visibility loss. Defaults to use | |
:class:`BCELoss` | |
oks_loss (Config): Config of the oks loss. Defaults to use | |
:class:`MSELoss` | |
error_loss (Config): Config of the error loss. Defaults to use | |
:class:`L1LogLoss` | |
normalize (bool): Whether to normalize values in the heatmaps between | |
0 and 1 with sigmoid. Defaults to ``False`` | |
detach_probability (bool): Whether to detach the probability | |
from gradient computation. Defaults to ``True`` | |
detach_visibility (bool): Whether to detach the visibility | |
from gradient computation. Defaults to ``True`` | |
learn_heatmaps_from_zeros (bool): Whether to learn the | |
heatmaps from zeros. Defaults to ``False`` | |
freeze_heatmaps (bool): Whether to freeze the heatmaps prediction. | |
Defaults to ``False`` | |
freeze_probability (bool): Whether to freeze the probability prediction. | |
Defaults to ``False`` | |
freeze_visibility (bool): Whether to freeze the visibility prediction. | |
Defaults to ``False`` | |
freeze_oks (bool): Whether to freeze the oks prediction. | |
Defaults to ``False`` | |
freeze_error (bool): Whether to freeze the error prediction. | |
Defaults to ``False`` | |
decoder (Config, optional): The decoder config that controls decoding | |
keypoint coordinates from the network output. Defaults to ``None`` | |
init_cfg (Config, optional): Config to control the initialization. See | |
:attr:`default_init_cfg` for default settings | |
.. _`Simple Baselines`: https://arxiv.org/abs/1804.06208 | |
""" | |
_version = 2 | |
def __init__(self, | |
in_channels: Union[int, Sequence[int]], | |
out_channels: int, | |
deconv_out_channels: OptIntSeq = (256, 256, 256), | |
deconv_kernel_sizes: OptIntSeq = (4, 4, 4), | |
conv_out_channels: OptIntSeq = None, | |
conv_kernel_sizes: OptIntSeq = None, | |
final_layer_dict: dict = dict(kernel_size=1), | |
keypoint_loss: ConfigType = dict( | |
type='KeypointMSELoss', use_target_weight=True), | |
probability_loss: ConfigType = dict( | |
type='BCELoss', use_target_weight=True), | |
visibility_loss: ConfigType = dict( | |
type='BCELoss', use_target_weight=True), | |
oks_loss: ConfigType = dict( | |
type='MSELoss', use_target_weight=True), | |
error_loss: ConfigType = dict( | |
type='L1LogLoss', use_target_weight=True), | |
normalize: float = None, | |
detach_probability: bool = True, | |
detach_visibility: bool = True, | |
learn_heatmaps_from_zeros: bool = False, | |
freeze_heatmaps: bool = False, | |
freeze_probability: bool = False, | |
freeze_visibility: bool = False, | |
freeze_oks: bool = False, | |
freeze_error: bool = False, | |
decoder: OptConfigType = dict( | |
type='UDPHeatmap', input_size=(192, 256), | |
heatmap_size=(48, 64), sigma=2), | |
init_cfg: OptConfigType = None, | |
): | |
if init_cfg is None: | |
init_cfg = self.default_init_cfg | |
super().__init__(init_cfg) | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.keypoint_loss_module = MODELS.build(keypoint_loss) | |
self.probability_loss_module = MODELS.build(probability_loss) | |
self.visibility_loss_module = MODELS.build(visibility_loss) | |
self.oks_loss_module = MODELS.build(oks_loss) | |
self.error_loss_module = MODELS.build(error_loss) | |
self.temperature = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True) | |
self.gauss_sigma = 2.0 | |
self.gauss_kernel_size = int(2.0 * 3.0 * self.gauss_sigma + 1.0) | |
ts = torch.linspace( | |
- self.gauss_kernel_size // 2, | |
self.gauss_kernel_size // 2, | |
self.gauss_kernel_size | |
) | |
gauss = torch.exp(-(ts / self.gauss_sigma)**2 / 2) | |
gauss = gauss / gauss.sum() | |
self.gauss_kernel = gauss.unsqueeze(0) * gauss.unsqueeze(1) | |
self.decoder = KEYPOINT_CODECS.build(decoder) | |
self.nonlinearity = nn.ReLU(inplace=True) | |
self.learn_heatmaps_from_zeros = learn_heatmaps_from_zeros | |
self.num_iters = 0 | |
unique_hash = np.random.randint(0, 100000) | |
self.loss_vis_folder = "work_dirs/loss_vis_{:05d}".format(unique_hash) | |
self.interval = 50 | |
shutil.rmtree(self.loss_vis_folder, ignore_errors=True) | |
print("Will save heatmap visualizations to folder '{:s}'".format(self.loss_vis_folder)) | |
self._build_heatmap_head( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
deconv_out_channels=deconv_out_channels, | |
deconv_kernel_sizes=deconv_kernel_sizes, | |
conv_out_channels=conv_out_channels, | |
conv_kernel_sizes=conv_kernel_sizes, | |
final_layer_dict=final_layer_dict, | |
normalize=normalize, | |
freeze=freeze_heatmaps) | |
self.normalize = normalize | |
self.detach_probability = detach_probability | |
self._build_probability_head( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
freeze=freeze_probability) | |
self.detach_visibility = detach_visibility | |
self._build_visibility_head( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
freeze=freeze_visibility) | |
self._build_oks_head( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
freeze=freeze_oks) | |
self.freeze_oks = freeze_oks | |
self._build_error_head( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
freeze=freeze_error) | |
self.freeze_error = freeze_error | |
# Register the hook to automatically convert old version state dicts | |
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) | |
self._freeze_all_but_temperature() | |
# Print all params and their gradients | |
print("\n", "="*20) | |
for name, param in self.named_parameters(): | |
print(name, param.requires_grad) | |
def _freeze_all_but_temperature(self): | |
for param in self.parameters(): | |
param.requires_grad = False | |
self.temperature.requires_grad = True | |
def _build_heatmap_head(self, in_channels: int, out_channels: int, | |
deconv_out_channels: Sequence[int], | |
deconv_kernel_sizes: Sequence[int], | |
conv_out_channels: Sequence[int], | |
conv_kernel_sizes: Sequence[int], | |
final_layer_dict: dict, | |
normalize: bool = False, | |
freeze: bool = False) -> nn.Module: | |
"""Build the heatmap head module.""" | |
if deconv_out_channels: | |
if deconv_kernel_sizes is None or len(deconv_out_channels) != len( | |
deconv_kernel_sizes): | |
raise ValueError( | |
'"deconv_out_channels" and "deconv_kernel_sizes" should ' | |
'be integer sequences with the same length. Got ' | |
f'mismatched lengths {deconv_out_channels} and ' | |
f'{deconv_kernel_sizes}') | |
self.deconv_layers = self._make_deconv_layers( | |
in_channels=in_channels, | |
layer_out_channels=deconv_out_channels, | |
layer_kernel_sizes=deconv_kernel_sizes, | |
) | |
in_channels = deconv_out_channels[-1] | |
else: | |
self.deconv_layers = nn.Identity() | |
if conv_out_channels: | |
if conv_kernel_sizes is None or len(conv_out_channels) != len( | |
conv_kernel_sizes): | |
raise ValueError( | |
'"conv_out_channels" and "conv_kernel_sizes" should ' | |
'be integer sequences with the same length. Got ' | |
f'mismatched lengths {conv_out_channels} and ' | |
f'{conv_kernel_sizes}') | |
self.conv_layers = self._make_conv_layers( | |
in_channels=in_channels, | |
layer_out_channels=conv_out_channels, | |
layer_kernel_sizes=conv_kernel_sizes) | |
in_channels = conv_out_channels[-1] | |
else: | |
self.conv_layers = nn.Identity() | |
if final_layer_dict is not None: | |
cfg = dict( | |
type='Conv2d', | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1) | |
cfg.update(final_layer_dict) | |
self.final_layer = build_conv_layer(cfg) | |
else: | |
self.final_layer = nn.Identity() | |
# self.normalize_layer = lambda x: x / x.sum(dim=-1, keepdim=True) if normalize else nn.Identity() | |
# self.normalize_layer = nn.Softmax(dim=-1) if normalize else nn.Identity() | |
self.normalize_layer = nn.Identity() if normalize is None else Sparsemax(dim=-1) | |
if freeze: | |
for param in self.deconv_layers.parameters(): | |
param.requires_grad = False | |
for param in self.conv_layers.parameters(): | |
param.requires_grad = False | |
for param in self.final_layer.parameters(): | |
param.requires_grad = False | |
def _build_probability_head(self, in_channels: int, out_channels: int, | |
freeze: bool = False) -> nn.Module: | |
"""Build the probability head module.""" | |
ppb_layers = [] | |
kernel_sizes = [(4, 3), (2, 2), (2, 2)] | |
for i in range(len(kernel_sizes)): | |
ppb_layers.append( | |
build_conv_layer( | |
dict(type='Conv2d'), | |
in_channels=in_channels, | |
out_channels=in_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1)) | |
ppb_layers.append( | |
nn.BatchNorm2d(num_features=in_channels)) | |
ppb_layers.append( | |
nn.MaxPool2d(kernel_size=kernel_sizes[i], stride=kernel_sizes[i], padding=0)) | |
ppb_layers.append(self.nonlinearity) | |
ppb_layers.append( | |
build_conv_layer( | |
dict(type='Conv2d'), | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0)) | |
ppb_layers.append(nn.Sigmoid()) | |
self.probability_layers = nn.Sequential(*ppb_layers) | |
if freeze: | |
for param in self.probability_layers.parameters(): | |
param.requires_grad = False | |
def _build_visibility_head(self, in_channels: int, out_channels: int, | |
freeze: bool = False) -> nn.Module: | |
"""Build the visibility head module.""" | |
vis_layers = [] | |
kernel_sizes = [(4, 3), (2, 2), (2, 2)] | |
for i in range(len(kernel_sizes)): | |
vis_layers.append( | |
build_conv_layer( | |
dict(type='Conv2d'), | |
in_channels=in_channels, | |
out_channels=in_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1)) | |
vis_layers.append( | |
nn.BatchNorm2d(num_features=in_channels)) | |
vis_layers.append( | |
nn.MaxPool2d(kernel_size=kernel_sizes[i], stride=kernel_sizes[i], padding=0)) | |
vis_layers.append(self.nonlinearity) | |
vis_layers.append( | |
build_conv_layer( | |
dict(type='Conv2d'), | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0)) | |
vis_layers.append(nn.Sigmoid()) | |
self.visibility_layers = nn.Sequential(*vis_layers) | |
if freeze: | |
for param in self.visibility_layers.parameters(): | |
param.requires_grad = False | |
def _build_oks_head(self, in_channels: int, out_channels: int, | |
freeze: bool = False) -> nn.Module: | |
"""Build the oks head module.""" | |
oks_layers = [] | |
kernel_sizes = [(4, 3), (2, 2), (2, 2)] | |
for i in range(len(kernel_sizes)): | |
oks_layers.append( | |
build_conv_layer( | |
dict(type='Conv2d'), | |
in_channels=in_channels, | |
out_channels=in_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1)) | |
oks_layers.append( | |
nn.BatchNorm2d(num_features=in_channels)) | |
oks_layers.append( | |
nn.MaxPool2d(kernel_size=kernel_sizes[i], stride=kernel_sizes[i], padding=0)) | |
oks_layers.append(self.nonlinearity) | |
oks_layers.append( | |
build_conv_layer( | |
dict(type='Conv2d'), | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0)) | |
oks_layers.append(nn.Sigmoid()) | |
self.oks_layers = nn.Sequential(*oks_layers) | |
if freeze: | |
for param in self.oks_layers.parameters(): | |
param.requires_grad = False | |
def _build_error_head(self, in_channels: int, out_channels: int, | |
freeze: bool = False) -> nn.Module: | |
"""Build the error head module.""" | |
error_layers = [] | |
kernel_sizes = [(4, 3), (2, 2), (2, 2)] | |
for i in range(len(kernel_sizes)): | |
error_layers.append( | |
build_conv_layer( | |
dict(type='Conv2d'), | |
in_channels=in_channels, | |
out_channels=in_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1)) | |
error_layers.append( | |
nn.BatchNorm2d(num_features=in_channels)) | |
error_layers.append( | |
nn.MaxPool2d(kernel_size=kernel_sizes[i], stride=kernel_sizes[i], padding=0)) | |
error_layers.append(self.nonlinearity) | |
error_layers.append( | |
build_conv_layer( | |
dict(type='Conv2d'), | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0)) | |
error_layers.append(self.nonlinearity) | |
self.error_layers = nn.Sequential(*error_layers) | |
if freeze: | |
for param in self.error_layers.parameters(): | |
param.requires_grad = False | |
def _make_conv_layers(self, in_channels: int, | |
layer_out_channels: Sequence[int], | |
layer_kernel_sizes: Sequence[int]) -> nn.Module: | |
"""Create convolutional layers by given parameters.""" | |
layers = [] | |
for out_channels, kernel_size in zip(layer_out_channels, | |
layer_kernel_sizes): | |
padding = (kernel_size - 1) // 2 | |
cfg = dict( | |
type='Conv2d', | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding) | |
layers.append(build_conv_layer(cfg)) | |
layers.append(nn.BatchNorm2d(num_features=out_channels)) | |
layers.append(self.nonlinearity) | |
in_channels = out_channels | |
return nn.Sequential(*layers) | |
def _make_deconv_layers(self, in_channels: int, | |
layer_out_channels: Sequence[int], | |
layer_kernel_sizes: Sequence[int]) -> nn.Module: | |
"""Create deconvolutional layers by given parameters.""" | |
layers = [] | |
for out_channels, kernel_size in zip(layer_out_channels, | |
layer_kernel_sizes): | |
if kernel_size == 4: | |
padding = 1 | |
output_padding = 0 | |
elif kernel_size == 3: | |
padding = 1 | |
output_padding = 1 | |
elif kernel_size == 2: | |
padding = 0 | |
output_padding = 0 | |
else: | |
raise ValueError(f'Unsupported kernel size {kernel_size} for' | |
'deconvlutional layers in ' | |
f'{self.__class__.__name__}') | |
cfg = dict( | |
type='deconv', | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=2, | |
padding=padding, | |
output_padding=output_padding, | |
bias=False) | |
layers.append(build_upsample_layer(cfg)) | |
layers.append(nn.BatchNorm2d(num_features=out_channels)) | |
layers.append(self.nonlinearity) | |
in_channels = out_channels | |
return nn.Sequential(*layers) | |
def _error_from_heatmaps(self, gt_heatmaps: Tensor, dt_heatmaps: Tensor) -> Tensor: | |
"""Calculate the error from heatmaps. | |
Args: | |
heatmaps (Tensor): The predicted heatmaps. | |
Returns: | |
Tensor: The predicted error. | |
""" | |
# Transform to numpy | |
gt_heatmaps = to_numpy(gt_heatmaps) | |
dt_heatmaps = to_numpy(dt_heatmaps) | |
# Get locations from heatmaps | |
B, C, H, W = gt_heatmaps.shape | |
gt_coords = np.zeros((B, C, 2)) | |
dt_coords = np.zeros((B, C, 2)) | |
for i, (gt_htm, dt_htm) in enumerate(zip(gt_heatmaps, dt_heatmaps)): | |
coords, score = self.decoder.decode(gt_htm) | |
coords = coords.squeeze() | |
gt_coords[i, :, :] = coords | |
coords, score = self.decoder.decode(dt_htm) | |
coords = coords.squeeze() | |
dt_coords[i, :, :] = coords | |
# NaN coordinates mean empty heatmaps -> set them to -1 | |
# as the error will be ignored by weight | |
gt_coords[np.isnan(gt_coords)] = -1 | |
# Calculate the error | |
target_errors = np.linalg.norm(gt_coords - dt_coords, axis=2) | |
assert (target_errors >= 0).all(), "Euclidean distance cannot be negative" | |
return target_errors | |
def _oks_from_heatmaps(self, gt_heatmaps: Tensor, dt_heatmaps: Tensor, weight: Tensor) -> Tensor: | |
"""Calculate the OKS from heatmaps. | |
Args: | |
heatmaps (Tensor): The predicted heatmaps. | |
Returns: | |
Tensor: The predicted OKS. | |
""" | |
C = dt_heatmaps.shape[1] | |
# Transform to numpy | |
gt_heatmaps = to_numpy(gt_heatmaps) | |
dt_heatmaps = to_numpy(dt_heatmaps) | |
B, C, H, W = gt_heatmaps.shape | |
weight = to_numpy(weight).squeeze().reshape((B, C, 1)) | |
# Get locations from heatmaps | |
gt_coords = np.zeros((B, C, 2)) | |
dt_coords = np.zeros((B, C, 2)) | |
for i, (gt_htm, dt_htm) in enumerate(zip(gt_heatmaps, dt_heatmaps)): | |
coords, score = self.decoder.decode(gt_htm) | |
coords = coords.squeeze() | |
gt_coords[i, :, :] = coords | |
coords, score = self.decoder.decode(dt_htm) | |
coords = coords.squeeze() | |
dt_coords[i, :, :] = coords | |
# NaN coordinates mean empty heatmaps -> set them to 0 | |
gt_coords[np.isnan(gt_coords)] = 0 | |
# Add probability as visibility | |
gt_coords = gt_coords * weight | |
dt_coords = dt_coords * weight | |
gt_coords = np.concatenate((gt_coords, weight*2), axis=2) | |
dt_coords = np.concatenate((dt_coords, weight*2), axis=2) | |
# Calculate the oks | |
target_oks = [] | |
oks_weights = [] | |
for i in range(len(gt_coords)): | |
gt_kpts = gt_coords[i] | |
dt_kpts = dt_coords[i] | |
valid_gt_kpts = gt_kpts[:, 2] > 0 | |
if not valid_gt_kpts.any(): | |
# Changed for per-keypoint OKS | |
target_oks.append(np.zeros(C)) | |
oks_weights.append(0) | |
continue | |
gt_bbox = np.array([ | |
0, 0, | |
64, 48, | |
]) | |
gt = { | |
'keypoints': gt_kpts, | |
'bbox': gt_bbox, | |
'area': gt_bbox[2] * gt_bbox[3], | |
} | |
dt = { | |
'keypoints': dt_kpts, | |
'bbox': gt_bbox, | |
'area': gt_bbox[2] * gt_bbox[3], | |
} | |
# Changed for per-keypoint OKS | |
oks = compute_oks(gt, dt, use_area=False, per_kpt=True) | |
target_oks.append(oks) | |
oks_weights.append(1) | |
target_oks = np.array(target_oks) | |
target_oks = torch.from_numpy(target_oks).float() | |
oks_weights = np.array(oks_weights) | |
oks_weights = torch.from_numpy(oks_weights).float() | |
return target_oks, oks_weights | |
def default_init_cfg(self): | |
init_cfg = [ | |
dict( | |
type='Normal', layer=['Conv2d', 'ConvTranspose2d'], std=0.001), | |
dict(type='Constant', layer='BatchNorm2d', val=1) | |
] | |
return init_cfg | |
def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: | |
"""Forward the network. The input is multi scale feature maps and the | |
output is (1) the heatmap, (2) probability, (3) visibility, (4) oks and (5) error. | |
Args: | |
feats (Tensor): Multi scale feature maps. | |
Returns: | |
Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: outputs. | |
""" | |
x = feats[-1].detach() | |
heatmaps = self.forward_heatmap(x) | |
probabilities = self.forward_probability(x) | |
visibilities = self.forward_visibility(x) | |
oks = self.forward_oks(x) | |
errors = self.forward_error(x) | |
return heatmaps, probabilities, visibilities, oks, errors | |
def forward_heatmap(self, x: Tensor) -> Tensor: | |
"""Forward the network. The input is multi scale feature maps and the | |
output is the heatmap. | |
Args: | |
x (Tensor): Multi scale feature maps. | |
Returns: | |
Tensor: output heatmap. | |
""" | |
x = self.deconv_layers(x) | |
x = self.conv_layers(x) | |
x = self.final_layer(x) | |
B, C, H, W = x.shape | |
x = x.reshape((B, C, H*W)) | |
x = self.normalize_layer(x/self.temperature) | |
if self.normalize is not None: | |
x = x * self.normalize | |
x = torch.clamp(x, 0, 1) | |
x = x.reshape((B, C, H, W)) | |
# # Blur the heatmaps with Gaussian | |
# x = x.reshape((B*C, 1, H, W)) | |
# x = nn.functional.conv2d(x, self.gauss_kernel[None, None, :, :].to(x.device), padding='same') | |
# x = x.reshape((B, C, H, W)) | |
return x | |
def forward_probability(self, x: Tensor) -> Tensor: | |
"""Forward the network. The input is multi scale feature maps and the | |
output is the probability. | |
Args: | |
x (Tensor): Multi scale feature maps. | |
detach (bool): Whether to detach the probability from gradient | |
Returns: | |
Tensor: output probability. | |
""" | |
if self.detach_probability: | |
x = x.detach() | |
x = self.probability_layers(x) | |
return x | |
def forward_visibility(self, x: Tensor) -> Tensor: | |
"""Forward the network. The input is multi scale feature maps and the | |
output is the visibility. | |
Args: | |
x (Tensor): Multi scale feature maps. | |
detach (bool): Whether to detach the visibility from gradient | |
Returns: | |
Tensor: output visibility. | |
""" | |
if self.detach_visibility: | |
x = x.detach() | |
x = self.visibility_layers(x) | |
return x | |
def forward_oks(self, x: Tensor) -> Tensor: | |
"""Forward the network. The input is multi scale feature maps and the | |
output is the oks. | |
Args: | |
x (Tensor): Multi scale feature maps. | |
Returns: | |
Tensor: output oks. | |
""" | |
x = x.detach() | |
x = self.oks_layers(x) | |
return x | |
def forward_error(self, x: Tensor) -> Tensor: | |
"""Forward the network. The input is multi scale feature maps and the | |
output is the euclidean error. | |
Args: | |
x (Tensor): Multi scale feature maps. | |
Returns: | |
Tensor: output error. | |
""" | |
x = x.detach() | |
x = self.error_layers(x) | |
return x | |
def predict(self, | |
feats: Features, | |
batch_data_samples: OptSampleList, | |
test_cfg: ConfigType = {}) -> Predictions: | |
"""Predict results from features. | |
Args: | |
feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage | |
features (or multiple multi-stage features in TTA) | |
batch_data_samples (List[:obj:`PoseDataSample`]): The batch | |
data samples | |
test_cfg (dict): The runtime config for testing process. Defaults | |
to {} | |
Returns: | |
Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If | |
``test_cfg['output_heatmap']==True``, return both pose and heatmap | |
prediction; otherwise only return the pose prediction. | |
The pose prediction is a list of ``InstanceData``, each contains | |
the following fields: | |
- keypoints (np.ndarray): predicted keypoint coordinates in | |
shape (num_instances, K, D) where K is the keypoint number | |
and D is the keypoint dimension | |
- keypoint_scores (np.ndarray): predicted keypoint scores in | |
shape (num_instances, K) | |
The heatmap prediction is a list of ``PixelData``, each contains | |
the following fields: | |
- heatmaps (Tensor): The predicted heatmaps in shape (K, h, w) | |
""" | |
if test_cfg.get('flip_test', False): | |
# TTA: flip test -> feats = [orig, flipped] | |
assert isinstance(feats, list) and len(feats) == 2 | |
flip_indices = batch_data_samples[0].metainfo['flip_indices'] | |
_feats, _feats_flip = feats | |
_htm, _prob, _vis, _oks, _err = self.forward(_feats) | |
_htm_flip, _prob_flip, _vis_flip, _oks_flip, _err_flip = self.forward(_feats_flip) | |
B, C, H, W = _htm.shape | |
# Flip back the keypoints | |
_htm_flip = flip_heatmaps( | |
_htm_flip, | |
flip_mode=test_cfg.get('flip_mode', 'heatmap'), | |
flip_indices=flip_indices, | |
shift_heatmap=test_cfg.get('shift_heatmap', False)) | |
heatmaps = (_htm + _htm_flip) * 0.5 | |
# Flip back scalars | |
_prob_flip = _prob_flip[:, flip_indices] | |
_vis_flip = _vis_flip[:, flip_indices] | |
_oks_flip = _oks_flip[:, flip_indices] | |
_err_flip = _err_flip[:, flip_indices] | |
probabilities = (_prob + _prob_flip) * 0.5 | |
visibilities = (_vis + _vis_flip) * 0.5 | |
oks = (_oks + _oks_flip) * 0.5 | |
errors = (_err + _err_flip) * 0.5 | |
else: | |
heatmaps, probabilities, visibilities, oks, errors = self.forward(feats) | |
B, C, H, W = heatmaps.shape | |
preds = self.decode(heatmaps) | |
probabilities = to_numpy(probabilities).reshape((B, 1, C)) | |
visibilities = to_numpy(visibilities).reshape((B, 1, C)) | |
oks = to_numpy(oks).reshape((B, 1, C)) | |
errors = to_numpy(errors).reshape((B, 1, C)) | |
# Normalize errors by dividing with the diagonal of the heatmap | |
htm_diagonal = np.sqrt(H**2 + W**2) | |
errors = errors / htm_diagonal | |
for pi, p in enumerate(preds): | |
p.set_field(p['keypoint_scores'], "keypoints_conf") | |
p.set_field(probabilities[pi], "keypoints_probs") | |
p.set_field(visibilities[pi], "keypoints_visible") | |
p.set_field(oks[pi], "keypoints_oks") | |
p.set_field(errors[pi], "keypoints_error") | |
# Replace the keypoint scores with OKS/errors | |
if not self.freeze_oks: | |
p.set_field(oks[pi], "keypoint_scores") | |
# p.set_field(1-errors[pi], "keypoint_scores") | |
# hm = heatmaps.detach().cpu().numpy() | |
# print("Heatmaps:", hm.shape, hm.min(), hm.max()) | |
if test_cfg.get('output_heatmaps', False): | |
pred_fields = [ | |
PixelData(heatmaps=hm) for hm in heatmaps.detach() | |
] | |
return preds, pred_fields | |
else: | |
return preds | |
def loss(self, | |
feats: Tuple[Tensor], | |
batch_data_samples: OptSampleList, | |
train_cfg: ConfigType = {}) -> dict: | |
"""Calculate losses from a batch of inputs and data samples. | |
Args: | |
feats (Tuple[Tensor]): The multi-stage features | |
batch_data_samples (List[:obj:`PoseDataSample`]): The batch | |
data samples | |
train_cfg (dict): The runtime config for training process. | |
Defaults to {} | |
Returns: | |
dict: A dictionary of losses. | |
""" | |
dt_heatmaps, dt_probs, dt_vis, dt_oks, dt_errs = self.forward(feats) | |
device=dt_heatmaps.device | |
B, C, H, W = dt_heatmaps.shape | |
# Extract GT data | |
gt_heatmaps = torch.stack( | |
[d.gt_fields.heatmaps for d in batch_data_samples]) | |
gt_probs = np.stack( | |
[d.gt_instances.in_image.astype(int) for d in batch_data_samples]) | |
gt_annotated = np.stack( | |
[d.gt_instances.keypoints_visible.astype(int) for d in batch_data_samples]) | |
gt_vis = np.stack( | |
[d.gt_instances.keypoints_visibility.astype(int) for d in batch_data_samples]) | |
keypoint_weights = torch.cat([ | |
d.gt_instance_labels.keypoint_weights for d in batch_data_samples | |
]) | |
# Compute GT errors and OKS | |
if self.freeze_error: | |
gt_errs = torch.zeros((B, C, 1), device=device, dtype=dt_errs.dtype) | |
else: | |
gt_errs = self._error_from_heatmaps(gt_heatmaps, dt_heatmaps) | |
if self.freeze_oks: | |
gt_oks = torch.zeros((B, C, 1), device=device, dtype=dt_oks.dtype) | |
oks_weight = torch.zeros((B, C, 1), device=device, dtype=dt_oks.dtype) | |
else: | |
gt_oks, oks_weight = self._oks_from_heatmaps( | |
gt_heatmaps, | |
dt_heatmaps, | |
gt_probs & gt_annotated, | |
) | |
# Convert everything to tensors | |
gt_probs = torch.tensor(gt_probs, device=device, dtype=dt_probs.dtype) | |
gt_vis = torch.tensor(gt_vis, device=device, dtype=dt_vis.dtype) | |
gt_annotated = torch.tensor(gt_annotated, device=device) | |
gt_oks = gt_oks.to(device).to(dt_oks.dtype) | |
oks_weight = oks_weight.to(device).to(dt_oks.dtype) | |
gt_errs = gt_errs.to(device).to(dt_errs.dtype) | |
# Reshape everything to comparable shapes | |
gt_heatmaps = gt_heatmaps.view((B, C, H, W)) | |
dt_heatmaps = dt_heatmaps.view((B, C, H, W)) | |
gt_probs = gt_probs.view((B, C)) | |
dt_probs = dt_probs.view((B, C)) | |
gt_vis = gt_vis.view((B, C)) | |
dt_vis = dt_vis.view((B, C)) | |
gt_oks = gt_oks.view((B, C)) | |
dt_oks = dt_oks.view((B, C)) | |
gt_errs = gt_errs.view((B, C)) | |
dt_errs = dt_errs.view((B, C)) | |
keypoint_weights = keypoint_weights.view((B, C)) | |
gt_annotated = gt_annotated.view((B, C)) | |
# oks_weight = oks_weight.view((B, C)) | |
annotated_in = gt_annotated & (gt_probs > 0.5) | |
# calculate losses | |
losses = dict() | |
if self.learn_heatmaps_from_zeros: | |
heatmap_weights = gt_annotated | |
else: | |
# heatmap_weights = keypoint_weights | |
heatmap_weights = annotated_in | |
heatmap_loss_pxl = self.keypoint_loss_module(dt_heatmaps, gt_heatmaps, annotated_in, per_pixel=True) | |
heatmap_loss = self.keypoint_loss_module(dt_heatmaps, gt_heatmaps, annotated_in) | |
# probability_loss = self.probability_loss_module(dt_probs, gt_probs, gt_annotated) | |
# visibility_loss = self.visibility_loss_module(dt_vis, gt_vis, annotated_in) | |
# oks_loss = self.oks_loss_module(dt_oks, gt_oks, annotated_in) | |
# error_loss = self.error_loss_module(dt_errs, gt_errs, annotated_in) | |
# Visualize some heatmaps | |
for i in range(0, B): | |
# continue | |
if self.num_iters % self.interval == 0: | |
self.interval = int(self.interval * 1.3) | |
os.makedirs(self.loss_vis_folder, exist_ok=True) | |
for kpt_i in np.random.choice(C, 17, replace=False): | |
tgt = gt_heatmaps[i, kpt_i].detach().cpu().numpy() | |
htm = dt_heatmaps[i, kpt_i].detach().cpu().numpy() | |
lss = heatmap_loss_pxl[i, kpt_i].detach().cpu().numpy() | |
save_img = self._visualize_heatmaps( | |
htm, tgt, lss, keypoint_weights[i, kpt_i], gt_probs[i, kpt_i] | |
) | |
save_path = os.path.join( | |
self.loss_vis_folder, | |
"heatmap_{:07d}-{:d}-{:d}.png".format(self.num_iters, i, kpt_i) | |
) | |
cv2.imwrite(save_path, save_img) | |
self.num_iters += 1 | |
losses.update( | |
loss_kpt=heatmap_loss | |
) | |
# calculate accuracy | |
if train_cfg.get('compute_acc', True): | |
acc_pose = self.get_pose_accuracy( | |
dt_heatmaps, gt_heatmaps, keypoint_weights > 0.5 | |
) | |
losses.update(acc_pose=acc_pose) | |
# Calculate the best binary accuracy for probability | |
acc_prob, _ = self.get_binary_accuracy( | |
dt_probs, | |
gt_probs, | |
gt_annotated > 0.5, | |
force_balanced=True, | |
) | |
losses.update(acc_prob=acc_prob) | |
# Calculate the best binary accuracy for visibility | |
acc_vis, _ = self.get_binary_accuracy( | |
dt_vis, | |
gt_vis, | |
annotated_in > 0.5, | |
force_balanced=True, | |
) | |
losses.update(acc_vis=acc_vis) | |
# Calculate the MAE for OKS | |
acc_oks = self.get_mae( | |
dt_oks, | |
gt_oks, | |
annotated_in > 0.5, | |
) | |
losses.update(mae_oks=acc_oks) | |
# Calculate the MAE for euclidean error | |
acc_err = self.get_mae( | |
dt_errs, | |
gt_errs, | |
annotated_in > 0.5, | |
) | |
losses.update(mae_err=acc_err) | |
# Calculate the MAE between Euclidean error and OKS | |
err_to_oks_mae = self.get_mae( | |
self.error_to_OKS(dt_errs, area=H*W), | |
gt_oks, | |
annotated_in > 0.5, | |
) | |
losses.update(mae_err_to_oks=err_to_oks_mae) | |
print(self.temperature.item()) | |
return losses | |
def _visualize_heatmaps( | |
self, | |
htm, | |
tgt, | |
lss, | |
weight, | |
prob | |
): | |
tgt_range = (tgt.min(), tgt.max()) | |
htm_range = (htm.min(), htm.max()) | |
lss_range = (lss.min(), lss.max()) | |
tgt[tgt < 0] = 0 | |
htm[htm < 0] = 0 | |
lss[lss < 0] = 0 | |
# Normalize heatmaps between 0 and 1 | |
tgt /= (tgt.max()+1e-10) | |
htm /= (htm.max()+1e-10) | |
lss /= (lss.max()+1e-10) | |
scale = 6 | |
htm_color = cv2.cvtColor((htm*255).astype(np.uint8), cv2.COLOR_GRAY2BGR) | |
htm_color = cv2.applyColorMap(htm_color, cv2.COLORMAP_JET) | |
htm_color = cv2.resize(htm_color, (htm.shape[1]*scale, htm.shape[0]*scale), interpolation=cv2.INTER_NEAREST) | |
tgt_color = cv2.cvtColor((tgt*255).astype(np.uint8), cv2.COLOR_GRAY2BGR) | |
tgt_color = cv2.applyColorMap(tgt_color, cv2.COLORMAP_JET) | |
tgt_color = cv2.resize(tgt_color, (htm.shape[1]*scale, htm.shape[0]*scale), interpolation=cv2.INTER_NEAREST) | |
lss_color = cv2.cvtColor((lss*255).astype(np.uint8), cv2.COLOR_GRAY2BGR) | |
lss_color = cv2.applyColorMap(lss_color, cv2.COLORMAP_JET) | |
lss_color = cv2.resize(lss_color, (htm.shape[1]*scale, htm.shape[0]*scale), interpolation=cv2.INTER_NEAREST) | |
if scale > 2: | |
tgt_color_text = tgt_color.copy() | |
cv2.putText(tgt_color_text, "tgt ({:.1f}, {:.1f})".format(tgt_range[0]*10, tgt_range[1]*10), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
tgt_color = cv2.addWeighted(tgt_color, 0.6, tgt_color_text, 0.4, 0) | |
htm_color_text = htm_color.copy() | |
cv2.putText(htm_color_text, "htm ({:.1f}, {:.1f})".format(htm_range[0]*10, htm_range[1]*10), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
htm_color = cv2.addWeighted(htm_color, 0.6, htm_color_text, 0.4, 0) | |
lss_color_text = lss_color.copy() | |
cv2.putText(lss_color_text, "lss ({:.1f}, {:.1f})".format(lss_range[0]*10, lss_range[1]*10), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
lss_color = cv2.addWeighted(lss_color, 0.6, lss_color_text, 0.4, 0) | |
# Get argmax of the target and draw horizontal and vertical lines | |
tgt_argmax = np.unravel_index(tgt.argmax(), tgt.shape) | |
tgt_color_line = tgt_color.copy() | |
cv2.line(tgt_color_line, (0, tgt_argmax[0]*scale), (tgt_color.shape[1], tgt_argmax[0]*scale), (0, 255, 255), 1) | |
cv2.line(tgt_color_line, (tgt_argmax[1]*scale, 0), (tgt_argmax[1]*scale, tgt_color.shape[0]), (0, 255, 255), 1) | |
tgt_color = cv2.addWeighted(tgt_color, 0.6, tgt_color_line, 0.4, 0) | |
htm_color_line = htm_color.copy() | |
cv2.line(htm_color_line, (0, tgt_argmax[0]*scale), (tgt_color.shape[1], tgt_argmax[0]*scale), (0, 255, 255), 1) | |
cv2.line(htm_color_line, (tgt_argmax[1]*scale, 0), (tgt_argmax[1]*scale, tgt_color.shape[0]), (0, 255, 255), 1) | |
htm_color = cv2.addWeighted(htm_color, 0.6, htm_color_line, 0.4, 0) | |
lss_color_line = lss_color.copy() | |
cv2.line(lss_color_line, (0, tgt_argmax[0]*scale), (tgt_color.shape[1], tgt_argmax[0]*scale), (0, 255, 255), 1) | |
cv2.line(lss_color_line, (tgt_argmax[1]*scale, 0), (tgt_argmax[1]*scale, tgt_color.shape[0]), (0, 255, 255), 1) | |
lss_color = cv2.addWeighted(lss_color, 0.6, lss_color_line, 0.4, 0) | |
white_column = np.ones((tgt_color.shape[0], 1, 3), dtype=np.uint8) * 255 | |
save_img = np.concatenate(( | |
tgt_color, | |
white_column, | |
htm_color, | |
white_column, | |
lss_color, | |
), axis=1) | |
if weight < 0.5: | |
# Draw a red X across the whole save_img | |
cv2.line(save_img, (0, 0), (save_img.shape[1], save_img.shape[0]), (0, 0, 255), 2) | |
cv2.line(save_img, (0, save_img.shape[0]), (save_img.shape[1], 0), (0, 0, 255), 2) | |
elif prob < 0.5: | |
# Draw an yellow X across the whole save_img | |
cv2.line(save_img, (0, 0), (save_img.shape[1], save_img.shape[0]), (0, 255, 255), 2) | |
cv2.line(save_img, (0, save_img.shape[0]), (save_img.shape[1], 0), (0, 255, 255), 2) | |
return save_img | |
def get_pose_accuracy(self, dt, gt, mask): | |
"""Calculate the accuracy of predicted pose.""" | |
_, avg_acc, _ = pose_pck_accuracy( | |
output=to_numpy(dt), | |
target=to_numpy(gt), | |
mask=to_numpy(mask), | |
method='argmax', | |
) | |
acc_pose = torch.tensor(avg_acc, device=gt.device) | |
return acc_pose | |
def get_binary_accuracy(self, dt, gt, mask, force_balanced=False): | |
"""Calculate the binary accuracy.""" | |
assert dt.shape == gt.shape | |
device = gt.device | |
dt = to_numpy(dt) | |
gt = to_numpy(gt) | |
mask = to_numpy(mask) | |
dt = dt[mask] | |
gt = gt[mask] | |
gt = gt.astype(bool) | |
if force_balanced: | |
# Force the number of positive and negative samples to be balanced | |
pos_num = np.sum(gt) | |
neg_num = len(gt) - pos_num | |
num = min(pos_num, neg_num) | |
if num == 0: | |
return torch.tensor([0.0], device=device), torch.tensor([0.0], device=device) | |
pos_idx = np.where(gt)[0] | |
neg_idx = np.where(~gt)[0] | |
# Randomly sample the same number of positive and negative samples | |
np.random.shuffle(pos_idx) | |
np.random.shuffle(neg_idx) | |
idx = np.concatenate([pos_idx[:num], neg_idx[:num]]) | |
dt = dt[idx] | |
gt = gt[idx] | |
n_samples = len(gt) | |
thresholds = np.arange(0.1, 1.0, 0.05) | |
preds = (dt[:, None] > thresholds) | |
correct = preds == gt[:, None] | |
counts = correct.sum(axis=0) | |
# Find the threshold that maximizes the accuracy | |
best_idx = np.argmax(counts) | |
best_threshold = thresholds[best_idx] | |
best_acc = counts[best_idx] / n_samples | |
best_acc = torch.tensor(best_acc, device=device).float() | |
best_threshold = torch.tensor(best_threshold, device=device).float() | |
return best_acc, best_threshold | |
def get_mae(self, dt, gt, mask): | |
"""Calculate the mean absolute error.""" | |
assert dt.shape == gt.shape | |
device = gt.device | |
dt = to_numpy(dt) | |
gt = to_numpy(gt) | |
mask = to_numpy(mask) | |
dt = dt[mask] | |
gt = gt[mask] | |
mae = np.abs(dt - gt).mean() | |
mae = torch.tensor(mae, device=device) | |
return mae | |
def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, | |
**kwargs): | |
"""A hook function to convert old-version state dict of | |
:class:`TopdownHeatmapSimpleHead` (before MMPose v1.0.0) to a | |
compatible format of :class:`HeatmapHead`. | |
The hook will be automatically registered during initialization. | |
""" | |
version = local_meta.get('version', None) | |
if version and version >= self._version: | |
return | |
# convert old-version state dict | |
keys = list(state_dict.keys()) | |
for _k in keys: | |
if not _k.startswith(prefix): | |
continue | |
v = state_dict.pop(_k) | |
k = _k[len(prefix):] | |
# In old version, "final_layer" includes both intermediate | |
# conv layers (new "conv_layers") and final conv layers (new | |
# "final_layer"). | |
# | |
# If there is no intermediate conv layer, old "final_layer" will | |
# have keys like "final_layer.xxx", which should be still | |
# named "final_layer.xxx"; | |
# | |
# If there are intermediate conv layers, old "final_layer" will | |
# have keys like "final_layer.n.xxx", where the weights of the last | |
# one should be renamed "final_layer.xxx", and others should be | |
# renamed "conv_layers.n.xxx" | |
k_parts = k.split('.') | |
if k_parts[0] == 'final_layer': | |
if len(k_parts) == 3: | |
assert isinstance(self.conv_layers, nn.Sequential) | |
idx = int(k_parts[1]) | |
if idx < len(self.conv_layers): | |
# final_layer.n.xxx -> conv_layers.n.xxx | |
k_new = 'conv_layers.' + '.'.join(k_parts[1:]) | |
else: | |
# final_layer.n.xxx -> final_layer.xxx | |
k_new = 'final_layer.' + k_parts[2] | |
else: | |
# final_layer.xxx remains final_layer.xxx | |
k_new = k | |
else: | |
k_new = k | |
state_dict[prefix + k_new] = v | |
def error_to_OKS(self, error, area=1.0): | |
"""Convert the error to OKS.""" | |
sigmas = np.array( | |
[.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89])/10.0 | |
if isinstance(error, torch.Tensor): | |
sigmas = torch.tensor(sigmas, device=error.device) | |
vars = (sigmas * 2)**2 | |
norm_error = error**2 / vars / area / 2.0 | |
return torch.exp(-norm_error) | |
def compute_oks(gt, dt, use_area=True, per_kpt=False): | |
sigmas = np.array( | |
[.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89])/10.0 | |
vars = (sigmas * 2)**2 | |
k = len(sigmas) | |
visibility_condition = lambda x: x > 0 | |
g = np.array(gt['keypoints']).reshape(k, 3) | |
xg = g[:, 0]; yg = g[:, 1]; vg = g[:, 2] | |
k1 = np.count_nonzero(visibility_condition(vg)) | |
bb = gt['bbox'] | |
x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2 | |
y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2 | |
d = np.array(dt['keypoints']).reshape((k, 3)) | |
xd = d[:, 0]; yd = d[:, 1] | |
if k1>0: | |
# measure the per-keypoint distance if keypoints visible | |
dx = xd - xg | |
dy = yd - yg | |
else: | |
# measure minimum distance to keypoints in (x0,y0) & (x1,y1) | |
z = np.zeros((k)) | |
dx = np.max((z, x0-xd),axis=0)+np.max((z, xd-x1),axis=0) | |
dy = np.max((z, y0-yd),axis=0)+np.max((z, yd-y1),axis=0) | |
if use_area: | |
e = (dx**2 + dy**2) / vars / (gt['area']+np.spacing(1)) / 2 | |
else: | |
tmparea = gt['bbox'][3] * gt['bbox'][2] * 0.53 | |
e = (dx**2 + dy**2) / vars / (tmparea+np.spacing(1)) / 2 | |
if per_kpt: | |
oks = np.exp(-e) | |
if k1 > 0: | |
oks[~visibility_condition(vg)] = 0 | |
else: | |
if k1 > 0: | |
e=e[visibility_condition(vg)] | |
oks = np.sum(np.exp(-e)) / e.shape[0] | |
return oks |