import torch import megablocks from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights def test_megablocks_moe_mlp_with_shared_expert_import(): mlp = MegaBlocksMoeMLPWithSharedExpert() assert hasattr(mlp, 'shared_up_proj_weight') assert hasattr(mlp, 'shared_down_proj_weight') assert hasattr(mlp, 'set_shared_expert_weights') def test_set_shared_expert_weights(): mlp = MegaBlocksMoeMLPWithSharedExpert() hidden_size = 128 shared_expert_hidden_size = 256 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dtype = torch.float32 up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device, dtype=dtype) down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device, dtype=dtype) up_proj_bias = torch.randn(shared_expert_hidden_size, device=device, dtype=dtype) down_proj_bias = torch.randn(hidden_size, device=device, dtype=dtype) mlp.set_shared_expert_weights( up_proj_weight=up_proj_weight, down_proj_weight=down_proj_weight, up_proj_bias=up_proj_bias, down_proj_bias=down_proj_bias, weighted_sum=True, activation_fn=torch.nn.functional.gelu ) assert torch.equal(mlp.shared_up_proj_weight, up_proj_weight) assert torch.equal(mlp.shared_down_proj_weight, down_proj_weight) assert torch.equal(mlp.shared_up_proj_bias, up_proj_bias) assert torch.equal(mlp.shared_down_proj_bias, down_proj_bias) assert mlp.shared_expert_weighted_sum == True assert mlp.shared_activation_fn == torch.nn.functional.gelu def test_create_shared_expert_weights(): hidden_size = 128 shared_expert_hidden_size = 256 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dtype = torch.float32 def init_method(tensor): torch.nn.init.xavier_uniform_(tensor) up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights( hidden_size=hidden_size, shared_expert_hidden_size=shared_expert_hidden_size, device=device, dtype=dtype, init_method=init_method ) assert up_proj_weight.shape == (shared_expert_hidden_size, hidden_size) assert down_proj_weight.shape == (hidden_size, shared_expert_hidden_size) assert up_proj_weight.device.type == device.type assert down_proj_weight.device.type == device.type assert up_proj_weight.dtype == dtype assert down_proj_weight.dtype == dtype assert up_proj_bias is None assert down_proj_bias is None def test_shared_expert_weights_none_by_default(): mlp = MegaBlocksMoeMLPWithSharedExpert() assert mlp.shared_up_proj_weight is None assert mlp.shared_down_proj_weight is None assert mlp.shared_up_proj_bias is None assert mlp.shared_down_proj_bias is None assert mlp.shared_expert_weighted_sum == False assert mlp.shared_activation_fn is None def test_inheritance_from_megablocks_moe_mlp(): mlp = MegaBlocksMoeMLPWithSharedExpert() from megablocks.layers import MegaBlocksMoeMLP assert isinstance(mlp, MegaBlocksMoeMLP) assert hasattr(mlp, 'forward') def test_shared_expert_weights_custom_init(): hidden_size = 64 shared_expert_hidden_size = 128 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dtype = torch.float16 def custom_init(tensor): torch.nn.init.constant_(tensor, 0.5) def custom_output_init(tensor): torch.nn.init.constant_(tensor, 0.1) up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights( hidden_size=hidden_size, shared_expert_hidden_size=shared_expert_hidden_size, device=device, dtype=dtype, init_method=custom_init, output_layer_init_method=custom_output_init ) assert torch.all(up_proj_weight == 0.5) assert torch.all(down_proj_weight == 0.1) assert up_proj_weight.dtype == dtype assert down_proj_weight.dtype == dtype def test_shared_expert_weights_dimensions(): mlp = MegaBlocksMoeMLPWithSharedExpert() batch_size = 4 seq_len = 16 hidden_size = 128 shared_expert_hidden_size = 256 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device) down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device) mlp.set_shared_expert_weights( up_proj_weight=up_proj_weight, down_proj_weight=down_proj_weight ) x = torch.randn(seq_len, batch_size, hidden_size, device=device) expected_up_output_shape = (seq_len, batch_size, shared_expert_hidden_size) expected_down_output_shape = (seq_len, batch_size, hidden_size) assert up_proj_weight.shape[1] == x.shape[-1] assert down_proj_weight.shape[0] == x.shape[-1]