""" SAM2 utilities for BMP demo: - Build and prepare SAM model - Convert poses to segmentation - Compute mask-pose consistency """ from typing import Any, List, Optional, Tuple import numpy as np import torch from mmengine.structures import InstanceData from pycocotools import mask as Mask from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor # Threshold for keypoint validity in mask-pose consistency STRICT_KPT_THRESHOLD: float = 0.5 def _validate_sam_args(sam_args): """ Validate that all required sam_args attributes are present. """ required = [ "crop", "use_bbox", "confidence_thr", "ignore_small_bboxes", "num_pos_keypoints", "num_pos_keypoints_if_crowd", "crowd_by_max_iou", "batch", "exclusive_masks", "extend_bbox", "pose_mask_consistency", "visibility_thr", ] for param in required: if not hasattr(sam_args, param): raise AttributeError(f"Missing required arg {param} in sam_args") def _get_max_ious(bboxes: List[np.ndarray]) -> np.ndarray: """ Compute maximum IoU for each bbox against others. """ is_crowd = [0] * len(bboxes) ious = Mask.iou(bboxes, bboxes, is_crowd) mat = np.array(ious) np.fill_diagonal(mat, 0) return mat.max(axis=1) def _compute_one_mask_pose_consistency( mask: np.ndarray, pos_keypoints: Optional[np.ndarray] = None, neg_keypoints: Optional[np.ndarray] = None ) -> float: """ Compute a consistency score between a mask and given keypoints. Args: mask (np.ndarray): Binary mask of shape (H, W). pos_keypoints (Optional[np.ndarray]): Positive keypoints array (N, 3). neg_keypoints (Optional[np.ndarray]): Negative keypoints array (M, 3). Returns: float: Weighted mean of positive and negative keypoint consistency. """ if mask is None: return 0.0 def _mean_inside(points: np.ndarray) -> float: if points.size == 0: return 0.0 pts_int = np.floor(points[:, :2]).astype(int) pts_int[:, 0] = np.clip(pts_int[:, 0], 0, mask.shape[1] - 1) pts_int[:, 1] = np.clip(pts_int[:, 1], 0, mask.shape[0] - 1) vals = mask[pts_int[:, 1], pts_int[:, 0]] return vals.mean() if vals.size > 0 else 0.0 pos_mean = 0.0 if pos_keypoints is not None: valid = pos_keypoints[:, 2] > STRICT_KPT_THRESHOLD pos_mean = _mean_inside(pos_keypoints[valid]) neg_mean = 0.0 if neg_keypoints is not None: valid = neg_keypoints[:, 2] > STRICT_KPT_THRESHOLD pts = neg_keypoints[valid][:, :2] inside = mask[np.floor(pts[:, 1]).astype(int), np.floor(pts[:, 0]).astype(int)] neg_mean = (~inside.astype(bool)).mean() if inside.size > 0 else 0.0 return 0.5 * pos_mean + 0.5 * neg_mean def _select_keypoints( args: Any, kpts: np.ndarray, num_visible: int, bbox: Optional[Tuple[float, float, float, float]] = None, method: Optional[str] = "distance+confidence", ) -> Tuple[np.ndarray, np.ndarray]: """ Select and order keypoints for SAM prompting based on specified method. Args: args: Configuration object with selection_method and visibility_thr attributes. kpts (np.ndarray): Keypoints array of shape (K, 3). num_visible (int): Number of keypoints above visibility threshold. bbox (Optional[Tuple]): Optional bbox for distance methods. method (Optional[str]): Override selection method. Returns: Tuple[np.ndarray, np.ndarray]: Selected keypoint coordinates (N,2) and confidences (N,). Raises: ValueError: If an unknown method is specified. """ if num_visible == 0: return kpts[:, :2], kpts[:, 2] methods = ["confidence", "distance", "distance+confidence", "closest"] sel_method = method or args.selection_method if sel_method not in methods: raise ValueError("Unknown method for keypoint selection: {}".format(sel_method)) # Select at maximum keypoint from the face facial_kpts = kpts[:3, :] facial_conf = kpts[:3, 2] facial_point = facial_kpts[np.argmax(facial_conf)] if facial_point[-1] >= args.visibility_thr: kpts = np.concatenate([facial_point[None, :], kpts[3:]], axis=0) conf = kpts[:, 2] vis_mask = conf >= args.visibility_thr coords = kpts[vis_mask, :2] confs = conf[vis_mask] if sel_method == "confidence": order = np.argsort(confs)[::-1] coords = coords[order] confs = confs[order] elif sel_method == "distance": if bbox is None: bbox_center = np.array([coords[:, 0].mean(), coords[:, 1].mean()]) else: bbox_center = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]) dists = np.linalg.norm(coords[:, :2] - bbox_center, axis=1) dist_matrix = np.linalg.norm(coords[:, None, :2] - coords[None, :, :2], axis=2) np.fill_diagonal(dist_matrix, np.inf) min_inter_dist = np.min(dist_matrix, axis=1) order = np.argsort(dists + 3 * min_inter_dist)[::-1] coords = coords[order, :2] confs = confs[order] elif sel_method == "distance+confidence": order = np.argsort(confs)[::-1] confidences = kpts[order, 2] coords = coords[order, :2] confs = confs[order] dist_matrix = np.linalg.norm(coords[:, None, :2] - coords[None, :, :2], axis=2) selected_idx = [0] confidences[0] = -1 for _ in range(coords.shape[0] - 1): min_dist = np.min(dist_matrix[:, selected_idx], axis=1) min_dist[confidences < np.percentile(confidences, 80)] = -1 next_idx = np.argmax(min_dist) selected_idx.append(next_idx) confidences[next_idx] = -1 coords = coords[selected_idx] confs = confs[selected_idx] elif sel_method == "closest": coords = coords[confs > STRICT_KPT_THRESHOLD, :] confs = confs[confs > STRICT_KPT_THRESHOLD] if bbox is None: bbox_center = np.array([coords[:, 0].mean(), coords[:, 1].mean()]) else: bbox_center = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]) dists = np.linalg.norm(coords[:, :2] - bbox_center, axis=1) order = np.argsort(dists) coords = coords[order, :2] confs = confs[order] return coords, confs def prepare_model(model_cfg: Any, model_checkpoint: str) -> SAM2ImagePredictor: """ Build and return a SAM2ImagePredictor model on the appropriate device. Args: model_cfg: Configuration for SAM2 model. model_checkpoint (str): Path to model checkpoint. Returns: SAM2ImagePredictor: Initialized SAM2 image predictor. """ if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") sam2 = build_sam2(model_cfg, model_checkpoint, device=device, apply_postprocessing=True) model = SAM2ImagePredictor( sam2, max_hole_area=10.0, max_sprinkle_area=50.0, ) return model def _compute_mask_pose_consistency(masks: List[np.ndarray], keypoints_list: List[np.ndarray]) -> np.ndarray: """ Compute mask-pose consistency score for each mask-keypoints pair. Args: masks (List[np.ndarray]): Binary masks list. keypoints_list (List[np.ndarray]): List of keypoint arrays per instance. Returns: np.ndarray: Consistency scores array of shape (N,). """ scores: List[float] = [] for mask, kpts in zip(masks, keypoints_list): other_kpts = np.concatenate([keypoints_list[:idx], keypoints_list[idx + 1 :]], axis=0).reshape(-1, 3) score = _compute_one_mask_pose_consistency(mask, kpts, other_kpts) scores.append(score) return np.array(scores) def _pose2seg( args: Any, model: SAM2ImagePredictor, bbox_xyxy: Optional[List[float]] = None, pos_kpts: Optional[np.ndarray] = None, neg_kpts: Optional[np.ndarray] = None, image: Optional[np.ndarray] = None, gt_mask: Optional[Any] = None, num_pos_keypoints: Optional[int] = None, gt_mask_is_binary: bool = False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]: """ Run SAM segmentation conditioned on pose keypoints and optional ground truth mask. Args: args: Configuration object with prompting settings. model (SAM2ImagePredictor): Prepared SAM2 model. bbox_xyxy (Optional[List[float]]): Bounding box coordinates in xyxy format. pos_kpts (Optional[np.ndarray]): Positive keypoints array. neg_kpts (Optional[np.ndarray]): Negative keypoints array. image (Optional[np.ndarray]): Input image array. gt_mask (Optional[Any]): Ground truth mask (optional). num_pos_keypoints (Optional[int]): Number of positive keypoints to use. gt_mask_is_binary (bool): Flag indicating if ground truth mask is binary. Returns: Tuple of (mask, pos_kpts_backup, neg_kpts_backup, score). """ num_pos_keypoints = args.num_pos_keypoints if num_pos_keypoints is None else num_pos_keypoints # Filter-out un-annotated and invisible keypoints if pos_kpts is not None: pos_kpts = pos_kpts.reshape(-1, 3) valid_kpts = pos_kpts[:, 2] > args.visibility_thr pose_bbox = np.array([pos_kpts[:, 0].min(), pos_kpts[:, 1].min(), pos_kpts[:, 0].max(), pos_kpts[:, 1].max()]) pos_kpts, conf = _select_keypoints(args, pos_kpts, num_visible=valid_kpts.sum(), bbox=bbox_xyxy) pos_kpts_backup = np.concatenate([pos_kpts, conf[:, None]], axis=1) if pos_kpts.shape[0] > num_pos_keypoints: pos_kpts = pos_kpts[:num_pos_keypoints, :] pos_kpts_backup = pos_kpts_backup[:num_pos_keypoints, :] else: pose_bbox = None pos_kpts = np.empty((0, 2), dtype=np.float32) pos_kpts_backup = np.empty((0, 3), dtype=np.float32) if neg_kpts is not None: neg_kpts = neg_kpts.reshape(-1, 3) valid_kpts = neg_kpts[:, 2] > args.visibility_thr neg_kpts, conf = _select_keypoints( args, neg_kpts, num_visible=valid_kpts.sum(), bbox=bbox_xyxy, method="closest" ) selected_neg_kpts = neg_kpts neg_kpts_backup = np.concatenate([neg_kpts, conf[:, None]], axis=1) if neg_kpts.shape[0] > args.num_neg_keypoints: selected_neg_kpts = neg_kpts[: args.num_neg_keypoints, :] else: selected_neg_kpts = np.empty((0, 2), dtype=np.float32) neg_kpts_backup = np.empty((0, 3), dtype=np.float32) # Concatenate positive and negative keypoints kpts = np.concatenate([pos_kpts, selected_neg_kpts], axis=0) kpts_labels = np.concatenate([np.ones(pos_kpts.shape[0]), np.zeros(selected_neg_kpts.shape[0])], axis=0) bbox = bbox_xyxy if args.use_bbox else None if args.extend_bbox and not bbox is None: # Expand the bbox such that it contains all positive keypoints pose_bbox = np.array( [pos_kpts[:, 0].min() - 2, pos_kpts[:, 1].min() - 2, pos_kpts[:, 0].max() + 2, pos_kpts[:, 1].max() + 2] ) expanded_bbox = np.array(bbox) expanded_bbox[:2] = np.minimum(bbox[:2], pose_bbox[:2]) expanded_bbox[2:] = np.maximum(bbox[2:], pose_bbox[2:]) bbox = expanded_bbox if args.crop and args.use_bbox and image is not None: # Crop the image to the 1.5 * bbox size crop_bbox = np.array(bbox) bbox_center = np.array([(crop_bbox[0] + crop_bbox[2]) / 2, (crop_bbox[1] + crop_bbox[3]) / 2]) bbox_size = np.array([crop_bbox[2] - crop_bbox[0], crop_bbox[3] - crop_bbox[1]]) bbox_size = 1.5 * bbox_size crop_bbox = np.array( [ bbox_center[0] - bbox_size[0] / 2, bbox_center[1] - bbox_size[1] / 2, bbox_center[0] + bbox_size[0] / 2, bbox_center[1] + bbox_size[1] / 2, ] ) crop_bbox = np.round(crop_bbox).astype(int) crop_bbox = np.clip(crop_bbox, 0, [image.shape[1], image.shape[0], image.shape[1], image.shape[0]]) original_image_size = image.shape[:2] image = image[crop_bbox[1] : crop_bbox[3], crop_bbox[0] : crop_bbox[2], :] # Update the keypoints kpts = kpts - crop_bbox[:2] bbox[:2] = bbox[:2] - crop_bbox[:2] bbox[2:] = bbox[2:] - crop_bbox[:2] model.set_image(image) masks, scores, logits = model.predict( point_coords=kpts, point_labels=kpts_labels, box=bbox, multimask_output=False, ) mask = masks[0] scores = scores[0] if args.crop and args.use_bbox and image is not None: # Pad the mask to the original image size mask_padded = np.zeros(original_image_size, dtype=np.uint8) mask_padded[crop_bbox[1] : crop_bbox[3], crop_bbox[0] : crop_bbox[2]] = mask mask = mask_padded bbox[:2] = bbox[:2] + crop_bbox[:2] bbox[2:] = bbox[2:] + crop_bbox[:2] if args.pose_mask_consistency: if gt_mask_is_binary: gt_mask_binary = gt_mask else: gt_mask_binary = Mask.decode(gt_mask).astype(bool) if gt_mask is not None else None gt_mask_pose_consistency = _compute_one_mask_pose_consistency(gt_mask_binary, pos_kpts_backup, neg_kpts_backup) dt_mask_pose_consistency = _compute_one_mask_pose_consistency(mask, pos_kpts_backup, neg_kpts_backup) tol = 0.1 dt_is_same = np.abs(dt_mask_pose_consistency - gt_mask_pose_consistency) < tol if dt_is_same: mask = gt_mask_binary if gt_mask_binary.sum() < mask.sum() else mask else: mask = gt_mask_binary if gt_mask_pose_consistency > dt_mask_pose_consistency else mask return mask, pos_kpts_backup, neg_kpts_backup, scores def process_image_with_SAM( sam_args: Any, image: np.ndarray, model: SAM2ImagePredictor, new_dets: InstanceData, old_dets: Optional[InstanceData] = None, ) -> InstanceData: """ Wrapper that validates args and routes to single or batch processing. """ _validate_sam_args(sam_args) if sam_args.batch: return _process_image_batch(sam_args, image, model, new_dets, old_dets) return _process_image_single(sam_args, image, model, new_dets, old_dets) def _process_image_single( sam_args: Any, image: np.ndarray, model: SAM2ImagePredictor, new_dets: InstanceData, old_dets: Optional[InstanceData] = None, ) -> InstanceData: """ Refine instance segmentation masks using SAM2 with pose-conditioned prompts. Args: sam_args (Any): DotDict containing required SAM parameters: crop (bool), use_bbox (bool), confidence_thr (float), ignore_small_bboxes (bool), num_pos_keypoints (int), num_pos_keypoints_if_crowd (int), crowd_by_max_iou (Optional[float]), batch (bool), exclusive_masks (bool), extend_bbox (bool), pose_mask_consistency (bool). image (np.ndarray): BGR image array of shape (H, W, 3). model (SAM2ImagePredictor): Initialized SAM2 predictor. new_dets (InstanceData): New detections with attributes: bboxes, pred_masks, keypoints, bbox_scores. old_dets (Optional[InstanceData]): Previous detections for negative prompts. Returns: InstanceData: `new_dets` updated in-place with `.refined_masks`, `.sam_scores`, and `.sam_kpts`. """ _validate_sam_args(sam_args) if not (sam_args.crop and sam_args.use_bbox): model.set_image(image) # Ignore all keypoints with confidence below the threshold new_keypoints = new_dets.keypoints.copy() for kpts in new_keypoints: conf_mask = kpts[:, 2] < sam_args.confidence_thr kpts[conf_mask, :] = 0 n_new_dets = len(new_dets.bboxes) n_old_dets = 0 if old_dets is not None: n_old_dets = len(old_dets.bboxes) old_keypoints = old_dets.keypoints.copy() for kpts in old_keypoints: conf_mask = kpts[:, 2] < sam_args.confidence_thr kpts[conf_mask, :] = 0 all_bboxes = new_dets.bboxes.copy() if old_dets is not None: all_bboxes = np.concatenate([all_bboxes, old_dets.bboxes], axis=0) max_ious = _get_max_ious(all_bboxes) gt_bboxes = [] new_dets.refined_masks = np.zeros((n_new_dets, image.shape[0], image.shape[1]), dtype=np.uint8) new_dets.sam_scores = np.zeros_like(new_dets.bbox_scores) new_dets.sam_kpts = np.zeros((len(new_dets.bboxes), sam_args.num_pos_keypoints, 3), dtype=np.float32) for instance_idx in range(len(new_dets.bboxes)): bbox_xywh = new_dets.bboxes[instance_idx] bbox_area = bbox_xywh[2] * bbox_xywh[3] if sam_args.ignore_small_bboxes and bbox_area < 100 * 100: continue dt_mask = new_dets.pred_masks[instance_idx] if new_dets.pred_masks is not None else None bbox_xyxy = [bbox_xywh[0], bbox_xywh[1], bbox_xywh[0] + bbox_xywh[2], bbox_xywh[1] + bbox_xywh[3]] gt_bboxes.append(bbox_xyxy) this_kpts = new_keypoints[instance_idx].reshape(1, -1, 3) other_kpts = None if old_dets is not None: other_kpts = old_keypoints.copy().reshape(n_old_dets, -1, 3) if len(new_keypoints) > 1: other_new_kpts = np.concatenate([new_keypoints[:instance_idx], new_keypoints[instance_idx + 1 :]], axis=0) other_kpts = ( np.concatenate([other_kpts, other_new_kpts], axis=0) if other_kpts is not None else other_new_kpts ) num_pos_keypoints = sam_args.num_pos_keypoints if sam_args.crowd_by_max_iou is not None and max_ious[instance_idx] > sam_args.crowd_by_max_iou: bbox_xyxy = None num_pos_keypoints = sam_args.num_pos_keypoints_if_crowd dt_mask, pos_kpts, neg_kpts, scores = _pose2seg( sam_args, model, bbox_xyxy, pos_kpts=this_kpts, neg_kpts=other_kpts, image=image if (sam_args.crop and sam_args.use_bbox) else None, gt_mask=dt_mask, num_pos_keypoints=num_pos_keypoints, gt_mask_is_binary=True, ) new_dets.refined_masks[instance_idx] = dt_mask new_dets.sam_scores[instance_idx] = scores # If the number of positive keypoints is less than the required number, fill the rest with zeros if len(pos_kpts) != sam_args.num_pos_keypoints: pos_kpts = np.concatenate( [pos_kpts, np.zeros((sam_args.num_pos_keypoints - len(pos_kpts), 3), dtype=np.float32)], axis=0 ) new_dets.sam_kpts[instance_idx] = pos_kpts n_masks = len(new_dets.refined_masks) + (len(old_dets.refined_masks) if old_dets is not None else 0) if sam_args.exclusive_masks and n_masks > 1: all_masks = ( np.concatenate([new_dets.refined_masks, old_dets.refined_masks], axis=0) if old_dets is not None else new_dets.refined_masks ) all_scores = ( np.concatenate([new_dets.sam_scores, old_dets.sam_scores], axis=0) if old_dets is not None else new_dets.sam_scores ) refined_masks = _apply_exclusive_masks(all_masks, all_scores) new_dets.refined_masks = refined_masks[: len(new_dets.refined_masks)] return new_dets def _process_image_batch( sam_args: Any, image: np.ndarray, model: SAM2ImagePredictor, new_dets: InstanceData, old_dets: Optional[InstanceData] = None, ) -> InstanceData: """ Batch process multiple detection instances with SAM2 refinement. Args: sam_args (Any): DotDict of SAM parameters (same as `process_image_with_SAM`). image (np.ndarray): Input BGR image. model (SAM2ImagePredictor): Prepared SAM2 predictor. new_dets (InstanceData): New detection instances. old_dets (Optional[InstanceData]): Previous detections for negative prompts. Returns: InstanceData: `new_dets` updated as in `process_image_with_SAM`. """ n_new_dets = len(new_dets.bboxes) model.set_image(image) image_kpts = [] image_bboxes = [] num_valid_kpts = [] for instance_idx in range(len(new_dets.bboxes)): bbox_xywh = new_dets.bboxes[instance_idx].copy() bbox_area = bbox_xywh[2] * bbox_xywh[3] if sam_args.ignore_small_bboxes and bbox_area < 100 * 100: continue this_kpts = new_dets.keypoints[instance_idx].copy().reshape(-1, 3) kpts_vis = np.array(this_kpts[:, 2]) visible_kpts = (kpts_vis > sam_args.visibility_thr) & (this_kpts[:, 2] > sam_args.confidence_thr) num_visible = (visible_kpts).sum() if num_visible <= 0: continue num_valid_kpts.append(num_visible) image_bboxes.append(np.array(bbox_xywh)) this_kpts[~visible_kpts, :2] = 0 this_kpts[:, 2] = visible_kpts image_kpts.append(this_kpts) if old_dets is not None: for instance_idx in range(len(old_dets.bboxes)): bbox_xywh = old_dets.bboxes[instance_idx].copy() bbox_area = bbox_xywh[2] * bbox_xywh[3] if sam_args.ignore_small_bboxes and bbox_area < 100 * 100: continue this_kpts = old_dets.keypoints[instance_idx].reshape(-1, 3) kpts_vis = np.array(this_kpts[:, 2]) visible_kpts = (kpts_vis > sam_args.visibility_thr) & (this_kpts[:, 2] > sam_args.confidence_thr) num_visible = (visible_kpts).sum() if num_visible <= 0: continue num_valid_kpts.append(num_visible) image_bboxes.append(np.array(bbox_xywh)) this_kpts[~visible_kpts, :2] = 0 this_kpts[:, 2] = visible_kpts image_kpts.append(this_kpts) image_kpts = np.array(image_kpts) image_bboxes = np.array(image_bboxes) num_valid_kpts = np.array(num_valid_kpts) image_kpts_backup = image_kpts.copy() # Prepare keypoints such that all instances have the same number of keypoints # First sort keypoints by their distance to the center of the bounding box # If some are missing, duplicate the last one prepared_kpts = [] prepared_kpts_backup = [] for bbox, kpts, num_visible in zip(image_bboxes, image_kpts, num_valid_kpts): this_kpts, this_conf = _select_keypoints(sam_args, kpts, num_visible, bbox) # Duplicate the last keypoint if some are missing if this_kpts.shape[0] < num_valid_kpts.max(): this_kpts = np.concatenate( [this_kpts, np.tile(this_kpts[-1], (num_valid_kpts.max() - this_kpts.shape[0], 1))], axis=0 ) this_conf = np.concatenate( [this_conf, np.tile(this_conf[-1], (num_valid_kpts.max() - this_conf.shape[0],))], axis=0 ) prepared_kpts.append(this_kpts) prepared_kpts_backup.append(np.concatenate([this_kpts, this_conf[:, None]], axis=1)) image_kpts = np.array(prepared_kpts) image_kpts_backup = np.array(prepared_kpts_backup) kpts_labels = np.ones(image_kpts.shape[:2]) # Compute IoUs between all bounding boxes max_ious = _get_max_ious(image_bboxes) num_pos_keypoints = sam_args.num_pos_keypoints use_bbox = sam_args.use_bbox if sam_args.crowd_by_max_iou is not None and max_ious[instance_idx] > sam_args.crowd_by_max_iou: use_bbox = False num_pos_keypoints = sam_args.num_pos_keypoints_if_crowd # Threshold the number of positive keypoints if num_pos_keypoints > 0 and num_pos_keypoints < image_kpts.shape[1]: image_kpts = image_kpts[:, :num_pos_keypoints, :] kpts_labels = kpts_labels[:, :num_pos_keypoints] image_kpts_backup = image_kpts_backup[:, :num_pos_keypoints, :] elif num_pos_keypoints == 0: image_kpts = None kpts_labels = None image_kpts_backup = np.empty((0, 3), dtype=np.float32) image_bboxes_xyxy = None if use_bbox: image_bboxes_xyxy = np.array(image_bboxes) image_bboxes_xyxy[:, 2:] += image_bboxes_xyxy[:, :2] # Expand the bbox to include the positive keypoints if sam_args.extend_bbox: pose_bbox = np.stack( [ np.min(image_kpts[:, :, 0], axis=1) - 2, np.min(image_kpts[:, :, 1], axis=1) - 2, np.max(image_kpts[:, :, 0], axis=1) + 2, np.max(image_kpts[:, :, 1], axis=1) + 2, ], axis=1, ) expanded_bbox = np.array(image_bboxes_xyxy) expanded_bbox[:, :2] = np.minimum(expanded_bbox[:, :2], pose_bbox[:, :2]) expanded_bbox[:, 2:] = np.maximum(expanded_bbox[:, 2:], pose_bbox[:, 2:]) # bbox_expanded = (np.abs(expanded_bbox - image_bboxes_xyxy) > 1e-4).any(axis=1) image_bboxes_xyxy = expanded_bbox # Process even old detections to get their 'negative' keypoints masks, scores, logits = model.predict( point_coords=image_kpts, point_labels=kpts_labels, box=image_bboxes_xyxy, multimask_output=False, ) # Reshape the masks to (N, C, H, W). If the model outputs (C, H, W), add a number of masks dimension if len(masks.shape) == 3: masks = masks[None, :, :, :] masks = masks[:, 0, :, :] N = masks.shape[0] scores = scores.reshape(N) if sam_args.exclusive_masks and N > 1: # Make sure the masks are non-overlapping # If two masks overlap, set the pixel to the one with the highest score masks = _apply_exclusive_masks(masks, scores) gt_masks = new_dets.pred_masks.copy() if new_dets.pred_masks is not None else None if sam_args.pose_mask_consistency and gt_masks is not None: # Measure 'mask-pose_conistency' by computing number of keypoints inside the mask # Compute for both gt (if available) and predicted masks and then choose the one with higher consistency dt_mask_pose_consistency = _compute_mask_pose_consistency(masks, image_kpts_backup) gt_mask_pose_consistency = _compute_mask_pose_consistency(gt_masks, image_kpts_backup) dt_masks_area = np.array([m.sum() for m in masks]) gt_masks_area = np.array([m.sum() for m in gt_masks]) if gt_masks is not None else np.zeros_like(dt_masks_area) # If PM-c is approx the same, prefer the smaller mask tol = 0.1 pmc_is_equal = np.isclose(dt_mask_pose_consistency, gt_mask_pose_consistency, atol=tol) dt_is_worse = (dt_mask_pose_consistency < (gt_mask_pose_consistency - tol)) | pmc_is_equal & ( dt_masks_area > gt_masks_area ) new_masks = [] for dt_mask, gt_mask, dt_worse in zip(masks, gt_masks, dt_is_worse): if dt_worse: new_masks.append(gt_mask) else: new_masks.append(dt_mask) masks = np.array(new_masks) new_dets.refined_masks = masks[:n_new_dets] new_dets.sam_scores = scores[:n_new_dets] new_dets.sam_kpts = image_kpts_backup[:n_new_dets] return new_dets def _apply_exclusive_masks(masks: np.ndarray, scores: np.ndarray) -> np.ndarray: """ Ensure masks are non-overlapping by keeping at each pixel the mask with the highest score. """ no_mask = masks.sum(axis=0) == 0 masked_scores = masks * scores[:, None, None] argmax_masks = np.argmax(masked_scores, axis=0) new_masks = argmax_masks[None, :, :] == (np.arange(masks.shape[0])[:, None, None]) new_masks[:, no_mask] = 0 return new_masks