File size: 9,573 Bytes
a249588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
# Copyright (c) OpenMMLab. All rights reserved.
"""
BMP Demo script: sequentially runs detection, pose estimation, SAM-based mask refinement, and visualization.
Usage:
    python bmp_demo.py <config.yaml> <input_image> [--output-root <dir>]
"""

import os
import shutil
from argparse import ArgumentParser, Namespace
from pathlib import Path

import mmcv
import mmengine
import numpy as np
import yaml
from demo_utils import DotDict, concat_instances, create_GIF, filter_instances, pose_nms, visualize_itteration
from mm_utils import run_MMDetector, run_MMPose
from mmdet.apis import init_detector
from mmengine.logging import print_log
from mmengine.structures import InstanceData
from sam2_utils import prepare_model as prepare_sam2_model
from sam2_utils import process_image_with_SAM

from mmpose.apis import init_model as init_pose_estimator
from mmpose.utils import adapt_mmdet_pipeline

# Default thresholds
DEFAULT_DET_CAT_ID: int = 0  # "person"
DEFAULT_BBOX_THR: float = 0.3
DEFAULT_NMS_THR: float = 0.3
DEFAULT_KPT_THR: float = 0.3


def parse_args() -> Namespace:
    """
    Parse command-line arguments for BMP demo.

    Returns:
        Namespace: Contains bmp_config (Path), input (Path), output_root (Path), device (str).
    """
    parser = ArgumentParser(description="BBoxMaskPose demo")
    parser.add_argument("bmp_config", type=Path, help="Path to BMP YAML config file")
    parser.add_argument("input", type=Path, help="Input image file")
    parser.add_argument("--output-root", type=Path, default=None, help="Directory to save outputs (default: ./outputs)")
    parser.add_argument("--device", type=str, default="cuda:0", help="Device for inference (e.g., cuda:0 or cpu)")
    parser.add_argument("--create-gif", action="store_true", default=False, help="Create GIF of all BMP iterations")
    args = parser.parse_args()
    if args.output_root is None:
        args.output_root = os.path.join(Path(__file__).parent, "outputs")
    return args


def parse_yaml_config(yaml_path: Path) -> DotDict:
    """
    Load BMP configuration from a YAML file.

    Args:
        yaml_path (Path): Path to YAML config.
    Returns:
        DotDict: Nested config dictionary.
    """
    with open(yaml_path, "r") as f:
        cfg = yaml.safe_load(f)
    return DotDict(cfg)


