purkrmir's picture
Update app.py
6d0261a verified
import gradio as gr
import spaces
from pathlib import Path
import numpy as np
import yaml
from demo.demo_utils import DotDict, concat_instances, filter_instances, pose_nms, visualize_demo
from demo.mm_utils import run_MMDetector, run_MMPose
from mmdet.apis import init_detector
from demo.sam2_utils import prepare_model as prepare_sam2_model
from demo.sam2_utils import process_image_with_SAM
from mmpose.apis import init_model as init_pose_estimator
from mmpose.utils import adapt_mmdet_pipeline
# Default thresholds
DEFAULT_CAT_ID: int = 0
DEFAULT_BBOX_THR: float = 0.3
DEFAULT_NMS_THR: float = 0.3
DEFAULT_KPT_THR: float = 0.3
# Global models variable
det_model = None
pose_model = None
sam2_model = None
def _parse_yaml_config(yaml_path: Path) -> DotDict:
"""
Load BMP configuration from a YAML file.
Args:
yaml_path (Path): Path to YAML config.
Returns:
DotDict: Nested config dictionary.
"""
with open(yaml_path, "r") as f:
cfg = yaml.safe_load(f)
return DotDict(cfg)
def load_models(bmp_config):
device = 'cuda:0'
global det_model, pose_model, sam2_model
# build detectors
det_model = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device='cpu') # Detect with CPU because of installation issues on HF
det_model.cfg = adapt_mmdet_pipeline(det_model.cfg)
# build pose estimator
pose_model = init_pose_estimator(
bmp_config.pose_estimator.pose_config,
bmp_config.pose_estimator.pose_checkpoint,
device=device,
cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))),
)
sam2_model = prepare_sam2_model(
model_cfg=bmp_config.sam2.sam2_config,
model_checkpoint=bmp_config.sam2.sam2_checkpoint,
)
return det_model, pose_model, sam2_model
@spaces.GPU(duration=60)
def process_image_with_BMP(
img: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""
Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization.
Args:
args (Namespace): Parsed CLI arguments.
bmp_config (DotDict): Configuration parameters.
img_path (Path): Path to the input image.
detector: Primary MMDetection model.
detector_prime: Secondary MMDetection model for iterations.
pose_estimator: MMPose model for keypoint estimation.
sam2_model: SAM model for mask refinement.
Returns:
InstanceData: Final merged detections and refined masks.
"""
bmp_config = _parse_yaml_config(Path("configs/bmp_D3.yaml"))
load_models(bmp_config)
# img: RGB -> BGR
img = img[..., ::-1]
img_for_detection = img.copy()
rtmdet_result = None
all_detections = None
for iteration in range(bmp_config.num_bmp_iters):
# Step 1: Detection
det_instances = run_MMDetector(
det_model,
img_for_detection,
det_cat_id=DEFAULT_CAT_ID,
bbox_thr=DEFAULT_BBOX_THR,
nms_thr=DEFAULT_NMS_THR,
)
if len(det_instances.bboxes) == 0:
continue
# Step 2: Pose estimation
pose_instances = run_MMPose(
pose_model,
img.copy(),
detections=det_instances,
kpt_thr=DEFAULT_KPT_THR,
)
# Restrict to first 17 COCO keypoints
pose_instances.keypoints = pose_instances.keypoints[:, :17, :]
pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17]
pose_instances.keypoints = np.concatenate(
[pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1
)
# Step 3: Pose-NMS and SAM refinement
all_keypoints = (
pose_instances.keypoints
if all_detections is None
else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0)
)
all_bboxes = (
pose_instances.bboxes
if all_detections is None
else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0)
)
num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1)
keep_indices = pose_nms(
DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}),
image_kpts=all_keypoints,
image_bboxes=all_bboxes,
num_valid_kpts=num_valid_kpts,
)
keep_indices = sorted(keep_indices) # Sort by original index
num_old_detections = 0 if all_detections is None else len(all_detections.bboxes)
keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections]
keep_old_indices = [i for i in keep_indices if i < num_old_detections]
if len(keep_new_indices) == 0:
continue
# filter new detections and compute scores
new_dets = filter_instances(pose_instances, keep_new_indices)
new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1)
old_dets = None
if len(keep_old_indices) > 0:
old_dets = filter_instances(all_detections, keep_old_indices)
new_detections = process_image_with_SAM(
DotDict(bmp_config.sam2.prompting),
img.copy(),
sam2_model,
new_dets,
old_dets if old_dets is not None else None,
)
# Merge detections
if all_detections is None:
all_detections = new_detections
else:
all_detections = concat_instances(all_detections, new_dets)
# Step 4: Visualization
img_for_detection, rtmdet_r, _ = visualize_demo(
img.copy(),
all_detections,
)
if iteration == 0:
rtmdet_result = rtmdet_r
_, _, bmp_result = visualize_demo(
img.copy(),
all_detections,
)
# img: BGR -> RGB
rtmdet_result = rtmdet_result[..., ::-1]
bmp_result = bmp_result[..., ::-1]
return rtmdet_result, bmp_result
with gr.Blocks() as app:
gr.Markdown("# BBoxMaskPose Image Demo")
gr.Markdown("### [M. Purkrabek](https://mirapurkrabek.github.io/), [J. Matas](https://cmp.felk.cvut.cz/~matas/)")
gr.Markdown(
"Official demo for paper **Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle.** [ICCV 2025]"
)
gr.Markdown(
"For details, see the [project website](https://mirapurkrabek.github.io/BBox-Mask-Pose/) or [arXiv paper](https://arxiv.org/abs/2412.01562). "
"The demo showcases the capabilities of the BBoxMaskPose framework on any image. "
"If you want to play around with parameters, use the [GitHub demo](https://github.com/MiraPurkrabek/BBoxMaskPose). "
"Please note that due to HuggingFace restrictions, the demo runs much slower than the GitHub implementation."
)
gr.Markdown(
"If you find the project interesting, please like ❤️ the HF demo and star ⭐ the GH repo to help us spread the word."
)
with gr.Row():
with gr.Column():
original_image_input = gr.Image(type="numpy", label="Original Image")
submit_button = gr.Button("Run Inference")
with gr.Column():
output_standard = gr.Image(type="numpy", label="RTMDet-L + MaskPose-B")
with gr.Column():
output_sahi_sliced = gr.Image(type="numpy", label="BBoxMaskPose 2x")
gr.Examples(
label="In-the-Wild Examples",
examples=[
["examples/prochazka_MMA.jpg"],
["examples/riner_judo.jpg"],
["examples/tackle3.jpg"],
["examples/tackle1.jpg"],
["examples/tackle2.jpg"],
["examples/tackle5.jpg"],
["examples/SKV_example1.jpg"],
["examples/SKV_example2.jpg"],
["examples/SKV_example3.jpg"],
["examples/SKV_example4.jpg"],
],
inputs=[
original_image_input,
],
outputs=[output_standard, output_sahi_sliced],
fn=process_image_with_BMP,
cache_examples=True,
)
gr.Examples(
label="OCHuman Examples",
examples=[
["examples/004806.jpg"],
["examples/005056.jpg"],
["examples/004981.jpg"],
["examples/004655.jpg"],
["examples/004684.jpg"],
["examples/004974.jpg"],
["examples/004983.jpg"],
["examples/005017.jpg"],
["examples/004849.jpg"],
["examples/000105.jpg"],
],
inputs=[
original_image_input,
],
outputs=[output_standard, output_sahi_sliced],
fn=process_image_with_BMP,
cache_examples=True,
)
gr.Examples(
label="Failure Cases",
examples=[
["examples/SKV_example_F1.jpg"],
["examples/tackle4.jpg"],
["examples/000061.jpg"],
["examples/000141.jpg"],
["examples/000287.jpg"],
],
inputs=[
original_image_input,
],
outputs=[output_standard, output_sahi_sliced],
fn=process_image_with_BMP,
cache_examples=True,
)
submit_button.click(
fn=process_image_with_BMP,
inputs=[
original_image_input,
],
outputs=[output_standard, output_sahi_sliced],
)
# Launch the demo
app.launch()