Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
from itertools import product | |
from typing import Any, List, Optional, Tuple | |
import numpy as np | |
import torch | |
from munkres import Munkres | |
from torch import Tensor | |
from mmpose.registry import KEYPOINT_CODECS | |
from mmpose.utils.tensor_utils import to_numpy | |
from .base import BaseKeypointCodec | |
from .utils import (batch_heatmap_nms, generate_gaussian_heatmaps, | |
generate_udp_gaussian_heatmaps, refine_keypoints, | |
refine_keypoints_dark_udp) | |
def _py_max_match(scores): | |
"""Apply munkres algorithm to get the best match. | |
Args: | |
scores(np.ndarray): cost matrix. | |
Returns: | |
np.ndarray: best match. | |
""" | |
m = Munkres() | |
tmp = m.compute(scores) | |
tmp = np.array(tmp).astype(int) | |
return tmp | |
def _group_keypoints_by_tags(vals: np.ndarray, | |
tags: np.ndarray, | |
locs: np.ndarray, | |
keypoint_order: List[int], | |
val_thr: float, | |
tag_thr: float = 1.0, | |
max_groups: Optional[int] = None) -> np.ndarray: | |
"""Group the keypoints by tags using Munkres algorithm. | |
Note: | |
- keypoint number: K | |
- candidate number: M | |
- tag dimenssion: L | |
- coordinate dimension: D | |
- group number: G | |
Args: | |
vals (np.ndarray): The heatmap response values of keypoints in shape | |
(K, M) | |
tags (np.ndarray): The tags of the keypoint candidates in shape | |
(K, M, L) | |
locs (np.ndarray): The locations of the keypoint candidates in shape | |
(K, M, D) | |
keypoint_order (List[int]): The grouping order of the keypoints. | |
The groupping usually starts from a keypoints around the head and | |
torso, and gruadually moves out to the limbs | |
val_thr (float): The threshold of the keypoint response value | |
tag_thr (float): The maximum allowed tag distance when matching a | |
keypoint to a group. A keypoint with larger tag distance to any | |
of the existing groups will initializes a new group | |
max_groups (int, optional): The maximum group number. ``None`` means | |
no limitation. Defaults to ``None`` | |
Returns: | |
np.ndarray: grouped keypoints in shape (G, K, D+1), where the last | |
dimenssion is the concatenated keypoint coordinates and scores. | |
""" | |
tag_k, loc_k, val_k = tags, locs, vals | |
K, M, D = locs.shape | |
assert vals.shape == tags.shape[:2] == (K, M) | |
assert len(keypoint_order) == K | |
default_ = np.zeros((K, 3 + tag_k.shape[2]), dtype=np.float32) | |
joint_dict = {} | |
tag_dict = {} | |
for i in range(K): | |
idx = keypoint_order[i] | |
tags = tag_k[idx] | |
joints = np.concatenate((loc_k[idx], val_k[idx, :, None], tags), 1) | |
mask = joints[:, 2] > val_thr | |
tags = tags[mask] # shape: [M, L] | |
joints = joints[mask] # shape: [M, 3 + L], 3: x, y, val | |
if joints.shape[0] == 0: | |
continue | |
if i == 0 or len(joint_dict) == 0: | |
for tag, joint in zip(tags, joints): | |
key = tag[0] | |
joint_dict.setdefault(key, np.copy(default_))[idx] = joint | |
tag_dict[key] = [tag] | |
else: | |
# shape: [M] | |
grouped_keys = list(joint_dict.keys()) | |
# shape: [M, L] | |
grouped_tags = [np.mean(tag_dict[i], axis=0) for i in grouped_keys] | |
# shape: [M, M, L] | |
diff = joints[:, None, 3:] - np.array(grouped_tags)[None, :, :] | |
# shape: [M, M] | |
diff_normed = np.linalg.norm(diff, ord=2, axis=2) | |
diff_saved = np.copy(diff_normed) | |
diff_normed = np.round(diff_normed) * 100 - joints[:, 2:3] | |
num_added = diff.shape[0] | |
num_grouped = diff.shape[1] | |
if num_added > num_grouped: | |
diff_normed = np.concatenate( | |
(diff_normed, | |
np.zeros((num_added, num_added - num_grouped), | |
dtype=np.float32) + 1e10), | |
axis=1) | |
pairs = _py_max_match(diff_normed) | |
for row, col in pairs: | |
if (row < num_added and col < num_grouped | |
and diff_saved[row][col] < tag_thr): | |
key = grouped_keys[col] | |
joint_dict[key][idx] = joints[row] | |
tag_dict[key].append(tags[row]) | |
else: | |
key = tags[row][0] | |
joint_dict.setdefault(key, np.copy(default_))[idx] = \ | |
joints[row] | |
tag_dict[key] = [tags[row]] | |
joint_dict_keys = list(joint_dict.keys())[:max_groups] | |
if joint_dict_keys: | |
results = np.array([joint_dict[i] | |
for i in joint_dict_keys]).astype(np.float32) | |
results = results[..., :D + 1] | |
else: | |
results = np.empty((0, K, D + 1), dtype=np.float32) | |
return results | |
class AssociativeEmbedding(BaseKeypointCodec): | |
"""Encode/decode keypoints with the method introduced in "Associative | |
Embedding". This is an asymmetric codec, where the keypoints are | |
represented as gaussian heatmaps and position indices during encoding, and | |
restored from predicted heatmaps and group tags. | |
See the paper `Associative Embedding: End-to-End Learning for Joint | |
Detection and Grouping`_ by Newell et al (2017) for details | |
Note: | |
- instance number: N | |
- keypoint number: K | |
- keypoint dimension: D | |
- embedding tag dimension: L | |
- image size: [w, h] | |
- heatmap size: [W, H] | |
Encoded: | |
- heatmaps (np.ndarray): The generated heatmap in shape (K, H, W) | |
where [W, H] is the `heatmap_size` | |
- keypoint_indices (np.ndarray): The keypoint position indices in shape | |
(N, K, 2). Each keypoint's index is [i, v], where i is the position | |
index in the heatmap (:math:`i=y*w+x`) and v is the visibility | |
- keypoint_weights (np.ndarray): The target weights in shape (N, K) | |
Args: | |
input_size (tuple): Image size in [w, h] | |
heatmap_size (tuple): Heatmap size in [W, H] | |
sigma (float): The sigma value of the Gaussian heatmap | |
use_udp (bool): Whether use unbiased data processing. See | |
`UDP (CVPR 2020)`_ for details. Defaults to ``False`` | |
decode_keypoint_order (List[int]): The grouping order of the | |
keypoint indices. The groupping usually starts from a keypoints | |
around the head and torso, and gruadually moves out to the limbs | |
decode_keypoint_thr (float): The threshold of keypoint response value | |
in heatmaps. Defaults to 0.1 | |
decode_tag_thr (float): The maximum allowed tag distance when matching | |
a keypoint to a group. A keypoint with larger tag distance to any | |
of the existing groups will initializes a new group. Defaults to | |
1.0 | |
decode_nms_kernel (int): The kernel size of the NMS during decoding, | |
which should be an odd integer. Defaults to 5 | |
decode_gaussian_kernel (int): The kernel size of the Gaussian blur | |
during decoding, which should be an odd integer. It is only used | |
when ``self.use_udp==True``. Defaults to 3 | |
decode_topk (int): The number top-k candidates of each keypoints that | |
will be retrieved from the heatmaps during dedocding. Defaults to | |
20 | |
decode_max_instances (int, optional): The maximum number of instances | |
to decode. ``None`` means no limitation to the instance number. | |
Defaults to ``None`` | |
.. _`Associative Embedding: End-to-End Learning for Joint Detection and | |
Grouping`: https://arxiv.org/abs/1611.05424 | |
.. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524 | |
""" | |
def __init__( | |
self, | |
input_size: Tuple[int, int], | |
heatmap_size: Tuple[int, int], | |
sigma: Optional[float] = None, | |
use_udp: bool = False, | |
decode_keypoint_order: List[int] = [], | |
decode_nms_kernel: int = 5, | |
decode_gaussian_kernel: int = 3, | |
decode_keypoint_thr: float = 0.1, | |
decode_tag_thr: float = 1.0, | |
decode_topk: int = 30, | |
decode_center_shift=0.0, | |
decode_max_instances: Optional[int] = None, | |
) -> None: | |
super().__init__() | |
self.input_size = input_size | |
self.heatmap_size = heatmap_size | |
self.use_udp = use_udp | |
self.decode_nms_kernel = decode_nms_kernel | |
self.decode_gaussian_kernel = decode_gaussian_kernel | |
self.decode_keypoint_thr = decode_keypoint_thr | |
self.decode_tag_thr = decode_tag_thr | |
self.decode_topk = decode_topk | |
self.decode_center_shift = decode_center_shift | |
self.decode_max_instances = decode_max_instances | |
self.decode_keypoint_order = decode_keypoint_order.copy() | |
if self.use_udp: | |
self.scale_factor = ((np.array(input_size) - 1) / | |
(np.array(heatmap_size) - 1)).astype( | |
np.float32) | |
else: | |
self.scale_factor = (np.array(input_size) / | |
heatmap_size).astype(np.float32) | |
if sigma is None: | |
sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 64 | |
self.sigma = sigma | |
def encode( | |
self, | |
keypoints: np.ndarray, | |
keypoints_visible: Optional[np.ndarray] = None | |
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
"""Encode keypoints into heatmaps and position indices. Note that the | |
original keypoint coordinates should be in the input image space. | |
Args: | |
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) | |
keypoints_visible (np.ndarray): Keypoint visibilities in shape | |
(N, K) | |
Returns: | |
dict: | |
- heatmaps (np.ndarray): The generated heatmap in shape | |
(K, H, W) where [W, H] is the `heatmap_size` | |
- keypoint_indices (np.ndarray): The keypoint position indices | |
in shape (N, K, 2). Each keypoint's index is [i, v], where i | |
is the position index in the heatmap (:math:`i=y*w+x`) and v | |
is the visibility | |
- keypoint_weights (np.ndarray): The target weights in shape | |
(N, K) | |
""" | |
if keypoints_visible is None: | |
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) | |
# keypoint coordinates in heatmap | |
_keypoints = keypoints / self.scale_factor | |
if self.use_udp: | |
heatmaps, keypoint_weights = generate_udp_gaussian_heatmaps( | |
heatmap_size=self.heatmap_size, | |
keypoints=_keypoints, | |
keypoints_visible=keypoints_visible, | |
sigma=self.sigma) | |
else: | |
heatmaps, keypoint_weights = generate_gaussian_heatmaps( | |
heatmap_size=self.heatmap_size, | |
keypoints=_keypoints, | |
keypoints_visible=keypoints_visible, | |
sigma=self.sigma) | |
keypoint_indices = self._encode_keypoint_indices( | |
heatmap_size=self.heatmap_size, | |
keypoints=_keypoints, | |
keypoints_visible=keypoints_visible) | |
encoded = dict( | |
heatmaps=heatmaps, | |
keypoint_indices=keypoint_indices, | |
keypoint_weights=keypoint_weights) | |
return encoded | |
def _encode_keypoint_indices(self, heatmap_size: Tuple[int, int], | |
keypoints: np.ndarray, | |
keypoints_visible: np.ndarray) -> np.ndarray: | |
w, h = heatmap_size | |
N, K, _ = keypoints.shape | |
keypoint_indices = np.zeros((N, K, 2), dtype=np.int64) | |
for n, k in product(range(N), range(K)): | |
x, y = (keypoints[n, k] + 0.5).astype(np.int64) | |
index = y * w + x | |
vis = (keypoints_visible[n, k] > 0.5 and 0 <= x < w and 0 <= y < h) | |
keypoint_indices[n, k] = [index, vis] | |
return keypoint_indices | |
def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]: | |
raise NotImplementedError() | |
def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor, | |
k: int): | |
"""Get top-k response values from the heatmaps and corresponding tag | |
values from the tagging heatmaps. | |
Args: | |
batch_heatmaps (Tensor): Keypoint detection heatmaps in shape | |
(B, K, H, W) | |
batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where | |
the tag dim C is 2*K when using flip testing, or K otherwise | |
k (int): The number of top responses to get | |
Returns: | |
tuple: | |
- topk_vals (Tensor): Top-k response values of each heatmap in | |
shape (B, K, Topk) | |
- topk_tags (Tensor): The corresponding embedding tags of the | |
top-k responses, in shape (B, K, Topk, L) | |
- topk_locs (Tensor): The location of the top-k responses in each | |
heatmap, in shape (B, K, Topk, 2) where last dimension | |
represents x and y coordinates | |
""" | |
B, K, H, W = batch_heatmaps.shape | |
L = batch_tags.shape[1] // K | |
# shape of topk_val, top_indices: (B, K, TopK) | |
topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk( | |
k, dim=-1) | |
topk_tags_per_kpts = [ | |
torch.gather(_tag, dim=2, index=topk_indices) | |
for _tag in torch.unbind(batch_tags.view(B, L, K, H * W), dim=1) | |
] | |
topk_tags = torch.stack(topk_tags_per_kpts, dim=-1) # (B, K, TopK, L) | |
topk_locs = torch.stack([topk_indices % W, topk_indices // W], | |
dim=-1) # (B, K, TopK, 2) | |
return topk_vals, topk_tags, topk_locs | |
def _group_keypoints(self, batch_vals: np.ndarray, batch_tags: np.ndarray, | |
batch_locs: np.ndarray): | |
"""Group keypoints into groups (each represents an instance) by tags. | |
Args: | |
batch_vals (Tensor): Heatmap response values of keypoint | |
candidates in shape (B, K, Topk) | |
batch_tags (Tensor): Tags of keypoint candidates in shape | |
(B, K, Topk, L) | |
batch_locs (Tensor): Locations of keypoint candidates in shape | |
(B, K, Topk, 2) | |
Returns: | |
List[np.ndarray]: Grouping results of a batch, each element is a | |
np.ndarray (in shape [N, K, D+1]) that contains the groups | |
detected in an image, including both keypoint coordinates and | |
scores. | |
""" | |
def _group_func(inputs: Tuple): | |
vals, tags, locs = inputs | |
return _group_keypoints_by_tags( | |
vals, | |
tags, | |
locs, | |
keypoint_order=self.decode_keypoint_order, | |
val_thr=self.decode_keypoint_thr, | |
tag_thr=self.decode_tag_thr, | |
max_groups=self.decode_max_instances) | |
_results = map(_group_func, zip(batch_vals, batch_tags, batch_locs)) | |
results = list(_results) | |
return results | |
def _fill_missing_keypoints(self, keypoints: np.ndarray, | |
keypoint_scores: np.ndarray, | |
heatmaps: np.ndarray, tags: np.ndarray): | |
"""Fill the missing keypoints in the initial predictions. | |
Args: | |
keypoints (np.ndarray): Keypoint predictions in shape (N, K, D) | |
keypoint_scores (np.ndarray): Keypint score predictions in shape | |
(N, K), in which 0 means the corresponding keypoint is | |
missing in the initial prediction | |
heatmaps (np.ndarry): Heatmaps in shape (K, H, W) | |
tags (np.ndarray): Tagging heatmaps in shape (C, H, W) where | |
C=L*K | |
Returns: | |
tuple: | |
- keypoints (np.ndarray): Keypoint predictions with missing | |
ones filled | |
- keypoint_scores (np.ndarray): Keypoint score predictions with | |
missing ones filled | |
""" | |
N, K = keypoints.shape[:2] | |
H, W = heatmaps.shape[1:] | |
L = tags.shape[0] // K | |
keypoint_tags = [tags[k::K] for k in range(K)] | |
for n in range(N): | |
# Calculate the instance tag (mean tag of detected keypoints) | |
_tag = [] | |
for k in range(K): | |
if keypoint_scores[n, k] > 0: | |
x, y = keypoints[n, k, :2].astype(np.int64) | |
x = np.clip(x, 0, W - 1) | |
y = np.clip(y, 0, H - 1) | |
_tag.append(keypoint_tags[k][:, y, x]) | |
tag = np.mean(_tag, axis=0) | |
tag = tag.reshape(L, 1, 1) | |
# Search maximum response of the missing keypoints | |
for k in range(K): | |
if keypoint_scores[n, k] > 0: | |
continue | |
dist_map = np.linalg.norm( | |
keypoint_tags[k] - tag, ord=2, axis=0) | |
cost_map = np.round(dist_map) * 100 - heatmaps[k] # H, W | |
y, x = np.unravel_index(np.argmin(cost_map), shape=(H, W)) | |
keypoints[n, k] = [x, y] | |
keypoint_scores[n, k] = heatmaps[k, y, x] | |
return keypoints, keypoint_scores | |
def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor | |
) -> Tuple[List[np.ndarray], List[np.ndarray]]: | |
"""Decode the keypoint coordinates from a batch of heatmaps and tagging | |
heatmaps. The decoded keypoint coordinates are in the input image | |
space. | |
Args: | |
batch_heatmaps (Tensor): Keypoint detection heatmaps in shape | |
(B, K, H, W) | |
batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where | |
:math:`C=L*K` | |
Returns: | |
tuple: | |
- batch_keypoints (List[np.ndarray]): Decoded keypoint coordinates | |
of the batch, each is in shape (N, K, D) | |
- batch_scores (List[np.ndarray]): Decoded keypoint scores of the | |
batch, each is in shape (N, K). It usually represents the | |
confidience of the keypoint prediction | |
""" | |
B, _, H, W = batch_heatmaps.shape | |
assert batch_tags.shape[0] == B and batch_tags.shape[2:4] == (H, W), ( | |
f'Mismatched shapes of heatmap ({batch_heatmaps.shape}) and ' | |
f'tagging map ({batch_tags.shape})') | |
# Heatmap NMS | |
batch_heatmaps_peak = batch_heatmap_nms(batch_heatmaps, | |
self.decode_nms_kernel) | |
# Get top-k in each heatmap and and convert to numpy | |
batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy( | |
self._get_batch_topk( | |
batch_heatmaps_peak, batch_tags, k=self.decode_topk)) | |
# Group keypoint candidates into groups (instances) | |
batch_groups = self._group_keypoints(batch_topk_vals, batch_topk_tags, | |
batch_topk_locs) | |
# Convert to numpy | |
batch_heatmaps_np = to_numpy(batch_heatmaps) | |
batch_tags_np = to_numpy(batch_tags) | |
# Refine the keypoint prediction | |
batch_keypoints = [] | |
batch_keypoint_scores = [] | |
batch_instance_scores = [] | |
for i, (groups, heatmaps, tags) in enumerate( | |
zip(batch_groups, batch_heatmaps_np, batch_tags_np)): | |
keypoints, scores = groups[..., :-1], groups[..., -1] | |
instance_scores = scores.mean(axis=-1) | |
if keypoints.size > 0: | |
# refine keypoint coordinates according to heatmap distribution | |
if self.use_udp: | |
keypoints = refine_keypoints_dark_udp( | |
keypoints, | |
heatmaps, | |
blur_kernel_size=self.decode_gaussian_kernel) | |
else: | |
keypoints = refine_keypoints(keypoints, heatmaps) | |
keypoints += self.decode_center_shift * \ | |
(scores > 0).astype(keypoints.dtype)[..., None] | |
# identify missing keypoints | |
keypoints, scores = self._fill_missing_keypoints( | |
keypoints, scores, heatmaps, tags) | |
batch_keypoints.append(keypoints) | |
batch_keypoint_scores.append(scores) | |
batch_instance_scores.append(instance_scores) | |
# restore keypoint scale | |
batch_keypoints = [ | |
kpts * self.scale_factor for kpts in batch_keypoints | |
] | |
return batch_keypoints, batch_keypoint_scores, batch_instance_scores | |