def process_one_image(
    args: Namespace,
    bmp_config: DotDict,
    img_path: Path,
    detector: object,
    detector_prime: object,
    pose_estimator: object,
    sam2_model: object,
) -> InstanceData:
    """
    Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization.

    Args:
        args (Namespace): Parsed CLI arguments.
        bmp_config (DotDict): Configuration parameters.
        img_path (Path): Path to the input image.
        detector: Primary MMDetection model.
        detector_prime: Secondary MMDetection model for iterations.
        pose_estimator: MMPose model for keypoint estimation.
        sam2_model: SAM model for mask refinement.
    Returns:
        InstanceData: Final merged detections and refined masks.
    """
    # Load image
    img = mmcv.imread(str(img_path), channel_order="bgr")
    if img is None:
        raise ValueError("Failed to read image from {}.".format(img_path))

    # Prepare output directory
    output_dir = os.path.join(args.output_root, img_path.stem)
    shutil.rmtree(str(output_dir), ignore_errors=True)
    mmengine.mkdir_or_exist(str(output_dir))

    img_for_detection = img.copy()
    all_detections = None
    for iteration in range(bmp_config.num_bmp_iters):
        print_log("BMP Iteration {}/{} started".format(iteration + 1, bmp_config.num_bmp_iters), logger="current")

        # Step 1: Detection
        det_instances = run_MMDetector(
            detector if iteration == 0 else detector_prime,
            img_for_detection,
            det_cat_id=DEFAULT_DET_CAT_ID,
            bbox_thr=DEFAULT_BBOX_THR,
            nms_thr=DEFAULT_NMS_THR,
        )
        print_log("Detected {} instances".format(len(det_instances.bboxes)), logger="current")
        if len(det_instances.bboxes) == 0:
            print_log("No detections found, skipping.", logger="current")
            continue

        # Step 2: Pose estimation
        pose_instances = run_MMPose(
            pose_estimator,
            img.copy(),
            detections=det_instances,
            kpt_thr=DEFAULT_KPT_THR,
        )
        # Restrict to first 17 COCO keypoints
        pose_instances.keypoints = pose_instances.keypoints[:, :17, :]
        pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17]
        pose_instances.keypoints = np.concatenate(
            [pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1
        )

        # Step 3: Pose-NMS and SAM refinement
        all_keypoints = (
            pose_instances.keypoints
            if all_detections is None
            else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0)
        )
        all_bboxes = (
            pose_instances.bboxes
            if all_detections is None
            else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0)
        )
        num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1)
        keep_indices = pose_nms(
            DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}),
            image_kpts=all_keypoints,
            image_bboxes=all_bboxes,
            num_valid_kpts=num_valid_kpts,
        )
        keep_indices = sorted(keep_indices)  # Sort by original index
        num_old_detections = 0 if all_detections is None else len(all_detections.bboxes)
        keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections]
        keep_old_indices = [i for i in keep_indices if i < num_old_detections]
        if len(keep_new_indices) == 0:
            print_log("No new instances passed pose NMS, skipping SAM refinement.", logger="current")
            continue
        # filter new detections and compute scores
        new_dets = filter_instances(pose_instances, keep_new_indices)
        new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1)
        old_dets = None
        if len(keep_old_indices) > 0:
            old_dets = filter_instances(all_detections, keep_old_indices)
        print_log(
            "Pose NMS reduced instances to {:d} ({:d}+{:d}) instances".format(
                len(new_dets.bboxes) + num_old_detections, num_old_detections, len(new_dets.bboxes)
            ),
            logger="current",
        )

        new_detections = process_image_with_SAM(
            DotDict(bmp_config.sam2.prompting),
            img.copy(),
            sam2_model,
            new_dets,
            old_dets if old_dets is not None else None,
        )

        # Merge detections
        if all_detections is None:
            all_detections = new_detections
        else:
            all_detections = concat_instances(all_detections, new_dets)

        # Step 4: Visualization
        img_for_detection = visualize_itteration(
            img.copy(),
            all_detections,
            iteration_idx=iteration,
            output_root=str(output_dir),
            img_name=img_path.stem,
        )
        print_log("Iteration {} completed".format(iteration + 1), logger="current")

    # Create GIF of iterations if requested
    if args.create_gif:
        image_file = os.path.join(output_dir, "{:s}.jpg".format(img_path.stem))
        create_GIF(
            img_path=str(image_file),
            output_root=str(output_dir),
            bmp_x=bmp_config.num_bmp_iters,
        )
    return all_detections


def main() -> None:
    """
    Entry point for the BMP demo: loads models and processes one image.
    """
    args = parse_args()
    bmp_config = parse_yaml_config(args.bmp_config)

    # Ensure output root exists
    mmengine.mkdir_or_exist(str(args.output_root))

    # build detectors
    detector = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device=args.device)
    detector.cfg = adapt_mmdet_pipeline(detector.cfg)
    if (
        bmp_config.detector.det_config == bmp_config.detector.det_prime_config
        and bmp_config.detector.det_checkpoint == bmp_config.detector.det_prime_checkpoint
    ) or (bmp_config.detector.det_prime_config is None or bmp_config.detector.det_prime_checkpoint is None):
        print_log("Using the same detector as D and D'", logger="current")
        detector_prime = detector
    else:
        detector_prime = init_detector(
            bmp_config.detector.det_prime_config, bmp_config.detector.det_prime_checkpoint, device=args.device
        )
        detector_prime.cfg = adapt_mmdet_pipeline(detector_prime.cfg)
        print_log("Using a different detector for D'", logger="current")

    # build pose estimator
    pose_estimator = init_pose_estimator(
        bmp_config.pose_estimator.pose_config,
        bmp_config.pose_estimator.pose_checkpoint,
        device=args.device,
        cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))),
    )

    sam2 = prepare_sam2_model(
        model_cfg=bmp_config.sam2.sam2_config,
        model_checkpoint=bmp_config.sam2.sam2_checkpoint,
    )

    # Run inference on one image
    _ = process_one_image(args, bmp_config, args.input, detector, detector_prime, pose_estimator, sam2)


if __name__ == "__main__":
    main()