# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import numpy as np import pycocotools.mask as mask_util import torch from cosmos_transfer1.utils import log sys.path.append("cosmos_transfer1/auxiliary") import tempfile from PIL import Image from sam2.sam2_video_predictor import SAM2VideoPredictor from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor from cosmos_transfer1.auxiliary.sam2.sam2_utils import ( capture_fps, convert_masks_to_frames, generate_tensor_from_images, video_to_frames, write_video, ) from cosmos_transfer1.checkpoints import GROUNDING_DINO_MODEL_CHECKPOINT, SAM2_MODEL_CHECKPOINT def rle_encode(mask: np.ndarray) -> dict: """ Encode a boolean mask (of shape (T, H, W)) using the pycocotools RLE format, matching the format of eff_segmentation.RleMaskSAMv2 (from Yotta). The procedure is: 1. Convert the mask to a numpy array in Fortran order. 2. Reshape the array to (-1, 1) (i.e. flatten in Fortran order). 3. Call pycocotools.mask.encode on the reshaped array. 4. Return a dictionary with the encoded data and the original mask shape. """ mask = np.array(mask, order="F") # Reshape the mask to (-1, 1) in Fortran order and encode it. encoded = mask_util.encode(np.array(mask.reshape(-1, 1), order="F")) return {"data": encoded, "mask_shape": mask.shape} class VideoSegmentationModel: def __init__(self, **kwargs): """Initialize the model and load all required components.""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize SAM2 predictor self.sam2_predictor = SAM2VideoPredictor.from_pretrained(SAM2_MODEL_CHECKPOINT).to(self.device) # Initialize GroundingDINO for text-based detection self.grounding_model_name = kwargs.get("grounding_model", GROUNDING_DINO_MODEL_CHECKPOINT) self.processor = AutoProcessor.from_pretrained(self.grounding_model_name) self.grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(self.grounding_model_name).to( self.device ) def get_boxes_from_text(self, image_path, text_prompt): """Get bounding boxes (and labels) from a text prompt using GroundingDINO.""" image = Image.open(image_path).convert("RGB") inputs = self.processor(images=image, text=text_prompt, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.grounding_model(**inputs) # Try with initial thresholds. results = self.processor.post_process_grounded_object_detection( outputs, inputs.input_ids, box_threshold=0.15, text_threshold=0.25, target_sizes=[image.size[::-1]], ) boxes = results[0]["boxes"].cpu().numpy() scores = results[0]["scores"].cpu().numpy() labels = results[0].get("labels", None) if len(boxes) == 0: print(f"No boxes detected for prompt: '{text_prompt}'. Trying with lower thresholds...") results = self.processor.post_process_grounded_object_detection( outputs, inputs.input_ids, box_threshold=0.1, text_threshold=0.1, target_sizes=[image.size[::-1]], ) boxes = results[0]["boxes"].cpu().numpy() scores = results[0]["scores"].cpu().numpy() labels = results[0].get("labels", None) if len(boxes) > 0: print(f"Found {len(boxes)} boxes with scores: {scores}") # Sort boxes by confidence score in descending order sorted_indices = np.argsort(scores)[::-1] boxes = boxes[sorted_indices] scores = scores[sorted_indices] if labels is not None: labels = np.array(labels)[sorted_indices] else: print("Still no boxes detected. Consider adjusting the prompt or using box/points mode.") return {"boxes": boxes, "labels": labels, "scores": scores} def visualize_frame(self, frame_idx, obj_ids, masks, video_dir, frame_names, visualization_data, save_dir=None): """ Process a single frame: load the image, apply the segmentation mask to black out the detected object(s), and save both the masked frame and the binary mask image. """ # Load the frame. frame_path = os.path.join(video_dir, frame_names[frame_idx]) img = Image.open(frame_path).convert("RGB") image_np = np.array(img) # Combine masks from the detection output. if isinstance(masks, torch.Tensor): mask_np = (masks[0] > 0.0).cpu().numpy().astype(bool) combined_mask = mask_np elif isinstance(masks, dict): first_mask = next(iter(masks.values())) combined_mask = np.zeros_like(first_mask, dtype=bool) for m in masks.values(): combined_mask |= m else: combined_mask = None if combined_mask is not None: combined_mask = np.squeeze(combined_mask) # If the mask shape doesn't match the image, resize it. if combined_mask.shape != image_np.shape[:2]: mask_img = Image.fromarray((combined_mask.astype(np.uint8)) * 255) mask_img = mask_img.resize((image_np.shape[1], image_np.shape[0]), resample=Image.NEAREST) combined_mask = np.array(mask_img) > 127 # Black out the detected region. image_np[combined_mask] = 0 mask_image = (combined_mask.astype(np.uint8)) * 255 mask_pil = Image.fromarray(mask_image) if save_dir: seg_frame_path = os.path.join(save_dir, f"frame_{frame_idx}_segmented.png") seg_pil = Image.fromarray(image_np) seg_pil.save(seg_frame_path) if combined_mask is not None: mask_save_path = os.path.join(save_dir, f"frame_{frame_idx}_mask.png") mask_pil.save(mask_save_path) def sample(self, **kwargs): """ Main sampling function for video segmentation. Returns a list of detections in which each detection contains a phrase and an RLE-encoded segmentation mask (matching the output of the Grounded SAM model). """ video_dir = kwargs.get("video_dir", "") mode = kwargs.get("mode", "points") input_data = kwargs.get("input_data", None) save_dir = kwargs.get("save_dir", None) visualize = kwargs.get("visualize", False) # Get frame names (expecting frames named as numbers with .jpg/.jpeg extension). frame_names = [p for p in os.listdir(video_dir) if os.path.splitext(p)[-1].lower() in [".jpg", ".jpeg"]] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): state = self.sam2_predictor.init_state(video_path=video_dir) ann_frame_idx = 0 ann_obj_id = 1 boxes = None points = None labels = None box = None visualization_data = {"mode": mode, "points": None, "labels": None, "box": None, "boxes": None} if input_data is not None: if mode == "points": points = input_data.get("points") labels = input_data.get("labels") frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box( inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels ) visualization_data["points"] = points visualization_data["labels"] = labels elif mode == "box": box = input_data.get("box") frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box( inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, box=box ) visualization_data["box"] = box elif mode == "prompt": text = input_data.get("text") first_frame_path = os.path.join(video_dir, frame_names[0]) gd_results = self.get_boxes_from_text(first_frame_path, text) boxes = gd_results["boxes"] labels_out = gd_results["labels"] scores = gd_results["scores"] log.info(f"scores: {scores}") if len(boxes) > 0: legacy_mask = kwargs.get("legacy_mask", False) if legacy_mask: # Use only the highest confidence box for legacy mask log.info(f"using legacy_mask: {legacy_mask}") frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box( inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, box=boxes[0] ) # Update boxes and labels after processing boxes = boxes[:1] if labels_out is not None: labels_out = labels_out[:1] else: log.info(f"using new_mask: {legacy_mask}") for object_id, (box, label) in enumerate(zip(boxes, labels_out)): frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box( inference_state=state, frame_idx=ann_frame_idx, obj_id=object_id, box=box ) visualization_data["boxes"] = boxes self.grounding_labels = [str(lbl) for lbl in labels_out] if labels_out is not None else [text] else: print("No boxes detected. Exiting.") return [] # Return empty list if no detections if visualize: self.visualize_frame( frame_idx=ann_frame_idx, obj_ids=obj_ids, masks=masks, video_dir=video_dir, frame_names=frame_names, visualization_data=visualization_data, save_dir=save_dir, ) video_segments = {} # keys: frame index, values: {obj_id: mask} for out_frame_idx, out_obj_ids, out_mask_logits in self.sam2_predictor.propagate_in_video(state): video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } # For propagated frames, visualization_data is not used. if visualize: propagate_visualization_data = { "mode": mode, "points": None, "labels": None, "box": None, "boxes": None, } self.visualize_frame( frame_idx=out_frame_idx, obj_ids=out_obj_ids, masks=video_segments[out_frame_idx], video_dir=video_dir, frame_names=frame_names, visualization_data=propagate_visualization_data, save_dir=save_dir, ) # --- Post-process video_segments to produce a list of detections --- if len(video_segments) == 0: return [] first_frame_path = os.path.join(video_dir, frame_names[0]) first_frame = np.array(Image.open(first_frame_path).convert("RGB")) original_shape = first_frame.shape[:2] # (height, width) object_masks = {} # key: obj_id, value: list of 2D boolean masks sorted_frame_indices = sorted(video_segments.keys()) for frame_idx in sorted_frame_indices: segments = video_segments[frame_idx] for obj_id, mask in segments.items(): mask = np.squeeze(mask) if mask.ndim != 2: print(f"Warning: Unexpected mask shape {mask.shape} for object {obj_id} in frame {frame_idx}.") continue if mask.shape != original_shape: mask_img = Image.fromarray(mask.astype(np.uint8) * 255) mask_img = mask_img.resize((original_shape[1], original_shape[0]), resample=Image.NEAREST) mask = np.array(mask_img) > 127 if obj_id not in object_masks: object_masks[obj_id] = [] object_masks[obj_id].append(mask) detections = [] for obj_id, mask_list in object_masks.items(): mask_stack = np.stack(mask_list, axis=0) # shape: (T, H, W) # Use our new rle_encode (which now follows the eff_segmentation.RleMaskSAMv2 format) rle = rle_encode(mask_stack) if mode == "prompt" and hasattr(self, "grounding_labels"): phrase = self.grounding_labels[0] else: phrase = input_data.get("text", "") detection = {"phrase": phrase, "segmentation_mask_rle": rle} detections.append(detection) return detections @staticmethod def parse_points(points_str): """Parse a string of points into a numpy array. Supports a single point ('200,300') or multiple points separated by ';' (e.g., '200,300;100,150'). """ points = [] for point in points_str.split(";"): coords = point.split(",") if len(coords) != 2: continue points.append([float(coords[0]), float(coords[1])]) return np.array(points, dtype=np.float32) @staticmethod def parse_labels(labels_str): """Parse a comma-separated string of labels into a numpy array.""" return np.array([int(x) for x in labels_str.split(",")], dtype=np.int32) @staticmethod def parse_box(box_str): """Parse a comma-separated string of 4 box coordinates into a numpy array.""" return np.array([float(x) for x in box_str.split(",")], dtype=np.float32) def __call__( self, input_video, output_video=None, output_tensor=None, prompt=None, box=None, points=None, labels=None, weight_scaler=None, binarize_video=False, legacy_mask=False, ): log.info( f"Processing video: {input_video} to generate segmentation video: {output_video} segmentation tensor: {output_tensor}" ) assert os.path.exists(input_video) # Prepare input data based on the selected mode. if points is not None: mode = "points" input_data = {"points": self.parse_points(points), "labels": self.parse_labels(labels)} elif box is not None: mode = "box" input_data = {"box": self.parse_box(box)} elif prompt is not None: mode = "prompt" input_data = {"text": prompt} with tempfile.TemporaryDirectory() as temp_input_dir: fps = capture_fps(input_video) video_to_frames(input_video, temp_input_dir) with tempfile.TemporaryDirectory() as temp_output_dir: masks = self.sample( video_dir=temp_input_dir, mode=mode, input_data=input_data, save_dir=str(temp_output_dir), visualize=True, legacy_mask=legacy_mask, ) if output_video: os.makedirs(os.path.dirname(output_video), exist_ok=True) frames = convert_masks_to_frames(masks) if binarize_video: frames = np.any(frames > 0, axis=-1).astype(np.uint8) * 255 write_video(frames, output_video, fps) if output_tensor: generate_tensor_from_images( temp_output_dir, output_tensor, fps, "mask", weight_scaler=weight_scaler )