harry900000's picture
add cosmos-tranfer1/ into repo
226c7c9
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import tempfile
import numpy as np
from cosmos_transfer1.auxiliary.sam2.sam2_model import VideoSegmentationModel
from cosmos_transfer1.auxiliary.sam2.sam2_utils import (
capture_fps,
generate_tensor_from_images,
generate_video_from_images,
video_to_frames,
)
def parse_args():
parser = argparse.ArgumentParser(description="Video Segmentation using SAM2")
parser.add_argument("--input_video", type=str, required=True, help="Path to input video file")
parser.add_argument(
"--output_video", type=str, default="./outputs/output_video.mp4", help="Path to save the output video"
)
parser.add_argument(
"--output_tensor", type=str, default="./outputs/output_tensor.pt", help="Path to save the output tensor"
)
parser.add_argument(
"--mode", type=str, choices=["points", "box", "prompt"], default="points", help="Segmentation mode"
)
parser.add_argument("--prompt", type=str, help="Text prompt for prompt mode")
parser.add_argument(
"--grounding_model_path",
type=str,
default="IDEA-Research/grounding-dino-tiny",
help="Local directory for GroundingDINO model files",
)
parser.add_argument(
"--points",
type=str,
default="200,300",
help="Comma-separated point coordinates for points mode (e.g., '200,300' or for multiple points use ';' as a separator, e.g., '200,300;100,150').",
)
parser.add_argument(
"--labels",
type=str,
default="1",
help="Comma-separated labels for points mode (e.g., '1' or '1,0' for multiple points).",
)
parser.add_argument(
"--box",
type=str,
default="300,0,500,400",
help="Comma-separated box coordinates for box mode (e.g., '300,0,500,400').",
)
# New flag to control visualization.
parser.add_argument("--visualize", action="store_true", help="If set, visualize segmentation frames (save images)")
return parser.parse_args()
def parse_points(points_str):
"""Parse a string of points into a numpy array.
Supports a single point ('200,300') or multiple points separated by ';' (e.g., '200,300;100,150').
"""
points = []
for point in points_str.split(";"):
coords = point.split(",")
if len(coords) != 2:
continue
points.append([float(coords[0]), float(coords[1])])
return np.array(points, dtype=np.float32)
def parse_labels(labels_str):
"""Parse a comma-separated string of labels into a numpy array."""
return np.array([int(x) for x in labels_str.split(",")], dtype=np.int32)
def parse_box(box_str):
"""Parse a comma-separated string of 4 box coordinates into a numpy array."""
return np.array([float(x) for x in box_str.split(",")], dtype=np.float32)
def main():
args = parse_args()
# Initialize the segmentation model.
model = VideoSegmentationModel(**vars(args))
# Prepare input data based on the selected mode.
if args.mode == "points":
input_data = {"points": parse_points(args.points), "labels": parse_labels(args.labels)}
elif args.mode == "box":
input_data = {"box": parse_box(args.box)}
elif args.mode == "prompt":
input_data = {"text": args.prompt}
with tempfile.TemporaryDirectory() as temp_input_dir:
fps = capture_fps(args.input_video)
video_to_frames(args.input_video, temp_input_dir)
with tempfile.TemporaryDirectory() as temp_output_dir:
model.sample(
video_dir=temp_input_dir,
mode=args.mode,
input_data=input_data,
save_dir=str(temp_output_dir),
visualize=True,
)
generate_video_from_images(temp_output_dir, args.output_video, fps)
generate_tensor_from_images(temp_output_dir, args.output_tensor, fps, "mask")
if __name__ == "__main__":
print("Starting video segmentation...")
main()