Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,243 Bytes
a249588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Union
import mmcv
import numpy as np
from mmengine.structures import InstanceData
from mmpose.datasets.datasets.utils import parse_pose_metainfo
from mmpose.structures import PoseDataSample
from mmpose.visualization import PoseLocalVisualizer
# from posevis import pose_visualization
# def visualize(
# img: Union[np.ndarray, str],
# keypoints: np.ndarray,
# keypoint_score: np.ndarray = None,
# metainfo: Union[str, dict] = None,
# visualizer: PoseLocalVisualizer = None,
# show_kpt_idx: bool = False,
# skeleton_style: str = 'mmpose',
# show: bool = False,
# kpt_thr: float = 0.3,
# ):
# """Visualize 2d keypoints on an image.
# Args:
# img (str | np.ndarray): The image to be displayed.
# keypoints (np.ndarray): The keypoint to be displayed.
# keypoint_score (np.ndarray): The score of each keypoint.
# metainfo (str | dict): The metainfo of dataset.
# visualizer (PoseLocalVisualizer): The visualizer.
# show_kpt_idx (bool): Whether to show the index of keypoints.
# skeleton_style (str): Skeleton style. Options are 'mmpose' and
# 'openpose'.
# show (bool): Whether to show the image.
# wait_time (int): Value of waitKey param.
# kpt_thr (float): Keypoint threshold.
# """
# kpts = keypoints.reshape(-1, 2)
# kpts = np.concatenate([kpts, keypoint_score[:, None]], axis=1)
# kpts[kpts[:, 2] < kpt_thr, :] = 0
# pose_results = [{
# 'keypoints': kpts,
# }]
# img = pose_visualization(
# img,
# pose_results,
# format="COCO",
# greyness=1.0,
# show_markers=True,
# show_bones=True,
# line_type="solid",
# width_multiplier=1.0,
# bbox_width_multiplier=1.0,
# show_bbox=False,
# differ_individuals=False,
# )
# return img
def visualize(
img: Union[np.ndarray, str],
keypoints: np.ndarray,
keypoint_score: np.ndarray = None,
metainfo: Union[str, dict] = None,
visualizer: PoseLocalVisualizer = None,
show_kpt_idx: bool = False,
skeleton_style: str = 'mmpose',
show: bool = False,
kpt_thr: float = 0.3,
):
"""Visualize 2d keypoints on an image.
Args:
img (str | np.ndarray): The image to be displayed.
keypoints (np.ndarray): The keypoint to be displayed.
keypoint_score (np.ndarray): The score of each keypoint.
metainfo (str | dict): The metainfo of dataset.
visualizer (PoseLocalVisualizer): The visualizer.
show_kpt_idx (bool): Whether to show the index of keypoints.
skeleton_style (str): Skeleton style. Options are 'mmpose' and
'openpose'.
show (bool): Whether to show the image.
wait_time (int): Value of waitKey param.
kpt_thr (float): Keypoint threshold.
"""
assert skeleton_style in [
'mmpose', 'openpose'
], (f'Only support skeleton style in {["mmpose", "openpose"]}, ')
if visualizer is None:
visualizer = PoseLocalVisualizer()
else:
visualizer = deepcopy(visualizer)
if isinstance(metainfo, str):
metainfo = parse_pose_metainfo(dict(from_file=metainfo))
elif isinstance(metainfo, dict):
metainfo = parse_pose_metainfo(metainfo)
if metainfo is not None:
visualizer.set_dataset_meta(metainfo, skeleton_style=skeleton_style)
if isinstance(img, str):
img = mmcv.imread(img, channel_order='rgb')
elif isinstance(img, np.ndarray):
img = mmcv.bgr2rgb(img)
if keypoint_score is None:
keypoint_score = np.ones(keypoints.shape[0])
tmp_instances = InstanceData()
tmp_instances.keypoints = keypoints
tmp_instances.keypoint_score = keypoint_score
tmp_datasample = PoseDataSample()
tmp_datasample.pred_instances = tmp_instances
visualizer.add_datasample(
'visualization',
img,
tmp_datasample,
show_kpt_idx=show_kpt_idx,
skeleton_style=skeleton_style,
show=show,
wait_time=0,
kpt_thr=kpt_thr)
return visualizer.get_image()
|