File size: 20,172 Bytes
96257b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
#! /usr/bin/env python
import json
import os
import time

import click
import numpy as np
import torch

from genmo.lib.progress import progress_bar
from genmo.lib.utils import save_video
from genmo.mochi_preview.pipelines_ti2v_release import (
    DecoderModelFactory,
    EncoderModelFactory,
    DitModelFactory,
    MochiMultiGPUPipeline,
    MochiSingleGPUPipeline,
    T5ModelFactory,
    linear_quadratic_schedule,
)


import torch
from torch.utils.data import Dataset, DataLoader
import random
import string
from lightning.pytorch  import LightningDataModule
from genmo.mochi_preview.vae.models import Encoder, add_fourier_features
from genmo.mochi_preview.vae.latent_dist import LatentDistribution
import torchvision
from einops import rearrange
from safetensors.torch import load_file
from genmo.mochi_preview.pipelines import DecoderModelFactory, decode_latents_tiled_spatial, decode_latents, decode_latents_tiled_full
from genmo.mochi_preview.vae.vae_stats import dit_latents_to_vae_latents


pipeline = None
model_dir_path = None
num_gpus = torch.cuda.device_count()
cpu_offload = False
dit_path = None

def configure_model(model_dir_path_, dit_path_, cpu_offload_):
    global model_dir_path, dit_path, cpu_offload
    model_dir_path = model_dir_path_
    dit_path = dit_path_
    cpu_offload = cpu_offload_


def load_model():
    global num_gpus, pipeline, model_dir_path, dit_path
    if pipeline is None:
        MOCHI_DIR = model_dir_path
        print(f"Launching with {num_gpus} GPUs. If you want to force single GPU mode use CUDA_VISIBLE_DEVICES=0.")
        klass = MochiSingleGPUPipeline if num_gpus == 1 else MochiMultiGPUPipeline
        kwargs = dict(
            text_encoder_factory=T5ModelFactory(),
            dit_factory=DitModelFactory(
                model_path=dit_path,
                model_dtype="bf16"
            ),
            decoder_factory=DecoderModelFactory(
                model_path=f"{MOCHI_DIR}/decoder.safetensors",
            ),
            encoder_factory=EncoderModelFactory(
                model_path=f"{MOCHI_DIR}/encoder.safetensors",
            ),
        )
        if num_gpus > 1:
            assert not cpu_offload, "CPU offload not supported in multi-GPU mode"
            kwargs["world_size"] = num_gpus
        else:
            kwargs["cpu_offload"] = cpu_offload
            # kwargs["decode_type"] = "tiled_full"
            kwargs["decode_type"] = "tiled_spatial"
        pipeline = klass(**kwargs)


