import gradio as gr import spaces from pathlib import Path import numpy as np import yaml from demo.demo_utils import DotDict, concat_instances, filter_instances, pose_nms, visualize_demo from demo.mm_utils import run_MMDetector, run_MMPose from mmdet.apis import init_detector from demo.sam2_utils import prepare_model as prepare_sam2_model from demo.sam2_utils import process_image_with_SAM from mmpose.apis import init_model as init_pose_estimator from mmpose.utils import adapt_mmdet_pipeline # Default thresholds DEFAULT_CAT_ID: int = 0 DEFAULT_BBOX_THR: float = 0.3 DEFAULT_NMS_THR: float = 0.3 DEFAULT_KPT_THR: float = 0.3 # Global models variable det_model = None pose_model = None sam2_model = None def _parse_yaml_config(yaml_path: Path) -> DotDict: """ Load BMP configuration from a YAML file. Args: yaml_path (Path): Path to YAML config. Returns: DotDict: Nested config dictionary. """ with open(yaml_path, "r") as f: cfg = yaml.safe_load(f) return DotDict(cfg) def load_models(bmp_config): device = 'cuda:0' global det_model, pose_model, sam2_model # build detectors det_model = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device='cpu') # Detect with CPU because of installation issues on HF det_model.cfg = adapt_mmdet_pipeline(det_model.cfg) # build pose estimator pose_model = init_pose_estimator( bmp_config.pose_estimator.pose_config, bmp_config.pose_estimator.pose_checkpoint, device=device, cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))), ) sam2_model = prepare_sam2_model( model_cfg=bmp_config.sam2.sam2_config, model_checkpoint=bmp_config.sam2.sam2_checkpoint, ) return det_model, pose_model, sam2_model @spaces.GPU(duration=60) def process_image_with_BMP( img: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: """ Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization. Args: args (Namespace): Parsed CLI arguments. bmp_config (DotDict): Configuration parameters. img_path (Path): Path to the input image. detector: Primary MMDetection model. detector_prime: Secondary MMDetection model for iterations. pose_estimator: MMPose model for keypoint estimation. sam2_model: SAM model for mask refinement. Returns: InstanceData: Final merged detections and refined masks. """ bmp_config = _parse_yaml_config(Path("configs/bmp_D3.yaml")) load_models(bmp_config) # img: RGB -> BGR img = img[..., ::-1] img_for_detection = img.copy() rtmdet_result = None all_detections = None for iteration in range(bmp_config.num_bmp_iters): # Step 1: Detection det_instances = run_MMDetector( det_model, img_for_detection, det_cat_id=DEFAULT_CAT_ID, bbox_thr=DEFAULT_BBOX_THR, nms_thr=DEFAULT_NMS_THR, ) if len(det_instances.bboxes) == 0: continue # Step 2: Pose estimation pose_instances = run_MMPose( pose_model, img.copy(), detections=det_instances, kpt_thr=DEFAULT_KPT_THR, ) # Restrict to first 17 COCO keypoints pose_instances.keypoints = pose_instances.keypoints[:, :17, :] pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17] pose_instances.keypoints = np.concatenate( [pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1 ) # Step 3: Pose-NMS and SAM refinement all_keypoints = ( pose_instances.keypoints if all_detections is None else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0) ) all_bboxes = ( pose_instances.bboxes if all_detections is None else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0) ) num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1) keep_indices = pose_nms( DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}), image_kpts=all_keypoints, image_bboxes=all_bboxes, num_valid_kpts=num_valid_kpts, ) keep_indices = sorted(keep_indices) # Sort by original index num_old_detections = 0 if all_detections is None else len(all_detections.bboxes) keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections] keep_old_indices = [i for i in keep_indices if i < num_old_detections] if len(keep_new_indices) == 0: continue # filter new detections and compute scores new_dets = filter_instances(pose_instances, keep_new_indices) new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1) old_dets = None if len(keep_old_indices) > 0: old_dets = filter_instances(all_detections, keep_old_indices) new_detections = process_image_with_SAM( DotDict(bmp_config.sam2.prompting), img.copy(), sam2_model, new_dets, old_dets if old_dets is not None else None, ) # Merge detections if all_detections is None: all_detections = new_detections else: all_detections = concat_instances(all_detections, new_dets) # Step 4: Visualization img_for_detection, rtmdet_r, _ = visualize_demo( img.copy(), all_detections, ) if iteration == 0: rtmdet_result = rtmdet_r _, _, bmp_result = visualize_demo( img.copy(), all_detections, ) # img: BGR -> RGB rtmdet_result = rtmdet_result[..., ::-1] bmp_result = bmp_result[..., ::-1] return rtmdet_result, bmp_result with gr.Blocks() as app: gr.Markdown("# BBoxMaskPose Image Demo") gr.Markdown("### [M. Purkrabek](https://mirapurkrabek.github.io/), [J. Matas](https://cmp.felk.cvut.cz/~matas/)") gr.Markdown( "Official demo for paper **Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle.** [ICCV 2025]" ) gr.Markdown( "For details, see the [project website](https://mirapurkrabek.github.io/BBox-Mask-Pose/) or [arXiv paper](https://arxiv.org/abs/2412.01562). " "The demo showcases the capabilities of the BBoxMaskPose framework on any image. " "If you want to play around with parameters, use the [GitHub demo](https://github.com/MiraPurkrabek/BBoxMaskPose). " "Please note that due to HuggingFace restrictions, the demo runs much slower than the GitHub implementation." ) gr.Markdown( "If you find the project interesting, please like ❤️ the HF demo and star ⭐ the GH repo to help us spread the word." ) with gr.Row(): with gr.Column(): original_image_input = gr.Image(type="numpy", label="Original Image") submit_button = gr.Button("Run Inference") with gr.Column(): output_standard = gr.Image(type="numpy", label="RTMDet-L + MaskPose-B") with gr.Column(): output_sahi_sliced = gr.Image(type="numpy", label="BBoxMaskPose 2x") gr.Examples( label="In-the-Wild Examples", examples=[ ["examples/prochazka_MMA.jpg"], ["examples/riner_judo.jpg"], ["examples/tackle3.jpg"], ["examples/tackle1.jpg"], ["examples/tackle2.jpg"], ["examples/tackle5.jpg"], ["examples/SKV_example1.jpg"], ["examples/SKV_example2.jpg"], ["examples/SKV_example3.jpg"], ["examples/SKV_example4.jpg"], ], inputs=[ original_image_input, ], outputs=[output_standard, output_sahi_sliced], fn=process_image_with_BMP, cache_examples=True, ) gr.Examples( label="OCHuman Examples", examples=[ ["examples/004806.jpg"], ["examples/005056.jpg"], ["examples/004981.jpg"], ["examples/004655.jpg"], ["examples/004684.jpg"], ["examples/004974.jpg"], ["examples/004983.jpg"], ["examples/005017.jpg"], ["examples/004849.jpg"], ["examples/000105.jpg"], ], inputs=[ original_image_input, ], outputs=[output_standard, output_sahi_sliced], fn=process_image_with_BMP, cache_examples=True, ) gr.Examples( label="Failure Cases", examples=[ ["examples/SKV_example_F1.jpg"], ["examples/tackle4.jpg"], ["examples/000061.jpg"], ["examples/000141.jpg"], ["examples/000287.jpg"], ], inputs=[ original_image_input, ], outputs=[output_standard, output_sahi_sliced], fn=process_image_with_BMP, cache_examples=True, ) submit_button.click( fn=process_image_with_BMP, inputs=[ original_image_input, ], outputs=[output_standard, output_sahi_sliced], ) # Launch the demo app.launch()