|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
from cosmos_transfer1.auxiliary.upsampler.model.upsampler import PixtralPromptUpsampler |
|
from cosmos_transfer1.utils.misc import extract_video_frames |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Prompt upsampler pipeline") |
|
parser.add_argument("--prompt", type=str, required=False, help="Prompt to upsample") |
|
parser.add_argument("--input_video", type=str, required=True, help="Path to input video file") |
|
parser.add_argument( |
|
"--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" |
|
) |
|
parser.add_argument( |
|
"--offload_prompt_upsampler", action="store_true", help="Offload prompt upsampler model after inference" |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
model = PixtralPromptUpsampler(args.checkpoint_dir, offload_prompt_upsampler=args.offload_prompt_upsampler) |
|
|
|
|
|
frame_paths = extract_video_frames(args.input_video) |
|
upsampled_prompt = model._prompt_upsample_with_offload(args.prompt, frame_paths) |
|
print("Upsampled prompt:", upsampled_prompt) |
|
|
|
|
|
if __name__ == "__main__": |
|
import os |
|
|
|
rank = int(os.environ["RANK"]) |
|
|
|
dist_keys = [ |
|
"RANK", |
|
"LOCAL_RANK", |
|
"WORLD_SIZE", |
|
"LOCAL_WORLD_SIZE", |
|
"GROUP_RANK", |
|
"ROLE_RANK", |
|
"ROLE_NAME", |
|
"OMP_NUM_THREADS", |
|
"MASTER_ADDR", |
|
"MASTER_PORT", |
|
"TORCHELASTIC_USE_AGENT_STORE", |
|
"TORCHELASTIC_MAX_RESTARTS", |
|
"TORCHELASTIC_RUN_ID", |
|
"TORCH_NCCL_ASYNC_ERROR_HANDLING", |
|
"TORCHELASTIC_ERROR_FILE", |
|
] |
|
|
|
for dist_key in dist_keys: |
|
del os.environ[dist_key] |
|
|
|
if rank == 0: |
|
main() |
|
|