def generate_video(

    prompt,

    negative_prompt,

    width,

    height,

    num_frames,

    seed,

    cfg_scale,

    num_inference_steps,

    data_path,

    input_image=None,

    noise_multiplier=0,

):
    load_model()
    global dit_path

    # sigma_schedule should be a list of floats of length (num_inference_steps + 1),
    # such that sigma_schedule[0] == 1.0 and sigma_schedule[-1] == 0.0 and monotonically decreasing.
    sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025)

    # cfg_schedule should be a list of floats of length num_inference_steps.
    # For simplicity, we just use the same cfg scale at all timesteps,
    # but more optimal schedules may use varying cfg, e.g:
    # [5.0] * (num_inference_steps // 2) + [4.5] * (num_inference_steps // 2)
    cfg_schedule = [cfg_scale] * num_inference_steps

    args = {
        "height": height,
        "width": width,
        "num_frames": num_frames,
        "sigma_schedule": sigma_schedule,
        "cfg_schedule": cfg_schedule,
        "num_inference_steps": num_inference_steps,
        # We *need* flash attention to batch cfg
        # and it's only worth doing in a high-memory regime (assume multiple GPUs)
        "batch_cfg": False,
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "seed": seed,
        "data_path": data_path,
        "noise_multiplier": noise_multiplier,
    }
    
    # Handle different input types
    if input_image is not None:
        # if "tensor" in input_image:
        # Check if this is an image tensor (for image conditioning) or latent tensor
        # if len(input_image["tensor"].shape) == 4:  # [B, C, H, W] - image tensor
        # This is an image tensor, prepare it for conditioning
        # cond_position = input_image.get("cond_position", 0)
        args["condition_image"] = input_image["tensor"]
        args["condition_frame_idx"] = input_image["cond_position"]
        # else:  # Latent tensor
        #     args["input_image"] = input_image["tensor"]

    # print(args)
    with progress_bar(type="tqdm"):
        final_frames = pipeline(**args)

        final_frames = final_frames[0]

        assert isinstance(final_frames, np.ndarray)
        assert final_frames.dtype == np.float32

        # Create a results directory based on model name and timestamp
        model_name = os.path.basename(dit_path.split('/')[-2])
        checkpoint_name = dit_path.split('/')[-1].split('train_loss')[0]
        # Use datetime format for timestamp_dir
        from datetime import datetime
        timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        if input_image is not None:
            cond_position = input_image["cond_position"]
        else:
            cond_position = ""
        results_base_dir = "/home/dyvm6xra/dyvm6xrauser02/raphael/mochi-1-preview/models/video_test_demos_results"
        results_dir = os.path.join(results_base_dir, f"{model_name}_{checkpoint_name}_github_user_demo_{cond_position}pos_{num_inference_steps}steps_crop_{noise_multiplier}sigma")
        os.makedirs(results_dir, exist_ok=True)
        
        # Extract filename from input_image if available
        filename_prefix = ""
        if isinstance(input_image, dict) and "filename" in input_image:
            filename_prefix = f"{os.path.basename(input_image['filename']).split('.')[0]}_"
            
        output_path = os.path.join(
            results_dir,
            f"{filename_prefix}{timestamp_str}.mp4"
        )
        save_video(final_frames, output_path)
        json_path = os.path.splitext(output_path)[0] + ".json"
        
        # Save args to JSON but remove input_image tensor and convert non-serializable objects
        json_args = args.copy()
        
        # Handle input_image for JSON serialization
        if "input_image" in json_args:
            json_args["input_image"] = None
        
        # Handle condition_image for JSON serialization
        if "condition_image" in json_args:
            json_args["condition_image"] = "Image tensor (removed for JSON)"
            
        if isinstance(input_image, dict):
            json_args["input_filename"] = input_image.get("filename", None)
            if "cond_position" in input_image:
                json_args["condition_frame_idx"] = input_image["cond_position"]
        
        # Convert sigma_schedule and cfg_schedule from tensors to lists if needed
        if isinstance(json_args["sigma_schedule"], torch.Tensor):
            json_args["sigma_schedule"] = json_args["sigma_schedule"].tolist()
        if isinstance(json_args["cfg_schedule"], torch.Tensor):
            json_args["cfg_schedule"] = json_args["cfg_schedule"].tolist()
            
        # Handle prompt if it's a tensor or other non-serializable object
        if not isinstance(json_args["prompt"], (str, type(None))):
            if hasattr(json_args["prompt"], "tolist"):
                json_args["prompt"] = "Tensor prompt (converted to string for JSON)"
            else:
                json_args["prompt"] = str(json_args["prompt"])
                
        # Handle negative_prompt if it's a tensor
        if not isinstance(json_args["negative_prompt"], (str, type(None))):
            if hasattr(json_args["negative_prompt"], "tolist"):
                json_args["negative_prompt"] = "Tensor negative prompt (converted to string for JSON)"
            else:
                json_args["negative_prompt"] = str(json_args["negative_prompt"])
                
        json.dump(json_args, open(json_path, "w"), indent=4)

        return output_path

from textwrap import dedent


@click.command()
@click.option("--prompt", default="A man is playing the basketball", help="Prompt for video generation.")
@click.option("--negative_prompt", default="", help="Negative prompt for video generation.")
@click.option("--width", default=848, type=int, help="Width of the video.")
@click.option("--height", default=480, type=int, help="Height of the video.")
@click.option("--num_frames", default=163, type=int, help="Number of frames.")
@click.option("--seed", default=1710977262, type=int, help="Random seed.")
@click.option("--cfg_scale", default=4.5, type=float, help="CFG Scale.")
@click.option("--num_steps", default=64, type=int, help="Number of inference steps.")
@click.option("--model_dir", required=True, help="Path to the model directory.")
@click.option("--dit_path", required=True, help="Path to the dit model directory.")
@click.option("--cpu_offload", is_flag=True, help="Whether to offload model to CPU")
@click.option("--data_path", required=True, default="/home/dyvm6xra/dyvm6xrauser02/data/vidgen1m", help="Path to the data directory.")
@click.option("--image_dir", default=None, help="Path to image or directory of images for conditioning.")
@click.option("--prompt_dir", default=None, help="Path to directory containing prompt text files.")
@click.option("--cond_position", default=0, type=int, help="Frame position to place the conditioning image, from 0 to 27.")
@click.option("--noise_multiplier", default=0, type=float, help="Noise multiplier for noise on the conditioning image.")

