|
import logging |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from muon import Muon, get_default_muon_param_groups |
|
from torch.distributed.fsdp import FSDPModule, fully_shard |
|
from torch.distributed.tensor import DTensor |
|
from torch.distributed.tensor.placement_types import Replicate |
|
from transformers import (AutoModelForCausalLM, AutoTokenizer, |
|
PreTrainedTokenizerBase) |
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
def load_model(fsdp: bool) -> torch.nn.Module: |
|
model_name = "Motif-Technologies/Motif-2.6B" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
).bfloat16().cuda() |
|
|
|
torch.manual_seed(0) |
|
random_grads = [] |
|
for param in model.parameters(): |
|
random_grad = torch.randn_like(param, |
|
device=param.device, |
|
dtype=param.dtype) |
|
random_grads.append(random_grad) |
|
|
|
if fsdp: |
|
for layer in model.model.layers: |
|
fully_shard(layer) |
|
layer.reshard() |
|
fully_shard(model) |
|
model.reshard() |
|
|
|
for i, param in enumerate(model.parameters()): |
|
if isinstance(param.data, DTensor): |
|
unsharded_grad = DTensor.from_local( |
|
random_grads[i], |
|
device_mesh=param.data.device_mesh, |
|
placements=[Replicate()] * param.data.device_mesh.ndim, |
|
) |
|
sharded_grad = unsharded_grad.redistribute( |
|
device_mesh=param.data.device_mesh, |
|
placements=param.data.placements) |
|
param.grad = sharded_grad |
|
else: |
|
param.grad = random_grads[i] |
|
|
|
return model |
|
|
|
|
|
def run_muon(fsdp: bool) -> torch.nn.Module: |
|
model = load_model(fsdp=fsdp) |
|
params = get_default_muon_param_groups(model) |
|
optim = Muon(params=params) |
|
optim.step() |
|
|
|
return model |
|
|
|
|
|
def compare_results(parallel_muon_result: torch.nn.Module, |
|
sequential_muon_result: torch.nn.Module) -> None: |
|
for (name_p, p), (name_s, |
|
s) in zip(parallel_muon_result.named_parameters(), |
|
sequential_muon_result.named_parameters()): |
|
p = p.data.full_tensor() |
|
s = s.data |
|
|
|
if torch.abs(p - s).max() > 0: |
|
max_diff_index = torch.argmax(torch.abs(p - s)) |
|
logger.error(f"Models differ at parameter {name_p}") |
|
return |
|
logger.info("Models match!") |
|
|
|
|
|
def test_muon(): |
|
parallel_muon_result = run_muon(fsdp=True) |
|
sequential_muon_result = run_muon(fsdp=False) |
|
|
|
compare_results(parallel_muon_result, sequential_muon_result) |
|
|
|
|
|
if __name__ == "__main__": |
|
dist.init_process_group(backend="nccl") |
|
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) |
|
test_muon() |
|
|