File size: 6,529 Bytes
96257b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from typing import Tuple, Union

import torch
import torch.distributed as dist
import torch.nn.functional as F

import genmo.mochi_preview.dit.joint_model.context_parallel as cp


def cast_tuple(t, length=1):
    return t if isinstance(t, tuple) else ((t,) * length)


def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor:
    """

    Forward pass that handles communication between ranks for inference.

    Args:

        x: Tensor of shape (B, C, T, H, W)

        frames_to_send: int, number of frames to communicate between ranks

    Returns:

        output: Tensor of shape (B, C, T', H, W)

    """
    cp_rank, cp_world_size = cp.get_cp_rank_size()
    if frames_to_send == 0 or cp_world_size == 1:
        return x

    group = cp.get_cp_group()
    global_rank = dist.get_rank()
    
    # print(f"cp_rank: {cp_rank}, cp_world_size: {cp_world_size}, global_rank: {global_rank}, frames_to_send: {frames_to_send}, x.shape: {x.shape}")
    # """print:
    # (MultiGPUContext pid=3061092) cp_rank: 3, cp_world_size: 8, global_rank: 3, frames_to_send: 2, x.shape: torch.Size([1, 768, 4, 60, 106])
    # (MultiGPUContext pid=3061092) cp_rank: 3, cp_world_size: 8, global_rank: 3, frames_to_send: 2, x.shape: torch.Size([1, 768, 4, 60, 106])
    # (MultiGPUContext pid=3061092) cp_rank: 3, cp_world_size: 8, global_rank: 3, frames_to_send: 2, x.shape: torch.Size([1, 512, 12, 120, 212])
    # (MultiGPUContext pid=3061092) cp_rank: 3, cp_world_size: 8, global_rank: 3, frames_to_send: 2, x.shape: torch.Size([1, 512, 12, 120, 212])
    # ...

    # """
    # Send to next rank
    if cp_rank < cp_world_size - 1:
        assert x.size(2) >= frames_to_send
        tail = x[:, :, -frames_to_send:].contiguous()
        dist.send(tail, global_rank + 1, group=group)

    # Receive from previous rank
    if cp_rank > 0:
        B, C, _, H, W = x.shape
        recv_buffer = torch.empty(
            (B, C, frames_to_send, H, W),
            dtype=x.dtype,
            device=x.device,
        )
        dist.recv(recv_buffer, global_rank - 1, group=group)
        x = torch.cat([recv_buffer, x], dim=2)

    return x


def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor:
    if max_T > x.size(2):
        pad_T = max_T - x.size(2)
        pad_dims = (0, 0, 0, 0, 0, pad_T)
        return F.pad(x, pad_dims)
    return x


def gather_all_frames(x: torch.Tensor) -> torch.Tensor:
    """

    Gathers all frames from all processes for inference.

    Args:

        x: Tensor of shape (B, C, T, H, W)

    Returns:

        output: Tensor of shape (B, C, T_total, H, W)

    """
    cp_rank, cp_size = cp.get_cp_rank_size()
    cp_group = cp.get_cp_group()

    # Ensure the tensor is contiguous for collective operations
    x = x.contiguous()

    # Get the local time dimension size
    local_T = x.size(2)
    local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64)

    # Gather all T sizes from all processes
    all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)]
    dist.all_gather(all_T, local_T_tensor, group=cp_group)
    all_T = [t.item() for t in all_T]

    # Pad the tensor at the end of the time dimension to match max_T
    max_T = max(all_T)
    x = _pad_to_max(x, max_T).contiguous()
    # if cp_rank == 0:
    #     print(f"gather_all_frames before: cp_rank: {cp_rank}, x.shape: {x.shape}, max_T: {max_T}")
    # Prepare a list to hold the gathered tensors
    gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)]
    # if cp_rank == 0:
    #     print(f"gather_all_frames after: cp_rank: {cp_rank}, gathered_x num: {len(gathered_x)}, max_T: {max_T}")

    # Perform the all_gather operation
    dist.all_gather(gathered_x, x, group=cp_group)

    # Slice each gathered tensor back to its original T size
    for idx, t_size in enumerate(all_T):
        gathered_x[idx] = gathered_x[idx][:, :, :t_size]

    return torch.cat(gathered_x, dim=2)


def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool:
    """Estimate memory usage based on input tensor size and data type."""
    element_size = input.element_size()  # Size in bytes of each element
    memory_bytes = input.numel() * element_size
    memory_gb = memory_bytes / 1024**3
    return memory_gb > max_gb


class ContextParallelCausalConv3d(torch.nn.Conv3d):
    def __init__(

        self,

        in_channels,

        out_channels,

        kernel_size: Union[int, Tuple[int, int, int]],

        stride: Union[int, Tuple[int, int, int]],

        **kwargs,

    ):
        kernel_size = cast_tuple(kernel_size, 3)
        stride = cast_tuple(stride, 3)
        height_pad = (kernel_size[1] - 1) // 2
        width_pad = (kernel_size[2] - 1) // 2

        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=(1, 1, 1),
            padding=(0, height_pad, width_pad),
            **kwargs,
        )

    def forward(self, x: torch.Tensor):
        cp_rank, cp_world_size = cp.get_cp_rank_size()

        context_size = self.kernel_size[0] - 1
        if cp_rank == 0:
            mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
            x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode)

        if cp_world_size == 1:
            return super().forward(x)
        # print(f"ContextParallelCausalConv3d: cp_rank: {cp_rank}, cp_world_size: {cp_world_size}, x.shape: {x.shape}, context_size: {context_size},self.kernel_size: {self.kernel_size}, self.stride: {self.stride}") # (MultiGPUContext pid=3061095) ContextParallelConv3d: cp_rank: 0, cp_world_size: 8, x.shape: torch.Size([1, 128, 22, 240, 424]), context_size: 1,self.kernel_size: (2, 2, 2), self.stride: (2, 2, 2)
        # not used

        if all(s == 1 for s in self.stride):
            # Receive some frames from previous rank.
            x = cp_pass_frames(x, context_size)
            return super().forward(x)

        # Less efficient implementation for strided convs.
        # All gather x, infer and chunk.
        x = gather_all_frames(x)  # [B, C, k - 1 + global_T, H, W]
        x = super().forward(x)
        x_chunks = x.tensor_split(cp_world_size, dim=2)
        assert len(x_chunks) == cp_world_size
        return x_chunks[cp_rank]