Spaces:
Paused
Paused
import json | |
import os | |
import random | |
from abc import ABC, abstractmethod | |
from contextlib import contextmanager | |
from functools import partial | |
from typing import Any, Dict, List, Literal, Optional, Union, cast | |
import numpy as np | |
import ray | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import repeat | |
from safetensors.torch import load_file | |
from torch import nn | |
from torch.distributed.fsdp import ( | |
BackwardPrefetch, | |
MixedPrecision, | |
ShardingStrategy, | |
) | |
from torch.distributed.fsdp import ( | |
FullyShardedDataParallel as FSDP, | |
) | |
from torch.distributed.fsdp.wrap import ( | |
lambda_auto_wrap_policy, | |
transformer_auto_wrap_policy, | |
) | |
from transformers import T5EncoderModel, T5Tokenizer | |
from transformers.models.t5.modeling_t5 import T5Block | |
import genmo.mochi_preview.dit.joint_model.context_parallel as cp | |
import genmo.mochi_preview.vae.cp_conv as cp_conv | |
from genmo.lib.progress import get_new_progress_bar, progress_bar | |
from genmo.lib.utils import Timer | |
from genmo.mochi_preview.vae.models import ( | |
Decoder, | |
decode_latents, | |
encode_latents, | |
decode_latents_tiled_full, | |
decode_latents_tiled_spatial, | |
) | |
from genmo.mochi_preview.vae.vae_stats import dit_latents_to_vae_latents | |
import ipdb | |
from genmo.mochi_preview.vae.models import Encoder, add_fourier_features | |
from datetime import datetime | |
from genmo.mochi_preview.vae.latent_dist import LatentDistribution | |
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): | |
if linear_steps is None: | |
linear_steps = num_steps // 2 | |
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] | |
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps | |
quadratic_steps = num_steps - linear_steps | |
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) | |
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) | |
const = quadratic_coef * (linear_steps**2) | |
quadratic_sigma_schedule = [ | |
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) | |
] | |
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] | |
sigma_schedule = [1.0 - x for x in sigma_schedule] | |
return sigma_schedule | |
# T5_MODEL = "google/t5-v1_1-xxl" | |
T5_MODEL = "/home/dyvm6xra/dyvm6xrauser02/AIGC/t5-v1_1-xxl" | |
MAX_T5_TOKEN_LENGTH = 256 | |
def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP: | |
model = FSDP( | |
model, | |
sharding_strategy=ShardingStrategy.FULL_SHARD, | |
mixed_precision=MixedPrecision( | |
param_dtype=param_dtype, | |
reduce_dtype=torch.float32, | |
buffer_dtype=torch.float32, | |
), | |
auto_wrap_policy=auto_wrap_policy, | |
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, | |
limit_all_gathers=True, | |
device_id=device_id, | |
sync_module_states=True, | |
use_orig_params=True, | |
) | |
torch.cuda.synchronize() | |
return model | |
class ModelFactory(ABC): | |
def __init__(self, **kwargs): | |
self.kwargs = kwargs | |
def get_model(self, *, local_rank: int, device_id: Union[int, Literal["cpu"]], world_size: int) -> Any: | |
if device_id == "cpu": | |
assert world_size == 1, "CPU offload only supports single-GPU inference" | |
class T5ModelFactory(ModelFactory): | |
def __init__(self): | |
super().__init__() | |
def get_model(self, *, local_rank, device_id, world_size): | |
super().get_model(local_rank=local_rank, device_id=device_id, world_size=world_size) | |
model = T5EncoderModel.from_pretrained(T5_MODEL) | |
if world_size > 1: | |
model = setup_fsdp_sync( | |
model, | |
device_id=device_id, | |
param_dtype=torch.float32, | |
auto_wrap_policy=partial( | |
transformer_auto_wrap_policy, | |
transformer_layer_cls={ | |
T5Block, | |
}, | |
), | |
) | |
elif isinstance(device_id, int): | |
model = model.to(torch.device(f"cuda:{device_id}")) # type: ignore | |
return model.eval() | |
class DitModelFactory(ModelFactory): | |
def __init__(self, *, model_path: str, model_dtype: str, attention_mode: Optional[str] = None): | |
if attention_mode is None: | |
from genmo.lib.attn_imports import flash_varlen_qkvpacked_attn # type: ignore | |
attention_mode = "sdpa" if flash_varlen_qkvpacked_attn is None else "flash" | |
print(f"Attention mode: {attention_mode}") | |
super().__init__( | |
model_path=model_path, model_dtype=model_dtype, attention_mode=attention_mode | |
) | |
def get_model(self, *, local_rank, device_id, world_size): | |
# TODO(ved): Set flag for torch.compile | |
from genmo.mochi_preview.dit.joint_model.asymm_models_joint import ( | |
AsymmDiTJoint, | |
) | |
model: nn.Module = torch.nn.utils.skip_init( | |
AsymmDiTJoint, | |
depth=48, | |
patch_size=2, | |
num_heads=24, | |
hidden_size_x=3072, | |
hidden_size_y=1536, | |
mlp_ratio_x=4.0, | |
mlp_ratio_y=4.0, | |
in_channels=12, | |
qk_norm=True, | |
qkv_bias=False, | |
out_bias=True, | |
patch_embed_bias=True, | |
timestep_mlp_bias=True, | |
timestep_scale=1000.0, | |
t5_feat_dim=4096, | |
t5_token_length=256, | |
rope_theta=10000.0, | |
attention_mode=self.kwargs["attention_mode"], | |
) | |
if local_rank == 0: | |
# FSDP syncs weights from rank 0 to all other ranks | |
model.load_state_dict(load_file(self.kwargs["model_path"])) | |
if world_size > 1: | |
assert self.kwargs["model_dtype"] == "bf16", "FP8 is not supported for multi-GPU inference" | |
model = setup_fsdp_sync( | |
model, | |
device_id=device_id, | |
param_dtype=torch.bfloat16, | |
auto_wrap_policy=partial( | |
lambda_auto_wrap_policy, | |
lambda_fn=lambda m: m in model.blocks, | |
), | |
) | |
elif isinstance(device_id, int): | |
model = model.to(torch.device(f"cuda:{device_id}")) | |
return model.eval() | |
class DecoderModelFactory(ModelFactory): | |
def __init__(self, *, model_path: str): | |
super().__init__(model_path=model_path) | |
def get_model(self, *, local_rank, device_id, world_size): | |
# TODO(ved): Set flag for torch.compile | |
# TODO(ved): Use skip_init | |
decoder = Decoder( | |
out_channels=3, | |
base_channels=128, | |
channel_multipliers=[1, 2, 4, 6], | |
temporal_expansions=[1, 2, 3], | |
spatial_expansions=[2, 2, 2], | |
num_res_blocks=[3, 3, 4, 6, 3], | |
latent_dim=12, | |
has_attention=[False, False, False, False, False], | |
output_norm=False, | |
nonlinearity="silu", | |
output_nonlinearity="silu", | |
causal=True, | |
) | |
# VAE is not FSDP-wrapped | |
state_dict = load_file(self.kwargs["model_path"]) | |
decoder.load_state_dict(state_dict, strict=True) | |
device = torch.device(f"cuda:{device_id}") if isinstance(device_id, int) else "cpu" | |
decoder.eval().to(device) | |
return decoder | |
class EncoderModelFactory(ModelFactory): | |
def __init__(self, *, model_path: str): | |
super().__init__(model_path=model_path) | |
def get_model(self, *, local_rank, device_id, world_size): | |
config = dict( | |
prune_bottlenecks=[False, False, False, False, False], | |
has_attentions=[False, True, True, True, True], | |
affine=True, | |
bias=True, | |
input_is_conv_1x1=True, | |
padding_mode="replicate" | |
) | |
encoder = Encoder( | |
in_channels=15, | |
base_channels=64, | |
channel_multipliers=[1, 2, 4, 6], | |
temporal_reductions=[1, 2, 3], | |
spatial_reductions=[2, 2, 2], | |
num_res_blocks=[3, 3, 4, 6, 3], | |
latent_dim=12, | |
**config, | |
) | |
state_dict = load_file(self.kwargs["model_path"]) | |
encoder.load_state_dict(state_dict, strict=True) | |
device = torch.device(f"cuda:{device_id}") if isinstance(device_id, int) else "cpu" | |
encoder = encoder.to(memory_format=torch.channels_last_3d) | |
encoder.eval().to(device) | |
return encoder | |
def get_conditioning(tokenizer, encoder, device, batch_inputs, *, prompt: str, negative_prompt: str): | |
if batch_inputs: | |
return dict(batched=get_conditioning_for_prompts(tokenizer, encoder, device, [prompt, negative_prompt])) | |
else: | |
cond_input = get_conditioning_for_prompts(tokenizer, encoder, device, [prompt]) | |
null_input = get_conditioning_for_prompts(tokenizer, encoder, device, [negative_prompt]) | |
return dict(cond=cond_input, null=null_input) | |
def get_conditioning_for_prompts(tokenizer, encoder, device, prompts: List[str]): | |
assert len(prompts) in [1, 2] # [neg] or [pos] or [pos, neg] | |
B = len(prompts) | |
t5_toks = tokenizer( | |
prompts, | |
padding="max_length", | |
truncation=True, | |
max_length=MAX_T5_TOKEN_LENGTH, | |
return_tensors="pt", | |
return_attention_mask=True, | |
) | |
caption_input_ids_t5 = t5_toks["input_ids"] | |
caption_attention_mask_t5 = t5_toks["attention_mask"].bool() | |
del t5_toks | |
assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH) | |
assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH) | |
# Special-case empty negative prompt by zero-ing it | |
if prompts[-1] == "": | |
caption_input_ids_t5[-1] = 0 | |
caption_attention_mask_t5[-1] = False | |
caption_input_ids_t5 = caption_input_ids_t5.to(device, non_blocking=True) | |
caption_attention_mask_t5 = caption_attention_mask_t5.to(device, non_blocking=True) | |
y_mask = [caption_attention_mask_t5] | |
y_feat = [encoder(caption_input_ids_t5, caption_attention_mask_t5).last_hidden_state.detach()] | |
# Sometimes returns a tensor, othertimes a tuple, not sure why | |
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3 | |
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096) | |
assert y_feat[-1].dtype == torch.float32 | |
return dict(y_mask=y_mask, y_feat=y_feat) | |
def compute_packed_indices( | |
device: torch.device, text_mask: torch.Tensor, num_latents: int | |
) -> Dict[str, Union[torch.Tensor, int]]: | |
""" | |
Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80 | |
Args: | |
num_latents: Number of latent tokens | |
text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding. | |
Returns: | |
packed_indices: Dict with keys for Flash Attention: | |
- valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding) | |
in the packed sequence. | |
- cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence. | |
- max_seqlen_in_batch_kv: int of the maximum sequence length in the batch. | |
""" | |
# Create an expanded token mask saying which tokens are valid across both visual and text tokens. | |
PATCH_SIZE = 2 | |
num_visual_tokens = num_latents // (PATCH_SIZE**2) | |
assert num_visual_tokens > 0 | |
mask = F.pad(text_mask, (num_visual_tokens, 0), value=True) # (B, N + L) | |
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,) | |
valid_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() # up to (B * (N + L),) | |
assert valid_token_indices.size(0) >= text_mask.size(0) * num_visual_tokens # At least (B * N,) | |
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | |
max_seqlen_in_batch = seqlens_in_batch.max().item() | |
return { | |
"cu_seqlens_kv": cu_seqlens.to(device, non_blocking=True), | |
"max_seqlen_in_batch_kv": cast(int, max_seqlen_in_batch), | |
"valid_token_indices_kv": valid_token_indices.to(device, non_blocking=True), | |
} | |
def assert_eq(x, y, msg=None): | |
assert x == y, f"{msg or 'Assertion failed'}: {x} != {y}" | |
def sample_model(device, dit, encoder, condition_image, condition_frame_idx, conditioning, **args): | |
random.seed(args["seed"]) | |
np.random.seed(args["seed"]) | |
torch.manual_seed(args["seed"]) | |
generator = torch.Generator(device=device) | |
generator.manual_seed(args["seed"]) | |
w, h, t = 848, 480, 163 # Directly initialized width, height, num_frames | |
sample_steps = args["num_inference_steps"] | |
cfg_schedule = args["cfg_schedule"] | |
sigma_schedule = args["sigma_schedule"] | |
noise_multiplier = args["noise_multiplier"] | |
assert_eq(len(cfg_schedule), sample_steps, "cfg_schedule must have length sample_steps") | |
assert_eq( | |
len(sigma_schedule), | |
sample_steps + 1, | |
"sigma_schedule must have length sample_steps + 1", | |
) | |
if condition_image is not None: | |
B = condition_image.shape[0] | |
else: | |
B = 1 | |
SPATIAL_DOWNSAMPLE = 8 | |
TEMPORAL_DOWNSAMPLE = 6 | |
IN_CHANNELS = 12 | |
latent_t = ((t - 1) // TEMPORAL_DOWNSAMPLE) + 1 | |
latent_w, latent_h = w // SPATIAL_DOWNSAMPLE, h // SPATIAL_DOWNSAMPLE | |
z_0 = torch.zeros( | |
(B, IN_CHANNELS, latent_t, latent_h, latent_w), | |
device=device, | |
dtype=torch.float32, | |
) | |
cond_latent = condition_image | |
if isinstance(condition_frame_idx, list): | |
z_0[:,:, condition_frame_idx,:,:] = cond_latent[:,:,condition_frame_idx] | |
elif isinstance(condition_frame_idx, int): | |
z_0[:,:,condition_frame_idx:(condition_frame_idx+1),:,:] = cond_latent | |
num_latents = latent_t * latent_h * latent_w | |
cond_batched = cond_text = cond_null = None | |
if "cond" in conditioning: | |
cond_text = conditioning["cond"] | |
cond_null = conditioning["null"] | |
cond_text["packed_indices"] = compute_packed_indices(device, cond_text["y_mask"][0], num_latents) | |
cond_null["packed_indices"] = compute_packed_indices(device, cond_null["y_mask"][0], num_latents) | |
else: | |
cond_batched = conditioning["batched"] | |
cond_batched["packed_indices"] = compute_packed_indices(device, cond_batched["y_mask"][0], num_latents) | |
z_0 = repeat(z_0, "b ... -> (repeat b) ...", repeat=2) | |
def model_fn(*, z, sigma, cfg_scale): | |
if cond_batched: | |
with torch.autocast("cuda", dtype=torch.bfloat16): | |
out = dit(z, sigma, **cond_batched) | |
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) | |
else: | |
nonlocal cond_text, cond_null | |
with torch.autocast("cuda", dtype=torch.bfloat16): | |
out_cond = dit(z, sigma, **cond_text) | |
out_uncond = dit(z, sigma, **cond_null) | |
assert out_cond.shape == out_uncond.shape | |
out_uncond = out_uncond.to(z) | |
out_cond = out_cond.to(z) | |
return out_uncond + cfg_scale * (out_cond - out_uncond) | |
# Euler sampler w/ customizable sigma schedule & cfg scale | |
for i in get_new_progress_bar(range(0, sample_steps), desc="Sampling"): | |
sigma = sigma_schedule[i] | |
bs = B if cond_text else B * 2 | |
sigma = torch.tensor([sigma] * (bs * latent_t), device=device).reshape((bs, latent_t)) | |
# if condition_frame_idx is list: | |
if isinstance(condition_frame_idx, list): # Any frames to video | |
sigma[:, condition_frame_idx] = sigma[:, condition_frame_idx] / 4 | |
elif isinstance(condition_frame_idx, int): #I2V | |
sigma[:,condition_frame_idx] = sigma[:,condition_frame_idx] * float(noise_multiplier) | |
if i == 0: | |
z = (1.0 - sigma[:B].view(B, 1, latent_t, 1, 1)) * z_0 + sigma[:B].view(B, 1, latent_t, 1, 1) * torch.randn( | |
(B, IN_CHANNELS, latent_t, latent_h, latent_w), | |
device=device, | |
dtype=torch.bfloat16, | |
) | |
if "cond" not in conditioning: | |
z = repeat(z, "b ... -> (repeat b) ...", repeat=2) | |
dsigma = sigma - sigma_schedule[i + 1] | |
if isinstance(condition_frame_idx, list): | |
dsigma[:, condition_frame_idx] = sigma[:, condition_frame_idx] - sigma_schedule[i + 1] / 4 | |
elif isinstance(condition_frame_idx, int): | |
dsigma[:,condition_frame_idx] = sigma[:,condition_frame_idx] - sigma_schedule[i + 1] * float(noise_multiplier) | |
pred = model_fn( | |
z=z, | |
sigma=sigma, | |
cfg_scale=cfg_schedule[i], | |
) | |
assert pred.dtype == torch.float32 | |
z = z + dsigma.view(1, 1, z.shape[2], 1, 1) * pred | |
pred_last = pred | |
z = z[:B] if cond_batched else z | |
return dit_latents_to_vae_latents(z) | |
def move_to_device(model: nn.Module, target_device): | |
og_device = next(model.parameters()).device | |
if og_device == target_device: | |
print(f"move_to_device is a no-op model is already on {target_device}") | |
else: | |
print(f"moving model from {og_device} -> {target_device}") | |
model.to(target_device) | |
yield | |
if og_device != target_device: | |
print(f"moving model from {target_device} -> {og_device}") | |
model.to(og_device) | |
def t5_tokenizer(): | |
return T5Tokenizer.from_pretrained(T5_MODEL, legacy=False) | |
class MochiSingleGPUPipeline: | |
def __init__( | |
self, | |
*, | |
text_encoder_factory: ModelFactory, | |
dit_factory: ModelFactory, | |
decoder_factory: ModelFactory, | |
encoder_factory: ModelFactory, | |
cpu_offload: Optional[bool] = False, | |
decode_type: str = "full", | |
decode_args: Optional[Dict[str, Any]] = None, | |
): | |
self.device = torch.device("cuda:0") | |
self.tokenizer = t5_tokenizer() | |
t = Timer() | |
self.cpu_offload = cpu_offload | |
self.decode_args = decode_args or {} | |
self.decode_type = decode_type | |
init_id = "cpu" if cpu_offload else 0 | |
with t("load_text_encoder"): | |
self.text_encoder = text_encoder_factory.get_model( | |
local_rank=0, | |
device_id=init_id, | |
world_size=1, | |
) | |
with t("load_dit"): | |
self.dit = dit_factory.get_model(local_rank=0, device_id=init_id, world_size=1) | |
with t("load_vae"): | |
self.decoder = decoder_factory.get_model(local_rank=0, device_id=init_id, world_size=1) | |
# self.encoder = encoder_factory.get_model(local_rank=0, device_id=init_id, world_size=1) | |
self.encoder = None | |
t.print_stats() | |
def __call__(self, batch_cfg, prompt, negative_prompt,condition_image=None, condition_frame_idx=None, **kwargs): | |
with torch.inference_mode(): | |
print_max_memory = lambda: print( | |
f"Max memory reserved: {torch.cuda.max_memory_reserved() / 1024**3:.2f} GB" | |
) | |
print_max_memory() | |
with move_to_device(self.text_encoder, self.device): | |
conditioning = get_conditioning( | |
self.tokenizer, | |
self.text_encoder, | |
self.device, | |
batch_cfg, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
) | |
# del self.text_encoder | |
self.text_encoder = self.text_encoder.to("cpu") | |
print_max_memory() | |
with move_to_device(self.dit, self.device): | |
latents = sample_model(self.device, self.dit, self.encoder, condition_image, condition_frame_idx, conditioning, **kwargs) | |
print_max_memory() | |
torch.cuda.empty_cache() | |
del self.dit | |
self.decode_type == "tiled_spatial" | |
with move_to_device(self.decoder, self.device): | |
frames = ( | |
decode_latents_tiled_full(self.decoder, latents, **self.decode_args) | |
if self.decode_type == "tiled_full" | |
else decode_latents_tiled_spatial(self.decoder, latents, num_tiles_w=2, num_tiles_h=2, **self.decode_args) | |
if self.decode_type == "tiled_spatial" | |
else decode_latents(self.decoder, latents) | |
) | |
print_max_memory() | |
return frames.cpu().numpy() | |
### ALL CODE BELOW HERE IS FOR MULTI-GPU MODE ### | |
# In multi-gpu mode, all models must belong to a device which has a predefined context parallel group | |
# So it doesn't make sense to work with models individually | |
class MultiGPUContext: | |
def __init__( | |
self, | |
*, | |
text_encoder_factory, | |
dit_factory, | |
decoder_factory, | |
encoder_factory, | |
device_id, | |
local_rank, | |
world_size, | |
): | |
t = Timer() | |
self.device = torch.device(f"cuda:{device_id}") | |
print(f"Initializing rank {local_rank+1}/{world_size}") | |
assert world_size > 1, f"Multi-GPU mode requires world_size > 1, got {world_size}" | |
os.environ["MASTER_ADDR"] = "127.0.0.1" | |
os.environ["MASTER_PORT"] = "29503" | |
with t("init_process_group"): | |
dist.init_process_group( | |
"nccl", | |
rank=local_rank, | |
world_size=world_size, | |
device_id=self.device, # force non-lazy init | |
) | |
pg = dist.group.WORLD | |
cp.set_cp_group(pg, list(range(world_size)), local_rank) | |
distributed_kwargs = dict(local_rank=local_rank, device_id=device_id, world_size=world_size) | |
self.world_size = world_size | |
self.tokenizer = t5_tokenizer() | |
with t("load_text_encoder"): | |
self.text_encoder = text_encoder_factory.get_model(**distributed_kwargs) | |
with t("load_dit"): | |
self.dit = dit_factory.get_model(**distributed_kwargs) | |
with t("load_vae"): | |
self.decoder = decoder_factory.get_model(**distributed_kwargs) | |
# self.encoder = encoder_factory.get_model(**distributed_kwargs) | |
self.encoder = None | |
self.local_rank = local_rank | |
t.print_stats() | |
def run(self, *, fn, **kwargs): | |
return fn(self, **kwargs) | |
class MochiMultiGPUPipeline: | |
def __init__( | |
self, | |
*, | |
text_encoder_factory: ModelFactory, | |
dit_factory: ModelFactory, | |
decoder_factory: ModelFactory, | |
encoder_factory: ModelFactory, | |
world_size: int, | |
): | |
ray.init( | |
address="local", # Force new cluster creation | |
# port=6380, # Use different port than the exisiting ray cluster | |
include_dashboard=False, | |
num_cpus=8*world_size, | |
num_gpus=world_size, | |
# logging_level="DEBUG", | |
object_store_memory=512 * 1024 * 1024 * 1024, | |
) | |
RemoteClass = ray.remote(MultiGPUContext) | |
self.ctxs = [ | |
RemoteClass.options(num_gpus=1).remote( | |
text_encoder_factory=text_encoder_factory, | |
dit_factory=dit_factory, | |
decoder_factory=decoder_factory, | |
encoder_factory=encoder_factory, | |
world_size=world_size, | |
device_id=0, | |
local_rank=i, | |
) | |
for i in range(world_size) | |
] | |
for ctx in self.ctxs: | |
ray.get(ctx.__ray_ready__.remote()) | |
def __call__(self, **kwargs): | |
def sample(ctx, *, batch_cfg, prompt, negative_prompt, condition_image=None, condition_frame_idx=None, **kwargs): | |
with progress_bar(type="ray_tqdm", enabled=ctx.local_rank == 0), torch.inference_mode(): | |
# Move condition_image to the appropriate device | |
if condition_image is not None: | |
condition_image = condition_image.to(ctx.device) | |
print(prompt) | |
conditioning = get_conditioning( | |
ctx.tokenizer, | |
ctx.text_encoder, | |
ctx.device, | |
batch_cfg, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
) | |
latents = sample_model(ctx.device, ctx.dit, ctx.encoder, condition_image=condition_image, condition_frame_idx=condition_frame_idx, conditioning=conditioning, **kwargs) | |
if ctx.local_rank == 0: | |
torch.save(latents, "latents.pt") | |
frames = decode_latents(ctx.decoder, latents) | |
return frames.cpu().numpy() | |
return ray.get([ctx.run.remote(fn=sample, **kwargs, show_progress=i == 0) for i, ctx in enumerate(self.ctxs)])[0] |