Spaces:
Running
on
Zero
Running
on
Zero
""" | |
SAM2 utilities for BMP demo: | |
- Build and prepare SAM model | |
- Convert poses to segmentation | |
- Compute mask-pose consistency | |
""" | |
from typing import Any, List, Optional, Tuple | |
import numpy as np | |
import torch | |
from mmengine.structures import InstanceData | |
from pycocotools import mask as Mask | |
from sam2.build_sam import build_sam2 | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
# Threshold for keypoint validity in mask-pose consistency | |
STRICT_KPT_THRESHOLD: float = 0.5 | |
def _validate_sam_args(sam_args): | |
""" | |
Validate that all required sam_args attributes are present. | |
""" | |
required = [ | |
"crop", | |
"use_bbox", | |
"confidence_thr", | |
"ignore_small_bboxes", | |
"num_pos_keypoints", | |
"num_pos_keypoints_if_crowd", | |
"crowd_by_max_iou", | |
"batch", | |
"exclusive_masks", | |
"extend_bbox", | |
"pose_mask_consistency", | |
"visibility_thr", | |
] | |
for param in required: | |
if not hasattr(sam_args, param): | |
raise AttributeError(f"Missing required arg {param} in sam_args") | |
def _get_max_ious(bboxes: List[np.ndarray]) -> np.ndarray: | |
""" | |
Compute maximum IoU for each bbox against others. | |
""" | |
is_crowd = [0] * len(bboxes) | |
ious = Mask.iou(bboxes, bboxes, is_crowd) | |
mat = np.array(ious) | |
np.fill_diagonal(mat, 0) | |
return mat.max(axis=1) | |
def _compute_one_mask_pose_consistency( | |
mask: np.ndarray, pos_keypoints: Optional[np.ndarray] = None, neg_keypoints: Optional[np.ndarray] = None | |
) -> float: | |
""" | |
Compute a consistency score between a mask and given keypoints. | |
Args: | |
mask (np.ndarray): Binary mask of shape (H, W). | |
pos_keypoints (Optional[np.ndarray]): Positive keypoints array (N, 3). | |
neg_keypoints (Optional[np.ndarray]): Negative keypoints array (M, 3). | |
Returns: | |
float: Weighted mean of positive and negative keypoint consistency. | |
""" | |
if mask is None: | |
return 0.0 | |
def _mean_inside(points: np.ndarray) -> float: | |
if points.size == 0: | |
return 0.0 | |
pts_int = np.floor(points[:, :2]).astype(int) | |
pts_int[:, 0] = np.clip(pts_int[:, 0], 0, mask.shape[1] - 1) | |
pts_int[:, 1] = np.clip(pts_int[:, 1], 0, mask.shape[0] - 1) | |
vals = mask[pts_int[:, 1], pts_int[:, 0]] | |
return vals.mean() if vals.size > 0 else 0.0 | |
pos_mean = 0.0 | |
if pos_keypoints is not None: | |
valid = pos_keypoints[:, 2] > STRICT_KPT_THRESHOLD | |
pos_mean = _mean_inside(pos_keypoints[valid]) | |
neg_mean = 0.0 | |
if neg_keypoints is not None: | |
valid = neg_keypoints[:, 2] > STRICT_KPT_THRESHOLD | |
pts = neg_keypoints[valid][:, :2] | |
inside = mask[np.floor(pts[:, 1]).astype(int), np.floor(pts[:, 0]).astype(int)] | |
neg_mean = (~inside.astype(bool)).mean() if inside.size > 0 else 0.0 | |
return 0.5 * pos_mean + 0.5 * neg_mean | |
def _select_keypoints( | |
args: Any, | |
kpts: np.ndarray, | |
num_visible: int, | |
bbox: Optional[Tuple[float, float, float, float]] = None, | |
method: Optional[str] = "distance+confidence", | |
) -> Tuple[np.ndarray, np.ndarray]: | |
""" | |
Select and order keypoints for SAM prompting based on specified method. | |
Args: | |
args: Configuration object with selection_method and visibility_thr attributes. | |
kpts (np.ndarray): Keypoints array of shape (K, 3). | |
num_visible (int): Number of keypoints above visibility threshold. | |
bbox (Optional[Tuple]): Optional bbox for distance methods. | |
method (Optional[str]): Override selection method. | |
Returns: | |
Tuple[np.ndarray, np.ndarray]: Selected keypoint coordinates (N,2) and confidences (N,). | |
Raises: | |
ValueError: If an unknown method is specified. | |
""" | |
if num_visible == 0: | |
return kpts[:, :2], kpts[:, 2] | |
methods = ["confidence", "distance", "distance+confidence", "closest"] | |
sel_method = method or args.selection_method | |
if sel_method not in methods: | |
raise ValueError("Unknown method for keypoint selection: {}".format(sel_method)) | |
# Select at maximum keypoint from the face | |
facial_kpts = kpts[:3, :] | |
facial_conf = kpts[:3, 2] | |
facial_point = facial_kpts[np.argmax(facial_conf)] | |
if facial_point[-1] >= args.visibility_thr: | |
kpts = np.concatenate([facial_point[None, :], kpts[3:]], axis=0) | |
conf = kpts[:, 2] | |
vis_mask = conf >= args.visibility_thr | |
coords = kpts[vis_mask, :2] | |
confs = conf[vis_mask] | |
if sel_method == "confidence": | |
order = np.argsort(confs)[::-1] | |
coords = coords[order] | |
confs = confs[order] | |
elif sel_method == "distance": | |
if bbox is None: | |
bbox_center = np.array([coords[:, 0].mean(), coords[:, 1].mean()]) | |
else: | |
bbox_center = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]) | |
dists = np.linalg.norm(coords[:, :2] - bbox_center, axis=1) | |
dist_matrix = np.linalg.norm(coords[:, None, :2] - coords[None, :, :2], axis=2) | |
np.fill_diagonal(dist_matrix, np.inf) | |
min_inter_dist = np.min(dist_matrix, axis=1) | |
order = np.argsort(dists + 3 * min_inter_dist)[::-1] | |
coords = coords[order, :2] | |
confs = confs[order] | |
elif sel_method == "distance+confidence": | |
order = np.argsort(confs)[::-1] | |
confidences = kpts[order, 2] | |
coords = coords[order, :2] | |
confs = confs[order] | |
dist_matrix = np.linalg.norm(coords[:, None, :2] - coords[None, :, :2], axis=2) | |
selected_idx = [0] | |
confidences[0] = -1 | |
for _ in range(coords.shape[0] - 1): | |
min_dist = np.min(dist_matrix[:, selected_idx], axis=1) | |
min_dist[confidences < np.percentile(confidences, 80)] = -1 | |
next_idx = np.argmax(min_dist) | |
selected_idx.append(next_idx) | |
confidences[next_idx] = -1 | |
coords = coords[selected_idx] | |
confs = confs[selected_idx] | |
elif sel_method == "closest": | |
coords = coords[confs > STRICT_KPT_THRESHOLD, :] | |
confs = confs[confs > STRICT_KPT_THRESHOLD] | |
if bbox is None: | |
bbox_center = np.array([coords[:, 0].mean(), coords[:, 1].mean()]) | |
else: | |
bbox_center = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]) | |
dists = np.linalg.norm(coords[:, :2] - bbox_center, axis=1) | |
order = np.argsort(dists) | |
coords = coords[order, :2] | |
confs = confs[order] | |
return coords, confs | |
def prepare_model(model_cfg: Any, model_checkpoint: str) -> SAM2ImagePredictor: | |
""" | |
Build and return a SAM2ImagePredictor model on the appropriate device. | |
Args: | |
model_cfg: Configuration for SAM2 model. | |
model_checkpoint (str): Path to model checkpoint. | |
Returns: | |
SAM2ImagePredictor: Initialized SAM2 image predictor. | |
""" | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
sam2 = build_sam2(model_cfg, model_checkpoint, device=device, apply_postprocessing=True) | |
model = SAM2ImagePredictor( | |
sam2, | |
max_hole_area=10.0, | |
max_sprinkle_area=50.0, | |
) | |
return model | |
def _compute_mask_pose_consistency(masks: List[np.ndarray], keypoints_list: List[np.ndarray]) -> np.ndarray: | |
""" | |
Compute mask-pose consistency score for each mask-keypoints pair. | |
Args: | |
masks (List[np.ndarray]): Binary masks list. | |
keypoints_list (List[np.ndarray]): List of keypoint arrays per instance. | |
Returns: | |
np.ndarray: Consistency scores array of shape (N,). | |
""" | |
scores: List[float] = [] | |
for mask, kpts in zip(masks, keypoints_list): | |
other_kpts = np.concatenate([keypoints_list[:idx], keypoints_list[idx + 1 :]], axis=0).reshape(-1, 3) | |
score = _compute_one_mask_pose_consistency(mask, kpts, other_kpts) | |
scores.append(score) | |
return np.array(scores) | |
def _pose2seg( | |
args: Any, | |
model: SAM2ImagePredictor, | |
bbox_xyxy: Optional[List[float]] = None, | |
pos_kpts: Optional[np.ndarray] = None, | |
neg_kpts: Optional[np.ndarray] = None, | |
image: Optional[np.ndarray] = None, | |
gt_mask: Optional[Any] = None, | |
num_pos_keypoints: Optional[int] = None, | |
gt_mask_is_binary: bool = False, | |
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]: | |
""" | |
Run SAM segmentation conditioned on pose keypoints and optional ground truth mask. | |
Args: | |
args: Configuration object with prompting settings. | |
model (SAM2ImagePredictor): Prepared SAM2 model. | |
bbox_xyxy (Optional[List[float]]): Bounding box coordinates in xyxy format. | |
pos_kpts (Optional[np.ndarray]): Positive keypoints array. | |
neg_kpts (Optional[np.ndarray]): Negative keypoints array. | |
image (Optional[np.ndarray]): Input image array. | |
gt_mask (Optional[Any]): Ground truth mask (optional). | |
num_pos_keypoints (Optional[int]): Number of positive keypoints to use. | |
gt_mask_is_binary (bool): Flag indicating if ground truth mask is binary. | |
Returns: | |
Tuple of (mask, pos_kpts_backup, neg_kpts_backup, score). | |
""" | |
num_pos_keypoints = args.num_pos_keypoints if num_pos_keypoints is None else num_pos_keypoints | |
# Filter-out un-annotated and invisible keypoints | |
if pos_kpts is not None: | |
pos_kpts = pos_kpts.reshape(-1, 3) | |
valid_kpts = pos_kpts[:, 2] > args.visibility_thr | |
pose_bbox = np.array([pos_kpts[:, 0].min(), pos_kpts[:, 1].min(), pos_kpts[:, 0].max(), pos_kpts[:, 1].max()]) | |
pos_kpts, conf = _select_keypoints(args, pos_kpts, num_visible=valid_kpts.sum(), bbox=bbox_xyxy) | |
pos_kpts_backup = np.concatenate([pos_kpts, conf[:, None]], axis=1) | |
if pos_kpts.shape[0] > num_pos_keypoints: | |
pos_kpts = pos_kpts[:num_pos_keypoints, :] | |
pos_kpts_backup = pos_kpts_backup[:num_pos_keypoints, :] | |
else: | |
pose_bbox = None | |
pos_kpts = np.empty((0, 2), dtype=np.float32) | |
pos_kpts_backup = np.empty((0, 3), dtype=np.float32) | |
if neg_kpts is not None: | |
neg_kpts = neg_kpts.reshape(-1, 3) | |
valid_kpts = neg_kpts[:, 2] > args.visibility_thr | |
neg_kpts, conf = _select_keypoints( | |
args, neg_kpts, num_visible=valid_kpts.sum(), bbox=bbox_xyxy, method="closest" | |
) | |
selected_neg_kpts = neg_kpts | |
neg_kpts_backup = np.concatenate([neg_kpts, conf[:, None]], axis=1) | |
if neg_kpts.shape[0] > args.num_neg_keypoints: | |
selected_neg_kpts = neg_kpts[: args.num_neg_keypoints, :] | |
else: | |
selected_neg_kpts = np.empty((0, 2), dtype=np.float32) | |
neg_kpts_backup = np.empty((0, 3), dtype=np.float32) | |
# Concatenate positive and negative keypoints | |
kpts = np.concatenate([pos_kpts, selected_neg_kpts], axis=0) | |
kpts_labels = np.concatenate([np.ones(pos_kpts.shape[0]), np.zeros(selected_neg_kpts.shape[0])], axis=0) | |
bbox = bbox_xyxy if args.use_bbox else None | |
if args.extend_bbox and not bbox is None: | |
# Expand the bbox such that it contains all positive keypoints | |
pose_bbox = np.array( | |
[pos_kpts[:, 0].min() - 2, pos_kpts[:, 1].min() - 2, pos_kpts[:, 0].max() + 2, pos_kpts[:, 1].max() + 2] | |
) | |
expanded_bbox = np.array(bbox) | |
expanded_bbox[:2] = np.minimum(bbox[:2], pose_bbox[:2]) | |
expanded_bbox[2:] = np.maximum(bbox[2:], pose_bbox[2:]) | |
bbox = expanded_bbox | |
if args.crop and args.use_bbox and image is not None: | |
# Crop the image to the 1.5 * bbox size | |
crop_bbox = np.array(bbox) | |
bbox_center = np.array([(crop_bbox[0] + crop_bbox[2]) / 2, (crop_bbox[1] + crop_bbox[3]) / 2]) | |
bbox_size = np.array([crop_bbox[2] - crop_bbox[0], crop_bbox[3] - crop_bbox[1]]) | |
bbox_size = 1.5 * bbox_size | |
crop_bbox = np.array( | |
[ | |
bbox_center[0] - bbox_size[0] / 2, | |
bbox_center[1] - bbox_size[1] / 2, | |
bbox_center[0] + bbox_size[0] / 2, | |
bbox_center[1] + bbox_size[1] / 2, | |
] | |
) | |
crop_bbox = np.round(crop_bbox).astype(int) | |
crop_bbox = np.clip(crop_bbox, 0, [image.shape[1], image.shape[0], image.shape[1], image.shape[0]]) | |
original_image_size = image.shape[:2] | |
image = image[crop_bbox[1] : crop_bbox[3], crop_bbox[0] : crop_bbox[2], :] | |
# Update the keypoints | |
kpts = kpts - crop_bbox[:2] | |
bbox[:2] = bbox[:2] - crop_bbox[:2] | |
bbox[2:] = bbox[2:] - crop_bbox[:2] | |
model.set_image(image) | |
masks, scores, logits = model.predict( | |
point_coords=kpts, | |
point_labels=kpts_labels, | |
box=bbox, | |
multimask_output=False, | |
) | |
mask = masks[0] | |
scores = scores[0] | |
if args.crop and args.use_bbox and image is not None: | |
# Pad the mask to the original image size | |
mask_padded = np.zeros(original_image_size, dtype=np.uint8) | |
mask_padded[crop_bbox[1] : crop_bbox[3], crop_bbox[0] : crop_bbox[2]] = mask | |
mask = mask_padded | |
bbox[:2] = bbox[:2] + crop_bbox[:2] | |
bbox[2:] = bbox[2:] + crop_bbox[:2] | |
if args.pose_mask_consistency: | |
if gt_mask_is_binary: | |
gt_mask_binary = gt_mask | |
else: | |
gt_mask_binary = Mask.decode(gt_mask).astype(bool) if gt_mask is not None else None | |
gt_mask_pose_consistency = _compute_one_mask_pose_consistency(gt_mask_binary, pos_kpts_backup, neg_kpts_backup) | |
dt_mask_pose_consistency = _compute_one_mask_pose_consistency(mask, pos_kpts_backup, neg_kpts_backup) | |
tol = 0.1 | |
dt_is_same = np.abs(dt_mask_pose_consistency - gt_mask_pose_consistency) < tol | |
if dt_is_same: | |
mask = gt_mask_binary if gt_mask_binary.sum() < mask.sum() else mask | |
else: | |
mask = gt_mask_binary if gt_mask_pose_consistency > dt_mask_pose_consistency else mask | |
return mask, pos_kpts_backup, neg_kpts_backup, scores | |
def process_image_with_SAM( | |
sam_args: Any, | |
image: np.ndarray, | |
model: SAM2ImagePredictor, | |
new_dets: InstanceData, | |
old_dets: Optional[InstanceData] = None, | |
) -> InstanceData: | |
""" | |
Wrapper that validates args and routes to single or batch processing. | |
""" | |
_validate_sam_args(sam_args) | |
if sam_args.batch: | |
return _process_image_batch(sam_args, image, model, new_dets, old_dets) | |
return _process_image_single(sam_args, image, model, new_dets, old_dets) | |
def _process_image_single( | |
sam_args: Any, | |
image: np.ndarray, | |
model: SAM2ImagePredictor, | |
new_dets: InstanceData, | |
old_dets: Optional[InstanceData] = None, | |
) -> InstanceData: | |
""" | |
Refine instance segmentation masks using SAM2 with pose-conditioned prompts. | |
Args: | |
sam_args (Any): DotDict containing required SAM parameters: | |
crop (bool), use_bbox (bool), confidence_thr (float), | |
ignore_small_bboxes (bool), num_pos_keypoints (int), | |
num_pos_keypoints_if_crowd (int), crowd_by_max_iou (Optional[float]), | |
batch (bool), exclusive_masks (bool), extend_bbox (bool), pose_mask_consistency (bool). | |
image (np.ndarray): BGR image array of shape (H, W, 3). | |
model (SAM2ImagePredictor): Initialized SAM2 predictor. | |
new_dets (InstanceData): New detections with attributes: | |
bboxes, pred_masks, keypoints, bbox_scores. | |
old_dets (Optional[InstanceData]): Previous detections for negative prompts. | |
Returns: | |
InstanceData: `new_dets` updated in-place with | |
`.refined_masks`, `.sam_scores`, and `.sam_kpts`. | |
""" | |
_validate_sam_args(sam_args) | |
if not (sam_args.crop and sam_args.use_bbox): | |
model.set_image(image) | |
# Ignore all keypoints with confidence below the threshold | |
new_keypoints = new_dets.keypoints.copy() | |
for kpts in new_keypoints: | |
conf_mask = kpts[:, 2] < sam_args.confidence_thr | |
kpts[conf_mask, :] = 0 | |
n_new_dets = len(new_dets.bboxes) | |
n_old_dets = 0 | |
if old_dets is not None: | |
n_old_dets = len(old_dets.bboxes) | |
old_keypoints = old_dets.keypoints.copy() | |
for kpts in old_keypoints: | |
conf_mask = kpts[:, 2] < sam_args.confidence_thr | |
kpts[conf_mask, :] = 0 | |
all_bboxes = new_dets.bboxes.copy() | |
if old_dets is not None: | |
all_bboxes = np.concatenate([all_bboxes, old_dets.bboxes], axis=0) | |
max_ious = _get_max_ious(all_bboxes) | |
gt_bboxes = [] | |
new_dets.refined_masks = np.zeros((n_new_dets, image.shape[0], image.shape[1]), dtype=np.uint8) | |
new_dets.sam_scores = np.zeros_like(new_dets.bbox_scores) | |
new_dets.sam_kpts = np.zeros((len(new_dets.bboxes), sam_args.num_pos_keypoints, 3), dtype=np.float32) | |
for instance_idx in range(len(new_dets.bboxes)): | |
bbox_xywh = new_dets.bboxes[instance_idx] | |
bbox_area = bbox_xywh[2] * bbox_xywh[3] | |
if sam_args.ignore_small_bboxes and bbox_area < 100 * 100: | |
continue | |
dt_mask = new_dets.pred_masks[instance_idx] if new_dets.pred_masks is not None else None | |
bbox_xyxy = [bbox_xywh[0], bbox_xywh[1], bbox_xywh[0] + bbox_xywh[2], bbox_xywh[1] + bbox_xywh[3]] | |
gt_bboxes.append(bbox_xyxy) | |
this_kpts = new_keypoints[instance_idx].reshape(1, -1, 3) | |
other_kpts = None | |
if old_dets is not None: | |
other_kpts = old_keypoints.copy().reshape(n_old_dets, -1, 3) | |
if len(new_keypoints) > 1: | |
other_new_kpts = np.concatenate([new_keypoints[:instance_idx], new_keypoints[instance_idx + 1 :]], axis=0) | |
other_kpts = ( | |
np.concatenate([other_kpts, other_new_kpts], axis=0) if other_kpts is not None else other_new_kpts | |
) | |
num_pos_keypoints = sam_args.num_pos_keypoints | |
if sam_args.crowd_by_max_iou is not None and max_ious[instance_idx] > sam_args.crowd_by_max_iou: | |
bbox_xyxy = None | |
num_pos_keypoints = sam_args.num_pos_keypoints_if_crowd | |
dt_mask, pos_kpts, neg_kpts, scores = _pose2seg( | |
sam_args, | |
model, | |
bbox_xyxy, | |
pos_kpts=this_kpts, | |
neg_kpts=other_kpts, | |
image=image if (sam_args.crop and sam_args.use_bbox) else None, | |
gt_mask=dt_mask, | |
num_pos_keypoints=num_pos_keypoints, | |
gt_mask_is_binary=True, | |
) | |
new_dets.refined_masks[instance_idx] = dt_mask | |
new_dets.sam_scores[instance_idx] = scores | |
# If the number of positive keypoints is less than the required number, fill the rest with zeros | |
if len(pos_kpts) != sam_args.num_pos_keypoints: | |
pos_kpts = np.concatenate( | |
[pos_kpts, np.zeros((sam_args.num_pos_keypoints - len(pos_kpts), 3), dtype=np.float32)], axis=0 | |
) | |
new_dets.sam_kpts[instance_idx] = pos_kpts | |
n_masks = len(new_dets.refined_masks) + (len(old_dets.refined_masks) if old_dets is not None else 0) | |
if sam_args.exclusive_masks and n_masks > 1: | |
all_masks = ( | |
np.concatenate([new_dets.refined_masks, old_dets.refined_masks], axis=0) | |
if old_dets is not None | |
else new_dets.refined_masks | |
) | |
all_scores = ( | |
np.concatenate([new_dets.sam_scores, old_dets.sam_scores], axis=0) | |
if old_dets is not None | |
else new_dets.sam_scores | |
) | |
refined_masks = _apply_exclusive_masks(all_masks, all_scores) | |
new_dets.refined_masks = refined_masks[: len(new_dets.refined_masks)] | |
return new_dets | |
def _process_image_batch( | |
sam_args: Any, | |
image: np.ndarray, | |
model: SAM2ImagePredictor, | |
new_dets: InstanceData, | |
old_dets: Optional[InstanceData] = None, | |
) -> InstanceData: | |
""" | |
Batch process multiple detection instances with SAM2 refinement. | |
Args: | |
sam_args (Any): DotDict of SAM parameters (same as `process_image_with_SAM`). | |
image (np.ndarray): Input BGR image. | |
model (SAM2ImagePredictor): Prepared SAM2 predictor. | |
new_dets (InstanceData): New detection instances. | |
old_dets (Optional[InstanceData]): Previous detections for negative prompts. | |
Returns: | |
InstanceData: `new_dets` updated as in `process_image_with_SAM`. | |
""" | |
n_new_dets = len(new_dets.bboxes) | |
model.set_image(image) | |
image_kpts = [] | |
image_bboxes = [] | |
num_valid_kpts = [] | |
for instance_idx in range(len(new_dets.bboxes)): | |
bbox_xywh = new_dets.bboxes[instance_idx].copy() | |
bbox_area = bbox_xywh[2] * bbox_xywh[3] | |
if sam_args.ignore_small_bboxes and bbox_area < 100 * 100: | |
continue | |
this_kpts = new_dets.keypoints[instance_idx].copy().reshape(-1, 3) | |
kpts_vis = np.array(this_kpts[:, 2]) | |
visible_kpts = (kpts_vis > sam_args.visibility_thr) & (this_kpts[:, 2] > sam_args.confidence_thr) | |
num_visible = (visible_kpts).sum() | |
if num_visible <= 0: | |
continue | |
num_valid_kpts.append(num_visible) | |
image_bboxes.append(np.array(bbox_xywh)) | |
this_kpts[~visible_kpts, :2] = 0 | |
this_kpts[:, 2] = visible_kpts | |
image_kpts.append(this_kpts) | |
if old_dets is not None: | |
for instance_idx in range(len(old_dets.bboxes)): | |
bbox_xywh = old_dets.bboxes[instance_idx].copy() | |
bbox_area = bbox_xywh[2] * bbox_xywh[3] | |
if sam_args.ignore_small_bboxes and bbox_area < 100 * 100: | |
continue | |
this_kpts = old_dets.keypoints[instance_idx].reshape(-1, 3) | |
kpts_vis = np.array(this_kpts[:, 2]) | |
visible_kpts = (kpts_vis > sam_args.visibility_thr) & (this_kpts[:, 2] > sam_args.confidence_thr) | |
num_visible = (visible_kpts).sum() | |
if num_visible <= 0: | |
continue | |
num_valid_kpts.append(num_visible) | |
image_bboxes.append(np.array(bbox_xywh)) | |
this_kpts[~visible_kpts, :2] = 0 | |
this_kpts[:, 2] = visible_kpts | |
image_kpts.append(this_kpts) | |
image_kpts = np.array(image_kpts) | |
image_bboxes = np.array(image_bboxes) | |
num_valid_kpts = np.array(num_valid_kpts) | |
image_kpts_backup = image_kpts.copy() | |
# Prepare keypoints such that all instances have the same number of keypoints | |
# First sort keypoints by their distance to the center of the bounding box | |
# If some are missing, duplicate the last one | |
prepared_kpts = [] | |
prepared_kpts_backup = [] | |
for bbox, kpts, num_visible in zip(image_bboxes, image_kpts, num_valid_kpts): | |
this_kpts, this_conf = _select_keypoints(sam_args, kpts, num_visible, bbox) | |
# Duplicate the last keypoint if some are missing | |
if this_kpts.shape[0] < num_valid_kpts.max(): | |
this_kpts = np.concatenate( | |
[this_kpts, np.tile(this_kpts[-1], (num_valid_kpts.max() - this_kpts.shape[0], 1))], axis=0 | |
) | |
this_conf = np.concatenate( | |
[this_conf, np.tile(this_conf[-1], (num_valid_kpts.max() - this_conf.shape[0],))], axis=0 | |
) | |
prepared_kpts.append(this_kpts) | |
prepared_kpts_backup.append(np.concatenate([this_kpts, this_conf[:, None]], axis=1)) | |
image_kpts = np.array(prepared_kpts) | |
image_kpts_backup = np.array(prepared_kpts_backup) | |
kpts_labels = np.ones(image_kpts.shape[:2]) | |
# Compute IoUs between all bounding boxes | |
max_ious = _get_max_ious(image_bboxes) | |
num_pos_keypoints = sam_args.num_pos_keypoints | |
use_bbox = sam_args.use_bbox | |
if sam_args.crowd_by_max_iou is not None and max_ious[instance_idx] > sam_args.crowd_by_max_iou: | |
use_bbox = False | |
num_pos_keypoints = sam_args.num_pos_keypoints_if_crowd | |
# Threshold the number of positive keypoints | |
if num_pos_keypoints > 0 and num_pos_keypoints < image_kpts.shape[1]: | |
image_kpts = image_kpts[:, :num_pos_keypoints, :] | |
kpts_labels = kpts_labels[:, :num_pos_keypoints] | |
image_kpts_backup = image_kpts_backup[:, :num_pos_keypoints, :] | |
elif num_pos_keypoints == 0: | |
image_kpts = None | |
kpts_labels = None | |
image_kpts_backup = np.empty((0, 3), dtype=np.float32) | |
image_bboxes_xyxy = None | |
if use_bbox: | |
image_bboxes_xyxy = np.array(image_bboxes) | |
image_bboxes_xyxy[:, 2:] += image_bboxes_xyxy[:, :2] | |
# Expand the bbox to include the positive keypoints | |
if sam_args.extend_bbox: | |
pose_bbox = np.stack( | |
[ | |
np.min(image_kpts[:, :, 0], axis=1) - 2, | |
np.min(image_kpts[:, :, 1], axis=1) - 2, | |
np.max(image_kpts[:, :, 0], axis=1) + 2, | |
np.max(image_kpts[:, :, 1], axis=1) + 2, | |
], | |
axis=1, | |
) | |
expanded_bbox = np.array(image_bboxes_xyxy) | |
expanded_bbox[:, :2] = np.minimum(expanded_bbox[:, :2], pose_bbox[:, :2]) | |
expanded_bbox[:, 2:] = np.maximum(expanded_bbox[:, 2:], pose_bbox[:, 2:]) | |
# bbox_expanded = (np.abs(expanded_bbox - image_bboxes_xyxy) > 1e-4).any(axis=1) | |
image_bboxes_xyxy = expanded_bbox | |
# Process even old detections to get their 'negative' keypoints | |
masks, scores, logits = model.predict( | |
point_coords=image_kpts, | |
point_labels=kpts_labels, | |
box=image_bboxes_xyxy, | |
multimask_output=False, | |
) | |
# Reshape the masks to (N, C, H, W). If the model outputs (C, H, W), add a number of masks dimension | |
if len(masks.shape) == 3: | |
masks = masks[None, :, :, :] | |
masks = masks[:, 0, :, :] | |
N = masks.shape[0] | |
scores = scores.reshape(N) | |
if sam_args.exclusive_masks and N > 1: | |
# Make sure the masks are non-overlapping | |
# If two masks overlap, set the pixel to the one with the highest score | |
masks = _apply_exclusive_masks(masks, scores) | |
gt_masks = new_dets.pred_masks.copy() if new_dets.pred_masks is not None else None | |
if sam_args.pose_mask_consistency and gt_masks is not None: | |
# Measure 'mask-pose_conistency' by computing number of keypoints inside the mask | |
# Compute for both gt (if available) and predicted masks and then choose the one with higher consistency | |
dt_mask_pose_consistency = _compute_mask_pose_consistency(masks, image_kpts_backup) | |
gt_mask_pose_consistency = _compute_mask_pose_consistency(gt_masks, image_kpts_backup) | |
dt_masks_area = np.array([m.sum() for m in masks]) | |
gt_masks_area = np.array([m.sum() for m in gt_masks]) if gt_masks is not None else np.zeros_like(dt_masks_area) | |
# If PM-c is approx the same, prefer the smaller mask | |
tol = 0.1 | |
pmc_is_equal = np.isclose(dt_mask_pose_consistency, gt_mask_pose_consistency, atol=tol) | |
dt_is_worse = (dt_mask_pose_consistency < (gt_mask_pose_consistency - tol)) | pmc_is_equal & ( | |
dt_masks_area > gt_masks_area | |
) | |
new_masks = [] | |
for dt_mask, gt_mask, dt_worse in zip(masks, gt_masks, dt_is_worse): | |
if dt_worse: | |
new_masks.append(gt_mask) | |
else: | |
new_masks.append(dt_mask) | |
masks = np.array(new_masks) | |
new_dets.refined_masks = masks[:n_new_dets] | |
new_dets.sam_scores = scores[:n_new_dets] | |
new_dets.sam_kpts = image_kpts_backup[:n_new_dets] | |
return new_dets | |
def _apply_exclusive_masks(masks: np.ndarray, scores: np.ndarray) -> np.ndarray: | |
""" | |
Ensure masks are non-overlapping by keeping at each pixel the mask with the highest score. | |
""" | |
no_mask = masks.sum(axis=0) == 0 | |
masked_scores = masks * scores[:, None, None] | |
argmax_masks = np.argmax(masked_scores, axis=0) | |
new_masks = argmax_masks[None, :, :] == (np.arange(masks.shape[0])[:, None, None]) | |
new_masks[:, no_mask] = 0 | |
return new_masks | |