Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,914 Bytes
a249588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
class FastVisualizer:
"""MMPose Fast Visualizer.
A simple yet fast visualizer for video/webcam inference.
Args:
metainfo (dict): pose meta information
radius (int, optional)): Keypoint radius for visualization.
Defaults to 6.
line_width (int, optional): Link width for visualization.
Defaults to 3.
kpt_thr (float, optional): Threshold for keypoints' confidence score,
keypoints with score below this value will not be drawn.
Defaults to 0.3.
"""
def __init__(self, metainfo, radius=6, line_width=3, kpt_thr=0.3):
self.radius = radius
self.line_width = line_width
self.kpt_thr = kpt_thr
self.keypoint_id2name = metainfo['keypoint_id2name']
self.keypoint_name2id = metainfo['keypoint_name2id']
self.keypoint_colors = metainfo['keypoint_colors']
self.skeleton_links = metainfo['skeleton_links']
self.skeleton_link_colors = metainfo['skeleton_link_colors']
def draw_pose(self, img, instances):
"""Draw pose estimations on the given image.
This method draws keypoints and skeleton links on the input image
using the provided instances.
Args:
img (numpy.ndarray): The input image on which to
draw the pose estimations.
instances (object): An object containing detected instances'
information, including keypoints and keypoint_scores.
Returns:
None: The input image will be modified in place.
"""
if instances is None:
print('no instance detected')
return
keypoints = instances.keypoints
scores = instances.keypoint_scores
for kpts, score in zip(keypoints, scores):
for sk_id, sk in enumerate(self.skeleton_links):
if score[sk[0]] < self.kpt_thr or score[sk[1]] < self.kpt_thr:
# skip the link that should not be drawn
continue
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
color = self.skeleton_link_colors[sk_id].tolist()
cv2.line(img, pos1, pos2, color, thickness=self.line_width)
for kid, kpt in enumerate(kpts):
if score[kid] < self.kpt_thr:
# skip the point that should not be drawn
continue
x_coord, y_coord = int(kpt[0]), int(kpt[1])
color = self.keypoint_colors[kid].tolist()
cv2.circle(img, (int(x_coord), int(y_coord)), self.radius,
color, -1)
cv2.circle(img, (int(x_coord), int(y_coord)), self.radius,
(255, 255, 255))
|