import logging import torch import torch.distributed as dist from muon import Muon, get_default_muon_param_groups from torch.distributed.fsdp import FSDPModule, fully_shard from torch.distributed.tensor import DTensor from torch.distributed.tensor.placement_types import Replicate from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def load_model(fsdp: bool) -> torch.nn.Module: model_name = "Motif-Technologies/Motif-2.6B" model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, ).bfloat16().cuda() torch.manual_seed(0) random_grads = [] for param in model.parameters(): random_grad = torch.randn_like(param, device=param.device, dtype=param.dtype) random_grads.append(random_grad) if fsdp: for layer in model.model.layers: fully_shard(layer) layer.reshard() fully_shard(model) model.reshard() for i, param in enumerate(model.parameters()): if isinstance(param.data, DTensor): unsharded_grad = DTensor.from_local( random_grads[i], device_mesh=param.data.device_mesh, placements=[Replicate()] * param.data.device_mesh.ndim, ) sharded_grad = unsharded_grad.redistribute( device_mesh=param.data.device_mesh, placements=param.data.placements) param.grad = sharded_grad else: param.grad = random_grads[i] return model def run_muon(fsdp: bool) -> torch.nn.Module: model = load_model(fsdp=fsdp) params = get_default_muon_param_groups(model) optim = Muon(params=params) optim.step() return model def compare_results(parallel_muon_result: torch.nn.Module, sequential_muon_result: torch.nn.Module) -> None: for (name_p, p), (name_s, s) in zip(parallel_muon_result.named_parameters(), sequential_muon_result.named_parameters()): p = p.data.full_tensor() s = s.data # Parallel Muon should exactly match Sequential Muon if torch.abs(p - s).max() > 0: max_diff_index = torch.argmax(torch.abs(p - s)) logger.error(f"Models differ at parameter {name_p}") return logger.info("Models match!") def test_muon(): parallel_muon_result = run_muon(fsdp=True) sequential_muon_result = run_muon(fsdp=False) compare_results(parallel_muon_result, sequential_muon_result) if __name__ == "__main__": dist.init_process_group(backend="nccl") torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) test_muon()