Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
import cv2 | |
class FastVisualizer: | |
"""MMPose Fast Visualizer. | |
A simple yet fast visualizer for video/webcam inference. | |
Args: | |
metainfo (dict): pose meta information | |
radius (int, optional)): Keypoint radius for visualization. | |
Defaults to 6. | |
line_width (int, optional): Link width for visualization. | |
Defaults to 3. | |
kpt_thr (float, optional): Threshold for keypoints' confidence score, | |
keypoints with score below this value will not be drawn. | |
Defaults to 0.3. | |
""" | |
def __init__(self, metainfo, radius=6, line_width=3, kpt_thr=0.3): | |
self.radius = radius | |
self.line_width = line_width | |
self.kpt_thr = kpt_thr | |
self.keypoint_id2name = metainfo['keypoint_id2name'] | |
self.keypoint_name2id = metainfo['keypoint_name2id'] | |
self.keypoint_colors = metainfo['keypoint_colors'] | |
self.skeleton_links = metainfo['skeleton_links'] | |
self.skeleton_link_colors = metainfo['skeleton_link_colors'] | |
def draw_pose(self, img, instances): | |
"""Draw pose estimations on the given image. | |
This method draws keypoints and skeleton links on the input image | |
using the provided instances. | |
Args: | |
img (numpy.ndarray): The input image on which to | |
draw the pose estimations. | |
instances (object): An object containing detected instances' | |
information, including keypoints and keypoint_scores. | |
Returns: | |
None: The input image will be modified in place. | |
""" | |
if instances is None: | |
print('no instance detected') | |
return | |
keypoints = instances.keypoints | |
scores = instances.keypoint_scores | |
for kpts, score in zip(keypoints, scores): | |
for sk_id, sk in enumerate(self.skeleton_links): | |
if score[sk[0]] < self.kpt_thr or score[sk[1]] < self.kpt_thr: | |
# skip the link that should not be drawn | |
continue | |
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) | |
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) | |
color = self.skeleton_link_colors[sk_id].tolist() | |
cv2.line(img, pos1, pos2, color, thickness=self.line_width) | |
for kid, kpt in enumerate(kpts): | |
if score[kid] < self.kpt_thr: | |
# skip the point that should not be drawn | |
continue | |
x_coord, y_coord = int(kpt[0]), int(kpt[1]) | |
color = self.keypoint_colors[kid].tolist() | |
cv2.circle(img, (int(x_coord), int(y_coord)), self.radius, | |
color, -1) | |
cv2.circle(img, (int(x_coord), int(y_coord)), self.radius, | |
(255, 255, 255)) | |