File size: 7,810 Bytes
96257b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import torch
import torch.distributed as dist
from einops import rearrange
from typing import Tuple
import ipdb

_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_RANK = None
_CONTEXT_PARALLEL_GROUP_SIZE = None
_CONTEXT_PARALLEL_GROUP_RANKS = None


def get_cp_rank_size():
    if _CONTEXT_PARALLEL_GROUP:
        return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE
    else:
        return 0, 1


def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor:
    if not _CONTEXT_PARALLEL_GROUP:
        return x

    cp_rank, cp_size = get_cp_rank_size()
    return x.tensor_split(cp_size, dim=dim)[cp_rank]


def set_cp_group(cp_group, ranks, global_rank):
    global _CONTEXT_PARALLEL_GROUP, _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE, _CONTEXT_PARALLEL_GROUP_RANKS
    if _CONTEXT_PARALLEL_GROUP is not None:
        raise RuntimeError("CP group already initialized.")
    _CONTEXT_PARALLEL_GROUP = cp_group
    _CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group)
    _CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group)
    _CONTEXT_PARALLEL_GROUP_RANKS = ranks

    assert _CONTEXT_PARALLEL_RANK == ranks.index(
        global_rank
    ), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} "
    assert _CONTEXT_PARALLEL_GROUP_SIZE == len(
        ranks
    ), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})"


def get_cp_group():
    if _CONTEXT_PARALLEL_GROUP is None:
        raise RuntimeError("CP group not initialized")
    return _CONTEXT_PARALLEL_GROUP


def is_cp_active():
    return _CONTEXT_PARALLEL_GROUP is not None


class AllGatherIntoTensorFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup):
        ctx.reduce_dtype = reduce_dtype
        ctx.group = group
        ctx.batch_size = x.size(0)
        group_size = _CONTEXT_PARALLEL_GROUP_SIZE

        x = x.contiguous()
        output = torch.empty(group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device)
        dist.all_gather_into_tensor(output, x, group=group)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        group = ctx.group
        reduce_dtype = ctx.reduce_dtype
        batch_size = ctx.batch_size
        group_size = _CONTEXT_PARALLEL_GROUP_SIZE
        
        # Split gradient across dimension 0 (context parallel dimension)
        grad_chunks = grad_output.chunk(group_size, dim=0)
        rank = _CONTEXT_PARALLEL_RANK
        return grad_chunks[rank].contiguous(), None, None


def all_gather(tensor: torch.Tensor, group=None) -> torch.Tensor:
    if not _CONTEXT_PARALLEL_GROUP:
        return tensor

    # Use the provided group or default to _CONTEXT_PARALLEL_GROUP
    group = group if group is not None else _CONTEXT_PARALLEL_GROUP
    return AllGatherIntoTensorFunction.apply(tensor, torch.float32, group)


@torch.compiler.disable()
def _all_to_all_single(output, input, group):
    # Disable compilation since torch compile changes contiguity.
    assert input.is_contiguous(), "Input tensor must be contiguous."
    assert output.is_contiguous(), "Output tensor must be contiguous."
    return dist.all_to_all_single(output, input, group=group)


class CollectTokens(torch.autograd.Function):
    @staticmethod
    def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int):
        """Redistribute heads and receive tokens.



        Args:

            qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim]

            group: The process group to use for communication

            num_heads: Total number of attention heads



        Returns:

            qkv: shape: [3, B, N, local_heads, head_dim]



        where M is the number of local tokens,

        N = cp_size * M is the number of global tokens,

        local_heads = num_heads // cp_size is the number of local heads.

        """
        ctx.group = group
        ctx.num_heads = num_heads
        cp_size = _CONTEXT_PARALLEL_GROUP_SIZE
        assert num_heads % cp_size == 0
        ctx.local_heads = num_heads // cp_size

        qkv = rearrange(
            qkv,
            "B M (qkv G h d) -> G M h B (qkv d)",
            qkv=3,
            G=cp_size,
            h=ctx.local_heads,
        ).contiguous()

        output_chunks = torch.empty_like(qkv)
        _all_to_all_single(output_chunks, qkv, group=group)

        return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3)

    @staticmethod
    def backward(ctx, grad_output):
        group = ctx.group
        cp_size = _CONTEXT_PARALLEL_GROUP_SIZE
        local_heads = ctx.local_heads

        # Reverse the final rearrange
        grad_output = rearrange(grad_output, "qkv B (G M) h d -> G M h B (qkv d)", G=cp_size).contiguous()
        
        # Reverse the all-to-all
        grad_input = torch.empty_like(grad_output)
        _all_to_all_single(grad_input, grad_output, group=group)
        
        # Reverse the initial rearrange
        return rearrange(grad_input, "G M h B (qkv d) -> B M (qkv G h d)", 
                        qkv=3, G=cp_size, h=local_heads), None, None


def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int, group=None) -> torch.Tensor:
    if not _CONTEXT_PARALLEL_GROUP:
        # Move QKV dimension to the front.
        #   B M (3 H d) -> 3 B M H d
        B, M, _ = x.size()
        x = x.view(B, M, 3, num_heads, -1)
        return x.permute(2, 0, 1, 3, 4)

    # Use the provided group or default to _CONTEXT_PARALLEL_GROUP
    group = group if group is not None else _CONTEXT_PARALLEL_GROUP
    return CollectTokens.apply(x, group, num_heads)


class CollectHeads(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup):
        """Redistribute tokens and receive heads.



        Args:

            x: Output of attention. Shape: [B, N, local_heads, head_dim]

            group: The process group to use for communication



        Returns:

            Shape: [B, M, num_heads * head_dim]

        """
        ctx.group = group
        ctx.local_heads = x.size(2)
        ctx.head_dim = x.size(3)
        group_size = _CONTEXT_PARALLEL_GROUP_SIZE
        x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous()
        output = torch.empty_like(x)
        _all_to_all_single(output, x, group=group)
        del x
        return rearrange(output, "G h M B D -> B M (G h D)")

    @staticmethod
    def backward(ctx, grad_output):
        group = ctx.group
        group_size = _CONTEXT_PARALLEL_GROUP_SIZE
        local_heads = ctx.local_heads
        head_dim = ctx.head_dim

        # Reverse the final rearrange
        grad_output = rearrange(grad_output, "B M (G h D) -> G h M B D", 
                              G=group_size, h=local_heads, D=head_dim)
        
        grad_output = grad_output.contiguous()

        # Reverse the all-to-all
        grad_input = torch.empty_like(grad_output)
        
        _all_to_all_single(grad_input, grad_output, group=group)
        
        # Reverse the initial rearrange
        return rearrange(grad_input, "G h M B D -> B (G M) h D", G=group_size), None


def all_to_all_collect_heads(x: torch.Tensor, group=None) -> torch.Tensor:
    if not _CONTEXT_PARALLEL_GROUP:
        # Merge heads.
        return x.view(x.size(0), x.size(1), x.size(2) * x.size(3))

    # Use the provided group or default to _CONTEXT_PARALLEL_GROUP
    group = group if group is not None else _CONTEXT_PARALLEL_GROUP
    return CollectHeads.apply(x, group)