kernel
File size: 5,033 Bytes
89e2950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import megablocks
from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights


def test_megablocks_moe_mlp_with_shared_expert_import():
    mlp = MegaBlocksMoeMLPWithSharedExpert()
    assert hasattr(mlp, 'shared_up_proj_weight')
    assert hasattr(mlp, 'shared_down_proj_weight')
    assert hasattr(mlp, 'set_shared_expert_weights')


def test_set_shared_expert_weights():
    mlp = MegaBlocksMoeMLPWithSharedExpert()
    
    hidden_size = 128
    shared_expert_hidden_size = 256
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dtype = torch.float32
    
    up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device, dtype=dtype)
    down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device, dtype=dtype)
    up_proj_bias = torch.randn(shared_expert_hidden_size, device=device, dtype=dtype)
    down_proj_bias = torch.randn(hidden_size, device=device, dtype=dtype)
    
    mlp.set_shared_expert_weights(
        up_proj_weight=up_proj_weight,
        down_proj_weight=down_proj_weight,
        up_proj_bias=up_proj_bias,
        down_proj_bias=down_proj_bias,
        weighted_sum=True,
        activation_fn=torch.nn.functional.gelu
    )
    
    assert torch.equal(mlp.shared_up_proj_weight, up_proj_weight)
    assert torch.equal(mlp.shared_down_proj_weight, down_proj_weight)
    assert torch.equal(mlp.shared_up_proj_bias, up_proj_bias)
    assert torch.equal(mlp.shared_down_proj_bias, down_proj_bias)
    assert mlp.shared_expert_weighted_sum == True
    assert mlp.shared_activation_fn == torch.nn.functional.gelu


def test_create_shared_expert_weights():
    hidden_size = 128
    shared_expert_hidden_size = 256
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dtype = torch.float32
    
    def init_method(tensor):
        torch.nn.init.xavier_uniform_(tensor)
    
    up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
        hidden_size=hidden_size,
        shared_expert_hidden_size=shared_expert_hidden_size,
        device=device,
        dtype=dtype,
        init_method=init_method
    )
    
    assert up_proj_weight.shape == (shared_expert_hidden_size, hidden_size)
    assert down_proj_weight.shape == (hidden_size, shared_expert_hidden_size)
    assert up_proj_weight.device.type == device.type
    assert down_proj_weight.device.type == device.type
    assert up_proj_weight.dtype == dtype
    assert down_proj_weight.dtype == dtype
    assert up_proj_bias is None
    assert down_proj_bias is None


def test_shared_expert_weights_none_by_default():
    mlp = MegaBlocksMoeMLPWithSharedExpert()
    
    assert mlp.shared_up_proj_weight is None
    assert mlp.shared_down_proj_weight is None
    assert mlp.shared_up_proj_bias is None
    assert mlp.shared_down_proj_bias is None
    assert mlp.shared_expert_weighted_sum == False
    assert mlp.shared_activation_fn is None


def test_inheritance_from_megablocks_moe_mlp():
    mlp = MegaBlocksMoeMLPWithSharedExpert()
    
    from megablocks.layers import MegaBlocksMoeMLP
    assert isinstance(mlp, MegaBlocksMoeMLP)
    assert hasattr(mlp, 'forward')


def test_shared_expert_weights_custom_init():
    hidden_size = 64
    shared_expert_hidden_size = 128
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dtype = torch.float16
    
    def custom_init(tensor):
        torch.nn.init.constant_(tensor, 0.5)
    
    def custom_output_init(tensor):
        torch.nn.init.constant_(tensor, 0.1)
    
    up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
        hidden_size=hidden_size,
        shared_expert_hidden_size=shared_expert_hidden_size,
        device=device,
        dtype=dtype,
        init_method=custom_init,
        output_layer_init_method=custom_output_init
    )
    
    assert torch.all(up_proj_weight == 0.5)
    assert torch.all(down_proj_weight == 0.1)
    assert up_proj_weight.dtype == dtype
    assert down_proj_weight.dtype == dtype


def test_shared_expert_weights_dimensions():
    mlp = MegaBlocksMoeMLPWithSharedExpert()
    
    batch_size = 4
    seq_len = 16
    hidden_size = 128
    shared_expert_hidden_size = 256
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device)
    down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device)
    
    mlp.set_shared_expert_weights(
        up_proj_weight=up_proj_weight,
        down_proj_weight=down_proj_weight
    )
    
    x = torch.randn(seq_len, batch_size, hidden_size, device=device)
    
    expected_up_output_shape = (seq_len, batch_size, shared_expert_hidden_size)
    expected_down_output_shape = (seq_len, batch_size, hidden_size)
    
    assert up_proj_weight.shape[1] == x.shape[-1]
    assert down_proj_weight.shape[0] == x.shape[-1]