File size: 7,364 Bytes
89e2950 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import os
import pytest
from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
def run_distributed_shared_expert_test(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12356"
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(
backend="gloo",
rank=rank,
world_size=world_size,
)
model = MegaBlocksMoeMLPWithSharedExpert()
hidden_size = 128
shared_expert_hidden_size = 192
device = "cuda" if torch.cuda.is_available() else "cpu"
def simple_init(tensor):
torch.nn.init.xavier_uniform_(tensor)
shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
hidden_size=hidden_size,
shared_expert_hidden_size=shared_expert_hidden_size,
device=torch.device(device),
dtype=torch.float32,
init_method=simple_init
)
model.set_shared_expert_weights(
up_proj_weight=shared_up_proj_weight,
down_proj_weight=shared_down_proj_weight,
up_proj_bias=shared_up_proj_bias,
down_proj_bias=shared_down_proj_bias,
weighted_sum=True,
activation_fn=torch.nn.functional.gelu
)
assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
assert model.shared_expert_weighted_sum == True, f"Weighted sum not set correctly on rank {rank}"
print(f"Rank {rank}: Shared expert setup test passed!")
dist.destroy_process_group()
def run_distributed_shared_expert_weighted_sum_test(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12357"
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(
backend="gloo",
rank=rank,
world_size=world_size,
)
model = MegaBlocksMoeMLPWithSharedExpert()
hidden_size = 64
device = "cuda" if torch.cuda.is_available() else "cpu"
def simple_init(tensor):
torch.nn.init.xavier_uniform_(tensor)
shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
hidden_size=hidden_size,
shared_expert_hidden_size=96,
device=torch.device(device),
dtype=torch.float32,
init_method=simple_init
)
model.set_shared_expert_weights(
up_proj_weight=shared_up_proj_weight,
down_proj_weight=shared_down_proj_weight,
weighted_sum=False,
activation_fn=torch.nn.functional.relu
)
assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
assert model.shared_expert_weighted_sum == False, f"Weighted sum not set correctly on rank {rank}"
assert model.shared_activation_fn == torch.nn.functional.relu, f"Activation function not set correctly on rank {rank}"
print(f"Rank {rank}: Weighted sum setup test passed!")
dist.destroy_process_group()
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
def test_shared_expert_distributed_functionality(world_size):
if world_size == 1:
# Single process test
model = MegaBlocksMoeMLPWithSharedExpert()
hidden_size = 128
shared_expert_hidden_size = 192
device = "cuda" if torch.cuda.is_available() else "cpu"
def simple_init(tensor):
torch.nn.init.xavier_uniform_(tensor)
shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
hidden_size=hidden_size,
shared_expert_hidden_size=shared_expert_hidden_size,
device=torch.device(device),
dtype=torch.float32,
init_method=simple_init
)
model.set_shared_expert_weights(
up_proj_weight=shared_up_proj_weight,
down_proj_weight=shared_down_proj_weight,
up_proj_bias=shared_up_proj_bias,
down_proj_bias=shared_down_proj_bias,
weighted_sum=True,
activation_fn=torch.nn.functional.gelu
)
assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
assert model.shared_expert_weighted_sum == True, "Weighted sum not set correctly"
print("Single process shared expert setup test passed!")
else:
# Multi-process test
mp.spawn(run_distributed_shared_expert_test, args=(world_size,), nprocs=world_size, join=True)
print("Multi-process shared expert test completed successfully!")
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
def test_shared_expert_distributed_weighted_sum(world_size):
if world_size == 1:
# Single process test
model = MegaBlocksMoeMLPWithSharedExpert()
hidden_size = 64
device = "cuda" if torch.cuda.is_available() else "cpu"
def simple_init(tensor):
torch.nn.init.xavier_uniform_(tensor)
shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
hidden_size=hidden_size,
shared_expert_hidden_size=96,
device=torch.device(device),
dtype=torch.float32,
init_method=simple_init
)
model.set_shared_expert_weights(
up_proj_weight=shared_up_proj_weight,
down_proj_weight=shared_down_proj_weight,
weighted_sum=False,
activation_fn=torch.nn.functional.relu
)
assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
assert model.shared_expert_weighted_sum == False, "Weighted sum not set correctly"
assert model.shared_activation_fn == torch.nn.functional.relu, "Activation function not set correctly"
print("Single process weighted sum setup test passed!")
else:
# Multi-process test
mp.spawn(run_distributed_shared_expert_weighted_sum_test, args=(world_size,), nprocs=world_size, join=True)
print("Multi-process shared expert weighted sum test completed successfully!")
def test_shared_expert_single_process():
model = MegaBlocksMoeMLPWithSharedExpert()
assert model.shared_up_proj_weight is None
assert model.shared_down_proj_weight is None
assert hasattr(model, 'set_shared_expert_weights')
print("Single process shared expert basic test passed!")
if __name__ == "__main__":
test_shared_expert_single_process()
print("Single process test passed!")
os.environ['WORLD_SIZE'] = '2'
test_shared_expert_distributed_functionality()
print("Distributed functionality test passed!")
test_shared_expert_distributed_weighted_sum()
print("Distributed weighted sum test passed!") |