File size: 5,822 Bytes
89df403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29553ae
89df403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import pytest

from megablocks.ops.binned_gather import BinnedGatherOp

binned_gather_triton = BinnedGatherOp.apply

def set_seeds(seed=0):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# Stress test expert_capacity, especially near and at the upper limit (e.g., 65535 for int16 indexing)
def make_stress_expert_capacity_tests():
    tests = []
    # Small cases for sanity
    for seq_len, hidden_size, num_experts, top_k in [
        (4, 2, 2, 1),
        (4, 2, 2, 2),
        (4, 2, 2, 4),
    ]:
        for expert_capacity in [1, 2, 4]:
            tests.append((seq_len, hidden_size, num_experts, top_k, expert_capacity))
    # Medium cases
    for seq_len, hidden_size, num_experts, top_k in [
        (1024, 1536, 4, 1),
        (1024, 1536, 4, 2),
        (1024, 1536, 4, 4),
        (1024, 1536, 64, 1),
        (1024, 1536, 64, 2),
        (1024, 1536, 64, 4),
        (1024, 1536, 128, 1),
        (1024, 1536, 128, 2),
        (1024, 1536, 128, 4),
    ]:
        for expert_capacity in [1, 2, 4, 128, 1024]:
            tests.append((seq_len, hidden_size, num_experts, top_k, expert_capacity))

    # Large cases, stress expert_capacity near 65536 (CUDA second dim grid limit)
    for seq_len, hidden_size, num_experts, top_k in [
        (4096, 768, 32, 4),
    ]:
        for expert_capacity in [65535, 70000, 90000]:
            tests.append((seq_len, hidden_size, num_experts, top_k, expert_capacity))

    return tuple(tests)

BINNED_GATHER_TESTS = make_stress_expert_capacity_tests()

@pytest.mark.parametrize(('seq_len', 'hidden_size', 'num_experts', 'top_k', 'expert_capacity'), BINNED_GATHER_TESTS)
def test_binned_gather(seq_len: int, hidden_size: int, num_experts: int, top_k: int, expert_capacity: int):
    # NOTE: Capacity factor == 1.
    set_seeds(42)
    # Create the data and indices with gradient tracking
    x = torch.arange(seq_len * hidden_size, device='cuda', dtype=torch.half).view(seq_len, hidden_size)
    x.requires_grad_(True)

    # Randomly assign tokens to experts.
    top_expert = torch.randint(0, num_experts, (seq_len * top_k,), device='cuda', dtype=torch.int)
    _, indices = torch.sort(top_expert)
    bins = torch.cumsum(torch.bincount(top_expert, minlength=num_experts), 0).to(torch.int32)
    # Example: counts is [12, 2, 3], the bins tensor will be [12, 14, 20]. This tells the gather function:
    # Expert 0's assignments are in indices[0:12].
    # Expert 1's assignments are in indices[12:14].
    # Expert 2's assignments are in indices[14:20]. (we have num_tokens * 3)

    def binned_gather_pytorch(
        x: torch.Tensor,
        indices: torch.Tensor,
        bins: torch.Tensor,
        expert_capacity: int,
        top_k: int,
    ):
        start = 0
        out = torch.zeros((num_experts, expert_capacity, hidden_size), dtype=x.dtype, device=x.device)
        for i in range(num_experts):
            end = bins[i]
            num_tokens = min(expert_capacity, end - start)
            if num_tokens > 0:
                # indices[start:end] are the indices for this expert
                # For each slot j, get the input index and copy the row
                idx = indices[start : start + num_tokens] // top_k
                out[i, :num_tokens, :] = x[idx, :]
            start = end
        return out

    out = binned_gather_triton(x, indices, bins, expert_capacity, top_k)
    expected_out = binned_gather_pytorch(x, indices, bins, expert_capacity, top_k)
    assert torch.all(torch.eq(out, expected_out))

    # Test backward pass
    grad_output = torch.arange(out.numel(), device=out.device, dtype=out.dtype).view_as(out)
    out.backward(grad_output)

    # Verify gradients were computed
    assert x.grad is not None, "Gradients should be computed for input x"
    assert x.grad.shape == x.shape, f"Gradient shape {x.grad.shape} should match input shape {x.shape}"
    
    # Reference implementation for backward pass (binned_scatter)
    def binned_scatter_pytorch(
        x: torch.Tensor,
        indices: torch.Tensor,
        weights: torch.Tensor,
        bins: torch.Tensor,
        top_k: int,
    ):
        # x: (ne, ec, hs)
        # indices: (sl * top_k,)
        # weights: (sl * top_k,)
        # bins: (ne,)
        # Output: (sl, hs)
        out = torch.zeros((seq_len, hidden_size), device=x.device, dtype=x.dtype)
        start = 0
        for i in range(num_experts):
            end = bins[i]
            num_tokens = min(expert_capacity, end - start)
            for j in range(num_tokens):
                index = indices[start + j]
                scale = weights[index] if weights is not None else 1.0
                token_pos = index // top_k
                
                out[token_pos, :] += scale * x[i, j, :]
            start = end
        return out
    
    expected_grad = binned_scatter_pytorch(grad_output, indices, None, bins, top_k)
    print(f"x.grad: {x.grad}")
    print(f"expected_grad: {expected_grad}")
    
    # Use torch.allclose instead of exact equality for floating point comparison
    if torch.allclose(x.grad, expected_grad, rtol=1e-3, atol=1e-3):
        print("✅ Success: Gradients match!")
    else:
        print("❌ Gradients don't match")
        # Let's see if it's just a reordering issue
        print("Checking if values match when sorted...")
        grad_sorted = torch.sort(x.grad.flatten())[0]
        expected_sorted = torch.sort(expected_grad.flatten())[0]
        if torch.allclose(grad_sorted, expected_sorted, rtol=1e-3, atol=1e-3):
            print("✅ Same values, different order - routing issue!")
        else:
            print("❌ Different values entirely")
    
    print(f"\nTriton Output Shape: {x.grad.shape}")
    print(f"PyTorch Output Shape: {expected_grad.shape}")