Spaces:
Paused
Paused
from typing import Tuple, Union | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
import genmo.mochi_preview.dit.joint_model.context_parallel as cp | |
def cast_tuple(t, length=1): | |
return t if isinstance(t, tuple) else ((t,) * length) | |
def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor: | |
""" | |
Forward pass that handles communication between ranks for inference. | |
Args: | |
x: Tensor of shape (B, C, T, H, W) | |
frames_to_send: int, number of frames to communicate between ranks | |
Returns: | |
output: Tensor of shape (B, C, T', H, W) | |
""" | |
cp_rank, cp_world_size = cp.get_cp_rank_size() | |
if frames_to_send == 0 or cp_world_size == 1: | |
return x | |
group = cp.get_cp_group() | |
global_rank = dist.get_rank() | |
# print(f"cp_rank: {cp_rank}, cp_world_size: {cp_world_size}, global_rank: {global_rank}, frames_to_send: {frames_to_send}, x.shape: {x.shape}") | |
# """print: | |
# (MultiGPUContext pid=3061092) cp_rank: 3, cp_world_size: 8, global_rank: 3, frames_to_send: 2, x.shape: torch.Size([1, 768, 4, 60, 106]) | |
# (MultiGPUContext pid=3061092) cp_rank: 3, cp_world_size: 8, global_rank: 3, frames_to_send: 2, x.shape: torch.Size([1, 768, 4, 60, 106]) | |
# (MultiGPUContext pid=3061092) cp_rank: 3, cp_world_size: 8, global_rank: 3, frames_to_send: 2, x.shape: torch.Size([1, 512, 12, 120, 212]) | |
# (MultiGPUContext pid=3061092) cp_rank: 3, cp_world_size: 8, global_rank: 3, frames_to_send: 2, x.shape: torch.Size([1, 512, 12, 120, 212]) | |
# ... | |
# """ | |
# Send to next rank | |
if cp_rank < cp_world_size - 1: | |
assert x.size(2) >= frames_to_send | |
tail = x[:, :, -frames_to_send:].contiguous() | |
dist.send(tail, global_rank + 1, group=group) | |
# Receive from previous rank | |
if cp_rank > 0: | |
B, C, _, H, W = x.shape | |
recv_buffer = torch.empty( | |
(B, C, frames_to_send, H, W), | |
dtype=x.dtype, | |
device=x.device, | |
) | |
dist.recv(recv_buffer, global_rank - 1, group=group) | |
x = torch.cat([recv_buffer, x], dim=2) | |
return x | |
def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor: | |
if max_T > x.size(2): | |
pad_T = max_T - x.size(2) | |
pad_dims = (0, 0, 0, 0, 0, pad_T) | |
return F.pad(x, pad_dims) | |
return x | |
def gather_all_frames(x: torch.Tensor) -> torch.Tensor: | |
""" | |
Gathers all frames from all processes for inference. | |
Args: | |
x: Tensor of shape (B, C, T, H, W) | |
Returns: | |
output: Tensor of shape (B, C, T_total, H, W) | |
""" | |
cp_rank, cp_size = cp.get_cp_rank_size() | |
cp_group = cp.get_cp_group() | |
# Ensure the tensor is contiguous for collective operations | |
x = x.contiguous() | |
# Get the local time dimension size | |
local_T = x.size(2) | |
local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64) | |
# Gather all T sizes from all processes | |
all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)] | |
dist.all_gather(all_T, local_T_tensor, group=cp_group) | |
all_T = [t.item() for t in all_T] | |
# Pad the tensor at the end of the time dimension to match max_T | |
max_T = max(all_T) | |
x = _pad_to_max(x, max_T).contiguous() | |
# if cp_rank == 0: | |
# print(f"gather_all_frames before: cp_rank: {cp_rank}, x.shape: {x.shape}, max_T: {max_T}") | |
# Prepare a list to hold the gathered tensors | |
gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)] | |
# if cp_rank == 0: | |
# print(f"gather_all_frames after: cp_rank: {cp_rank}, gathered_x num: {len(gathered_x)}, max_T: {max_T}") | |
# Perform the all_gather operation | |
dist.all_gather(gathered_x, x, group=cp_group) | |
# Slice each gathered tensor back to its original T size | |
for idx, t_size in enumerate(all_T): | |
gathered_x[idx] = gathered_x[idx][:, :, :t_size] | |
return torch.cat(gathered_x, dim=2) | |
def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool: | |
"""Estimate memory usage based on input tensor size and data type.""" | |
element_size = input.element_size() # Size in bytes of each element | |
memory_bytes = input.numel() * element_size | |
memory_gb = memory_bytes / 1024**3 | |
return memory_gb > max_gb | |
class ContextParallelCausalConv3d(torch.nn.Conv3d): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size: Union[int, Tuple[int, int, int]], | |
stride: Union[int, Tuple[int, int, int]], | |
**kwargs, | |
): | |
kernel_size = cast_tuple(kernel_size, 3) | |
stride = cast_tuple(stride, 3) | |
height_pad = (kernel_size[1] - 1) // 2 | |
width_pad = (kernel_size[2] - 1) // 2 | |
super().__init__( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=(1, 1, 1), | |
padding=(0, height_pad, width_pad), | |
**kwargs, | |
) | |
def forward(self, x: torch.Tensor): | |
cp_rank, cp_world_size = cp.get_cp_rank_size() | |
context_size = self.kernel_size[0] - 1 | |
if cp_rank == 0: | |
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode | |
x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode) | |
if cp_world_size == 1: | |
return super().forward(x) | |
# print(f"ContextParallelCausalConv3d: cp_rank: {cp_rank}, cp_world_size: {cp_world_size}, x.shape: {x.shape}, context_size: {context_size},self.kernel_size: {self.kernel_size}, self.stride: {self.stride}") # (MultiGPUContext pid=3061095) ContextParallelConv3d: cp_rank: 0, cp_world_size: 8, x.shape: torch.Size([1, 128, 22, 240, 424]), context_size: 1,self.kernel_size: (2, 2, 2), self.stride: (2, 2, 2) | |
# not used | |
if all(s == 1 for s in self.stride): | |
# Receive some frames from previous rank. | |
x = cp_pass_frames(x, context_size) | |
return super().forward(x) | |
# Less efficient implementation for strided convs. | |
# All gather x, infer and chunk. | |
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W] | |
x = super().forward(x) | |
x_chunks = x.tensor_split(cp_world_size, dim=2) | |
assert len(x_chunks) == cp_world_size | |
return x_chunks[cp_rank] | |