kernel
megablocks / tests /test_mb_moe_shared_expert.py
drbh
feat: support shared experts layer and tests
89e2950
import torch
import megablocks
from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
def test_megablocks_moe_mlp_with_shared_expert_import():
mlp = MegaBlocksMoeMLPWithSharedExpert()
assert hasattr(mlp, 'shared_up_proj_weight')
assert hasattr(mlp, 'shared_down_proj_weight')
assert hasattr(mlp, 'set_shared_expert_weights')
def test_set_shared_expert_weights():
mlp = MegaBlocksMoeMLPWithSharedExpert()
hidden_size = 128
shared_expert_hidden_size = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32
up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device, dtype=dtype)
down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device, dtype=dtype)
up_proj_bias = torch.randn(shared_expert_hidden_size, device=device, dtype=dtype)
down_proj_bias = torch.randn(hidden_size, device=device, dtype=dtype)
mlp.set_shared_expert_weights(
up_proj_weight=up_proj_weight,
down_proj_weight=down_proj_weight,
up_proj_bias=up_proj_bias,
down_proj_bias=down_proj_bias,
weighted_sum=True,
activation_fn=torch.nn.functional.gelu
)
assert torch.equal(mlp.shared_up_proj_weight, up_proj_weight)
assert torch.equal(mlp.shared_down_proj_weight, down_proj_weight)
assert torch.equal(mlp.shared_up_proj_bias, up_proj_bias)
assert torch.equal(mlp.shared_down_proj_bias, down_proj_bias)
assert mlp.shared_expert_weighted_sum == True
assert mlp.shared_activation_fn == torch.nn.functional.gelu
def test_create_shared_expert_weights():
hidden_size = 128
shared_expert_hidden_size = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32
def init_method(tensor):
torch.nn.init.xavier_uniform_(tensor)
up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
hidden_size=hidden_size,
shared_expert_hidden_size=shared_expert_hidden_size,
device=device,
dtype=dtype,
init_method=init_method
)
assert up_proj_weight.shape == (shared_expert_hidden_size, hidden_size)
assert down_proj_weight.shape == (hidden_size, shared_expert_hidden_size)
assert up_proj_weight.device.type == device.type
assert down_proj_weight.device.type == device.type
assert up_proj_weight.dtype == dtype
assert down_proj_weight.dtype == dtype
assert up_proj_bias is None
assert down_proj_bias is None
def test_shared_expert_weights_none_by_default():
mlp = MegaBlocksMoeMLPWithSharedExpert()
assert mlp.shared_up_proj_weight is None
assert mlp.shared_down_proj_weight is None
assert mlp.shared_up_proj_bias is None
assert mlp.shared_down_proj_bias is None
assert mlp.shared_expert_weighted_sum == False
assert mlp.shared_activation_fn is None
def test_inheritance_from_megablocks_moe_mlp():
mlp = MegaBlocksMoeMLPWithSharedExpert()
from megablocks.layers import MegaBlocksMoeMLP
assert isinstance(mlp, MegaBlocksMoeMLP)
assert hasattr(mlp, 'forward')
def test_shared_expert_weights_custom_init():
hidden_size = 64
shared_expert_hidden_size = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float16
def custom_init(tensor):
torch.nn.init.constant_(tensor, 0.5)
def custom_output_init(tensor):
torch.nn.init.constant_(tensor, 0.1)
up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
hidden_size=hidden_size,
shared_expert_hidden_size=shared_expert_hidden_size,
device=device,
dtype=dtype,
init_method=custom_init,
output_layer_init_method=custom_output_init
)
assert torch.all(up_proj_weight == 0.5)
assert torch.all(down_proj_weight == 0.1)
assert up_proj_weight.dtype == dtype
assert down_proj_weight.dtype == dtype
def test_shared_expert_weights_dimensions():
mlp = MegaBlocksMoeMLPWithSharedExpert()
batch_size = 4
seq_len = 16
hidden_size = 128
shared_expert_hidden_size = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device)
down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device)
mlp.set_shared_expert_weights(
up_proj_weight=up_proj_weight,
down_proj_weight=down_proj_weight
)
x = torch.randn(seq_len, batch_size, hidden_size, device=device)
expected_up_output_shape = (seq_len, batch_size, shared_expert_hidden_size)
expected_down_output_shape = (seq_len, batch_size, hidden_size)
assert up_proj_weight.shape[1] == x.shape[-1]
assert down_proj_weight.shape[0] == x.shape[-1]