|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
import cv2 |
|
import numpy as np |
|
from rtmlib import Wholebody |
|
|
|
from cosmos_transfer1.diffusion.datasets.augmentors.human_keypoint_utils import ( |
|
coco_wholebody_133_skeleton, |
|
openpose134_skeleton, |
|
) |
|
from cosmos_transfer1.utils import log |
|
|
|
|
|
class HumanKeypointModel: |
|
def __init__(self, to_openpose=True, conf_thres=0.6): |
|
self.model = Wholebody( |
|
to_openpose=to_openpose, |
|
mode="performance", |
|
backend="onnxruntime", |
|
device="cuda", |
|
) |
|
self.to_openpose = to_openpose |
|
self.conf_thres = conf_thres |
|
|
|
def __call__(self, input_video: str, output_video: str = "keypoint.mp4") -> str: |
|
""" |
|
Generate the human body keypoint plot for the keypointControlNet video2world model. |
|
Input: mp4 video |
|
Output: mp4 keypoint video, of the same spatial and temporal dimensions as the input video. |
|
""" |
|
|
|
log.info(f"Processing video: {input_video} to generate keypoint video: {output_video}") |
|
assert os.path.exists(input_video) |
|
|
|
cap = cv2.VideoCapture(input_video) |
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
frame_size = (frame_width, frame_height) |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
skeleton_writer = cv2.VideoWriter(output_video, fourcc, fps, frame_size) |
|
|
|
log.info(f"frame width: {frame_width}, frame height: {frame_height}, fps: {fps}") |
|
log.info("start pose estimation for frames..") |
|
|
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
black_frame = np.zeros_like(frame) |
|
|
|
|
|
keypoints, scores = self.model(frame) |
|
|
|
if keypoints is not None and len(keypoints) > 0: |
|
skeleton_frame = self.plot_person_kpts( |
|
black_frame, |
|
keypoints, |
|
scores, |
|
kpt_thr=self.conf_thres, |
|
openpose_format=True, |
|
line_width=4, |
|
) |
|
else: |
|
skeleton_frame = black_frame |
|
|
|
skeleton_writer.write(skeleton_frame[:, :, ::-1]) |
|
|
|
cap.release() |
|
skeleton_writer.release() |
|
|
|
def draw_skeleton( |
|
self, |
|
img: np.ndarray, |
|
keypoints: np.ndarray, |
|
scores: np.ndarray, |
|
kpt_thr: float = 0.6, |
|
openpose_format: bool = True, |
|
radius: int = 2, |
|
line_width: int = 4, |
|
): |
|
skeleton_topology = openpose134_skeleton if openpose_format else coco_wholebody_133_skeleton |
|
assert len(keypoints.shape) == 2 |
|
keypoint_info, skeleton_info = ( |
|
skeleton_topology["keypoint_info"], |
|
skeleton_topology["skeleton_info"], |
|
) |
|
vis_kpt = [s >= kpt_thr for s in scores] |
|
link_dict = {} |
|
for i, kpt_info in keypoint_info.items(): |
|
kpt_color = tuple(kpt_info["color"]) |
|
link_dict[kpt_info["name"]] = kpt_info["id"] |
|
|
|
kpt = keypoints[i] |
|
|
|
if vis_kpt[i]: |
|
img = cv2.circle(img, (int(kpt[0]), int(kpt[1])), int(radius), kpt_color, -1) |
|
|
|
for i, ske_info in skeleton_info.items(): |
|
link = ske_info["link"] |
|
pt0, pt1 = link_dict[link[0]], link_dict[link[1]] |
|
|
|
if vis_kpt[pt0] and vis_kpt[pt1]: |
|
link_color = ske_info["color"] |
|
kpt0 = keypoints[pt0] |
|
kpt1 = keypoints[pt1] |
|
|
|
img = cv2.line( |
|
img, (int(kpt0[0]), int(kpt0[1])), (int(kpt1[0]), int(kpt1[1])), link_color, thickness=line_width |
|
) |
|
|
|
return img |
|
|
|
def plot_person_kpts( |
|
self, |
|
pose_vis_img: np.ndarray, |
|
keypoints: np.ndarray, |
|
scores: np.ndarray, |
|
kpt_thr: float = 0.6, |
|
openpose_format: bool = True, |
|
line_width: int = 4, |
|
) -> np.ndarray: |
|
""" |
|
plot a single person |
|
in-place update the pose image |
|
""" |
|
for kpts, ss in zip(keypoints, scores): |
|
try: |
|
pose_vis_img = self.draw_skeleton( |
|
pose_vis_img, kpts, ss, kpt_thr=kpt_thr, openpose_format=openpose_format, line_width=line_width |
|
) |
|
except ValueError as e: |
|
log.error(f"Error in draw_skeleton func, {e}") |
|
|
|
return pose_vis_img |
|
|