File size: 2,918 Bytes
3261444
 
 
 
b0f46c7
3261444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0f46c7
 
3261444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import logging

import torch
import torch.distributed as dist
from muon import Muon, get_default_muon_param_groups
from torch.distributed.fsdp import FSDPModule, fully_shard
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Replicate
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def load_model(fsdp: bool) -> torch.nn.Module:
    model_name = "Motif-Technologies/Motif-2.6B"
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
    ).bfloat16().cuda()

    torch.manual_seed(0)
    random_grads = []
    for param in model.parameters():
        random_grad = torch.randn_like(param,
                                       device=param.device,
                                       dtype=param.dtype)
        random_grads.append(random_grad)

    if fsdp:
        for layer in model.model.layers:
            fully_shard(layer)
            layer.reshard()
        fully_shard(model)
        model.reshard()

    for i, param in enumerate(model.parameters()):
        if isinstance(param.data, DTensor):
            unsharded_grad = DTensor.from_local(
                random_grads[i],
                device_mesh=param.data.device_mesh,
                placements=[Replicate()] * param.data.device_mesh.ndim,
            )
            sharded_grad = unsharded_grad.redistribute(
                device_mesh=param.data.device_mesh,
                placements=param.data.placements)
            param.grad = sharded_grad
        else:
            param.grad = random_grads[i]

    return model


def run_muon(fsdp: bool) -> torch.nn.Module:
    model = load_model(fsdp=fsdp)
    params = get_default_muon_param_groups(model)
    optim = Muon(params=params)
    optim.step()

    return model


def compare_results(parallel_muon_result: torch.nn.Module,
                    sequential_muon_result: torch.nn.Module) -> None:
    for (name_p, p), (name_s,
                      s) in zip(parallel_muon_result.named_parameters(),
                                sequential_muon_result.named_parameters()):
        p = p.data.full_tensor()
        s = s.data
        # Parallel Muon should exactly match Sequential Muon
        if torch.abs(p - s).max() > 0:
            max_diff_index = torch.argmax(torch.abs(p - s))
            logger.error(f"Models differ at parameter {name_p}")
            return
    logger.info("Models match!")


def test_muon():
    parallel_muon_result = run_muon(fsdp=True)
    sequential_muon_result = run_muon(fsdp=False)

    compare_results(parallel_muon_result, sequential_muon_result)


if __name__ == "__main__":
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
    test_muon()