Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
from copy import deepcopy | |
from typing import Union | |
import mmcv | |
import numpy as np | |
from mmengine.structures import InstanceData | |
from mmpose.datasets.datasets.utils import parse_pose_metainfo | |
from mmpose.structures import PoseDataSample | |
from mmpose.visualization import PoseLocalVisualizer | |
# from posevis import pose_visualization | |
# def visualize( | |
# img: Union[np.ndarray, str], | |
# keypoints: np.ndarray, | |
# keypoint_score: np.ndarray = None, | |
# metainfo: Union[str, dict] = None, | |
# visualizer: PoseLocalVisualizer = None, | |
# show_kpt_idx: bool = False, | |
# skeleton_style: str = 'mmpose', | |
# show: bool = False, | |
# kpt_thr: float = 0.3, | |
# ): | |
# """Visualize 2d keypoints on an image. | |
# Args: | |
# img (str | np.ndarray): The image to be displayed. | |
# keypoints (np.ndarray): The keypoint to be displayed. | |
# keypoint_score (np.ndarray): The score of each keypoint. | |
# metainfo (str | dict): The metainfo of dataset. | |
# visualizer (PoseLocalVisualizer): The visualizer. | |
# show_kpt_idx (bool): Whether to show the index of keypoints. | |
# skeleton_style (str): Skeleton style. Options are 'mmpose' and | |
# 'openpose'. | |
# show (bool): Whether to show the image. | |
# wait_time (int): Value of waitKey param. | |
# kpt_thr (float): Keypoint threshold. | |
# """ | |
# kpts = keypoints.reshape(-1, 2) | |
# kpts = np.concatenate([kpts, keypoint_score[:, None]], axis=1) | |
# kpts[kpts[:, 2] < kpt_thr, :] = 0 | |
# pose_results = [{ | |
# 'keypoints': kpts, | |
# }] | |
# img = pose_visualization( | |
# img, | |
# pose_results, | |
# format="COCO", | |
# greyness=1.0, | |
# show_markers=True, | |
# show_bones=True, | |
# line_type="solid", | |
# width_multiplier=1.0, | |
# bbox_width_multiplier=1.0, | |
# show_bbox=False, | |
# differ_individuals=False, | |
# ) | |
# return img | |
def visualize( | |
img: Union[np.ndarray, str], | |
keypoints: np.ndarray, | |
keypoint_score: np.ndarray = None, | |
metainfo: Union[str, dict] = None, | |
visualizer: PoseLocalVisualizer = None, | |
show_kpt_idx: bool = False, | |
skeleton_style: str = 'mmpose', | |
show: bool = False, | |
kpt_thr: float = 0.3, | |
): | |
"""Visualize 2d keypoints on an image. | |
Args: | |
img (str | np.ndarray): The image to be displayed. | |
keypoints (np.ndarray): The keypoint to be displayed. | |
keypoint_score (np.ndarray): The score of each keypoint. | |
metainfo (str | dict): The metainfo of dataset. | |
visualizer (PoseLocalVisualizer): The visualizer. | |
show_kpt_idx (bool): Whether to show the index of keypoints. | |
skeleton_style (str): Skeleton style. Options are 'mmpose' and | |
'openpose'. | |
show (bool): Whether to show the image. | |
wait_time (int): Value of waitKey param. | |
kpt_thr (float): Keypoint threshold. | |
""" | |
assert skeleton_style in [ | |
'mmpose', 'openpose' | |
], (f'Only support skeleton style in {["mmpose", "openpose"]}, ') | |
if visualizer is None: | |
visualizer = PoseLocalVisualizer() | |
else: | |
visualizer = deepcopy(visualizer) | |
if isinstance(metainfo, str): | |
metainfo = parse_pose_metainfo(dict(from_file=metainfo)) | |
elif isinstance(metainfo, dict): | |
metainfo = parse_pose_metainfo(metainfo) | |
if metainfo is not None: | |
visualizer.set_dataset_meta(metainfo, skeleton_style=skeleton_style) | |
if isinstance(img, str): | |
img = mmcv.imread(img, channel_order='rgb') | |
elif isinstance(img, np.ndarray): | |
img = mmcv.bgr2rgb(img) | |
if keypoint_score is None: | |
keypoint_score = np.ones(keypoints.shape[0]) | |
tmp_instances = InstanceData() | |
tmp_instances.keypoints = keypoints | |
tmp_instances.keypoint_score = keypoint_score | |
tmp_datasample = PoseDataSample() | |
tmp_datasample.pred_instances = tmp_instances | |
visualizer.add_datasample( | |
'visualization', | |
img, | |
tmp_datasample, | |
show_kpt_idx=show_kpt_idx, | |
skeleton_style=skeleton_style, | |
show=show, | |
wait_time=0, | |
kpt_thr=kpt_thr) | |
return visualizer.get_image() | |