File size: 5,314 Bytes
226c7c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import cv2
import numpy as np
from rtmlib import Wholebody

from cosmos_transfer1.diffusion.datasets.augmentors.human_keypoint_utils import (
    coco_wholebody_133_skeleton,
    openpose134_skeleton,
)
from cosmos_transfer1.utils import log


class HumanKeypointModel:
    def __init__(self, to_openpose=True, conf_thres=0.6):
        self.model = Wholebody(
            to_openpose=to_openpose,
            mode="performance",
            backend="onnxruntime",
            device="cuda",
        )
        self.to_openpose = to_openpose
        self.conf_thres = conf_thres

    def __call__(self, input_video: str, output_video: str = "keypoint.mp4") -> str:
        """
        Generate the human body keypoint plot for the keypointControlNet video2world model.
        Input: mp4 video
        Output: mp4 keypoint video, of the same spatial and temporal dimensions as the input video.
        """

        log.info(f"Processing video: {input_video} to generate keypoint video: {output_video}")
        assert os.path.exists(input_video)

        cap = cv2.VideoCapture(input_video)
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        frame_size = (frame_width, frame_height)

        # vid writer
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        skeleton_writer = cv2.VideoWriter(output_video, fourcc, fps, frame_size)

        log.info(f"frame width: {frame_width}, frame height: {frame_height}, fps: {fps}")
        log.info("start pose estimation for frames..")

        # Process each frame
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # Create a black background frame
            black_frame = np.zeros_like(frame)

            # Run pose estimation
            keypoints, scores = self.model(frame)

            if keypoints is not None and len(keypoints) > 0:
                skeleton_frame = self.plot_person_kpts(
                    black_frame,
                    keypoints,
                    scores,
                    kpt_thr=self.conf_thres,
                    openpose_format=True,
                    line_width=4,
                )  # (h, w, 3)
            else:
                skeleton_frame = black_frame

            skeleton_writer.write(skeleton_frame[:, :, ::-1])

        cap.release()
        skeleton_writer.release()

    def draw_skeleton(
        self,
        img: np.ndarray,
        keypoints: np.ndarray,
        scores: np.ndarray,
        kpt_thr: float = 0.6,
        openpose_format: bool = True,
        radius: int = 2,
        line_width: int = 4,
    ):
        skeleton_topology = openpose134_skeleton if openpose_format else coco_wholebody_133_skeleton
        assert len(keypoints.shape) == 2
        keypoint_info, skeleton_info = (
            skeleton_topology["keypoint_info"],
            skeleton_topology["skeleton_info"],
        )
        vis_kpt = [s >= kpt_thr for s in scores]
        link_dict = {}
        for i, kpt_info in keypoint_info.items():
            kpt_color = tuple(kpt_info["color"])
            link_dict[kpt_info["name"]] = kpt_info["id"]

            kpt = keypoints[i]

            if vis_kpt[i]:
                img = cv2.circle(img, (int(kpt[0]), int(kpt[1])), int(radius), kpt_color, -1)

        for i, ske_info in skeleton_info.items():
            link = ske_info["link"]
            pt0, pt1 = link_dict[link[0]], link_dict[link[1]]

            if vis_kpt[pt0] and vis_kpt[pt1]:
                link_color = ske_info["color"]
                kpt0 = keypoints[pt0]
                kpt1 = keypoints[pt1]

                img = cv2.line(
                    img, (int(kpt0[0]), int(kpt0[1])), (int(kpt1[0]), int(kpt1[1])), link_color, thickness=line_width
                )

        return img

    def plot_person_kpts(
        self,
        pose_vis_img: np.ndarray,
        keypoints: np.ndarray,
        scores: np.ndarray,
        kpt_thr: float = 0.6,
        openpose_format: bool = True,
        line_width: int = 4,
    ) -> np.ndarray:
        """
        plot a single person
        in-place update the pose image
        """
        for kpts, ss in zip(keypoints, scores):
            try:
                pose_vis_img = self.draw_skeleton(
                    pose_vis_img, kpts, ss, kpt_thr=kpt_thr, openpose_format=openpose_format, line_width=line_width
                )
            except ValueError as e:
                log.error(f"Error in draw_skeleton func, {e}")

        return pose_vis_img