File size: 4,243 Bytes
a249588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Union

import mmcv
import numpy as np
from mmengine.structures import InstanceData

from mmpose.datasets.datasets.utils import parse_pose_metainfo
from mmpose.structures import PoseDataSample
from mmpose.visualization import PoseLocalVisualizer

# from posevis import pose_visualization

# def visualize(
#     img: Union[np.ndarray, str],
#     keypoints: np.ndarray,
#     keypoint_score: np.ndarray = None,
#     metainfo: Union[str, dict] = None,
#     visualizer: PoseLocalVisualizer = None,
#     show_kpt_idx: bool = False,
#     skeleton_style: str = 'mmpose',
#     show: bool = False,
#     kpt_thr: float = 0.3,
# ):
#     """Visualize 2d keypoints on an image.

#     Args:
#         img (str | np.ndarray): The image to be displayed.
#         keypoints (np.ndarray): The keypoint to be displayed.
#         keypoint_score (np.ndarray): The score of each keypoint.
#         metainfo (str | dict): The metainfo of dataset.
#         visualizer (PoseLocalVisualizer): The visualizer.
#         show_kpt_idx (bool): Whether to show the index of keypoints.
#         skeleton_style (str): Skeleton style. Options are 'mmpose' and
#             'openpose'.
#         show (bool): Whether to show the image.
#         wait_time (int): Value of waitKey param.
#         kpt_thr (float): Keypoint threshold.
#     """
#     kpts = keypoints.reshape(-1, 2)
#     kpts = np.concatenate([kpts, keypoint_score[:, None]], axis=1)
#     kpts[kpts[:, 2] < kpt_thr, :] = 0
#     pose_results = [{
#         'keypoints': kpts,
#     }]

#     img = pose_visualization(
#         img,
#         pose_results,
#         format="COCO",
#         greyness=1.0,
#         show_markers=True,
#         show_bones=True,
#         line_type="solid",
#         width_multiplier=1.0,
#         bbox_width_multiplier=1.0,
#         show_bbox=False,
#         differ_individuals=False,
#     )
#     return img


def visualize(
    img: Union[np.ndarray, str],
    keypoints: np.ndarray,
    keypoint_score: np.ndarray = None,
    metainfo: Union[str, dict] = None,
    visualizer: PoseLocalVisualizer = None,
    show_kpt_idx: bool = False,
    skeleton_style: str = 'mmpose',
    show: bool = False,
    kpt_thr: float = 0.3,
):
    """Visualize 2d keypoints on an image.

    Args:
        img (str | np.ndarray): The image to be displayed.
        keypoints (np.ndarray): The keypoint to be displayed.
        keypoint_score (np.ndarray): The score of each keypoint.
        metainfo (str | dict): The metainfo of dataset.
        visualizer (PoseLocalVisualizer): The visualizer.
        show_kpt_idx (bool): Whether to show the index of keypoints.
        skeleton_style (str): Skeleton style. Options are 'mmpose' and
            'openpose'.
        show (bool): Whether to show the image.
        wait_time (int): Value of waitKey param.
        kpt_thr (float): Keypoint threshold.
    """
    assert skeleton_style in [
        'mmpose', 'openpose'
    ], (f'Only support skeleton style in {["mmpose", "openpose"]}, ')

    if visualizer is None:
        visualizer = PoseLocalVisualizer()
    else:
        visualizer = deepcopy(visualizer)

    if isinstance(metainfo, str):
        metainfo = parse_pose_metainfo(dict(from_file=metainfo))
    elif isinstance(metainfo, dict):
        metainfo = parse_pose_metainfo(metainfo)

    if metainfo is not None:
        visualizer.set_dataset_meta(metainfo, skeleton_style=skeleton_style)

    if isinstance(img, str):
        img = mmcv.imread(img, channel_order='rgb')
    elif isinstance(img, np.ndarray):
        img = mmcv.bgr2rgb(img)

    if keypoint_score is None:
        keypoint_score = np.ones(keypoints.shape[0])

    tmp_instances = InstanceData()
    tmp_instances.keypoints = keypoints
    tmp_instances.keypoint_score = keypoint_score

    tmp_datasample = PoseDataSample()
    tmp_datasample.pred_instances = tmp_instances

    visualizer.add_datasample(
        'visualization',
        img,
        tmp_datasample,
        show_kpt_idx=show_kpt_idx,
        skeleton_style=skeleton_style,
        show=show,
        wait_time=0,
        kpt_thr=kpt_thr)

    return visualizer.get_image()