File size: 5,033 Bytes
89e2950 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
import megablocks
from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
def test_megablocks_moe_mlp_with_shared_expert_import():
mlp = MegaBlocksMoeMLPWithSharedExpert()
assert hasattr(mlp, 'shared_up_proj_weight')
assert hasattr(mlp, 'shared_down_proj_weight')
assert hasattr(mlp, 'set_shared_expert_weights')
def test_set_shared_expert_weights():
mlp = MegaBlocksMoeMLPWithSharedExpert()
hidden_size = 128
shared_expert_hidden_size = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32
up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device, dtype=dtype)
down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device, dtype=dtype)
up_proj_bias = torch.randn(shared_expert_hidden_size, device=device, dtype=dtype)
down_proj_bias = torch.randn(hidden_size, device=device, dtype=dtype)
mlp.set_shared_expert_weights(
up_proj_weight=up_proj_weight,
down_proj_weight=down_proj_weight,
up_proj_bias=up_proj_bias,
down_proj_bias=down_proj_bias,
weighted_sum=True,
activation_fn=torch.nn.functional.gelu
)
assert torch.equal(mlp.shared_up_proj_weight, up_proj_weight)
assert torch.equal(mlp.shared_down_proj_weight, down_proj_weight)
assert torch.equal(mlp.shared_up_proj_bias, up_proj_bias)
assert torch.equal(mlp.shared_down_proj_bias, down_proj_bias)
assert mlp.shared_expert_weighted_sum == True
assert mlp.shared_activation_fn == torch.nn.functional.gelu
def test_create_shared_expert_weights():
hidden_size = 128
shared_expert_hidden_size = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32
def init_method(tensor):
torch.nn.init.xavier_uniform_(tensor)
up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
hidden_size=hidden_size,
shared_expert_hidden_size=shared_expert_hidden_size,
device=device,
dtype=dtype,
init_method=init_method
)
assert up_proj_weight.shape == (shared_expert_hidden_size, hidden_size)
assert down_proj_weight.shape == (hidden_size, shared_expert_hidden_size)
assert up_proj_weight.device.type == device.type
assert down_proj_weight.device.type == device.type
assert up_proj_weight.dtype == dtype
assert down_proj_weight.dtype == dtype
assert up_proj_bias is None
assert down_proj_bias is None
def test_shared_expert_weights_none_by_default():
mlp = MegaBlocksMoeMLPWithSharedExpert()
assert mlp.shared_up_proj_weight is None
assert mlp.shared_down_proj_weight is None
assert mlp.shared_up_proj_bias is None
assert mlp.shared_down_proj_bias is None
assert mlp.shared_expert_weighted_sum == False
assert mlp.shared_activation_fn is None
def test_inheritance_from_megablocks_moe_mlp():
mlp = MegaBlocksMoeMLPWithSharedExpert()
from megablocks.layers import MegaBlocksMoeMLP
assert isinstance(mlp, MegaBlocksMoeMLP)
assert hasattr(mlp, 'forward')
def test_shared_expert_weights_custom_init():
hidden_size = 64
shared_expert_hidden_size = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float16
def custom_init(tensor):
torch.nn.init.constant_(tensor, 0.5)
def custom_output_init(tensor):
torch.nn.init.constant_(tensor, 0.1)
up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
hidden_size=hidden_size,
shared_expert_hidden_size=shared_expert_hidden_size,
device=device,
dtype=dtype,
init_method=custom_init,
output_layer_init_method=custom_output_init
)
assert torch.all(up_proj_weight == 0.5)
assert torch.all(down_proj_weight == 0.1)
assert up_proj_weight.dtype == dtype
assert down_proj_weight.dtype == dtype
def test_shared_expert_weights_dimensions():
mlp = MegaBlocksMoeMLPWithSharedExpert()
batch_size = 4
seq_len = 16
hidden_size = 128
shared_expert_hidden_size = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device)
down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device)
mlp.set_shared_expert_weights(
up_proj_weight=up_proj_weight,
down_proj_weight=down_proj_weight
)
x = torch.randn(seq_len, batch_size, hidden_size, device=device)
expected_up_output_shape = (seq_len, batch_size, shared_expert_hidden_size)
expected_down_output_shape = (seq_len, batch_size, hidden_size)
assert up_proj_weight.shape[1] == x.shape[-1]
assert down_proj_weight.shape[0] == x.shape[-1] |