File size: 5,314 Bytes
226c7c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
from rtmlib import Wholebody
from cosmos_transfer1.diffusion.datasets.augmentors.human_keypoint_utils import (
coco_wholebody_133_skeleton,
openpose134_skeleton,
)
from cosmos_transfer1.utils import log
class HumanKeypointModel:
def __init__(self, to_openpose=True, conf_thres=0.6):
self.model = Wholebody(
to_openpose=to_openpose,
mode="performance",
backend="onnxruntime",
device="cuda",
)
self.to_openpose = to_openpose
self.conf_thres = conf_thres
def __call__(self, input_video: str, output_video: str = "keypoint.mp4") -> str:
"""
Generate the human body keypoint plot for the keypointControlNet video2world model.
Input: mp4 video
Output: mp4 keypoint video, of the same spatial and temporal dimensions as the input video.
"""
log.info(f"Processing video: {input_video} to generate keypoint video: {output_video}")
assert os.path.exists(input_video)
cap = cv2.VideoCapture(input_video)
fps = int(cap.get(cv2.CAP_PROP_FPS))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_size = (frame_width, frame_height)
# vid writer
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
skeleton_writer = cv2.VideoWriter(output_video, fourcc, fps, frame_size)
log.info(f"frame width: {frame_width}, frame height: {frame_height}, fps: {fps}")
log.info("start pose estimation for frames..")
# Process each frame
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Create a black background frame
black_frame = np.zeros_like(frame)
# Run pose estimation
keypoints, scores = self.model(frame)
if keypoints is not None and len(keypoints) > 0:
skeleton_frame = self.plot_person_kpts(
black_frame,
keypoints,
scores,
kpt_thr=self.conf_thres,
openpose_format=True,
line_width=4,
) # (h, w, 3)
else:
skeleton_frame = black_frame
skeleton_writer.write(skeleton_frame[:, :, ::-1])
cap.release()
skeleton_writer.release()
def draw_skeleton(
self,
img: np.ndarray,
keypoints: np.ndarray,
scores: np.ndarray,
kpt_thr: float = 0.6,
openpose_format: bool = True,
radius: int = 2,
line_width: int = 4,
):
skeleton_topology = openpose134_skeleton if openpose_format else coco_wholebody_133_skeleton
assert len(keypoints.shape) == 2
keypoint_info, skeleton_info = (
skeleton_topology["keypoint_info"],
skeleton_topology["skeleton_info"],
)
vis_kpt = [s >= kpt_thr for s in scores]
link_dict = {}
for i, kpt_info in keypoint_info.items():
kpt_color = tuple(kpt_info["color"])
link_dict[kpt_info["name"]] = kpt_info["id"]
kpt = keypoints[i]
if vis_kpt[i]:
img = cv2.circle(img, (int(kpt[0]), int(kpt[1])), int(radius), kpt_color, -1)
for i, ske_info in skeleton_info.items():
link = ske_info["link"]
pt0, pt1 = link_dict[link[0]], link_dict[link[1]]
if vis_kpt[pt0] and vis_kpt[pt1]:
link_color = ske_info["color"]
kpt0 = keypoints[pt0]
kpt1 = keypoints[pt1]
img = cv2.line(
img, (int(kpt0[0]), int(kpt0[1])), (int(kpt1[0]), int(kpt1[1])), link_color, thickness=line_width
)
return img
def plot_person_kpts(
self,
pose_vis_img: np.ndarray,
keypoints: np.ndarray,
scores: np.ndarray,
kpt_thr: float = 0.6,
openpose_format: bool = True,
line_width: int = 4,
) -> np.ndarray:
"""
plot a single person
in-place update the pose image
"""
for kpts, ss in zip(keypoints, scores):
try:
pose_vis_img = self.draw_skeleton(
pose_vis_img, kpts, ss, kpt_thr=kpt_thr, openpose_format=openpose_format, line_width=line_width
)
except ValueError as e:
log.error(f"Error in draw_skeleton func, {e}")
return pose_vis_img
|