import cv2 import json import torch import numpy as np from PIL import Image from skimage import morphology from typing import Optional, Tuple, List from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection from zim_anything import zim_model_registry, ZimPredictor import os os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" class DetPredictor(ZimPredictor): def predict( self, point_coords: Optional[np.ndarray] = None, point_labels: Optional[np.ndarray] = None, box: Optional[np.ndarray] = None, multimask_output: bool = True, return_logits: bool = False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Predict masks for the given input prompts, using the currently set image. Arguments: point_coords (np.ndarray or None): A Nx2 array of point prompts to the model. Each point is in (X,Y) in pixels. point_labels (np.ndarray or None): A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point. box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format. mask_input (np.ndarray): A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256. multimask_output (bool): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results. return_logits (bool): If true, returns un-thresholded masks logits instead of a binary mask. Returns: (np.ndarray): The output masks in CxHxW format, where C is the number of masks, and (H, W) is the original image size. (np.ndarray): An array of length C containing the model's predictions for the quality of each mask. (np.ndarray): An array of shape CxHxW, where C is the number of masks and H=W=256. These low resolution logits can be passed to a subsequent iteration as mask input. """ if not self.is_image_set: raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") # Transform input prompts coords_torch = None labels_torch = None box_torch = None if point_coords is not None: assert ( point_labels is not None ), "point_labels must be supplied if point_coords is supplied." point_coords = self.transform.apply_coords(point_coords, self.original_size) coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) labels_torch = torch.as_tensor(point_labels, dtype=torch.float, device=self.device) coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] if box is not None: box = self.transform.apply_boxes(box, self.original_size) box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) masks, iou_predictions, low_res_masks = self.predict_torch( coords_torch, labels_torch, box_torch, multimask_output, return_logits=return_logits, ) if not return_logits: masks = masks > 0.5 masks_np = masks.squeeze(0).float().detach().cpu().numpy() iou_predictions_np = iou_predictions[0].squeeze(0).float().detach().cpu().numpy() low_res_masks_np = low_res_masks[0].squeeze(0).float().detach().cpu().numpy() return masks_np, iou_predictions_np, low_res_masks_np def build_gd_model(GROUNDING_MODEL, device="cuda"): """Build Grounding DINO model from HuggingFace Args: GROUNDING_MODEL: Model identifier device: Device to load model on (default: "cuda") Returns: processor: Model processor grounding_model: Loaded model """ model_id = GROUNDING_MODEL processor = AutoProcessor.from_pretrained(model_id) grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained( model_id).to(device) return processor, grounding_model def build_zim_model(ZIM_MODEL_CONFIG, ZIM_CHECKPOINT, device="cuda"): """Build ZIM-Anything model from HuggingFace Args: ZIM_MODEL_CONFIG: Model configuration ZIM_CHECKPOINT: Model checkpoint path device: Device to load model on (default: "cuda") Returns: zim_predictor: Initialized ZIM predictor """ zim_model = zim_model_registry[ZIM_MODEL_CONFIG]( checkpoint=ZIM_CHECKPOINT).to(device) zim_predictor = DetPredictor(zim_model) return zim_predictor def mask_nms(masks, scores, threshold=0.5): """Perform Non-Maximum Suppression based on mask overlap Args: masks: Input masks tensor (N,H,W) scores: Confidence scores for each mask threshold: IoU threshold for suppression (default: 0.5) Returns: keep: Indices of kept masks """ areas = torch.sum(masks, dim=(1, 2)) # [N,] _, order = scores.sort(0, descending=True) keep = [] while order.numel() > 0: if order.numel() == 1: i = order.item() keep.append(i) break else: i = order[0].item() keep.append(i) inter = torch.sum(torch.logical_and( masks[order[1:]], masks[i]), dim=(1, 2)) # [N-1,] min_areas = torch.minimum(areas[i], areas[order[1:]]) # [N-1,] iomin = inter / min_areas idx = (iomin <= threshold).nonzero().squeeze() if idx.numel() == 0: break order = order[idx + 1] return torch.LongTensor(keep) def filter_small_bboxes(results, max_num=100): """Filter small bounding boxes to avoid memory overflow Args: results: Detection results containing boxes max_num: Maximum number of boxes to keep (default: 100) Returns: keep: Indices of kept boxes """ bboxes = results[0]["boxes"] x1 = bboxes[:, 0] y1 = bboxes[:, 1] x2 = bboxes[:, 2] y2 = bboxes[:, 3] scores = (x2-x1)*(y2-y1) _, order = scores.sort(0, descending=True) keep = [order[i].item() for i in range(min(max_num, order.numel()))] return torch.LongTensor(keep) def filter_by_general_score(results, score_threshold=0.35): """Filter results by confidence score Args: results: Detection results score_threshold: Minimum confidence score (default: 0.35) Returns: filtered_data: Filtered results """ filtered_data = [] for entry in results: scores = entry['scores'] labels = entry['labels'] mask = scores > score_threshold filtered_scores = scores[mask] filtered_boxes = entry['boxes'][mask] mask_list = mask.tolist() filtered_labels = [labels[i] for i in range(len(labels)) if mask_list[i]] filtered_entry = { 'scores': filtered_scores, 'labels': filtered_labels, 'boxes': filtered_boxes } filtered_data.append(filtered_entry) return filtered_data def filter_by_location(results, edge_threshold=20): """Filter boxes near the left edge Args: results: Detection results edge_threshold: Distance threshold from left edge (default: 20) Returns: keep: Indices of kept boxes """ bboxes = results[0]["boxes"] keep = [] for i in range(bboxes.shape[0]): x1 = bboxes[i][0] if x1 < edge_threshold: continue keep.append(i) return torch.LongTensor(keep) def unpad_mask(results, masks, pad_len): """Remove padding from masks and adjust boxes Args: results: Detection results masks: Padded masks pad_len: Padding length to remove Returns: results: Adjusted results masks: Unpadded masks """ results[0]["boxes"][:, 0] = results[0]["boxes"][:, 0] - pad_len results[0]["boxes"][:, 2] = results[0]["boxes"][:, 2] - pad_len for i in range(results[0]["boxes"].shape[0]): if results[0]["boxes"][i][0] < 0: results[0]["boxes"][i][0] += pad_len * 2 new_mask = torch.cat( (masks[i][:, pad_len:pad_len*2], masks[i][:, :pad_len]), dim=1) masks[i] = torch.cat((masks[i][:, :pad_len], new_mask), dim=1) if results[0]["boxes"][i][2] < 0: results[0]["boxes"][i][2] += pad_len * 2 return results, masks[:, :, pad_len:] def remove_small_objects(masks, min_size=1000): """Remove small objects from masks Args: masks: Input masks min_size: Minimum object size (default: 1000) Returns: masks: Cleaned masks """ for i in range(masks.shape[0]): masks[i] = morphology.remove_small_objects( masks[i], min_size=min_size, connectivity=2) return masks def remove_sky_floaters(mask, min_size=1000): """Remove small disconnected regions from sky mask Args: mask: Input sky mask min_size: Minimum region size (default: 1000) Returns: mask: Cleaned sky mask """ mask = morphology.remove_small_objects( mask, min_size=min_size, connectivity=2) return mask def remove_disconnected_masks(masks): """Remove masks with too many disconnected components Args: masks: Input masks Returns: keep: Indices of kept masks """ keep = [] for i in range(masks.shape[0]): binary = masks[i].astype(np.uint8) * 255 num, _ = cv2.connectedComponents( binary, connectivity=8, ltype=cv2.CV_32S) if num > 2: continue keep.append(i) return torch.LongTensor(keep) def get_contours_sky(mask): """Get contours of sky mask and fill them Args: mask: Input sky mask Returns: mask: Filled contour mask """ binary = mask.astype(np.uint8) * 255 contours, _ = cv2.findContours( binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if len(contours) == 0: return mask mask = np.zeros_like(binary) cv2.drawContours(mask, contours, -1, 1, -1) return mask.astype(np.bool_) def get_fg_pad( OUTPUT_DIR, IMG_PATH, IMG_SR_PATH, zim_predictor, processor, grounding_model, text, layer, scale=2, is_outdoor=True ): """Process foreground layer with padding and segmentation Args: OUTPUT_DIR: Output directory IMG_PATH: Input image path IMG_SR_PATH: Super-resolved image path zim_predictor: ZIM model predictor processor: Grounding model processor grounding_model: Grounding model text: Text prompt for detection layer: Layer identifier (0=fg1, else=fg2) scale: Scaling factor (default: 2) is_outdoor: Whether outdoor scene (default: True) """ # Load and pad input image image = cv2.imread(IMG_PATH, cv2.IMREAD_UNCHANGED) pad_len = image.shape[1] // 2 image = cv2.copyMakeBorder(image, 0, 0, pad_len, 0, cv2.BORDER_WRAP) image = Image.fromarray(image).convert("RGB") # Process super-resolution image image_sr = Image.open(IMG_SR_PATH) H, W = image_sr.height, image_sr.width image_sr = np.array(image_sr.convert("RGB")) pad_len_sr = W // 2 image_sr_pad = cv2.copyMakeBorder( image_sr, 0, 0, pad_len_sr, 0, cv2.BORDER_WRAP) zim_predictor.set_image(image_sr_pad) # Run object detection inputs = processor(images=image, text=text, return_tensors="pt").to( grounding_model.device) with torch.no_grad(): outputs = grounding_model(**inputs) # Process detection results results = processor.post_process_grounded_object_detection( outputs, inputs.input_ids, box_threshold=0.3, text_threshold=0.3, target_sizes=[image.size[::-1]] ) saved_json = {"bboxes": []} # Apply filters based on scene type if is_outdoor: results = filter_by_general_score(results, score_threshold=0.35) location_keep = filter_by_location(results) results[0]["boxes"] = results[0]["boxes"][location_keep] results[0]["scores"] = results[0]["scores"][location_keep] results[0]["labels"] = [results[0]["labels"][i] for i in location_keep] # Prepare box prompts for ZIM results[0]["boxes"] = results[0]["boxes"] * scale filter_keep = filter_small_bboxes(results) results[0]["boxes"] = results[0]["boxes"][filter_keep] results[0]["scores"] = results[0]["scores"][filter_keep] results[0]["labels"] = [results[0]["labels"][i] for i in filter_keep] input_boxes = results[0]["boxes"].cpu().numpy() if input_boxes.shape[0] == 0: return # Get masks from ZIM predictor masks, scores, _ = zim_predictor.predict( point_coords=None, point_labels=None, box=input_boxes, multimask_output=False, ) # Post-process masks if masks.ndim == 4: masks = masks.squeeze(1) min_floater = 500 masks = masks.astype(np.bool_) masks = remove_small_objects(masks, min_size=min_floater*(scale**2)) disconnect_keep = remove_disconnected_masks(masks) masks = torch.tensor(masks).bool()[disconnect_keep] results[0]["boxes"] = results[0]["boxes"][disconnect_keep] results[0]["scores"] = results[0]["scores"][disconnect_keep] results[0]["labels"] = [results[0]["labels"][i] for i in disconnect_keep] results, masks = unpad_mask(results, masks, pad_len=pad_len_sr) # Apply NMS scores = torch.sum(masks, dim=(1, 2)) keep = mask_nms(masks, scores, threshold=0.5) masks = masks[keep] results[0]["boxes"] = results[0]["boxes"][keep] results[0]["scores"] = results[0]["scores"][keep] results[0]["labels"] = [results[0]["labels"][i] for i in keep] if masks.shape[0] == 0: return # Create final foreground mask fg_mask = np.zeros((H, W), dtype=np.uint8) masks = masks.float().detach().cpu().numpy().astype(np.bool_) if masks.shape[0] == 0: return cnt = 0 min_sum = 3000 name = "fg1" if layer == 0 else "fg2" # Process each valid mask for i in range(masks.shape[0]): mask = masks[i] if mask.sum() < min_sum*(scale**2): continue saved_json["bboxes"].append({ "label": results[0]["labels"][i], "bbox": results[0]["boxes"][i].cpu().numpy().tolist(), "score": results[0]["scores"][i].item(), "area": int(mask.sum()) }) cnt += 1 fg_mask[mask] = cnt if cnt == 0: return # Save outputs with open(os.path.join(OUTPUT_DIR, f"{name}.json"), "w") as f: json.dump(saved_json, f, indent=4) Image.fromarray(fg_mask).save(os.path.join(OUTPUT_DIR, f"{name}_mask.png")) def get_fg_pad_outdoor( OUTPUT_DIR, IMG_PATH, IMG_SR_PATH, zim_predictor, processor, grounding_model, text, layer, scale=2, ): """write the foreground layer outdoor""" return get_fg_pad( OUTPUT_DIR, IMG_PATH, IMG_SR_PATH, zim_predictor, processor, grounding_model, text, layer, scale=2, is_outdoor=True ) def get_fg_pad_indoor( OUTPUT_DIR, IMG_PATH, IMG_SR_PATH, zim_predictor, processor, grounding_model, text, layer, scale=2, ): """write the foreground layer indoor""" return get_fg_pad( OUTPUT_DIR, IMG_PATH, IMG_SR_PATH, zim_predictor, processor, grounding_model, text, layer, scale=2, is_outdoor=False ) def get_sky( OUTPUT_DIR, IMG_PATH, IMG_SR_PATH, zim_predictor, processor, grounding_model, text, scale=2 ): """Extract and process sky layer from input image Args: OUTPUT_DIR: Output directory IMG_PATH: Input image path IMG_SR_PATH: Super-resolved image path zim_predictor: ZIM model predictor processor: Grounding model processor grounding_model: Grounding model text: Text prompt for detection scale: Scaling factor (default: 2) """ # Load input images image = Image.open(IMG_PATH).convert("RGB") image_sr = Image.open(IMG_SR_PATH) H, W = image_sr.height, image_sr.width zim_predictor.set_image(np.array(image_sr.convert("RGB"))) # Run object detection inputs = processor(images=image, text=text, return_tensors="pt").to( grounding_model.device) with torch.no_grad(): outputs = grounding_model(**inputs) # Process detection results results = processor.post_process_grounded_object_detection( outputs, inputs.input_ids, box_threshold=0.3, text_threshold=0.3, target_sizes=[image.size[::-1]] ) # Prepare box prompts for ZIM results[0]["boxes"] = results[0]["boxes"] * scale filter_keep = filter_small_bboxes(results) results[0]["boxes"] = results[0]["boxes"][filter_keep] results[0]["scores"] = results[0]["scores"][filter_keep] results[0]["labels"] = [results[0]["labels"][i] for i in filter_keep] input_boxes = results[0]["boxes"].cpu().numpy() if input_boxes.shape[0] == 0: sky_mask = np.zeros((H, W), dtype=np.bool_) return # Get masks from ZIM predictor masks, _, _ = zim_predictor.predict( point_coords=None, point_labels=None, box=input_boxes, multimask_output=False, ) # Post-process masks if masks.ndim == 4: masks = masks.squeeze(1) # Combine all detected masks sky_mask = np.zeros((H, W), dtype=np.bool_) for i in range(masks.shape[0]): mask = masks[i].astype(np.bool_) sky_mask[mask] = 1 # Clean up sky mask min_floater = 1000 sky_mask = sky_mask.astype(np.bool_) sky_mask = get_contours_sky(sky_mask) sky_mask = 1 - sky_mask # Invert to get sky area sky_mask = sky_mask.astype(np.bool_) sky_mask = remove_sky_floaters(sky_mask, min_size=min_floater*(scale**2)) sky_mask = get_contours_sky(sky_mask) # Save output mask Image.fromarray(sky_mask.astype(np.uint8) * 255).save(os.path.join(OUTPUT_DIR, "sky_mask.png"))