# Copyright (c) OpenMMLab. All rights reserved. """ BMP Demo script: sequentially runs detection, pose estimation, SAM-based mask refinement, and visualization. Usage: python bmp_demo.py [--output-root ] """ import os import shutil from argparse import ArgumentParser, Namespace from pathlib import Path import mmcv import mmengine import numpy as np import yaml from demo_utils import DotDict, concat_instances, create_GIF, filter_instances, pose_nms, visualize_itteration from mm_utils import run_MMDetector, run_MMPose from mmdet.apis import init_detector from mmengine.logging import print_log from mmengine.structures import InstanceData from sam2_utils import prepare_model as prepare_sam2_model from sam2_utils import process_image_with_SAM from mmpose.apis import init_model as init_pose_estimator from mmpose.utils import adapt_mmdet_pipeline # Default thresholds DEFAULT_DET_CAT_ID: int = 0 # "person" DEFAULT_BBOX_THR: float = 0.3 DEFAULT_NMS_THR: float = 0.3 DEFAULT_KPT_THR: float = 0.3 def parse_args() -> Namespace: """ Parse command-line arguments for BMP demo. Returns: Namespace: Contains bmp_config (Path), input (Path), output_root (Path), device (str). """ parser = ArgumentParser(description="BBoxMaskPose demo") parser.add_argument("bmp_config", type=Path, help="Path to BMP YAML config file") parser.add_argument("input", type=Path, help="Input image file") parser.add_argument("--output-root", type=Path, default=None, help="Directory to save outputs (default: ./outputs)") parser.add_argument("--device", type=str, default="cuda:0", help="Device for inference (e.g., cuda:0 or cpu)") parser.add_argument("--create-gif", action="store_true", default=False, help="Create GIF of all BMP iterations") args = parser.parse_args() if args.output_root is None: args.output_root = os.path.join(Path(__file__).parent, "outputs") return args def parse_yaml_config(yaml_path: Path) -> DotDict: """ Load BMP configuration from a YAML file. Args: yaml_path (Path): Path to YAML config. Returns: DotDict: Nested config dictionary. """ with open(yaml_path, "r") as f: cfg = yaml.safe_load(f) return DotDict(cfg) def process_one_image( args: Namespace, bmp_config: DotDict, img_path: Path, detector: object, detector_prime: object, pose_estimator: object, sam2_model: object, ) -> InstanceData: """ Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization. Args: args (Namespace): Parsed CLI arguments. bmp_config (DotDict): Configuration parameters. img_path (Path): Path to the input image. detector: Primary MMDetection model. detector_prime: Secondary MMDetection model for iterations. pose_estimator: MMPose model for keypoint estimation. sam2_model: SAM model for mask refinement. Returns: InstanceData: Final merged detections and refined masks. """ # Load image img = mmcv.imread(str(img_path), channel_order="bgr") if img is None: raise ValueError("Failed to read image from {}.".format(img_path)) # Prepare output directory output_dir = os.path.join(args.output_root, img_path.stem) shutil.rmtree(str(output_dir), ignore_errors=True) mmengine.mkdir_or_exist(str(output_dir)) img_for_detection = img.copy() all_detections = None for iteration in range(bmp_config.num_bmp_iters): print_log("BMP Iteration {}/{} started".format(iteration + 1, bmp_config.num_bmp_iters), logger="current") # Step 1: Detection det_instances = run_MMDetector( detector if iteration == 0 else detector_prime, img_for_detection, det_cat_id=DEFAULT_DET_CAT_ID, bbox_thr=DEFAULT_BBOX_THR, nms_thr=DEFAULT_NMS_THR, ) print_log("Detected {} instances".format(len(det_instances.bboxes)), logger="current") if len(det_instances.bboxes) == 0: print_log("No detections found, skipping.", logger="current") continue # Step 2: Pose estimation pose_instances = run_MMPose( pose_estimator, img.copy(), detections=det_instances, kpt_thr=DEFAULT_KPT_THR, ) # Restrict to first 17 COCO keypoints pose_instances.keypoints = pose_instances.keypoints[:, :17, :] pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17] pose_instances.keypoints = np.concatenate( [pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1 ) # Step 3: Pose-NMS and SAM refinement all_keypoints = ( pose_instances.keypoints if all_detections is None else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0) ) all_bboxes = ( pose_instances.bboxes if all_detections is None else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0) ) num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1) keep_indices = pose_nms( DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}), image_kpts=all_keypoints, image_bboxes=all_bboxes, num_valid_kpts=num_valid_kpts, ) keep_indices = sorted(keep_indices) # Sort by original index num_old_detections = 0 if all_detections is None else len(all_detections.bboxes) keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections] keep_old_indices = [i for i in keep_indices if i < num_old_detections] if len(keep_new_indices) == 0: print_log("No new instances passed pose NMS, skipping SAM refinement.", logger="current") continue # filter new detections and compute scores new_dets = filter_instances(pose_instances, keep_new_indices) new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1) old_dets = None if len(keep_old_indices) > 0: old_dets = filter_instances(all_detections, keep_old_indices) print_log( "Pose NMS reduced instances to {:d} ({:d}+{:d}) instances".format( len(new_dets.bboxes) + num_old_detections, num_old_detections, len(new_dets.bboxes) ), logger="current", ) new_detections = process_image_with_SAM( DotDict(bmp_config.sam2.prompting), img.copy(), sam2_model, new_dets, old_dets if old_dets is not None else None, ) # Merge detections if all_detections is None: all_detections = new_detections else: all_detections = concat_instances(all_detections, new_dets) # Step 4: Visualization img_for_detection = visualize_itteration( img.copy(), all_detections, iteration_idx=iteration, output_root=str(output_dir), img_name=img_path.stem, ) print_log("Iteration {} completed".format(iteration + 1), logger="current") # Create GIF of iterations if requested if args.create_gif: image_file = os.path.join(output_dir, "{:s}.jpg".format(img_path.stem)) create_GIF( img_path=str(image_file), output_root=str(output_dir), bmp_x=bmp_config.num_bmp_iters, ) return all_detections def main() -> None: """ Entry point for the BMP demo: loads models and processes one image. """ args = parse_args() bmp_config = parse_yaml_config(args.bmp_config) # Ensure output root exists mmengine.mkdir_or_exist(str(args.output_root)) # build detectors detector = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device=args.device) detector.cfg = adapt_mmdet_pipeline(detector.cfg) if ( bmp_config.detector.det_config == bmp_config.detector.det_prime_config and bmp_config.detector.det_checkpoint == bmp_config.detector.det_prime_checkpoint ) or (bmp_config.detector.det_prime_config is None or bmp_config.detector.det_prime_checkpoint is None): print_log("Using the same detector as D and D'", logger="current") detector_prime = detector else: detector_prime = init_detector( bmp_config.detector.det_prime_config, bmp_config.detector.det_prime_checkpoint, device=args.device ) detector_prime.cfg = adapt_mmdet_pipeline(detector_prime.cfg) print_log("Using a different detector for D'", logger="current") # build pose estimator pose_estimator = init_pose_estimator( bmp_config.pose_estimator.pose_config, bmp_config.pose_estimator.pose_checkpoint, device=args.device, cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))), ) sam2 = prepare_sam2_model( model_cfg=bmp_config.sam2.sam2_config, model_checkpoint=bmp_config.sam2.sam2_checkpoint, ) # Run inference on one image _ = process_one_image(args, bmp_config, args.input, detector, detector_prime, pose_estimator, sam2) if __name__ == "__main__": main()