BBoxMaskPose-demo / demo /bmp_demo.py
Miroslav Purkrabek
add code
a249588
# Copyright (c) OpenMMLab. All rights reserved.
"""
BMP Demo script: sequentially runs detection, pose estimation, SAM-based mask refinement, and visualization.
Usage:
python bmp_demo.py <config.yaml> <input_image> [--output-root <dir>]
"""
import os
import shutil
from argparse import ArgumentParser, Namespace
from pathlib import Path
import mmcv
import mmengine
import numpy as np
import yaml
from demo_utils import DotDict, concat_instances, create_GIF, filter_instances, pose_nms, visualize_itteration
from mm_utils import run_MMDetector, run_MMPose
from mmdet.apis import init_detector
from mmengine.logging import print_log
from mmengine.structures import InstanceData
from sam2_utils import prepare_model as prepare_sam2_model
from sam2_utils import process_image_with_SAM
from mmpose.apis import init_model as init_pose_estimator
from mmpose.utils import adapt_mmdet_pipeline
# Default thresholds
DEFAULT_DET_CAT_ID: int = 0 # "person"
DEFAULT_BBOX_THR: float = 0.3
DEFAULT_NMS_THR: float = 0.3
DEFAULT_KPT_THR: float = 0.3
def parse_args() -> Namespace:
"""
Parse command-line arguments for BMP demo.
Returns:
Namespace: Contains bmp_config (Path), input (Path), output_root (Path), device (str).
"""
parser = ArgumentParser(description="BBoxMaskPose demo")
parser.add_argument("bmp_config", type=Path, help="Path to BMP YAML config file")
parser.add_argument("input", type=Path, help="Input image file")
parser.add_argument("--output-root", type=Path, default=None, help="Directory to save outputs (default: ./outputs)")
parser.add_argument("--device", type=str, default="cuda:0", help="Device for inference (e.g., cuda:0 or cpu)")
parser.add_argument("--create-gif", action="store_true", default=False, help="Create GIF of all BMP iterations")
args = parser.parse_args()
if args.output_root is None:
args.output_root = os.path.join(Path(__file__).parent, "outputs")
return args
def parse_yaml_config(yaml_path: Path) -> DotDict:
"""
Load BMP configuration from a YAML file.
Args:
yaml_path (Path): Path to YAML config.
Returns:
DotDict: Nested config dictionary.
"""
with open(yaml_path, "r") as f:
cfg = yaml.safe_load(f)
return DotDict(cfg)
def process_one_image(
args: Namespace,
bmp_config: DotDict,
img_path: Path,
detector: object,
detector_prime: object,
pose_estimator: object,
sam2_model: object,
) -> InstanceData:
"""
Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization.
Args:
args (Namespace): Parsed CLI arguments.
bmp_config (DotDict): Configuration parameters.
img_path (Path): Path to the input image.
detector: Primary MMDetection model.
detector_prime: Secondary MMDetection model for iterations.
pose_estimator: MMPose model for keypoint estimation.
sam2_model: SAM model for mask refinement.
Returns:
InstanceData: Final merged detections and refined masks.
"""
# Load image
img = mmcv.imread(str(img_path), channel_order="bgr")
if img is None:
raise ValueError("Failed to read image from {}.".format(img_path))
# Prepare output directory
output_dir = os.path.join(args.output_root, img_path.stem)
shutil.rmtree(str(output_dir), ignore_errors=True)
mmengine.mkdir_or_exist(str(output_dir))
img_for_detection = img.copy()
all_detections = None
for iteration in range(bmp_config.num_bmp_iters):
print_log("BMP Iteration {}/{} started".format(iteration + 1, bmp_config.num_bmp_iters), logger="current")
# Step 1: Detection
det_instances = run_MMDetector(
detector if iteration == 0 else detector_prime,
img_for_detection,
det_cat_id=DEFAULT_DET_CAT_ID,
bbox_thr=DEFAULT_BBOX_THR,
nms_thr=DEFAULT_NMS_THR,
)
print_log("Detected {} instances".format(len(det_instances.bboxes)), logger="current")
if len(det_instances.bboxes) == 0:
print_log("No detections found, skipping.", logger="current")
continue
# Step 2: Pose estimation
pose_instances = run_MMPose(
pose_estimator,
img.copy(),
detections=det_instances,
kpt_thr=DEFAULT_KPT_THR,
)
# Restrict to first 17 COCO keypoints
pose_instances.keypoints = pose_instances.keypoints[:, :17, :]
pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17]
pose_instances.keypoints = np.concatenate(
[pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1
)
# Step 3: Pose-NMS and SAM refinement
all_keypoints = (
pose_instances.keypoints
if all_detections is None
else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0)
)
all_bboxes = (
pose_instances.bboxes
if all_detections is None
else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0)
)
num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1)
keep_indices = pose_nms(
DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}),
image_kpts=all_keypoints,
image_bboxes=all_bboxes,
num_valid_kpts=num_valid_kpts,
)
keep_indices = sorted(keep_indices) # Sort by original index
num_old_detections = 0 if all_detections is None else len(all_detections.bboxes)
keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections]
keep_old_indices = [i for i in keep_indices if i < num_old_detections]
if len(keep_new_indices) == 0:
print_log("No new instances passed pose NMS, skipping SAM refinement.", logger="current")
continue
# filter new detections and compute scores
new_dets = filter_instances(pose_instances, keep_new_indices)
new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1)
old_dets = None
if len(keep_old_indices) > 0:
old_dets = filter_instances(all_detections, keep_old_indices)
print_log(
"Pose NMS reduced instances to {:d} ({:d}+{:d}) instances".format(
len(new_dets.bboxes) + num_old_detections, num_old_detections, len(new_dets.bboxes)
),
logger="current",
)
new_detections = process_image_with_SAM(
DotDict(bmp_config.sam2.prompting),
img.copy(),
sam2_model,
new_dets,
old_dets if old_dets is not None else None,
)
# Merge detections
if all_detections is None:
all_detections = new_detections
else:
all_detections = concat_instances(all_detections, new_dets)
# Step 4: Visualization
img_for_detection = visualize_itteration(
img.copy(),
all_detections,
iteration_idx=iteration,
output_root=str(output_dir),
img_name=img_path.stem,
)
print_log("Iteration {} completed".format(iteration + 1), logger="current")
# Create GIF of iterations if requested
if args.create_gif:
image_file = os.path.join(output_dir, "{:s}.jpg".format(img_path.stem))
create_GIF(
img_path=str(image_file),
output_root=str(output_dir),
bmp_x=bmp_config.num_bmp_iters,
)
return all_detections
def main() -> None:
"""
Entry point for the BMP demo: loads models and processes one image.
"""
args = parse_args()
bmp_config = parse_yaml_config(args.bmp_config)
# Ensure output root exists
mmengine.mkdir_or_exist(str(args.output_root))
# build detectors
detector = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device=args.device)
detector.cfg = adapt_mmdet_pipeline(detector.cfg)
if (
bmp_config.detector.det_config == bmp_config.detector.det_prime_config
and bmp_config.detector.det_checkpoint == bmp_config.detector.det_prime_checkpoint
) or (bmp_config.detector.det_prime_config is None or bmp_config.detector.det_prime_checkpoint is None):
print_log("Using the same detector as D and D'", logger="current")
detector_prime = detector
else:
detector_prime = init_detector(
bmp_config.detector.det_prime_config, bmp_config.detector.det_prime_checkpoint, device=args.device
)
detector_prime.cfg = adapt_mmdet_pipeline(detector_prime.cfg)
print_log("Using a different detector for D'", logger="current")
# build pose estimator
pose_estimator = init_pose_estimator(
bmp_config.pose_estimator.pose_config,
bmp_config.pose_estimator.pose_checkpoint,
device=args.device,
cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))),
)
sam2 = prepare_sam2_model(
model_cfg=bmp_config.sam2.sam2_config,
model_checkpoint=bmp_config.sam2.sam2_checkpoint,
)
# Run inference on one image
_ = process_one_image(args, bmp_config, args.input, detector, detector_prime, pose_estimator, sam2)
if __name__ == "__main__":
main()