def generate_cli(

    prompt, negative_prompt, width, height, num_frames, seed, cfg_scale, num_steps, model_dir, 

    dit_path, cpu_offload, data_path, image_dir, prompt_dir, cond_position, noise_multiplier

):
    configure_model(model_dir, dit_path, cpu_offload)

    # Case 1: Text to video generation
    if image_dir is None:
        
        # Check if we have a prompt directory to process multiple prompts
        if prompt_dir is not None and os.path.isdir(prompt_dir):
            prompt_files = [f for f in os.listdir(prompt_dir) if f.lower().endswith('.txt')]

            prompt_files = prompt_files[210:] # TODO: Remove this

            if not prompt_files:
                click.echo(f"No prompt files found in {prompt_dir}")
                return
                
            click.echo(f"Found {len(prompt_files)} prompt files to process")
            
            for i, prompt_file in enumerate(prompt_files):
                file_path = os.path.join(prompt_dir, prompt_file)
                click.echo(f"Processing prompt file {i+1}/{len(prompt_files)}: {file_path}")
                
                # Read prompt from file
                with open(file_path, 'r') as f:
                    file_prompt = f.read().strip()
                
                click.echo(f"Using prompt: {file_prompt}")
                
                with torch.inference_mode():
                    output = generate_video(
                        file_prompt,
                        negative_prompt,
                        width,
                        height,
                        num_frames,
                        seed,
                        cfg_scale,
                        num_steps,
                        data_path,
                        input_image=None,
                        noise_multiplier=noise_multiplier,
                    )
                    click.echo(f"Video generated at: {output}")
        else:
            # Process single prompt as before
            click.echo("Running text-to-video generation with provided prompt")
            with torch.inference_mode():
                output = generate_video(
                    prompt,
                    negative_prompt,
                    width,
                    height,
                    num_frames,
                    seed,
                    cfg_scale,
                    num_steps,
                    data_path,
                    input_image=None,
                )
                click.echo(f"Video generated at: {output}")
        return

    config = dict(
            prune_bottlenecks=[False, False, False, False, False],
            has_attentions=[False, True, True, True, True],
            affine=True,
            bias=True,
            input_is_conv_1x1=True,
            padding_mode="replicate",
        )
    # Create VAE encoder
    encoder = Encoder(
        in_channels=15,
        base_channels=64,
        channel_multipliers=[1, 2, 4, 6],
        num_res_blocks=[3, 3, 4, 6, 3],
        latent_dim=12,
        temporal_reductions=[1, 2, 3],
        spatial_reductions=[2, 2, 2],
        **config,
    )
    device = torch.device("cuda:0")
    encoder = encoder.to(device, memory_format=torch.channels_last_3d)
    encoder.load_state_dict(load_file(f"{model_dir}/encoder.safetensors"))
    encoder.eval()

    # Case 2: Image-to-video, image_dir is a single file
    if image_dir is not None and os.path.isfile(image_dir) and image_dir.lower().endswith(('.jpg', '.jpeg', '.png')):
        click.echo(f"Processing single image: {image_dir}")
        
        # Load the image
        from PIL import Image
        import torchvision.transforms as transforms
        
        image = Image.open(image_dir)
        
        # Crop and resize the image to the target dimensions rather than directly resize
        # Calculate crop dimensions to maintain aspect ratio
        target_ratio = width / height
        current_ratio = image.width / image.height
        
        if current_ratio > target_ratio:
            # Image is wider than target ratio - crop width
            new_width = int(image.height * target_ratio)
            x1 = (image.width - new_width) // 2
            image = image.crop((x1, 0, x1 + new_width, image.height))
        else:
            # Image is taller than target ratio - crop height
            new_height = int(image.width / target_ratio)
            y1 = (image.height - new_height) // 2
            image = image.crop((0, y1, image.width, y1 + new_height))
        
        # Now resize the cropped image
        transform = transforms.Compose([
            transforms.Resize((height, width)),
            transforms.ToTensor(),
        ])
        
        image_tensor = (transform(image)* 2 - 1).unsqueeze(1).unsqueeze(0)
        print("image_tensor.shape", image_tensor.shape)
        image_tensor = add_fourier_features(image_tensor.to(device))
        # Encode image to latent
        with torch.inference_mode():
            with torch.autocast("cuda", dtype=torch.bfloat16):
                t0 = time.time()
                encoder = encoder.to(device)
                ldist = encoder(image_tensor)
                image_tensor = ldist.sample()
        
        torch.cuda.empty_cache()
        encoder = encoder.to("cpu")
        del ldist
        
        # Package input for generate_video
        input_image = {
            "tensor": image_tensor,
            "filename": os.path.basename(image_dir),
            "cond_position": cond_position
        }
        
        with torch.inference_mode():
            output = generate_video(
                prompt,
                negative_prompt,
                width,
                height,
                num_frames,
                seed,
                cfg_scale,
                num_steps,
                data_path,
                input_image,
                noise_multiplier=noise_multiplier,
            )
            click.echo(f"Video generated at: {output}")
        return
    

    # Case 3: image_dir is a directory of images
    if image_dir is not None and os.path.isdir(image_dir):
        # Get all image files in the directory
        image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        
        if not image_files:
            click.echo(f"No image files found in {image_dir}")
            return
        
        click.echo(f"Found {len(image_files)} image files to process")
        
        from PIL import Image
        import torchvision.transforms as transforms
        
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        
        with torch.inference_mode():
            for i, image_file in enumerate(image_files):
                file_path = os.path.join(image_dir, image_file)
                click.echo(f"Processing file {i+1}/{len(image_files)}: {file_path}")
                
                # Load image
                image = Image.open(file_path)
                
                # Calculate crop dimensions to maintain aspect ratio
                target_ratio = width / height
                current_ratio = image.width / image.height
                
                if current_ratio > target_ratio:
                    # Image is wider than target ratio - crop width
                    new_width = int(image.height * target_ratio)
                    x1 = (image.width - new_width) // 2
                    image = image.crop((x1, 0, x1 + new_width, image.height))
                else:
                    # Image is taller than target ratio - crop height
                    new_height = int(image.width / target_ratio)
                    y1 = (image.height - new_height) // 2
                    image = image.crop((0, y1, image.width, y1 + new_height))
                
                # Now resize the cropped image
                image = image.resize((width, height))
                
                image_tensor = (transform(image)* 2 - 1).unsqueeze(1).unsqueeze(0)
                print("image_tensor.shape", image_tensor.shape)
                image_tensor = add_fourier_features(image_tensor.to(device))
                # Encode image to latent
                with torch.inference_mode():
                    with torch.autocast("cuda", dtype=torch.bfloat16):
                        t0 = time.time()
                        encoder = encoder.to(device)
                        ldist = encoder(image_tensor)
                        image_tensor = ldist.sample()
                torch.cuda.empty_cache()
                encoder = encoder.to("cpu")
                del ldist

                # Get corresponding prompt
                img_basename = os.path.basename(file_path).split('.')[0]
                prompt_file = os.path.join(prompt_dir, f"{img_basename}.txt")
                
                if os.path.exists(prompt_file):
                    with open(prompt_file, 'r') as f:
                        file_prompt = f.read().strip()
                    click.echo(f"Using prompt from file: {file_prompt}")
                else:
                    click.echo(f"Warning: Prompt file not found for {file_path}. Using default prompt.")
            
                # Package input for generate_video
                input_image = {
                    "tensor": image_tensor,
                    "filename": os.path.basename(file_path),
                    "cond_position": cond_position
                }
                
                output = generate_video(
                    file_prompt,
                    negative_prompt,
                    width,
                    height,
                    num_frames,
                    seed,
                    cfg_scale,
                    num_steps,
                    data_path,
                    input_image,
                    noise_multiplier=noise_multiplier,
                )
                click.echo(f"Video generated at: {output}")
        return


if __name__ == "__main__":
    generate_cli()