Spaces:
Paused
Paused
File size: 6,529 Bytes
96257b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
from typing import Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
import genmo.mochi_preview.dit.joint_model.context_parallel as cp
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor:
"""
Forward pass that handles communication between ranks for inference.
Args:
x: Tensor of shape (B, C, T, H, W)
frames_to_send: int, number of frames to communicate between ranks
Returns:
output: Tensor of shape (B, C, T', H, W)
"""
cp_rank, cp_world_size = cp.get_cp_rank_size()
if frames_to_send == 0 or cp_world_size == 1:
return x
group = cp.get_cp_group()
global_rank = dist.get_rank()
# print(f"cp_rank: {cp_rank}, cp_world_size: {cp_world_size}, global_rank: {global_rank}, frames_to_send: {frames_to_send}, x.shape: {x.shape}")
# """print:
# (MultiGPUContext pid=3061092) cp_rank: 3, cp_world_size: 8, global_rank: 3, frames_to_send: 2, x.shape: torch.Size([1, 768, 4, 60, 106])
# (MultiGPUContext pid=3061092) cp_rank: 3, cp_world_size: 8, global_rank: 3, frames_to_send: 2, x.shape: torch.Size([1, 768, 4, 60, 106])
# (MultiGPUContext pid=3061092) cp_rank: 3, cp_world_size: 8, global_rank: 3, frames_to_send: 2, x.shape: torch.Size([1, 512, 12, 120, 212])
# (MultiGPUContext pid=3061092) cp_rank: 3, cp_world_size: 8, global_rank: 3, frames_to_send: 2, x.shape: torch.Size([1, 512, 12, 120, 212])
# ...
# """
# Send to next rank
if cp_rank < cp_world_size - 1:
assert x.size(2) >= frames_to_send
tail = x[:, :, -frames_to_send:].contiguous()
dist.send(tail, global_rank + 1, group=group)
# Receive from previous rank
if cp_rank > 0:
B, C, _, H, W = x.shape
recv_buffer = torch.empty(
(B, C, frames_to_send, H, W),
dtype=x.dtype,
device=x.device,
)
dist.recv(recv_buffer, global_rank - 1, group=group)
x = torch.cat([recv_buffer, x], dim=2)
return x
def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor:
if max_T > x.size(2):
pad_T = max_T - x.size(2)
pad_dims = (0, 0, 0, 0, 0, pad_T)
return F.pad(x, pad_dims)
return x
def gather_all_frames(x: torch.Tensor) -> torch.Tensor:
"""
Gathers all frames from all processes for inference.
Args:
x: Tensor of shape (B, C, T, H, W)
Returns:
output: Tensor of shape (B, C, T_total, H, W)
"""
cp_rank, cp_size = cp.get_cp_rank_size()
cp_group = cp.get_cp_group()
# Ensure the tensor is contiguous for collective operations
x = x.contiguous()
# Get the local time dimension size
local_T = x.size(2)
local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64)
# Gather all T sizes from all processes
all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)]
dist.all_gather(all_T, local_T_tensor, group=cp_group)
all_T = [t.item() for t in all_T]
# Pad the tensor at the end of the time dimension to match max_T
max_T = max(all_T)
x = _pad_to_max(x, max_T).contiguous()
# if cp_rank == 0:
# print(f"gather_all_frames before: cp_rank: {cp_rank}, x.shape: {x.shape}, max_T: {max_T}")
# Prepare a list to hold the gathered tensors
gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)]
# if cp_rank == 0:
# print(f"gather_all_frames after: cp_rank: {cp_rank}, gathered_x num: {len(gathered_x)}, max_T: {max_T}")
# Perform the all_gather operation
dist.all_gather(gathered_x, x, group=cp_group)
# Slice each gathered tensor back to its original T size
for idx, t_size in enumerate(all_T):
gathered_x[idx] = gathered_x[idx][:, :, :t_size]
return torch.cat(gathered_x, dim=2)
def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool:
"""Estimate memory usage based on input tensor size and data type."""
element_size = input.element_size() # Size in bytes of each element
memory_bytes = input.numel() * element_size
memory_gb = memory_bytes / 1024**3
return memory_gb > max_gb
class ContextParallelCausalConv3d(torch.nn.Conv3d):
def __init__(
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]],
**kwargs,
):
kernel_size = cast_tuple(kernel_size, 3)
stride = cast_tuple(stride, 3)
height_pad = (kernel_size[1] - 1) // 2
width_pad = (kernel_size[2] - 1) // 2
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=(1, 1, 1),
padding=(0, height_pad, width_pad),
**kwargs,
)
def forward(self, x: torch.Tensor):
cp_rank, cp_world_size = cp.get_cp_rank_size()
context_size = self.kernel_size[0] - 1
if cp_rank == 0:
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode)
if cp_world_size == 1:
return super().forward(x)
# print(f"ContextParallelCausalConv3d: cp_rank: {cp_rank}, cp_world_size: {cp_world_size}, x.shape: {x.shape}, context_size: {context_size},self.kernel_size: {self.kernel_size}, self.stride: {self.stride}") # (MultiGPUContext pid=3061095) ContextParallelConv3d: cp_rank: 0, cp_world_size: 8, x.shape: torch.Size([1, 128, 22, 240, 424]), context_size: 1,self.kernel_size: (2, 2, 2), self.stride: (2, 2, 2)
# not used
if all(s == 1 for s in self.stride):
# Receive some frames from previous rank.
x = cp_pass_frames(x, context_size)
return super().forward(x)
# Less efficient implementation for strided convs.
# All gather x, infer and chunk.
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
x = super().forward(x)
x_chunks = x.tensor_split(cp_world_size, dim=2)
assert len(x_chunks) == cp_world_size
return x_chunks[cp_rank]
|