Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from pathlib import Path | |
import numpy as np | |
import yaml | |
from demo.demo_utils import DotDict, concat_instances, filter_instances, pose_nms, visualize_demo | |
from demo.mm_utils import run_MMDetector, run_MMPose | |
from mmdet.apis import init_detector | |
from demo.sam2_utils import prepare_model as prepare_sam2_model | |
from demo.sam2_utils import process_image_with_SAM | |
from mmpose.apis import init_model as init_pose_estimator | |
from mmpose.utils import adapt_mmdet_pipeline | |
# Default thresholds | |
DEFAULT_CAT_ID: int = 0 | |
DEFAULT_BBOX_THR: float = 0.3 | |
DEFAULT_NMS_THR: float = 0.3 | |
DEFAULT_KPT_THR: float = 0.3 | |
# Global models variable | |
det_model = None | |
pose_model = None | |
sam2_model = None | |
def _parse_yaml_config(yaml_path: Path) -> DotDict: | |
""" | |
Load BMP configuration from a YAML file. | |
Args: | |
yaml_path (Path): Path to YAML config. | |
Returns: | |
DotDict: Nested config dictionary. | |
""" | |
with open(yaml_path, "r") as f: | |
cfg = yaml.safe_load(f) | |
return DotDict(cfg) | |
def load_models(bmp_config): | |
device = 'cuda:0' | |
global det_model, pose_model, sam2_model | |
# build detectors | |
det_model = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device='cpu') # Detect with CPU because of installation issues on HF | |
det_model.cfg = adapt_mmdet_pipeline(det_model.cfg) | |
# build pose estimator | |
pose_model = init_pose_estimator( | |
bmp_config.pose_estimator.pose_config, | |
bmp_config.pose_estimator.pose_checkpoint, | |
device=device, | |
cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))), | |
) | |
sam2_model = prepare_sam2_model( | |
model_cfg=bmp_config.sam2.sam2_config, | |
model_checkpoint=bmp_config.sam2.sam2_checkpoint, | |
) | |
return det_model, pose_model, sam2_model | |
def process_image_with_BMP( | |
img: np.ndarray | |
) -> tuple[np.ndarray, np.ndarray]: | |
""" | |
Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization. | |
Args: | |
args (Namespace): Parsed CLI arguments. | |
bmp_config (DotDict): Configuration parameters. | |
img_path (Path): Path to the input image. | |
detector: Primary MMDetection model. | |
detector_prime: Secondary MMDetection model for iterations. | |
pose_estimator: MMPose model for keypoint estimation. | |
sam2_model: SAM model for mask refinement. | |
Returns: | |
InstanceData: Final merged detections and refined masks. | |
""" | |
bmp_config = _parse_yaml_config(Path("configs/bmp_D3.yaml")) | |
load_models(bmp_config) | |
# img: RGB -> BGR | |
img = img[..., ::-1] | |
img_for_detection = img.copy() | |
rtmdet_result = None | |
all_detections = None | |
for iteration in range(bmp_config.num_bmp_iters): | |
# Step 1: Detection | |
det_instances = run_MMDetector( | |
det_model, | |
img_for_detection, | |
det_cat_id=DEFAULT_CAT_ID, | |
bbox_thr=DEFAULT_BBOX_THR, | |
nms_thr=DEFAULT_NMS_THR, | |
) | |
if len(det_instances.bboxes) == 0: | |
continue | |
# Step 2: Pose estimation | |
pose_instances = run_MMPose( | |
pose_model, | |
img.copy(), | |
detections=det_instances, | |
kpt_thr=DEFAULT_KPT_THR, | |
) | |
# Restrict to first 17 COCO keypoints | |
pose_instances.keypoints = pose_instances.keypoints[:, :17, :] | |
pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17] | |
pose_instances.keypoints = np.concatenate( | |
[pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1 | |
) | |
# Step 3: Pose-NMS and SAM refinement | |
all_keypoints = ( | |
pose_instances.keypoints | |
if all_detections is None | |
else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0) | |
) | |
all_bboxes = ( | |
pose_instances.bboxes | |
if all_detections is None | |
else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0) | |
) | |
num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1) | |
keep_indices = pose_nms( | |
DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}), | |
image_kpts=all_keypoints, | |
image_bboxes=all_bboxes, | |
num_valid_kpts=num_valid_kpts, | |
) | |
keep_indices = sorted(keep_indices) # Sort by original index | |
num_old_detections = 0 if all_detections is None else len(all_detections.bboxes) | |
keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections] | |
keep_old_indices = [i for i in keep_indices if i < num_old_detections] | |
if len(keep_new_indices) == 0: | |
continue | |
# filter new detections and compute scores | |
new_dets = filter_instances(pose_instances, keep_new_indices) | |
new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1) | |
old_dets = None | |
if len(keep_old_indices) > 0: | |
old_dets = filter_instances(all_detections, keep_old_indices) | |
new_detections = process_image_with_SAM( | |
DotDict(bmp_config.sam2.prompting), | |
img.copy(), | |
sam2_model, | |
new_dets, | |
old_dets if old_dets is not None else None, | |
) | |
# Merge detections | |
if all_detections is None: | |
all_detections = new_detections | |
else: | |
all_detections = concat_instances(all_detections, new_dets) | |
# Step 4: Visualization | |
img_for_detection, rtmdet_r, _ = visualize_demo( | |
img.copy(), | |
all_detections, | |
) | |
if iteration == 0: | |
rtmdet_result = rtmdet_r | |
_, _, bmp_result = visualize_demo( | |
img.copy(), | |
all_detections, | |
) | |
# img: BGR -> RGB | |
rtmdet_result = rtmdet_result[..., ::-1] | |
bmp_result = bmp_result[..., ::-1] | |
return rtmdet_result, bmp_result | |
with gr.Blocks() as app: | |
gr.Markdown("# BBoxMaskPose Image Demo") | |
gr.Markdown("### [M. Purkrabek](https://mirapurkrabek.github.io/), [J. Matas](https://cmp.felk.cvut.cz/~matas/)") | |
gr.Markdown( | |
"Official demo for paper **Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle.** [ICCV 2025]" | |
) | |
gr.Markdown( | |
"For details, see the [project website](https://mirapurkrabek.github.io/BBox-Mask-Pose/) or [arXiv paper](https://arxiv.org/abs/2412.01562). " | |
"The demo showcases the capabilities of the BBoxMaskPose framework on any image. " | |
"If you want to play around with parameters, use the [GitHub demo](https://github.com/MiraPurkrabek/BBoxMaskPose). " | |
"Please note that due to HuggingFace restrictions, the demo runs much slower than the GitHub implementation." | |
) | |
gr.Markdown( | |
"If you find the project interesting, please like ❤️ the HF demo and star ⭐ the GH repo to help us spread the word." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
original_image_input = gr.Image(type="numpy", label="Original Image") | |
submit_button = gr.Button("Run Inference") | |
with gr.Column(): | |
output_standard = gr.Image(type="numpy", label="RTMDet-L + MaskPose-B") | |
with gr.Column(): | |
output_sahi_sliced = gr.Image(type="numpy", label="BBoxMaskPose 2x") | |
gr.Examples( | |
label="In-the-Wild Examples", | |
examples=[ | |
["examples/prochazka_MMA.jpg"], | |
["examples/riner_judo.jpg"], | |
["examples/tackle3.jpg"], | |
["examples/tackle1.jpg"], | |
["examples/tackle2.jpg"], | |
["examples/tackle5.jpg"], | |
["examples/SKV_example1.jpg"], | |
["examples/SKV_example2.jpg"], | |
["examples/SKV_example3.jpg"], | |
["examples/SKV_example4.jpg"], | |
], | |
inputs=[ | |
original_image_input, | |
], | |
outputs=[output_standard, output_sahi_sliced], | |
fn=process_image_with_BMP, | |
cache_examples=True, | |
) | |
gr.Examples( | |
label="OCHuman Examples", | |
examples=[ | |
["examples/004806.jpg"], | |
["examples/005056.jpg"], | |
["examples/004981.jpg"], | |
["examples/004655.jpg"], | |
["examples/004684.jpg"], | |
["examples/004974.jpg"], | |
["examples/004983.jpg"], | |
["examples/005017.jpg"], | |
["examples/004849.jpg"], | |
["examples/000105.jpg"], | |
], | |
inputs=[ | |
original_image_input, | |
], | |
outputs=[output_standard, output_sahi_sliced], | |
fn=process_image_with_BMP, | |
cache_examples=True, | |
) | |
gr.Examples( | |
label="Failure Cases", | |
examples=[ | |
["examples/SKV_example_F1.jpg"], | |
["examples/tackle4.jpg"], | |
["examples/000061.jpg"], | |
["examples/000141.jpg"], | |
["examples/000287.jpg"], | |
], | |
inputs=[ | |
original_image_input, | |
], | |
outputs=[output_standard, output_sahi_sliced], | |
fn=process_image_with_BMP, | |
cache_examples=True, | |
) | |
submit_button.click( | |
fn=process_image_with_BMP, | |
inputs=[ | |
original_image_input, | |
], | |
outputs=[output_standard, output_sahi_sliced], | |
) | |
# Launch the demo | |
app.launch() |