File size: 9,583 Bytes
a249588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b0dccc
a249588
 
 
 
 
 
 
 
 
6d0261a
 
 
a249588
 
 
 
 
 
 
 
 
 
8b0dccc
a249588
 
 
8b0dccc
a249588
0ca8706
 
 
 
 
 
 
 
 
8b0dccc
a249588
 
 
 
 
 
 
 
 
8b0dccc
a249588
0ca8706
 
 
 
 
 
 
 
 
8b0dccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a249588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import gradio as gr
import spaces

from pathlib import Path

import numpy as np
import yaml
from demo.demo_utils import DotDict, concat_instances, filter_instances, pose_nms, visualize_demo
from demo.mm_utils import run_MMDetector, run_MMPose
from mmdet.apis import init_detector
from demo.sam2_utils import prepare_model as prepare_sam2_model
from demo.sam2_utils import process_image_with_SAM

from mmpose.apis import init_model as init_pose_estimator
from mmpose.utils import adapt_mmdet_pipeline

# Default thresholds
DEFAULT_CAT_ID: int = 0

DEFAULT_BBOX_THR: float = 0.3
DEFAULT_NMS_THR: float = 0.3
DEFAULT_KPT_THR: float = 0.3

# Global models variable
det_model = None
pose_model = None
sam2_model = None

def _parse_yaml_config(yaml_path: Path) -> DotDict:
    """
    Load BMP configuration from a YAML file.

    Args:
        yaml_path (Path): Path to YAML config.
    Returns:
        DotDict: Nested config dictionary.
    """
    with open(yaml_path, "r") as f:
        cfg = yaml.safe_load(f)
    return DotDict(cfg)

def load_models(bmp_config):
    device = 'cuda:0'

    global det_model, pose_model, sam2_model

    # build detectors
    det_model = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device='cpu') # Detect with CPU because of installation issues on HF
    det_model.cfg = adapt_mmdet_pipeline(det_model.cfg)
    

    # build pose estimator
    pose_model = init_pose_estimator(
        bmp_config.pose_estimator.pose_config,
        bmp_config.pose_estimator.pose_checkpoint,
        device=device,
        cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))),
    )

    sam2_model = prepare_sam2_model(
        model_cfg=bmp_config.sam2.sam2_config,
        model_checkpoint=bmp_config.sam2.sam2_checkpoint,
    )

    return det_model, pose_model, sam2_model

@spaces.GPU(duration=60)
def process_image_with_BMP(
    img: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
    """
    Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization.

    Args:
        args (Namespace): Parsed CLI arguments.
        bmp_config (DotDict): Configuration parameters.
        img_path (Path): Path to the input image.
        detector: Primary MMDetection model.
        detector_prime: Secondary MMDetection model for iterations.
        pose_estimator: MMPose model for keypoint estimation.
        sam2_model: SAM model for mask refinement.
    Returns:
        InstanceData: Final merged detections and refined masks.
    """
    bmp_config = _parse_yaml_config(Path("configs/bmp_D3.yaml"))
    load_models(bmp_config)

    # img: RGB -> BGR
    img = img[..., ::-1]

    img_for_detection = img.copy()
    rtmdet_result = None
    all_detections = None
    for iteration in range(bmp_config.num_bmp_iters):

        # Step 1: Detection
        det_instances = run_MMDetector(
            det_model,
            img_for_detection,
            det_cat_id=DEFAULT_CAT_ID,
            bbox_thr=DEFAULT_BBOX_THR,
            nms_thr=DEFAULT_NMS_THR,
        )
        if len(det_instances.bboxes) == 0:
            continue

        # Step 2: Pose estimation
        pose_instances = run_MMPose(
            pose_model,
            img.copy(),
            detections=det_instances,
            kpt_thr=DEFAULT_KPT_THR,
        )

        # Restrict to first 17 COCO keypoints
        pose_instances.keypoints = pose_instances.keypoints[:, :17, :]
        pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17]
        pose_instances.keypoints = np.concatenate(
            [pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1
        )

        # Step 3: Pose-NMS and SAM refinement
        all_keypoints = (
            pose_instances.keypoints
            if all_detections is None
            else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0)
        )
        all_bboxes = (
            pose_instances.bboxes
            if all_detections is None
            else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0)
        )
        num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1)
        keep_indices = pose_nms(
            DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}),
            image_kpts=all_keypoints,
            image_bboxes=all_bboxes,
            num_valid_kpts=num_valid_kpts,
        )
        keep_indices = sorted(keep_indices)  # Sort by original index
        num_old_detections = 0 if all_detections is None else len(all_detections.bboxes)
        keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections]
        keep_old_indices = [i for i in keep_indices if i < num_old_detections]
        if len(keep_new_indices) == 0:
            continue
        # filter new detections and compute scores
        new_dets = filter_instances(pose_instances, keep_new_indices)
        new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1)
        old_dets = None
        if len(keep_old_indices) > 0:
            old_dets = filter_instances(all_detections, keep_old_indices)

        new_detections = process_image_with_SAM(
            DotDict(bmp_config.sam2.prompting),
            img.copy(),
            sam2_model,
            new_dets,
            old_dets if old_dets is not None else None,
        )

        # Merge detections
        if all_detections is None:
            all_detections = new_detections
        else:
            all_detections = concat_instances(all_detections, new_dets)

        # Step 4: Visualization
        img_for_detection, rtmdet_r, _ = visualize_demo(
            img.copy(),
            all_detections,
        )

        if iteration == 0:
            rtmdet_result = rtmdet_r

    _, _, bmp_result = visualize_demo(
        img.copy(),
        all_detections,
    )

    # img: BGR -> RGB
    rtmdet_result = rtmdet_result[..., ::-1]
    bmp_result = bmp_result[..., ::-1] 

    return rtmdet_result, bmp_result


