Support HSDP
#4
by
iamwyldecat
- opened
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +33 -15
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +33 -15
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +33 -15
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +33 -15
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +33 -15
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +33 -15
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +33 -15
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +33 -15
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +33 -15
- torch-ext/optimizer/muon.py +33 -15
build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
Binary files a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
Binary files a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_2dc97a1_dirty
|
3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1787368
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9112c8dde01baefa0e3130e143288cd3073ccbab47369a6dc925ce0d35400c6d
|
3 |
size 1787368
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
-
mesh = g.device_mesh
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
-
|
|
|
62 |
else:
|
63 |
gather_list = None
|
64 |
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
67 |
g.to_local(),
|
68 |
dst=state.worker_rank,
|
69 |
gather_list=gather_list,
|
70 |
-
group=
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
if state.gathered_grad is not None:
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
105 |
@torch.no_grad()
|
106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
-
mesh = p.device_mesh
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
|
|
112 |
if state.compute_event is None:
|
113 |
raise RuntimeError("Compute event must be set before scatter.")
|
114 |
comm_stream.wait_event(state.compute_event)
|
115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
116 |
else:
|
117 |
scatter_list = None
|
118 |
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
121 |
u,
|
122 |
scatter_list=scatter_list,
|
123 |
src=state.worker_rank,
|
124 |
-
group=
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
# Clear u to free memory
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
132 |
-
device_mesh=
|
133 |
)
|
134 |
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
235 |
adjusted_lr = lr * adjusted_ratio
|
236 |
return adjusted_lr
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def init_state_and_assign_params(self, params, group):
|
239 |
param_to_state = {}
|
240 |
param_to_flops = {}
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
259 |
|
260 |
round_robin = 0
|
261 |
mesh = None
|
|
|
|
|
262 |
for p in ordered_params:
|
263 |
if mesh is None:
|
264 |
mesh = p.device_mesh
|
265 |
-
|
266 |
-
raise NotImplementedError(
|
267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
268 |
-
)
|
269 |
elif mesh != p.device_mesh:
|
270 |
raise ValueError("All parameters must be on the same mesh.")
|
271 |
|
272 |
param_to_state[id(p)] = _muon_state()
|
273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
274 |
|
275 |
-
round_robin = (round_robin + 1) %
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
)
|
374 |
|
375 |
-
chunk_size = params[0].
|
376 |
|
377 |
# Wait grad update
|
378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
+
process_group = None
|
54 |
|
55 |
|
56 |
@torch.no_grad()
|
57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
58 |
g = p.grad
|
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
63 |
else:
|
64 |
gather_list = None
|
65 |
|
|
|
68 |
g.to_local(),
|
69 |
dst=state.worker_rank,
|
70 |
gather_list=gather_list,
|
71 |
+
group=state.process_group,
|
72 |
)
|
73 |
if rank == state.worker_rank:
|
74 |
if state.gathered_grad is not None:
|
|
|
106 |
@torch.no_grad()
|
107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
108 |
u = state.computed_u
|
|
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
113 |
if state.compute_event is None:
|
114 |
raise RuntimeError("Compute event must be set before scatter.")
|
115 |
comm_stream.wait_event(state.compute_event)
|
116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
117 |
else:
|
118 |
scatter_list = None
|
119 |
|
|
|
122 |
u,
|
123 |
scatter_list=scatter_list,
|
124 |
src=state.worker_rank,
|
125 |
+
group=state.process_group,
|
126 |
)
|
127 |
if rank == state.worker_rank:
|
128 |
# Clear u to free memory
|
|
|
130 |
u = DTensor.from_local(
|
131 |
u,
|
132 |
placements=p.placements,
|
133 |
+
device_mesh=p.device_mesh,
|
134 |
)
|
135 |
p.data.mul_(1 - lr * weight_decay)
|
136 |
p.data.add_(u, alpha=-lr)
|
|
|
236 |
adjusted_lr = lr * adjusted_ratio
|
237 |
return adjusted_lr
|
238 |
|
239 |
+
def get_shard_mesh(self, p, rank):
|
240 |
+
"""
|
241 |
+
Get the shard mesh for a parameter p on the given rank.
|
242 |
+
"""
|
243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
244 |
+
|
245 |
+
if p.placements == (Shard(dim=0),):
|
246 |
+
# Case for FSDP
|
247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
249 |
+
# Case for HSDP
|
250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
251 |
+
if rank in shard_mesh:
|
252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
253 |
+
else:
|
254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
255 |
+
|
256 |
def init_state_and_assign_params(self, params, group):
|
257 |
param_to_state = {}
|
258 |
param_to_flops = {}
|
|
|
277 |
|
278 |
round_robin = 0
|
279 |
mesh = None
|
280 |
+
shard_mesh = None
|
281 |
+
process_group = None
|
282 |
for p in ordered_params:
|
283 |
if mesh is None:
|
284 |
mesh = p.device_mesh
|
285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
286 |
elif mesh != p.device_mesh:
|
287 |
raise ValueError("All parameters must be on the same mesh.")
|
288 |
|
289 |
param_to_state[id(p)] = _muon_state()
|
290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
291 |
+
param_to_state[id(p)].process_group = process_group
|
292 |
|
293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
294 |
|
295 |
return param_to_state, ordered_params
|
296 |
|
|
|
390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
391 |
)
|
392 |
|
393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
394 |
|
395 |
# Wait grad update
|
396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
Binary files a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
Binary files a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_2dc97a1_dirty
|
3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824256
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0449cd352f44c3e848d1f9c847b00bf576673b4fef2a954ec8bd8d2524b8353a
|
3 |
size 1824256
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
-
mesh = g.device_mesh
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
-
|
|
|
62 |
else:
|
63 |
gather_list = None
|
64 |
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
67 |
g.to_local(),
|
68 |
dst=state.worker_rank,
|
69 |
gather_list=gather_list,
|
70 |
-
group=
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
if state.gathered_grad is not None:
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
105 |
@torch.no_grad()
|
106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
-
mesh = p.device_mesh
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
|
|
112 |
if state.compute_event is None:
|
113 |
raise RuntimeError("Compute event must be set before scatter.")
|
114 |
comm_stream.wait_event(state.compute_event)
|
115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
116 |
else:
|
117 |
scatter_list = None
|
118 |
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
121 |
u,
|
122 |
scatter_list=scatter_list,
|
123 |
src=state.worker_rank,
|
124 |
-
group=
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
# Clear u to free memory
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
132 |
-
device_mesh=
|
133 |
)
|
134 |
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
235 |
adjusted_lr = lr * adjusted_ratio
|
236 |
return adjusted_lr
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def init_state_and_assign_params(self, params, group):
|
239 |
param_to_state = {}
|
240 |
param_to_flops = {}
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
259 |
|
260 |
round_robin = 0
|
261 |
mesh = None
|
|
|
|
|
262 |
for p in ordered_params:
|
263 |
if mesh is None:
|
264 |
mesh = p.device_mesh
|
265 |
-
|
266 |
-
raise NotImplementedError(
|
267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
268 |
-
)
|
269 |
elif mesh != p.device_mesh:
|
270 |
raise ValueError("All parameters must be on the same mesh.")
|
271 |
|
272 |
param_to_state[id(p)] = _muon_state()
|
273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
274 |
|
275 |
-
round_robin = (round_robin + 1) %
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
)
|
374 |
|
375 |
-
chunk_size = params[0].
|
376 |
|
377 |
# Wait grad update
|
378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
+
process_group = None
|
54 |
|
55 |
|
56 |
@torch.no_grad()
|
57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
58 |
g = p.grad
|
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
63 |
else:
|
64 |
gather_list = None
|
65 |
|
|
|
68 |
g.to_local(),
|
69 |
dst=state.worker_rank,
|
70 |
gather_list=gather_list,
|
71 |
+
group=state.process_group,
|
72 |
)
|
73 |
if rank == state.worker_rank:
|
74 |
if state.gathered_grad is not None:
|
|
|
106 |
@torch.no_grad()
|
107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
108 |
u = state.computed_u
|
|
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
113 |
if state.compute_event is None:
|
114 |
raise RuntimeError("Compute event must be set before scatter.")
|
115 |
comm_stream.wait_event(state.compute_event)
|
116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
117 |
else:
|
118 |
scatter_list = None
|
119 |
|
|
|
122 |
u,
|
123 |
scatter_list=scatter_list,
|
124 |
src=state.worker_rank,
|
125 |
+
group=state.process_group,
|
126 |
)
|
127 |
if rank == state.worker_rank:
|
128 |
# Clear u to free memory
|
|
|
130 |
u = DTensor.from_local(
|
131 |
u,
|
132 |
placements=p.placements,
|
133 |
+
device_mesh=p.device_mesh,
|
134 |
)
|
135 |
p.data.mul_(1 - lr * weight_decay)
|
136 |
p.data.add_(u, alpha=-lr)
|
|
|
236 |
adjusted_lr = lr * adjusted_ratio
|
237 |
return adjusted_lr
|
238 |
|
239 |
+
def get_shard_mesh(self, p, rank):
|
240 |
+
"""
|
241 |
+
Get the shard mesh for a parameter p on the given rank.
|
242 |
+
"""
|
243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
244 |
+
|
245 |
+
if p.placements == (Shard(dim=0),):
|
246 |
+
# Case for FSDP
|
247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
249 |
+
# Case for HSDP
|
250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
251 |
+
if rank in shard_mesh:
|
252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
253 |
+
else:
|
254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
255 |
+
|
256 |
def init_state_and_assign_params(self, params, group):
|
257 |
param_to_state = {}
|
258 |
param_to_flops = {}
|
|
|
277 |
|
278 |
round_robin = 0
|
279 |
mesh = None
|
280 |
+
shard_mesh = None
|
281 |
+
process_group = None
|
282 |
for p in ordered_params:
|
283 |
if mesh is None:
|
284 |
mesh = p.device_mesh
|
285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
286 |
elif mesh != p.device_mesh:
|
287 |
raise ValueError("All parameters must be on the same mesh.")
|
288 |
|
289 |
param_to_state[id(p)] = _muon_state()
|
290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
291 |
+
param_to_state[id(p)].process_group = process_group
|
292 |
|
293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
294 |
|
295 |
return param_to_state, ordered_params
|
296 |
|
|
|
390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
391 |
)
|
392 |
|
393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
394 |
|
395 |
# Wait grad update
|
396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
Binary files a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
Binary files a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_2dc97a1_dirty
|
3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1883352
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2e6bab72b965f42d466cd74bbda49851549f2810278e642cef8738e40de4fdc5
|
3 |
size 1883352
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
-
mesh = g.device_mesh
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
-
|
|
|
62 |
else:
|
63 |
gather_list = None
|
64 |
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
67 |
g.to_local(),
|
68 |
dst=state.worker_rank,
|
69 |
gather_list=gather_list,
|
70 |
-
group=
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
if state.gathered_grad is not None:
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
105 |
@torch.no_grad()
|
106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
-
mesh = p.device_mesh
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
|
|
112 |
if state.compute_event is None:
|
113 |
raise RuntimeError("Compute event must be set before scatter.")
|
114 |
comm_stream.wait_event(state.compute_event)
|
115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
116 |
else:
|
117 |
scatter_list = None
|
118 |
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
121 |
u,
|
122 |
scatter_list=scatter_list,
|
123 |
src=state.worker_rank,
|
124 |
-
group=
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
# Clear u to free memory
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
132 |
-
device_mesh=
|
133 |
)
|
134 |
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
235 |
adjusted_lr = lr * adjusted_ratio
|
236 |
return adjusted_lr
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def init_state_and_assign_params(self, params, group):
|
239 |
param_to_state = {}
|
240 |
param_to_flops = {}
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
259 |
|
260 |
round_robin = 0
|
261 |
mesh = None
|
|
|
|
|
262 |
for p in ordered_params:
|
263 |
if mesh is None:
|
264 |
mesh = p.device_mesh
|
265 |
-
|
266 |
-
raise NotImplementedError(
|
267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
268 |
-
)
|
269 |
elif mesh != p.device_mesh:
|
270 |
raise ValueError("All parameters must be on the same mesh.")
|
271 |
|
272 |
param_to_state[id(p)] = _muon_state()
|
273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
274 |
|
275 |
-
round_robin = (round_robin + 1) %
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
)
|
374 |
|
375 |
-
chunk_size = params[0].
|
376 |
|
377 |
# Wait grad update
|
378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
+
process_group = None
|
54 |
|
55 |
|
56 |
@torch.no_grad()
|
57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
58 |
g = p.grad
|
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
63 |
else:
|
64 |
gather_list = None
|
65 |
|
|
|
68 |
g.to_local(),
|
69 |
dst=state.worker_rank,
|
70 |
gather_list=gather_list,
|
71 |
+
group=state.process_group,
|
72 |
)
|
73 |
if rank == state.worker_rank:
|
74 |
if state.gathered_grad is not None:
|
|
|
106 |
@torch.no_grad()
|
107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
108 |
u = state.computed_u
|
|
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
113 |
if state.compute_event is None:
|
114 |
raise RuntimeError("Compute event must be set before scatter.")
|
115 |
comm_stream.wait_event(state.compute_event)
|
116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
117 |
else:
|
118 |
scatter_list = None
|
119 |
|
|
|
122 |
u,
|
123 |
scatter_list=scatter_list,
|
124 |
src=state.worker_rank,
|
125 |
+
group=state.process_group,
|
126 |
)
|
127 |
if rank == state.worker_rank:
|
128 |
# Clear u to free memory
|
|
|
130 |
u = DTensor.from_local(
|
131 |
u,
|
132 |
placements=p.placements,
|
133 |
+
device_mesh=p.device_mesh,
|
134 |
)
|
135 |
p.data.mul_(1 - lr * weight_decay)
|
136 |
p.data.add_(u, alpha=-lr)
|
|
|
236 |
adjusted_lr = lr * adjusted_ratio
|
237 |
return adjusted_lr
|
238 |
|
239 |
+
def get_shard_mesh(self, p, rank):
|
240 |
+
"""
|
241 |
+
Get the shard mesh for a parameter p on the given rank.
|
242 |
+
"""
|
243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
244 |
+
|
245 |
+
if p.placements == (Shard(dim=0),):
|
246 |
+
# Case for FSDP
|
247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
249 |
+
# Case for HSDP
|
250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
251 |
+
if rank in shard_mesh:
|
252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
253 |
+
else:
|
254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
255 |
+
|
256 |
def init_state_and_assign_params(self, params, group):
|
257 |
param_to_state = {}
|
258 |
param_to_flops = {}
|
|
|
277 |
|
278 |
round_robin = 0
|
279 |
mesh = None
|
280 |
+
shard_mesh = None
|
281 |
+
process_group = None
|
282 |
for p in ordered_params:
|
283 |
if mesh is None:
|
284 |
mesh = p.device_mesh
|
285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
286 |
elif mesh != p.device_mesh:
|
287 |
raise ValueError("All parameters must be on the same mesh.")
|
288 |
|
289 |
param_to_state[id(p)] = _muon_state()
|
290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
291 |
+
param_to_state[id(p)].process_group = process_group
|
292 |
|
293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
294 |
|
295 |
return param_to_state, ordered_params
|
296 |
|
|
|
390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
391 |
)
|
392 |
|
393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
394 |
|
395 |
# Wait grad update
|
396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_2dc97a1_dirty
|
3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1749840
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bdcf9e3d8bf13aa01bf1ae7a94a12dd05c50702a24b57e4cfcc2e54ca5ed21c3
|
3 |
size 1749840
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
-
mesh = g.device_mesh
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
-
|
|
|
62 |
else:
|
63 |
gather_list = None
|
64 |
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
67 |
g.to_local(),
|
68 |
dst=state.worker_rank,
|
69 |
gather_list=gather_list,
|
70 |
-
group=
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
if state.gathered_grad is not None:
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
105 |
@torch.no_grad()
|
106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
-
mesh = p.device_mesh
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
|
|
112 |
if state.compute_event is None:
|
113 |
raise RuntimeError("Compute event must be set before scatter.")
|
114 |
comm_stream.wait_event(state.compute_event)
|
115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
116 |
else:
|
117 |
scatter_list = None
|
118 |
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
121 |
u,
|
122 |
scatter_list=scatter_list,
|
123 |
src=state.worker_rank,
|
124 |
-
group=
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
# Clear u to free memory
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
132 |
-
device_mesh=
|
133 |
)
|
134 |
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
235 |
adjusted_lr = lr * adjusted_ratio
|
236 |
return adjusted_lr
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def init_state_and_assign_params(self, params, group):
|
239 |
param_to_state = {}
|
240 |
param_to_flops = {}
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
259 |
|
260 |
round_robin = 0
|
261 |
mesh = None
|
|
|
|
|
262 |
for p in ordered_params:
|
263 |
if mesh is None:
|
264 |
mesh = p.device_mesh
|
265 |
-
|
266 |
-
raise NotImplementedError(
|
267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
268 |
-
)
|
269 |
elif mesh != p.device_mesh:
|
270 |
raise ValueError("All parameters must be on the same mesh.")
|
271 |
|
272 |
param_to_state[id(p)] = _muon_state()
|
273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
274 |
|
275 |
-
round_robin = (round_robin + 1) %
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
)
|
374 |
|
375 |
-
chunk_size = params[0].
|
376 |
|
377 |
# Wait grad update
|
378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
+
process_group = None
|
54 |
|
55 |
|
56 |
@torch.no_grad()
|
57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
58 |
g = p.grad
|
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
63 |
else:
|
64 |
gather_list = None
|
65 |
|
|
|
68 |
g.to_local(),
|
69 |
dst=state.worker_rank,
|
70 |
gather_list=gather_list,
|
71 |
+
group=state.process_group,
|
72 |
)
|
73 |
if rank == state.worker_rank:
|
74 |
if state.gathered_grad is not None:
|
|
|
106 |
@torch.no_grad()
|
107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
108 |
u = state.computed_u
|
|
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
113 |
if state.compute_event is None:
|
114 |
raise RuntimeError("Compute event must be set before scatter.")
|
115 |
comm_stream.wait_event(state.compute_event)
|
116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
117 |
else:
|
118 |
scatter_list = None
|
119 |
|
|
|
122 |
u,
|
123 |
scatter_list=scatter_list,
|
124 |
src=state.worker_rank,
|
125 |
+
group=state.process_group,
|
126 |
)
|
127 |
if rank == state.worker_rank:
|
128 |
# Clear u to free memory
|
|
|
130 |
u = DTensor.from_local(
|
131 |
u,
|
132 |
placements=p.placements,
|
133 |
+
device_mesh=p.device_mesh,
|
134 |
)
|
135 |
p.data.mul_(1 - lr * weight_decay)
|
136 |
p.data.add_(u, alpha=-lr)
|
|
|
236 |
adjusted_lr = lr * adjusted_ratio
|
237 |
return adjusted_lr
|
238 |
|
239 |
+
def get_shard_mesh(self, p, rank):
|
240 |
+
"""
|
241 |
+
Get the shard mesh for a parameter p on the given rank.
|
242 |
+
"""
|
243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
244 |
+
|
245 |
+
if p.placements == (Shard(dim=0),):
|
246 |
+
# Case for FSDP
|
247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
249 |
+
# Case for HSDP
|
250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
251 |
+
if rank in shard_mesh:
|
252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
253 |
+
else:
|
254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
255 |
+
|
256 |
def init_state_and_assign_params(self, params, group):
|
257 |
param_to_state = {}
|
258 |
param_to_flops = {}
|
|
|
277 |
|
278 |
round_robin = 0
|
279 |
mesh = None
|
280 |
+
shard_mesh = None
|
281 |
+
process_group = None
|
282 |
for p in ordered_params:
|
283 |
if mesh is None:
|
284 |
mesh = p.device_mesh
|
285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
286 |
elif mesh != p.device_mesh:
|
287 |
raise ValueError("All parameters must be on the same mesh.")
|
288 |
|
289 |
param_to_state[id(p)] = _muon_state()
|
290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
291 |
+
param_to_state[id(p)].process_group = process_group
|
292 |
|
293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
294 |
|
295 |
return param_to_state, ordered_params
|
296 |
|
|
|
390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
391 |
)
|
392 |
|
393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
394 |
|
395 |
# Wait grad update
|
396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
Binary files a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
Binary files a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_2dc97a1_dirty
|
3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824256
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a423eb4ab3a31c53a3326c71e34fa59fc661f8d432701e41a7de900a9c23e37c
|
3 |
size 1824256
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
-
mesh = g.device_mesh
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
-
|
|
|
62 |
else:
|
63 |
gather_list = None
|
64 |
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
67 |
g.to_local(),
|
68 |
dst=state.worker_rank,
|
69 |
gather_list=gather_list,
|
70 |
-
group=
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
if state.gathered_grad is not None:
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
105 |
@torch.no_grad()
|
106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
-
mesh = p.device_mesh
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
|
|
112 |
if state.compute_event is None:
|
113 |
raise RuntimeError("Compute event must be set before scatter.")
|
114 |
comm_stream.wait_event(state.compute_event)
|
115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
116 |
else:
|
117 |
scatter_list = None
|
118 |
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
121 |
u,
|
122 |
scatter_list=scatter_list,
|
123 |
src=state.worker_rank,
|
124 |
-
group=
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
# Clear u to free memory
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
132 |
-
device_mesh=
|
133 |
)
|
134 |
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
235 |
adjusted_lr = lr * adjusted_ratio
|
236 |
return adjusted_lr
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def init_state_and_assign_params(self, params, group):
|
239 |
param_to_state = {}
|
240 |
param_to_flops = {}
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
259 |
|
260 |
round_robin = 0
|
261 |
mesh = None
|
|
|
|
|
262 |
for p in ordered_params:
|
263 |
if mesh is None:
|
264 |
mesh = p.device_mesh
|
265 |
-
|
266 |
-
raise NotImplementedError(
|
267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
268 |
-
)
|
269 |
elif mesh != p.device_mesh:
|
270 |
raise ValueError("All parameters must be on the same mesh.")
|
271 |
|
272 |
param_to_state[id(p)] = _muon_state()
|
273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
274 |
|
275 |
-
round_robin = (round_robin + 1) %
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
)
|
374 |
|
375 |
-
chunk_size = params[0].
|
376 |
|
377 |
# Wait grad update
|
378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
+
process_group = None
|
54 |
|
55 |
|
56 |
@torch.no_grad()
|
57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
58 |
g = p.grad
|
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
63 |
else:
|
64 |
gather_list = None
|
65 |
|
|
|
68 |
g.to_local(),
|
69 |
dst=state.worker_rank,
|
70 |
gather_list=gather_list,
|
71 |
+
group=state.process_group,
|
72 |
)
|
73 |
if rank == state.worker_rank:
|
74 |
if state.gathered_grad is not None:
|
|
|
106 |
@torch.no_grad()
|
107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
108 |
u = state.computed_u
|
|
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
113 |
if state.compute_event is None:
|
114 |
raise RuntimeError("Compute event must be set before scatter.")
|
115 |
comm_stream.wait_event(state.compute_event)
|
116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
117 |
else:
|
118 |
scatter_list = None
|
119 |
|
|
|
122 |
u,
|
123 |
scatter_list=scatter_list,
|
124 |
src=state.worker_rank,
|
125 |
+
group=state.process_group,
|
126 |
)
|
127 |
if rank == state.worker_rank:
|
128 |
# Clear u to free memory
|
|
|
130 |
u = DTensor.from_local(
|
131 |
u,
|
132 |
placements=p.placements,
|
133 |
+
device_mesh=p.device_mesh,
|
134 |
)
|
135 |
p.data.mul_(1 - lr * weight_decay)
|
136 |
p.data.add_(u, alpha=-lr)
|
|
|
236 |
adjusted_lr = lr * adjusted_ratio
|
237 |
return adjusted_lr
|
238 |
|
239 |
+
def get_shard_mesh(self, p, rank):
|
240 |
+
"""
|
241 |
+
Get the shard mesh for a parameter p on the given rank.
|
242 |
+
"""
|
243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
244 |
+
|
245 |
+
if p.placements == (Shard(dim=0),):
|
246 |
+
# Case for FSDP
|
247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
249 |
+
# Case for HSDP
|
250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
251 |
+
if rank in shard_mesh:
|
252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
253 |
+
else:
|
254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
255 |
+
|
256 |
def init_state_and_assign_params(self, params, group):
|
257 |
param_to_state = {}
|
258 |
param_to_flops = {}
|
|
|
277 |
|
278 |
round_robin = 0
|
279 |
mesh = None
|
280 |
+
shard_mesh = None
|
281 |
+
process_group = None
|
282 |
for p in ordered_params:
|
283 |
if mesh is None:
|
284 |
mesh = p.device_mesh
|
285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
286 |
elif mesh != p.device_mesh:
|
287 |
raise ValueError("All parameters must be on the same mesh.")
|
288 |
|
289 |
param_to_state[id(p)] = _muon_state()
|
290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
291 |
+
param_to_state[id(p)].process_group = process_group
|
292 |
|
293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
294 |
|
295 |
return param_to_state, ordered_params
|
296 |
|
|
|
390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
391 |
)
|
392 |
|
393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
394 |
|
395 |
# Wait grad update
|
396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
Binary files a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
Binary files a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_2dc97a1_dirty
|
3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1883352
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:86d98863cc7ef0b271808b0ef7b1082603cfb5a76986481df37431527aaaf27b
|
3 |
size 1883352
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
-
mesh = g.device_mesh
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
-
|
|
|
62 |
else:
|
63 |
gather_list = None
|
64 |
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
67 |
g.to_local(),
|
68 |
dst=state.worker_rank,
|
69 |
gather_list=gather_list,
|
70 |
-
group=
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
if state.gathered_grad is not None:
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
105 |
@torch.no_grad()
|
106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
-
mesh = p.device_mesh
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
|
|
112 |
if state.compute_event is None:
|
113 |
raise RuntimeError("Compute event must be set before scatter.")
|
114 |
comm_stream.wait_event(state.compute_event)
|
115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
116 |
else:
|
117 |
scatter_list = None
|
118 |
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
121 |
u,
|
122 |
scatter_list=scatter_list,
|
123 |
src=state.worker_rank,
|
124 |
-
group=
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
# Clear u to free memory
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
132 |
-
device_mesh=
|
133 |
)
|
134 |
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
235 |
adjusted_lr = lr * adjusted_ratio
|
236 |
return adjusted_lr
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def init_state_and_assign_params(self, params, group):
|
239 |
param_to_state = {}
|
240 |
param_to_flops = {}
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
259 |
|
260 |
round_robin = 0
|
261 |
mesh = None
|
|
|
|
|
262 |
for p in ordered_params:
|
263 |
if mesh is None:
|
264 |
mesh = p.device_mesh
|
265 |
-
|
266 |
-
raise NotImplementedError(
|
267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
268 |
-
)
|
269 |
elif mesh != p.device_mesh:
|
270 |
raise ValueError("All parameters must be on the same mesh.")
|
271 |
|
272 |
param_to_state[id(p)] = _muon_state()
|
273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
274 |
|
275 |
-
round_robin = (round_robin + 1) %
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
)
|
374 |
|
375 |
-
chunk_size = params[0].
|
376 |
|
377 |
# Wait grad update
|
378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
+
process_group = None
|
54 |
|
55 |
|
56 |
@torch.no_grad()
|
57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
58 |
g = p.grad
|
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
63 |
else:
|
64 |
gather_list = None
|
65 |
|
|
|
68 |
g.to_local(),
|
69 |
dst=state.worker_rank,
|
70 |
gather_list=gather_list,
|
71 |
+
group=state.process_group,
|
72 |
)
|
73 |
if rank == state.worker_rank:
|
74 |
if state.gathered_grad is not None:
|
|
|
106 |
@torch.no_grad()
|
107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
108 |
u = state.computed_u
|
|
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
113 |
if state.compute_event is None:
|
114 |
raise RuntimeError("Compute event must be set before scatter.")
|
115 |
comm_stream.wait_event(state.compute_event)
|
116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
117 |
else:
|
118 |
scatter_list = None
|
119 |
|
|
|
122 |
u,
|
123 |
scatter_list=scatter_list,
|
124 |
src=state.worker_rank,
|
125 |
+
group=state.process_group,
|
126 |
)
|
127 |
if rank == state.worker_rank:
|
128 |
# Clear u to free memory
|
|
|
130 |
u = DTensor.from_local(
|
131 |
u,
|
132 |
placements=p.placements,
|
133 |
+
device_mesh=p.device_mesh,
|
134 |
)
|
135 |
p.data.mul_(1 - lr * weight_decay)
|
136 |
p.data.add_(u, alpha=-lr)
|
|
|
236 |
adjusted_lr = lr * adjusted_ratio
|
237 |
return adjusted_lr
|
238 |
|
239 |
+
def get_shard_mesh(self, p, rank):
|
240 |
+
"""
|
241 |
+
Get the shard mesh for a parameter p on the given rank.
|
242 |
+
"""
|
243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
244 |
+
|
245 |
+
if p.placements == (Shard(dim=0),):
|
246 |
+
# Case for FSDP
|
247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
249 |
+
# Case for HSDP
|
250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
251 |
+
if rank in shard_mesh:
|
252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
253 |
+
else:
|
254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
255 |
+
|
256 |
def init_state_and_assign_params(self, params, group):
|
257 |
param_to_state = {}
|
258 |
param_to_flops = {}
|
|
|
277 |
|
278 |
round_robin = 0
|
279 |
mesh = None
|
280 |
+
shard_mesh = None
|
281 |
+
process_group = None
|
282 |
for p in ordered_params:
|
283 |
if mesh is None:
|
284 |
mesh = p.device_mesh
|
285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
286 |
elif mesh != p.device_mesh:
|
287 |
raise ValueError("All parameters must be on the same mesh.")
|
288 |
|
289 |
param_to_state[id(p)] = _muon_state()
|
290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
291 |
+
param_to_state[id(p)].process_group = process_group
|
292 |
|
293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
294 |
|
295 |
return param_to_state, ordered_params
|
296 |
|
|
|
390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
391 |
)
|
392 |
|
393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
394 |
|
395 |
# Wait grad update
|
396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
Binary files a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
Binary files a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_2dc97a1_dirty
|
3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1883352
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8daaad69e6958850f848fab60c9acb938c3a5e54e3ec34a1bec03a3d32653cb
|
3 |
size 1883352
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
-
mesh = g.device_mesh
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
-
|
|
|
62 |
else:
|
63 |
gather_list = None
|
64 |
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
67 |
g.to_local(),
|
68 |
dst=state.worker_rank,
|
69 |
gather_list=gather_list,
|
70 |
-
group=
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
if state.gathered_grad is not None:
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
105 |
@torch.no_grad()
|
106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
-
mesh = p.device_mesh
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
|
|
112 |
if state.compute_event is None:
|
113 |
raise RuntimeError("Compute event must be set before scatter.")
|
114 |
comm_stream.wait_event(state.compute_event)
|
115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
116 |
else:
|
117 |
scatter_list = None
|
118 |
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
121 |
u,
|
122 |
scatter_list=scatter_list,
|
123 |
src=state.worker_rank,
|
124 |
-
group=
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
# Clear u to free memory
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
132 |
-
device_mesh=
|
133 |
)
|
134 |
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
235 |
adjusted_lr = lr * adjusted_ratio
|
236 |
return adjusted_lr
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def init_state_and_assign_params(self, params, group):
|
239 |
param_to_state = {}
|
240 |
param_to_flops = {}
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
259 |
|
260 |
round_robin = 0
|
261 |
mesh = None
|
|
|
|
|
262 |
for p in ordered_params:
|
263 |
if mesh is None:
|
264 |
mesh = p.device_mesh
|
265 |
-
|
266 |
-
raise NotImplementedError(
|
267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
268 |
-
)
|
269 |
elif mesh != p.device_mesh:
|
270 |
raise ValueError("All parameters must be on the same mesh.")
|
271 |
|
272 |
param_to_state[id(p)] = _muon_state()
|
273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
274 |
|
275 |
-
round_robin = (round_robin + 1) %
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
)
|
374 |
|
375 |
-
chunk_size = params[0].
|
376 |
|
377 |
# Wait grad update
|
378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
+
process_group = None
|
54 |
|
55 |
|
56 |
@torch.no_grad()
|
57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
58 |
g = p.grad
|
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
63 |
else:
|
64 |
gather_list = None
|
65 |
|
|
|
68 |
g.to_local(),
|
69 |
dst=state.worker_rank,
|
70 |
gather_list=gather_list,
|
71 |
+
group=state.process_group,
|
72 |
)
|
73 |
if rank == state.worker_rank:
|
74 |
if state.gathered_grad is not None:
|
|
|
106 |
@torch.no_grad()
|
107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
108 |
u = state.computed_u
|
|
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
113 |
if state.compute_event is None:
|
114 |
raise RuntimeError("Compute event must be set before scatter.")
|
115 |
comm_stream.wait_event(state.compute_event)
|
116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
117 |
else:
|
118 |
scatter_list = None
|
119 |
|
|
|
122 |
u,
|
123 |
scatter_list=scatter_list,
|
124 |
src=state.worker_rank,
|
125 |
+
group=state.process_group,
|
126 |
)
|
127 |
if rank == state.worker_rank:
|
128 |
# Clear u to free memory
|
|
|
130 |
u = DTensor.from_local(
|
131 |
u,
|
132 |
placements=p.placements,
|
133 |
+
device_mesh=p.device_mesh,
|
134 |
)
|
135 |
p.data.mul_(1 - lr * weight_decay)
|
136 |
p.data.add_(u, alpha=-lr)
|
|
|
236 |
adjusted_lr = lr * adjusted_ratio
|
237 |
return adjusted_lr
|
238 |
|
239 |
+
def get_shard_mesh(self, p, rank):
|
240 |
+
"""
|
241 |
+
Get the shard mesh for a parameter p on the given rank.
|
242 |
+
"""
|
243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
244 |
+
|
245 |
+
if p.placements == (Shard(dim=0),):
|
246 |
+
# Case for FSDP
|
247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
249 |
+
# Case for HSDP
|
250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
251 |
+
if rank in shard_mesh:
|
252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
253 |
+
else:
|
254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
255 |
+
|
256 |
def init_state_and_assign_params(self, params, group):
|
257 |
param_to_state = {}
|
258 |
param_to_flops = {}
|
|
|
277 |
|
278 |
round_robin = 0
|
279 |
mesh = None
|
280 |
+
shard_mesh = None
|
281 |
+
process_group = None
|
282 |
for p in ordered_params:
|
283 |
if mesh is None:
|
284 |
mesh = p.device_mesh
|
285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
286 |
elif mesh != p.device_mesh:
|
287 |
raise ValueError("All parameters must be on the same mesh.")
|
288 |
|
289 |
param_to_state[id(p)] = _muon_state()
|
290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
291 |
+
param_to_state[id(p)].process_group = process_group
|
292 |
|
293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
294 |
|
295 |
return param_to_state, ordered_params
|
296 |
|
|
|
390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
391 |
)
|
392 |
|
393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
394 |
|
395 |
# Wait grad update
|
396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
Binary files a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
Binary files a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_2dc97a1_dirty
|
3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1750000
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:76910ba81e2c95c83207118725c4379db636346c4ccf05010e2ee00c41dff1ce
|
3 |
size 1750000
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
-
mesh = g.device_mesh
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
-
|
|
|
62 |
else:
|
63 |
gather_list = None
|
64 |
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
67 |
g.to_local(),
|
68 |
dst=state.worker_rank,
|
69 |
gather_list=gather_list,
|
70 |
-
group=
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
if state.gathered_grad is not None:
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
105 |
@torch.no_grad()
|
106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
-
mesh = p.device_mesh
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
|
|
112 |
if state.compute_event is None:
|
113 |
raise RuntimeError("Compute event must be set before scatter.")
|
114 |
comm_stream.wait_event(state.compute_event)
|
115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
116 |
else:
|
117 |
scatter_list = None
|
118 |
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
121 |
u,
|
122 |
scatter_list=scatter_list,
|
123 |
src=state.worker_rank,
|
124 |
-
group=
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
# Clear u to free memory
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
132 |
-
device_mesh=
|
133 |
)
|
134 |
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
235 |
adjusted_lr = lr * adjusted_ratio
|
236 |
return adjusted_lr
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def init_state_and_assign_params(self, params, group):
|
239 |
param_to_state = {}
|
240 |
param_to_flops = {}
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
259 |
|
260 |
round_robin = 0
|
261 |
mesh = None
|
|
|
|
|
262 |
for p in ordered_params:
|
263 |
if mesh is None:
|
264 |
mesh = p.device_mesh
|
265 |
-
|
266 |
-
raise NotImplementedError(
|
267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
268 |
-
)
|
269 |
elif mesh != p.device_mesh:
|
270 |
raise ValueError("All parameters must be on the same mesh.")
|
271 |
|
272 |
param_to_state[id(p)] = _muon_state()
|
273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
274 |
|
275 |
-
round_robin = (round_robin + 1) %
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
)
|
374 |
|
375 |
-
chunk_size = params[0].
|
376 |
|
377 |
# Wait grad update
|
378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
+
process_group = None
|
54 |
|
55 |
|
56 |
@torch.no_grad()
|
57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
58 |
g = p.grad
|
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
63 |
else:
|
64 |
gather_list = None
|
65 |
|
|
|
68 |
g.to_local(),
|
69 |
dst=state.worker_rank,
|
70 |
gather_list=gather_list,
|
71 |
+
group=state.process_group,
|
72 |
)
|
73 |
if rank == state.worker_rank:
|
74 |
if state.gathered_grad is not None:
|
|
|
106 |
@torch.no_grad()
|
107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
108 |
u = state.computed_u
|
|
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
113 |
if state.compute_event is None:
|
114 |
raise RuntimeError("Compute event must be set before scatter.")
|
115 |
comm_stream.wait_event(state.compute_event)
|
116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
117 |
else:
|
118 |
scatter_list = None
|
119 |
|
|
|
122 |
u,
|
123 |
scatter_list=scatter_list,
|
124 |
src=state.worker_rank,
|
125 |
+
group=state.process_group,
|
126 |
)
|
127 |
if rank == state.worker_rank:
|
128 |
# Clear u to free memory
|
|
|
130 |
u = DTensor.from_local(
|
131 |
u,
|
132 |
placements=p.placements,
|
133 |
+
device_mesh=p.device_mesh,
|
134 |
)
|
135 |
p.data.mul_(1 - lr * weight_decay)
|
136 |
p.data.add_(u, alpha=-lr)
|
|
|
236 |
adjusted_lr = lr * adjusted_ratio
|
237 |
return adjusted_lr
|
238 |
|
239 |
+
def get_shard_mesh(self, p, rank):
|
240 |
+
"""
|
241 |
+
Get the shard mesh for a parameter p on the given rank.
|
242 |
+
"""
|
243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
244 |
+
|
245 |
+
if p.placements == (Shard(dim=0),):
|
246 |
+
# Case for FSDP
|
247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
249 |
+
# Case for HSDP
|
250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
251 |
+
if rank in shard_mesh:
|
252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
253 |
+
else:
|
254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
255 |
+
|
256 |
def init_state_and_assign_params(self, params, group):
|
257 |
param_to_state = {}
|
258 |
param_to_flops = {}
|
|
|
277 |
|
278 |
round_robin = 0
|
279 |
mesh = None
|
280 |
+
shard_mesh = None
|
281 |
+
process_group = None
|
282 |
for p in ordered_params:
|
283 |
if mesh is None:
|
284 |
mesh = p.device_mesh
|
285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
286 |
elif mesh != p.device_mesh:
|
287 |
raise ValueError("All parameters must be on the same mesh.")
|
288 |
|
289 |
param_to_state[id(p)] = _muon_state()
|
290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
291 |
+
param_to_state[id(p)].process_group = process_group
|
292 |
|
293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
294 |
|
295 |
return param_to_state, ordered_params
|
296 |
|
|
|
390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
391 |
)
|
392 |
|
393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
394 |
|
395 |
# Wait grad update
|
396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
Binary files a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
Binary files a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_2dc97a1_dirty
|
3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1750088
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd0a35a6f846a075a8f4561cfc66ef17c6358dd4a0062e63057b02625d9d6af7
|
3 |
size 1750088
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
-
mesh = g.device_mesh
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
-
|
|
|
62 |
else:
|
63 |
gather_list = None
|
64 |
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
67 |
g.to_local(),
|
68 |
dst=state.worker_rank,
|
69 |
gather_list=gather_list,
|
70 |
-
group=
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
if state.gathered_grad is not None:
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
105 |
@torch.no_grad()
|
106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
-
mesh = p.device_mesh
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
|
|
112 |
if state.compute_event is None:
|
113 |
raise RuntimeError("Compute event must be set before scatter.")
|
114 |
comm_stream.wait_event(state.compute_event)
|
115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
116 |
else:
|
117 |
scatter_list = None
|
118 |
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
121 |
u,
|
122 |
scatter_list=scatter_list,
|
123 |
src=state.worker_rank,
|
124 |
-
group=
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
# Clear u to free memory
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
132 |
-
device_mesh=
|
133 |
)
|
134 |
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
235 |
adjusted_lr = lr * adjusted_ratio
|
236 |
return adjusted_lr
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def init_state_and_assign_params(self, params, group):
|
239 |
param_to_state = {}
|
240 |
param_to_flops = {}
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
259 |
|
260 |
round_robin = 0
|
261 |
mesh = None
|
|
|
|
|
262 |
for p in ordered_params:
|
263 |
if mesh is None:
|
264 |
mesh = p.device_mesh
|
265 |
-
|
266 |
-
raise NotImplementedError(
|
267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
268 |
-
)
|
269 |
elif mesh != p.device_mesh:
|
270 |
raise ValueError("All parameters must be on the same mesh.")
|
271 |
|
272 |
param_to_state[id(p)] = _muon_state()
|
273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
274 |
|
275 |
-
round_robin = (round_robin + 1) %
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
)
|
374 |
|
375 |
-
chunk_size = params[0].
|
376 |
|
377 |
# Wait grad update
|
378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
+
process_group = None
|
54 |
|
55 |
|
56 |
@torch.no_grad()
|
57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
58 |
g = p.grad
|
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
63 |
else:
|
64 |
gather_list = None
|
65 |
|
|
|
68 |
g.to_local(),
|
69 |
dst=state.worker_rank,
|
70 |
gather_list=gather_list,
|
71 |
+
group=state.process_group,
|
72 |
)
|
73 |
if rank == state.worker_rank:
|
74 |
if state.gathered_grad is not None:
|
|
|
106 |
@torch.no_grad()
|
107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
108 |
u = state.computed_u
|
|
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
113 |
if state.compute_event is None:
|
114 |
raise RuntimeError("Compute event must be set before scatter.")
|
115 |
comm_stream.wait_event(state.compute_event)
|
116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
117 |
else:
|
118 |
scatter_list = None
|
119 |
|
|
|
122 |
u,
|
123 |
scatter_list=scatter_list,
|
124 |
src=state.worker_rank,
|
125 |
+
group=state.process_group,
|
126 |
)
|
127 |
if rank == state.worker_rank:
|
128 |
# Clear u to free memory
|
|
|
130 |
u = DTensor.from_local(
|
131 |
u,
|
132 |
placements=p.placements,
|
133 |
+
device_mesh=p.device_mesh,
|
134 |
)
|
135 |
p.data.mul_(1 - lr * weight_decay)
|
136 |
p.data.add_(u, alpha=-lr)
|
|
|
236 |
adjusted_lr = lr * adjusted_ratio
|
237 |
return adjusted_lr
|
238 |
|
239 |
+
def get_shard_mesh(self, p, rank):
|
240 |
+
"""
|
241 |
+
Get the shard mesh for a parameter p on the given rank.
|
242 |
+
"""
|
243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
244 |
+
|
245 |
+
if p.placements == (Shard(dim=0),):
|
246 |
+
# Case for FSDP
|
247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
249 |
+
# Case for HSDP
|
250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
251 |
+
if rank in shard_mesh:
|
252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
253 |
+
else:
|
254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
255 |
+
|
256 |
def init_state_and_assign_params(self, params, group):
|
257 |
param_to_state = {}
|
258 |
param_to_flops = {}
|
|
|
277 |
|
278 |
round_robin = 0
|
279 |
mesh = None
|
280 |
+
shard_mesh = None
|
281 |
+
process_group = None
|
282 |
for p in ordered_params:
|
283 |
if mesh is None:
|
284 |
mesh = p.device_mesh
|
285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
286 |
elif mesh != p.device_mesh:
|
287 |
raise ValueError("All parameters must be on the same mesh.")
|
288 |
|
289 |
param_to_state[id(p)] = _muon_state()
|
290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
291 |
+
param_to_state[id(p)].process_group = process_group
|
292 |
|
293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
294 |
|
295 |
return param_to_state, ordered_params
|
296 |
|
|
|
390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
391 |
)
|
392 |
|
393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
394 |
|
395 |
# Wait grad update
|
396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
torch-ext/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
|
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
57 |
g = p.grad
|
58 |
-
mesh = g.device_mesh
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
-
|
|
|
62 |
else:
|
63 |
gather_list = None
|
64 |
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
67 |
g.to_local(),
|
68 |
dst=state.worker_rank,
|
69 |
gather_list=gather_list,
|
70 |
-
group=
|
71 |
)
|
72 |
if rank == state.worker_rank:
|
73 |
if state.gathered_grad is not None:
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
105 |
@torch.no_grad()
|
106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
-
mesh = p.device_mesh
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
|
|
112 |
if state.compute_event is None:
|
113 |
raise RuntimeError("Compute event must be set before scatter.")
|
114 |
comm_stream.wait_event(state.compute_event)
|
115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
116 |
else:
|
117 |
scatter_list = None
|
118 |
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
121 |
u,
|
122 |
scatter_list=scatter_list,
|
123 |
src=state.worker_rank,
|
124 |
-
group=
|
125 |
)
|
126 |
if rank == state.worker_rank:
|
127 |
# Clear u to free memory
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
129 |
u = DTensor.from_local(
|
130 |
u,
|
131 |
placements=p.placements,
|
132 |
-
device_mesh=
|
133 |
)
|
134 |
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
235 |
adjusted_lr = lr * adjusted_ratio
|
236 |
return adjusted_lr
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def init_state_and_assign_params(self, params, group):
|
239 |
param_to_state = {}
|
240 |
param_to_flops = {}
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
259 |
|
260 |
round_robin = 0
|
261 |
mesh = None
|
|
|
|
|
262 |
for p in ordered_params:
|
263 |
if mesh is None:
|
264 |
mesh = p.device_mesh
|
265 |
-
|
266 |
-
raise NotImplementedError(
|
267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
268 |
-
)
|
269 |
elif mesh != p.device_mesh:
|
270 |
raise ValueError("All parameters must be on the same mesh.")
|
271 |
|
272 |
param_to_state[id(p)] = _muon_state()
|
273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
274 |
|
275 |
-
round_robin = (round_robin + 1) %
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
)
|
374 |
|
375 |
-
chunk_size = params[0].
|
376 |
|
377 |
# Wait grad update
|
378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
+
process_group = None
|
54 |
|
55 |
|
56 |
@torch.no_grad()
|
57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
58 |
g = p.grad
|
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
63 |
else:
|
64 |
gather_list = None
|
65 |
|
|
|
68 |
g.to_local(),
|
69 |
dst=state.worker_rank,
|
70 |
gather_list=gather_list,
|
71 |
+
group=state.process_group,
|
72 |
)
|
73 |
if rank == state.worker_rank:
|
74 |
if state.gathered_grad is not None:
|
|
|
106 |
@torch.no_grad()
|
107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
108 |
u = state.computed_u
|
|
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
111 |
if rank == state.worker_rank:
|
112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
113 |
if state.compute_event is None:
|
114 |
raise RuntimeError("Compute event must be set before scatter.")
|
115 |
comm_stream.wait_event(state.compute_event)
|
116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
117 |
else:
|
118 |
scatter_list = None
|
119 |
|
|
|
122 |
u,
|
123 |
scatter_list=scatter_list,
|
124 |
src=state.worker_rank,
|
125 |
+
group=state.process_group,
|
126 |
)
|
127 |
if rank == state.worker_rank:
|
128 |
# Clear u to free memory
|
|
|
130 |
u = DTensor.from_local(
|
131 |
u,
|
132 |
placements=p.placements,
|
133 |
+
device_mesh=p.device_mesh,
|
134 |
)
|
135 |
p.data.mul_(1 - lr * weight_decay)
|
136 |
p.data.add_(u, alpha=-lr)
|
|
|
236 |
adjusted_lr = lr * adjusted_ratio
|
237 |
return adjusted_lr
|
238 |
|
239 |
+
def get_shard_mesh(self, p, rank):
|
240 |
+
"""
|
241 |
+
Get the shard mesh for a parameter p on the given rank.
|
242 |
+
"""
|
243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
244 |
+
|
245 |
+
if p.placements == (Shard(dim=0),):
|
246 |
+
# Case for FSDP
|
247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
249 |
+
# Case for HSDP
|
250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
251 |
+
if rank in shard_mesh:
|
252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
253 |
+
else:
|
254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
255 |
+
|
256 |
def init_state_and_assign_params(self, params, group):
|
257 |
param_to_state = {}
|
258 |
param_to_flops = {}
|
|
|
277 |
|
278 |
round_robin = 0
|
279 |
mesh = None
|
280 |
+
shard_mesh = None
|
281 |
+
process_group = None
|
282 |
for p in ordered_params:
|
283 |
if mesh is None:
|
284 |
mesh = p.device_mesh
|
285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
286 |
elif mesh != p.device_mesh:
|
287 |
raise ValueError("All parameters must be on the same mesh.")
|
288 |
|
289 |
param_to_state[id(p)] = _muon_state()
|
290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
291 |
+
param_to_state[id(p)].process_group = process_group
|
292 |
|
293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
294 |
|
295 |
return param_to_state, ordered_params
|
296 |
|
|
|
390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
391 |
)
|
392 |
|
393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
394 |
|
395 |
# Wait grad update
|
396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|