junhyeok-motech
feat: update muon to receive paramgroups, not model (#4)
b0f46c7 unverified
import logging
import torch
import torch.distributed as dist
from muon import Muon, get_default_muon_param_groups
from torch.distributed.fsdp import FSDPModule, fully_shard
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Replicate
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def load_model(fsdp: bool) -> torch.nn.Module:
model_name = "Motif-Technologies/Motif-2.6B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
).bfloat16().cuda()
torch.manual_seed(0)
random_grads = []
for param in model.parameters():
random_grad = torch.randn_like(param,
device=param.device,
dtype=param.dtype)
random_grads.append(random_grad)
if fsdp:
for layer in model.model.layers:
fully_shard(layer)
layer.reshard()
fully_shard(model)
model.reshard()
for i, param in enumerate(model.parameters()):
if isinstance(param.data, DTensor):
unsharded_grad = DTensor.from_local(
random_grads[i],
device_mesh=param.data.device_mesh,
placements=[Replicate()] * param.data.device_mesh.ndim,
)
sharded_grad = unsharded_grad.redistribute(
device_mesh=param.data.device_mesh,
placements=param.data.placements)
param.grad = sharded_grad
else:
param.grad = random_grads[i]
return model
def run_muon(fsdp: bool) -> torch.nn.Module:
model = load_model(fsdp=fsdp)
params = get_default_muon_param_groups(model)
optim = Muon(params=params)
optim.step()
return model
def compare_results(parallel_muon_result: torch.nn.Module,
sequential_muon_result: torch.nn.Module) -> None:
for (name_p, p), (name_s,
s) in zip(parallel_muon_result.named_parameters(),
sequential_muon_result.named_parameters()):
p = p.data.full_tensor()
s = s.data
# Parallel Muon should exactly match Sequential Muon
if torch.abs(p - s).max() > 0:
max_diff_index = torch.argmax(torch.abs(p - s))
logger.error(f"Models differ at parameter {name_p}")
return
logger.info("Models match!")
def test_muon():
parallel_muon_result = run_muon(fsdp=True)
sequential_muon_result = run_muon(fsdp=False)
compare_results(parallel_muon_result, sequential_muon_result)
if __name__ == "__main__":
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
test_muon()