|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
|
|
import numpy as np |
|
import pycocotools.mask as mask_util |
|
import torch |
|
|
|
from cosmos_transfer1.utils import log |
|
|
|
sys.path.append("cosmos_transfer1/auxiliary") |
|
|
|
import tempfile |
|
|
|
from PIL import Image |
|
from sam2.sam2_video_predictor import SAM2VideoPredictor |
|
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor |
|
|
|
from cosmos_transfer1.auxiliary.sam2.sam2_utils import ( |
|
capture_fps, |
|
convert_masks_to_frames, |
|
generate_tensor_from_images, |
|
video_to_frames, |
|
write_video, |
|
) |
|
from cosmos_transfer1.checkpoints import GROUNDING_DINO_MODEL_CHECKPOINT, SAM2_MODEL_CHECKPOINT |
|
|
|
|
|
def rle_encode(mask: np.ndarray) -> dict: |
|
""" |
|
Encode a boolean mask (of shape (T, H, W)) using the pycocotools RLE format, |
|
matching the format of eff_segmentation.RleMaskSAMv2 (from Yotta). |
|
|
|
The procedure is: |
|
1. Convert the mask to a numpy array in Fortran order. |
|
2. Reshape the array to (-1, 1) (i.e. flatten in Fortran order). |
|
3. Call pycocotools.mask.encode on the reshaped array. |
|
4. Return a dictionary with the encoded data and the original mask shape. |
|
""" |
|
mask = np.array(mask, order="F") |
|
|
|
encoded = mask_util.encode(np.array(mask.reshape(-1, 1), order="F")) |
|
return {"data": encoded, "mask_shape": mask.shape} |
|
|
|
|
|
class VideoSegmentationModel: |
|
def __init__(self, **kwargs): |
|
"""Initialize the model and load all required components.""" |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self.sam2_predictor = SAM2VideoPredictor.from_pretrained(SAM2_MODEL_CHECKPOINT).to(self.device) |
|
|
|
|
|
self.grounding_model_name = kwargs.get("grounding_model", GROUNDING_DINO_MODEL_CHECKPOINT) |
|
self.processor = AutoProcessor.from_pretrained(self.grounding_model_name) |
|
self.grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(self.grounding_model_name).to( |
|
self.device |
|
) |
|
|
|
def get_boxes_from_text(self, image_path, text_prompt): |
|
"""Get bounding boxes (and labels) from a text prompt using GroundingDINO.""" |
|
image = Image.open(image_path).convert("RGB") |
|
|
|
inputs = self.processor(images=image, text=text_prompt, return_tensors="pt").to(self.device) |
|
|
|
with torch.no_grad(): |
|
outputs = self.grounding_model(**inputs) |
|
|
|
|
|
results = self.processor.post_process_grounded_object_detection( |
|
outputs, |
|
inputs.input_ids, |
|
box_threshold=0.15, |
|
text_threshold=0.25, |
|
target_sizes=[image.size[::-1]], |
|
) |
|
|
|
boxes = results[0]["boxes"].cpu().numpy() |
|
scores = results[0]["scores"].cpu().numpy() |
|
labels = results[0].get("labels", None) |
|
if len(boxes) == 0: |
|
print(f"No boxes detected for prompt: '{text_prompt}'. Trying with lower thresholds...") |
|
results = self.processor.post_process_grounded_object_detection( |
|
outputs, |
|
inputs.input_ids, |
|
box_threshold=0.1, |
|
text_threshold=0.1, |
|
target_sizes=[image.size[::-1]], |
|
) |
|
boxes = results[0]["boxes"].cpu().numpy() |
|
scores = results[0]["scores"].cpu().numpy() |
|
labels = results[0].get("labels", None) |
|
|
|
if len(boxes) > 0: |
|
print(f"Found {len(boxes)} boxes with scores: {scores}") |
|
|
|
sorted_indices = np.argsort(scores)[::-1] |
|
boxes = boxes[sorted_indices] |
|
scores = scores[sorted_indices] |
|
if labels is not None: |
|
labels = np.array(labels)[sorted_indices] |
|
else: |
|
print("Still no boxes detected. Consider adjusting the prompt or using box/points mode.") |
|
|
|
return {"boxes": boxes, "labels": labels, "scores": scores} |
|
|
|
def visualize_frame(self, frame_idx, obj_ids, masks, video_dir, frame_names, visualization_data, save_dir=None): |
|
""" |
|
Process a single frame: load the image, apply the segmentation mask to black out the |
|
detected object(s), and save both the masked frame and the binary mask image. |
|
""" |
|
|
|
frame_path = os.path.join(video_dir, frame_names[frame_idx]) |
|
img = Image.open(frame_path).convert("RGB") |
|
image_np = np.array(img) |
|
|
|
|
|
if isinstance(masks, torch.Tensor): |
|
mask_np = (masks[0] > 0.0).cpu().numpy().astype(bool) |
|
combined_mask = mask_np |
|
elif isinstance(masks, dict): |
|
first_mask = next(iter(masks.values())) |
|
combined_mask = np.zeros_like(first_mask, dtype=bool) |
|
for m in masks.values(): |
|
combined_mask |= m |
|
else: |
|
combined_mask = None |
|
|
|
if combined_mask is not None: |
|
combined_mask = np.squeeze(combined_mask) |
|
|
|
|
|
if combined_mask.shape != image_np.shape[:2]: |
|
mask_img = Image.fromarray((combined_mask.astype(np.uint8)) * 255) |
|
mask_img = mask_img.resize((image_np.shape[1], image_np.shape[0]), resample=Image.NEAREST) |
|
combined_mask = np.array(mask_img) > 127 |
|
|
|
|
|
image_np[combined_mask] = 0 |
|
|
|
mask_image = (combined_mask.astype(np.uint8)) * 255 |
|
mask_pil = Image.fromarray(mask_image) |
|
|
|
if save_dir: |
|
seg_frame_path = os.path.join(save_dir, f"frame_{frame_idx}_segmented.png") |
|
seg_pil = Image.fromarray(image_np) |
|
seg_pil.save(seg_frame_path) |
|
if combined_mask is not None: |
|
mask_save_path = os.path.join(save_dir, f"frame_{frame_idx}_mask.png") |
|
mask_pil.save(mask_save_path) |
|
|
|
def sample(self, **kwargs): |
|
""" |
|
Main sampling function for video segmentation. |
|
Returns a list of detections in which each detection contains a phrase and |
|
an RLE-encoded segmentation mask (matching the output of the Grounded SAM model). |
|
""" |
|
video_dir = kwargs.get("video_dir", "") |
|
mode = kwargs.get("mode", "points") |
|
input_data = kwargs.get("input_data", None) |
|
save_dir = kwargs.get("save_dir", None) |
|
visualize = kwargs.get("visualize", False) |
|
|
|
|
|
frame_names = [p for p in os.listdir(video_dir) if os.path.splitext(p)[-1].lower() in [".jpg", ".jpeg"]] |
|
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) |
|
|
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
|
state = self.sam2_predictor.init_state(video_path=video_dir) |
|
|
|
ann_frame_idx = 0 |
|
ann_obj_id = 1 |
|
boxes = None |
|
points = None |
|
labels = None |
|
box = None |
|
|
|
visualization_data = {"mode": mode, "points": None, "labels": None, "box": None, "boxes": None} |
|
|
|
if input_data is not None: |
|
if mode == "points": |
|
points = input_data.get("points") |
|
labels = input_data.get("labels") |
|
frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box( |
|
inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels |
|
) |
|
visualization_data["points"] = points |
|
visualization_data["labels"] = labels |
|
elif mode == "box": |
|
box = input_data.get("box") |
|
frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box( |
|
inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, box=box |
|
) |
|
visualization_data["box"] = box |
|
elif mode == "prompt": |
|
text = input_data.get("text") |
|
first_frame_path = os.path.join(video_dir, frame_names[0]) |
|
gd_results = self.get_boxes_from_text(first_frame_path, text) |
|
boxes = gd_results["boxes"] |
|
labels_out = gd_results["labels"] |
|
scores = gd_results["scores"] |
|
log.info(f"scores: {scores}") |
|
if len(boxes) > 0: |
|
legacy_mask = kwargs.get("legacy_mask", False) |
|
if legacy_mask: |
|
|
|
log.info(f"using legacy_mask: {legacy_mask}") |
|
frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box( |
|
inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, box=boxes[0] |
|
) |
|
|
|
boxes = boxes[:1] |
|
if labels_out is not None: |
|
labels_out = labels_out[:1] |
|
else: |
|
log.info(f"using new_mask: {legacy_mask}") |
|
for object_id, (box, label) in enumerate(zip(boxes, labels_out)): |
|
frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box( |
|
inference_state=state, frame_idx=ann_frame_idx, obj_id=object_id, box=box |
|
) |
|
visualization_data["boxes"] = boxes |
|
self.grounding_labels = [str(lbl) for lbl in labels_out] if labels_out is not None else [text] |
|
else: |
|
print("No boxes detected. Exiting.") |
|
return [] |
|
|
|
if visualize: |
|
self.visualize_frame( |
|
frame_idx=ann_frame_idx, |
|
obj_ids=obj_ids, |
|
masks=masks, |
|
video_dir=video_dir, |
|
frame_names=frame_names, |
|
visualization_data=visualization_data, |
|
save_dir=save_dir, |
|
) |
|
|
|
video_segments = {} |
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.sam2_predictor.propagate_in_video(state): |
|
video_segments[out_frame_idx] = { |
|
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) |
|
} |
|
|
|
|
|
if visualize: |
|
propagate_visualization_data = { |
|
"mode": mode, |
|
"points": None, |
|
"labels": None, |
|
"box": None, |
|
"boxes": None, |
|
} |
|
self.visualize_frame( |
|
frame_idx=out_frame_idx, |
|
obj_ids=out_obj_ids, |
|
masks=video_segments[out_frame_idx], |
|
video_dir=video_dir, |
|
frame_names=frame_names, |
|
visualization_data=propagate_visualization_data, |
|
save_dir=save_dir, |
|
) |
|
|
|
|
|
if len(video_segments) == 0: |
|
return [] |
|
|
|
first_frame_path = os.path.join(video_dir, frame_names[0]) |
|
first_frame = np.array(Image.open(first_frame_path).convert("RGB")) |
|
original_shape = first_frame.shape[:2] |
|
|
|
object_masks = {} |
|
sorted_frame_indices = sorted(video_segments.keys()) |
|
for frame_idx in sorted_frame_indices: |
|
segments = video_segments[frame_idx] |
|
for obj_id, mask in segments.items(): |
|
mask = np.squeeze(mask) |
|
if mask.ndim != 2: |
|
print(f"Warning: Unexpected mask shape {mask.shape} for object {obj_id} in frame {frame_idx}.") |
|
continue |
|
|
|
if mask.shape != original_shape: |
|
mask_img = Image.fromarray(mask.astype(np.uint8) * 255) |
|
mask_img = mask_img.resize((original_shape[1], original_shape[0]), resample=Image.NEAREST) |
|
mask = np.array(mask_img) > 127 |
|
|
|
if obj_id not in object_masks: |
|
object_masks[obj_id] = [] |
|
object_masks[obj_id].append(mask) |
|
|
|
detections = [] |
|
for obj_id, mask_list in object_masks.items(): |
|
mask_stack = np.stack(mask_list, axis=0) |
|
|
|
rle = rle_encode(mask_stack) |
|
if mode == "prompt" and hasattr(self, "grounding_labels"): |
|
phrase = self.grounding_labels[0] |
|
else: |
|
phrase = input_data.get("text", "") |
|
detection = {"phrase": phrase, "segmentation_mask_rle": rle} |
|
detections.append(detection) |
|
|
|
return detections |
|
|
|
@staticmethod |
|
def parse_points(points_str): |
|
"""Parse a string of points into a numpy array. |
|
Supports a single point ('200,300') or multiple points separated by ';' (e.g., '200,300;100,150'). |
|
""" |
|
points = [] |
|
for point in points_str.split(";"): |
|
coords = point.split(",") |
|
if len(coords) != 2: |
|
continue |
|
points.append([float(coords[0]), float(coords[1])]) |
|
return np.array(points, dtype=np.float32) |
|
|
|
@staticmethod |
|
def parse_labels(labels_str): |
|
"""Parse a comma-separated string of labels into a numpy array.""" |
|
return np.array([int(x) for x in labels_str.split(",")], dtype=np.int32) |
|
|
|
@staticmethod |
|
def parse_box(box_str): |
|
"""Parse a comma-separated string of 4 box coordinates into a numpy array.""" |
|
return np.array([float(x) for x in box_str.split(",")], dtype=np.float32) |
|
|
|
def __call__( |
|
self, |
|
input_video, |
|
output_video=None, |
|
output_tensor=None, |
|
prompt=None, |
|
box=None, |
|
points=None, |
|
labels=None, |
|
weight_scaler=None, |
|
binarize_video=False, |
|
legacy_mask=False, |
|
): |
|
log.info( |
|
f"Processing video: {input_video} to generate segmentation video: {output_video} segmentation tensor: {output_tensor}" |
|
) |
|
assert os.path.exists(input_video) |
|
|
|
|
|
if points is not None: |
|
mode = "points" |
|
input_data = {"points": self.parse_points(points), "labels": self.parse_labels(labels)} |
|
elif box is not None: |
|
mode = "box" |
|
input_data = {"box": self.parse_box(box)} |
|
elif prompt is not None: |
|
mode = "prompt" |
|
input_data = {"text": prompt} |
|
|
|
with tempfile.TemporaryDirectory() as temp_input_dir: |
|
fps = capture_fps(input_video) |
|
video_to_frames(input_video, temp_input_dir) |
|
with tempfile.TemporaryDirectory() as temp_output_dir: |
|
masks = self.sample( |
|
video_dir=temp_input_dir, |
|
mode=mode, |
|
input_data=input_data, |
|
save_dir=str(temp_output_dir), |
|
visualize=True, |
|
legacy_mask=legacy_mask, |
|
) |
|
if output_video: |
|
os.makedirs(os.path.dirname(output_video), exist_ok=True) |
|
frames = convert_masks_to_frames(masks) |
|
if binarize_video: |
|
frames = np.any(frames > 0, axis=-1).astype(np.uint8) * 255 |
|
write_video(frames, output_video, fps) |
|
if output_tensor: |
|
generate_tensor_from_images( |
|
temp_output_dir, output_tensor, fps, "mask", weight_scaler=weight_scaler |
|
) |
|
|