with gr.Blocks() as app:
    gr.Markdown("# BBoxMaskPose Image Demo")
    gr.Markdown("### [M. Purkrabek](https://mirapurkrabek.github.io/), [J. Matas](https://cmp.felk.cvut.cz/~matas/)")
    gr.Markdown(
        "Official demo for paper **Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle.** [ICCV 2025]"
    )
    gr.Markdown(
        "For details, see the [project website](https://mirapurkrabek.github.io/BBox-Mask-Pose/) or [arXiv paper](https://arxiv.org/abs/2412.01562). "
        "The demo showcases the capabilities of the BBoxMaskPose framework on any image. "
        "If you want to play around with parameters, use the [GitHub demo](https://github.com/MiraPurkrabek/BBoxMaskPose). "
        "Please note that due to HuggingFace restrictions, the demo runs much slower than the GitHub implementation."
    )
    gr.Markdown(
        "If you find the project interesting, please like ❤️ the HF demo and star ⭐ the GH repo to help us spread the word."
    )

    with gr.Row():
        with gr.Column():
            original_image_input = gr.Image(type="numpy", label="Original Image")
            submit_button = gr.Button("Run Inference")

        with gr.Column():
            output_standard = gr.Image(type="numpy", label="RTMDet-L + MaskPose-B")
        
        with gr.Column():
            output_sahi_sliced = gr.Image(type="numpy", label="BBoxMaskPose 2x")

    
    gr.Examples(
        label="In-the-Wild Examples",
        examples=[
            ["examples/prochazka_MMA.jpg"],
            ["examples/riner_judo.jpg"],
            ["examples/tackle3.jpg"],
            ["examples/tackle1.jpg"],
            ["examples/tackle2.jpg"],
            ["examples/tackle5.jpg"],
            ["examples/SKV_example1.jpg"],
            ["examples/SKV_example2.jpg"],
            ["examples/SKV_example3.jpg"],
            ["examples/SKV_example4.jpg"],
        ],
        inputs=[
            original_image_input,
        ],
        outputs=[output_standard, output_sahi_sliced],
        fn=process_image_with_BMP,
        cache_examples=True,
    )
    gr.Examples(
        label="OCHuman Examples",
        examples=[
            ["examples/004806.jpg"],
            ["examples/005056.jpg"],
            ["examples/004981.jpg"],
            ["examples/004655.jpg"],
            ["examples/004684.jpg"],
            ["examples/004974.jpg"],
            ["examples/004983.jpg"],
            ["examples/005017.jpg"],
            ["examples/004849.jpg"],
            ["examples/000105.jpg"],
        ],
        inputs=[
            original_image_input,
        ],
        outputs=[output_standard, output_sahi_sliced],
        fn=process_image_with_BMP,
        cache_examples=True,
    )
    gr.Examples(
        label="Failure Cases",
        examples=[
            ["examples/SKV_example_F1.jpg"],
            ["examples/tackle4.jpg"],
            ["examples/000061.jpg"],
            ["examples/000141.jpg"],
            ["examples/000287.jpg"],
        ],
        inputs=[
            original_image_input,
        ],
        outputs=[output_standard, output_sahi_sliced],
        fn=process_image_with_BMP,
        cache_examples=True,
    )

    submit_button.click(
        fn=process_image_with_BMP,
        inputs=[
            original_image_input,
        ],
        outputs=[output_standard, output_sahi_sliced],
    )

# Launch the demo
app.launch()