Files changed (46) hide show
  1. build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  2. build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  3. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  4. build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} +1 -1
  5. build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +33 -15
  6. build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  7. build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  8. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  9. build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
  10. build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +33 -15
  11. build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  12. build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  13. build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  14. build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
  15. build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +33 -15
  16. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  17. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  18. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  19. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} +1 -1
  20. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +33 -15
  21. build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  22. build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  23. build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  24. build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
  25. build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +33 -15
  26. build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  27. build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  28. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  29. build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
  30. build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +33 -15
  31. build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  32. build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  33. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
  34. build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
  35. build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +33 -15
  36. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  37. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  38. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  39. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} +1 -1
  40. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +33 -15
  41. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  42. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  43. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
  44. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} +1 -1
  45. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +33 -15
  46. torch-ext/optimizer/muon.py +33 -15
build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
 
build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
 
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_1f13dae_dirty
3
- ops = torch.ops._optimizer_1f13dae_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_1f13dae_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_2dc97a1_dirty
3
+ ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_2dc97a1_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7dc5f8a57aa60483209dfcbb0c7cc0e54f1739d643145c1e685fbe2b6675ac43
3
  size 1787368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9112c8dde01baefa0e3130e143288cd3073ccbab47369a6dc925ce0d35400c6d
3
  size 1787368
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -50,15 +50,16 @@ class _muon_state:
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
 
54
 
55
  @torch.no_grad()
56
  def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
- mesh = g.device_mesh
59
 
60
  if rank == state.worker_rank:
61
- gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
 
62
  else:
63
  gather_list = None
64
 
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
67
  g.to_local(),
68
  dst=state.worker_rank,
69
  gather_list=gather_list,
70
- group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
  if state.gathered_grad is not None:
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
105
  @torch.no_grad()
106
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
- mesh = p.device_mesh
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
 
112
  if state.compute_event is None:
113
  raise RuntimeError("Compute event must be set before scatter.")
114
  comm_stream.wait_event(state.compute_event)
115
- scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
116
  else:
117
  scatter_list = None
118
 
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
121
  u,
122
  scatter_list=scatter_list,
123
  src=state.worker_rank,
124
- group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
  # Clear u to free memory
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
132
- device_mesh=mesh,
133
  )
134
  p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
235
  adjusted_lr = lr * adjusted_ratio
236
  return adjusted_lr
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def init_state_and_assign_params(self, params, group):
239
  param_to_state = {}
240
  param_to_flops = {}
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
259
 
260
  round_robin = 0
261
  mesh = None
 
 
262
  for p in ordered_params:
263
  if mesh is None:
264
  mesh = p.device_mesh
265
- if mesh.ndim != 1:
266
- raise NotImplementedError(
267
- "Muon requires a 1D mesh for distributed training yet."
268
- )
269
  elif mesh != p.device_mesh:
270
  raise ValueError("All parameters must be on the same mesh.")
271
 
272
  param_to_state[id(p)] = _muon_state()
273
- param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
 
274
 
275
- round_robin = (round_robin + 1) % mesh.mesh.numel()
276
 
277
  return param_to_state, ordered_params
278
 
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
372
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
  )
374
 
375
- chunk_size = params[0].device_mesh.mesh.numel()
376
 
377
  # Wait grad update
378
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
+ process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
58
  g = p.grad
 
59
 
60
  if rank == state.worker_rank:
61
+ num_ranks = dist.get_world_size(group=state.process_group)
62
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
63
  else:
64
  gather_list = None
65
 
 
68
  g.to_local(),
69
  dst=state.worker_rank,
70
  gather_list=gather_list,
71
+ group=state.process_group,
72
  )
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
 
106
  @torch.no_grad()
107
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
  u = state.computed_u
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
112
+ num_ranks = dist.get_world_size(group=state.process_group)
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
116
+ scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
117
  else:
118
  scatter_list = None
119
 
 
122
  u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
+ group=state.process_group,
126
  )
127
  if rank == state.worker_rank:
128
  # Clear u to free memory
 
130
  u = DTensor.from_local(
131
  u,
132
  placements=p.placements,
133
+ device_mesh=p.device_mesh,
134
  )
135
  p.data.mul_(1 - lr * weight_decay)
136
  p.data.add_(u, alpha=-lr)
 
236
  adjusted_lr = lr * adjusted_ratio
237
  return adjusted_lr
238
 
239
+ def get_shard_mesh(self, p, rank):
240
+ """
241
+ Get the shard mesh for a parameter p on the given rank.
242
+ """
243
+ assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
244
+
245
+ if p.placements == (Shard(dim=0),):
246
+ # Case for FSDP
247
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
+ elif p.placements == (Replicate(), Shard(dim=0)):
249
+ # Case for HSDP
250
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
251
+ if rank in shard_mesh:
252
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
253
+ else:
254
+ raise ValueError(f"Unsupported placements ({p.placements}).")
255
+
256
  def init_state_and_assign_params(self, params, group):
257
  param_to_state = {}
258
  param_to_flops = {}
 
277
 
278
  round_robin = 0
279
  mesh = None
280
+ shard_mesh = None
281
+ process_group = None
282
  for p in ordered_params:
283
  if mesh is None:
284
  mesh = p.device_mesh
285
+ shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
286
  elif mesh != p.device_mesh:
287
  raise ValueError("All parameters must be on the same mesh.")
288
 
289
  param_to_state[id(p)] = _muon_state()
290
+ param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
291
+ param_to_state[id(p)].process_group = process_group
292
 
293
+ round_robin = (round_robin + 1) % len(shard_mesh)
294
 
295
  return param_to_state, ordered_params
296
 
 
390
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
  )
392
 
393
+ chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
 
build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
 
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_1f13dae_dirty
3
- ops = torch.ops._optimizer_1f13dae_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_1f13dae_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_2dc97a1_dirty
3
+ ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_2dc97a1_dirty::{op_name}"
build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a082b5629efc4e9b8ce608713665d47904949b5d220dad350049bc806d58ecd7
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0449cd352f44c3e848d1f9c847b00bf576673b4fef2a954ec8bd8d2524b8353a
3
  size 1824256
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -50,15 +50,16 @@ class _muon_state:
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
 
54
 
55
  @torch.no_grad()
56
  def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
- mesh = g.device_mesh
59
 
60
  if rank == state.worker_rank:
61
- gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
 
62
  else:
63
  gather_list = None
64
 
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
67
  g.to_local(),
68
  dst=state.worker_rank,
69
  gather_list=gather_list,
70
- group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
  if state.gathered_grad is not None:
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
105
  @torch.no_grad()
106
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
- mesh = p.device_mesh
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
 
112
  if state.compute_event is None:
113
  raise RuntimeError("Compute event must be set before scatter.")
114
  comm_stream.wait_event(state.compute_event)
115
- scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
116
  else:
117
  scatter_list = None
118
 
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
121
  u,
122
  scatter_list=scatter_list,
123
  src=state.worker_rank,
124
- group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
  # Clear u to free memory
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
132
- device_mesh=mesh,
133
  )
134
  p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
235
  adjusted_lr = lr * adjusted_ratio
236
  return adjusted_lr
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def init_state_and_assign_params(self, params, group):
239
  param_to_state = {}
240
  param_to_flops = {}
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
259
 
260
  round_robin = 0
261
  mesh = None
 
 
262
  for p in ordered_params:
263
  if mesh is None:
264
  mesh = p.device_mesh
265
- if mesh.ndim != 1:
266
- raise NotImplementedError(
267
- "Muon requires a 1D mesh for distributed training yet."
268
- )
269
  elif mesh != p.device_mesh:
270
  raise ValueError("All parameters must be on the same mesh.")
271
 
272
  param_to_state[id(p)] = _muon_state()
273
- param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
 
274
 
275
- round_robin = (round_robin + 1) % mesh.mesh.numel()
276
 
277
  return param_to_state, ordered_params
278
 
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
372
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
  )
374
 
375
- chunk_size = params[0].device_mesh.mesh.numel()
376
 
377
  # Wait grad update
378
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
+ process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
58
  g = p.grad
 
59
 
60
  if rank == state.worker_rank:
61
+ num_ranks = dist.get_world_size(group=state.process_group)
62
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
63
  else:
64
  gather_list = None
65
 
 
68
  g.to_local(),
69
  dst=state.worker_rank,
70
  gather_list=gather_list,
71
+ group=state.process_group,
72
  )
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
 
106
  @torch.no_grad()
107
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
  u = state.computed_u
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
112
+ num_ranks = dist.get_world_size(group=state.process_group)
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
116
+ scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
117
  else:
118
  scatter_list = None
119
 
 
122
  u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
+ group=state.process_group,
126
  )
127
  if rank == state.worker_rank:
128
  # Clear u to free memory
 
130
  u = DTensor.from_local(
131
  u,
132
  placements=p.placements,
133
+ device_mesh=p.device_mesh,
134
  )
135
  p.data.mul_(1 - lr * weight_decay)
136
  p.data.add_(u, alpha=-lr)
 
236
  adjusted_lr = lr * adjusted_ratio
237
  return adjusted_lr
238
 
239
+ def get_shard_mesh(self, p, rank):
240
+ """
241
+ Get the shard mesh for a parameter p on the given rank.
242
+ """
243
+ assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
244
+
245
+ if p.placements == (Shard(dim=0),):
246
+ # Case for FSDP
247
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
+ elif p.placements == (Replicate(), Shard(dim=0)):
249
+ # Case for HSDP
250
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
251
+ if rank in shard_mesh:
252
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
253
+ else:
254
+ raise ValueError(f"Unsupported placements ({p.placements}).")
255
+
256
  def init_state_and_assign_params(self, params, group):
257
  param_to_state = {}
258
  param_to_flops = {}
 
277
 
278
  round_robin = 0
279
  mesh = None
280
+ shard_mesh = None
281
+ process_group = None
282
  for p in ordered_params:
283
  if mesh is None:
284
  mesh = p.device_mesh
285
+ shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
286
  elif mesh != p.device_mesh:
287
  raise ValueError("All parameters must be on the same mesh.")
288
 
289
  param_to_state[id(p)] = _muon_state()
290
+ param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
291
+ param_to_state[id(p)].process_group = process_group
292
 
293
+ round_robin = (round_robin + 1) % len(shard_mesh)
294
 
295
  return param_to_state, ordered_params
296
 
 
390
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
  )
392
 
393
+ chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
 
build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
 
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_1f13dae_dirty
3
- ops = torch.ops._optimizer_1f13dae_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_1f13dae_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_2dc97a1_dirty
3
+ ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_2dc97a1_dirty::{op_name}"
build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7d2e65e315cd82d0b6fc2043ff37ee2d1223d6bd293ef552d658db5bf4de0a45
3
  size 1883352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e6bab72b965f42d466cd74bbda49851549f2810278e642cef8738e40de4fdc5
3
  size 1883352
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -50,15 +50,16 @@ class _muon_state:
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
 
54
 
55
  @torch.no_grad()
56
  def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
- mesh = g.device_mesh
59
 
60
  if rank == state.worker_rank:
61
- gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
 
62
  else:
63
  gather_list = None
64
 
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
67
  g.to_local(),
68
  dst=state.worker_rank,
69
  gather_list=gather_list,
70
- group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
  if state.gathered_grad is not None:
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
105
  @torch.no_grad()
106
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
- mesh = p.device_mesh
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
 
112
  if state.compute_event is None:
113
  raise RuntimeError("Compute event must be set before scatter.")
114
  comm_stream.wait_event(state.compute_event)
115
- scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
116
  else:
117
  scatter_list = None
118
 
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
121
  u,
122
  scatter_list=scatter_list,
123
  src=state.worker_rank,
124
- group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
  # Clear u to free memory
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
132
- device_mesh=mesh,
133
  )
134
  p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
235
  adjusted_lr = lr * adjusted_ratio
236
  return adjusted_lr
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def init_state_and_assign_params(self, params, group):
239
  param_to_state = {}
240
  param_to_flops = {}
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
259
 
260
  round_robin = 0
261
  mesh = None
 
 
262
  for p in ordered_params:
263
  if mesh is None:
264
  mesh = p.device_mesh
265
- if mesh.ndim != 1:
266
- raise NotImplementedError(
267
- "Muon requires a 1D mesh for distributed training yet."
268
- )
269
  elif mesh != p.device_mesh:
270
  raise ValueError("All parameters must be on the same mesh.")
271
 
272
  param_to_state[id(p)] = _muon_state()
273
- param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
 
274
 
275
- round_robin = (round_robin + 1) % mesh.mesh.numel()
276
 
277
  return param_to_state, ordered_params
278
 
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
372
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
  )
374
 
375
- chunk_size = params[0].device_mesh.mesh.numel()
376
 
377
  # Wait grad update
378
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
+ process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
58
  g = p.grad
 
59
 
60
  if rank == state.worker_rank:
61
+ num_ranks = dist.get_world_size(group=state.process_group)
62
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
63
  else:
64
  gather_list = None
65
 
 
68
  g.to_local(),
69
  dst=state.worker_rank,
70
  gather_list=gather_list,
71
+ group=state.process_group,
72
  )
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
 
106
  @torch.no_grad()
107
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
  u = state.computed_u
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
112
+ num_ranks = dist.get_world_size(group=state.process_group)
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
116
+ scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
117
  else:
118
  scatter_list = None
119
 
 
122
  u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
+ group=state.process_group,
126
  )
127
  if rank == state.worker_rank:
128
  # Clear u to free memory
 
130
  u = DTensor.from_local(
131
  u,
132
  placements=p.placements,
133
+ device_mesh=p.device_mesh,
134
  )
135
  p.data.mul_(1 - lr * weight_decay)
136
  p.data.add_(u, alpha=-lr)
 
236
  adjusted_lr = lr * adjusted_ratio
237
  return adjusted_lr
238
 
239
+ def get_shard_mesh(self, p, rank):
240
+ """
241
+ Get the shard mesh for a parameter p on the given rank.
242
+ """
243
+ assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
244
+
245
+ if p.placements == (Shard(dim=0),):
246
+ # Case for FSDP
247
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
+ elif p.placements == (Replicate(), Shard(dim=0)):
249
+ # Case for HSDP
250
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
251
+ if rank in shard_mesh:
252
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
253
+ else:
254
+ raise ValueError(f"Unsupported placements ({p.placements}).")
255
+
256
  def init_state_and_assign_params(self, params, group):
257
  param_to_state = {}
258
  param_to_flops = {}
 
277
 
278
  round_robin = 0
279
  mesh = None
280
+ shard_mesh = None
281
+ process_group = None
282
  for p in ordered_params:
283
  if mesh is None:
284
  mesh = p.device_mesh
285
+ shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
286
  elif mesh != p.device_mesh:
287
  raise ValueError("All parameters must be on the same mesh.")
288
 
289
  param_to_state[id(p)] = _muon_state()
290
+ param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
291
+ param_to_state[id(p)].process_group = process_group
292
 
293
+ round_robin = (round_robin + 1) % len(shard_mesh)
294
 
295
  return param_to_state, ordered_params
296
 
 
390
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
  )
392
 
393
+ chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
 
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
 
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_1f13dae_dirty
3
- ops = torch.ops._optimizer_1f13dae_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_1f13dae_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_2dc97a1_dirty
3
+ ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_2dc97a1_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3d9ee2420e8528032369c476152a1960d123034a83e2c43f38a7fb2d1423aa23
3
  size 1749840
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdcf9e3d8bf13aa01bf1ae7a94a12dd05c50702a24b57e4cfcc2e54ca5ed21c3
3
  size 1749840
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -50,15 +50,16 @@ class _muon_state:
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
 
54
 
55
  @torch.no_grad()
56
  def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
- mesh = g.device_mesh
59
 
60
  if rank == state.worker_rank:
61
- gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
 
62
  else:
63
  gather_list = None
64
 
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
67
  g.to_local(),
68
  dst=state.worker_rank,
69
  gather_list=gather_list,
70
- group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
  if state.gathered_grad is not None:
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
105
  @torch.no_grad()
106
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
- mesh = p.device_mesh
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
 
112
  if state.compute_event is None:
113
  raise RuntimeError("Compute event must be set before scatter.")
114
  comm_stream.wait_event(state.compute_event)
115
- scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
116
  else:
117
  scatter_list = None
118
 
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
121
  u,
122
  scatter_list=scatter_list,
123
  src=state.worker_rank,
124
- group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
  # Clear u to free memory
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
132
- device_mesh=mesh,
133
  )
134
  p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
235
  adjusted_lr = lr * adjusted_ratio
236
  return adjusted_lr
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def init_state_and_assign_params(self, params, group):
239
  param_to_state = {}
240
  param_to_flops = {}
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
259
 
260
  round_robin = 0
261
  mesh = None
 
 
262
  for p in ordered_params:
263
  if mesh is None:
264
  mesh = p.device_mesh
265
- if mesh.ndim != 1:
266
- raise NotImplementedError(
267
- "Muon requires a 1D mesh for distributed training yet."
268
- )
269
  elif mesh != p.device_mesh:
270
  raise ValueError("All parameters must be on the same mesh.")
271
 
272
  param_to_state[id(p)] = _muon_state()
273
- param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
 
274
 
275
- round_robin = (round_robin + 1) % mesh.mesh.numel()
276
 
277
  return param_to_state, ordered_params
278
 
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
372
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
  )
374
 
375
- chunk_size = params[0].device_mesh.mesh.numel()
376
 
377
  # Wait grad update
378
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
+ process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
58
  g = p.grad
 
59
 
60
  if rank == state.worker_rank:
61
+ num_ranks = dist.get_world_size(group=state.process_group)
62
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
63
  else:
64
  gather_list = None
65
 
 
68
  g.to_local(),
69
  dst=state.worker_rank,
70
  gather_list=gather_list,
71
+ group=state.process_group,
72
  )
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
 
106
  @torch.no_grad()
107
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
  u = state.computed_u
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
112
+ num_ranks = dist.get_world_size(group=state.process_group)
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
116
+ scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
117
  else:
118
  scatter_list = None
119
 
 
122
  u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
+ group=state.process_group,
126
  )
127
  if rank == state.worker_rank:
128
  # Clear u to free memory
 
130
  u = DTensor.from_local(
131
  u,
132
  placements=p.placements,
133
+ device_mesh=p.device_mesh,
134
  )
135
  p.data.mul_(1 - lr * weight_decay)
136
  p.data.add_(u, alpha=-lr)
 
236
  adjusted_lr = lr * adjusted_ratio
237
  return adjusted_lr
238
 
239
+ def get_shard_mesh(self, p, rank):
240
+ """
241
+ Get the shard mesh for a parameter p on the given rank.
242
+ """
243
+ assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
244
+
245
+ if p.placements == (Shard(dim=0),):
246
+ # Case for FSDP
247
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
+ elif p.placements == (Replicate(), Shard(dim=0)):
249
+ # Case for HSDP
250
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
251
+ if rank in shard_mesh:
252
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
253
+ else:
254
+ raise ValueError(f"Unsupported placements ({p.placements}).")
255
+
256
  def init_state_and_assign_params(self, params, group):
257
  param_to_state = {}
258
  param_to_flops = {}
 
277
 
278
  round_robin = 0
279
  mesh = None
280
+ shard_mesh = None
281
+ process_group = None
282
  for p in ordered_params:
283
  if mesh is None:
284
  mesh = p.device_mesh
285
+ shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
286
  elif mesh != p.device_mesh:
287
  raise ValueError("All parameters must be on the same mesh.")
288
 
289
  param_to_state[id(p)] = _muon_state()
290
+ param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
291
+ param_to_state[id(p)].process_group = process_group
292
 
293
+ round_robin = (round_robin + 1) % len(shard_mesh)
294
 
295
  return param_to_state, ordered_params
296
 
 
390
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
  )
392
 
393
+ chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
 
build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
 
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_1f13dae_dirty
3
- ops = torch.ops._optimizer_1f13dae_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_1f13dae_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_2dc97a1_dirty
3
+ ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_2dc97a1_dirty::{op_name}"
build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:96c7e281f9634e3b252f720f4fea4f61490f2f1a1ef1280a3e259decb41c846f
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a423eb4ab3a31c53a3326c71e34fa59fc661f8d432701e41a7de900a9c23e37c
3
  size 1824256
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -50,15 +50,16 @@ class _muon_state:
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
 
54
 
55
  @torch.no_grad()
56
  def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
- mesh = g.device_mesh
59
 
60
  if rank == state.worker_rank:
61
- gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
 
62
  else:
63
  gather_list = None
64
 
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
67
  g.to_local(),
68
  dst=state.worker_rank,
69
  gather_list=gather_list,
70
- group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
  if state.gathered_grad is not None:
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
105
  @torch.no_grad()
106
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
- mesh = p.device_mesh
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
 
112
  if state.compute_event is None:
113
  raise RuntimeError("Compute event must be set before scatter.")
114
  comm_stream.wait_event(state.compute_event)
115
- scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
116
  else:
117
  scatter_list = None
118
 
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
121
  u,
122
  scatter_list=scatter_list,
123
  src=state.worker_rank,
124
- group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
  # Clear u to free memory
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
132
- device_mesh=mesh,
133
  )
134
  p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
235
  adjusted_lr = lr * adjusted_ratio
236
  return adjusted_lr
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def init_state_and_assign_params(self, params, group):
239
  param_to_state = {}
240
  param_to_flops = {}
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
259
 
260
  round_robin = 0
261
  mesh = None
 
 
262
  for p in ordered_params:
263
  if mesh is None:
264
  mesh = p.device_mesh
265
- if mesh.ndim != 1:
266
- raise NotImplementedError(
267
- "Muon requires a 1D mesh for distributed training yet."
268
- )
269
  elif mesh != p.device_mesh:
270
  raise ValueError("All parameters must be on the same mesh.")
271
 
272
  param_to_state[id(p)] = _muon_state()
273
- param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
 
274
 
275
- round_robin = (round_robin + 1) % mesh.mesh.numel()
276
 
277
  return param_to_state, ordered_params
278
 
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
372
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
  )
374
 
375
- chunk_size = params[0].device_mesh.mesh.numel()
376
 
377
  # Wait grad update
378
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
+ process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
58
  g = p.grad
 
59
 
60
  if rank == state.worker_rank:
61
+ num_ranks = dist.get_world_size(group=state.process_group)
62
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
63
  else:
64
  gather_list = None
65
 
 
68
  g.to_local(),
69
  dst=state.worker_rank,
70
  gather_list=gather_list,
71
+ group=state.process_group,
72
  )
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
 
106
  @torch.no_grad()
107
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
  u = state.computed_u
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
112
+ num_ranks = dist.get_world_size(group=state.process_group)
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
116
+ scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
117
  else:
118
  scatter_list = None
119
 
 
122
  u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
+ group=state.process_group,
126
  )
127
  if rank == state.worker_rank:
128
  # Clear u to free memory
 
130
  u = DTensor.from_local(
131
  u,
132
  placements=p.placements,
133
+ device_mesh=p.device_mesh,
134
  )
135
  p.data.mul_(1 - lr * weight_decay)
136
  p.data.add_(u, alpha=-lr)
 
236
  adjusted_lr = lr * adjusted_ratio
237
  return adjusted_lr
238
 
239
+ def get_shard_mesh(self, p, rank):
240
+ """
241
+ Get the shard mesh for a parameter p on the given rank.
242
+ """
243
+ assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
244
+
245
+ if p.placements == (Shard(dim=0),):
246
+ # Case for FSDP
247
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
+ elif p.placements == (Replicate(), Shard(dim=0)):
249
+ # Case for HSDP
250
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
251
+ if rank in shard_mesh:
252
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
253
+ else:
254
+ raise ValueError(f"Unsupported placements ({p.placements}).")
255
+
256
  def init_state_and_assign_params(self, params, group):
257
  param_to_state = {}
258
  param_to_flops = {}
 
277
 
278
  round_robin = 0
279
  mesh = None
280
+ shard_mesh = None
281
+ process_group = None
282
  for p in ordered_params:
283
  if mesh is None:
284
  mesh = p.device_mesh
285
+ shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
286
  elif mesh != p.device_mesh:
287
  raise ValueError("All parameters must be on the same mesh.")
288
 
289
  param_to_state[id(p)] = _muon_state()
290
+ param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
291
+ param_to_state[id(p)].process_group = process_group
292
 
293
+ round_robin = (round_robin + 1) % len(shard_mesh)
294
 
295
  return param_to_state, ordered_params
296
 
 
390
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
  )
392
 
393
+ chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
 
build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
 
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_1f13dae_dirty
3
- ops = torch.ops._optimizer_1f13dae_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_1f13dae_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_2dc97a1_dirty
3
+ ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_2dc97a1_dirty::{op_name}"
build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:89fea7bfad71c806bc10bf2dc6aa66a6e154c09fc418498b1cab7f48a83432d4
3
  size 1883352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86d98863cc7ef0b271808b0ef7b1082603cfb5a76986481df37431527aaaf27b
3
  size 1883352
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -50,15 +50,16 @@ class _muon_state:
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
 
54
 
55
  @torch.no_grad()
56
  def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
- mesh = g.device_mesh
59
 
60
  if rank == state.worker_rank:
61
- gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
 
62
  else:
63
  gather_list = None
64
 
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
67
  g.to_local(),
68
  dst=state.worker_rank,
69
  gather_list=gather_list,
70
- group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
  if state.gathered_grad is not None:
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
105
  @torch.no_grad()
106
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
- mesh = p.device_mesh
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
 
112
  if state.compute_event is None:
113
  raise RuntimeError("Compute event must be set before scatter.")
114
  comm_stream.wait_event(state.compute_event)
115
- scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
116
  else:
117
  scatter_list = None
118
 
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
121
  u,
122
  scatter_list=scatter_list,
123
  src=state.worker_rank,
124
- group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
  # Clear u to free memory
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
132
- device_mesh=mesh,
133
  )
134
  p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
235
  adjusted_lr = lr * adjusted_ratio
236
  return adjusted_lr
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def init_state_and_assign_params(self, params, group):
239
  param_to_state = {}
240
  param_to_flops = {}
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
259
 
260
  round_robin = 0
261
  mesh = None
 
 
262
  for p in ordered_params:
263
  if mesh is None:
264
  mesh = p.device_mesh
265
- if mesh.ndim != 1:
266
- raise NotImplementedError(
267
- "Muon requires a 1D mesh for distributed training yet."
268
- )
269
  elif mesh != p.device_mesh:
270
  raise ValueError("All parameters must be on the same mesh.")
271
 
272
  param_to_state[id(p)] = _muon_state()
273
- param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
 
274
 
275
- round_robin = (round_robin + 1) % mesh.mesh.numel()
276
 
277
  return param_to_state, ordered_params
278
 
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
372
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
  )
374
 
375
- chunk_size = params[0].device_mesh.mesh.numel()
376
 
377
  # Wait grad update
378
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
+ process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
58
  g = p.grad
 
59
 
60
  if rank == state.worker_rank:
61
+ num_ranks = dist.get_world_size(group=state.process_group)
62
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
63
  else:
64
  gather_list = None
65
 
 
68
  g.to_local(),
69
  dst=state.worker_rank,
70
  gather_list=gather_list,
71
+ group=state.process_group,
72
  )
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
 
106
  @torch.no_grad()
107
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
  u = state.computed_u
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
112
+ num_ranks = dist.get_world_size(group=state.process_group)
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
116
+ scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
117
  else:
118
  scatter_list = None
119
 
 
122
  u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
+ group=state.process_group,
126
  )
127
  if rank == state.worker_rank:
128
  # Clear u to free memory
 
130
  u = DTensor.from_local(
131
  u,
132
  placements=p.placements,
133
+ device_mesh=p.device_mesh,
134
  )
135
  p.data.mul_(1 - lr * weight_decay)
136
  p.data.add_(u, alpha=-lr)
 
236
  adjusted_lr = lr * adjusted_ratio
237
  return adjusted_lr
238
 
239
+ def get_shard_mesh(self, p, rank):
240
+ """
241
+ Get the shard mesh for a parameter p on the given rank.
242
+ """
243
+ assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
244
+
245
+ if p.placements == (Shard(dim=0),):
246
+ # Case for FSDP
247
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
+ elif p.placements == (Replicate(), Shard(dim=0)):
249
+ # Case for HSDP
250
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
251
+ if rank in shard_mesh:
252
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
253
+ else:
254
+ raise ValueError(f"Unsupported placements ({p.placements}).")
255
+
256
  def init_state_and_assign_params(self, params, group):
257
  param_to_state = {}
258
  param_to_flops = {}
 
277
 
278
  round_robin = 0
279
  mesh = None
280
+ shard_mesh = None
281
+ process_group = None
282
  for p in ordered_params:
283
  if mesh is None:
284
  mesh = p.device_mesh
285
+ shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
286
  elif mesh != p.device_mesh:
287
  raise ValueError("All parameters must be on the same mesh.")
288
 
289
  param_to_state[id(p)] = _muon_state()
290
+ param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
291
+ param_to_state[id(p)].process_group = process_group
292
 
293
+ round_robin = (round_robin + 1) % len(shard_mesh)
294
 
295
  return param_to_state, ordered_params
296
 
 
390
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
  )
392
 
393
+ chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
 
build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
 
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_1f13dae_dirty
3
- ops = torch.ops._optimizer_1f13dae_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_1f13dae_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_2dc97a1_dirty
3
+ ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_2dc97a1_dirty::{op_name}"
build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:046a45fae81c2b7d79ff2237a1d26277f4883ef8a8b87a3980bf06d1182711b1
3
  size 1883352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8daaad69e6958850f848fab60c9acb938c3a5e54e3ec34a1bec03a3d32653cb
3
  size 1883352
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -50,15 +50,16 @@ class _muon_state:
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
 
54
 
55
  @torch.no_grad()
56
  def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
- mesh = g.device_mesh
59
 
60
  if rank == state.worker_rank:
61
- gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
 
62
  else:
63
  gather_list = None
64
 
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
67
  g.to_local(),
68
  dst=state.worker_rank,
69
  gather_list=gather_list,
70
- group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
  if state.gathered_grad is not None:
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
105
  @torch.no_grad()
106
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
- mesh = p.device_mesh
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
 
112
  if state.compute_event is None:
113
  raise RuntimeError("Compute event must be set before scatter.")
114
  comm_stream.wait_event(state.compute_event)
115
- scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
116
  else:
117
  scatter_list = None
118
 
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
121
  u,
122
  scatter_list=scatter_list,
123
  src=state.worker_rank,
124
- group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
  # Clear u to free memory
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
132
- device_mesh=mesh,
133
  )
134
  p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
235
  adjusted_lr = lr * adjusted_ratio
236
  return adjusted_lr
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def init_state_and_assign_params(self, params, group):
239
  param_to_state = {}
240
  param_to_flops = {}
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
259
 
260
  round_robin = 0
261
  mesh = None
 
 
262
  for p in ordered_params:
263
  if mesh is None:
264
  mesh = p.device_mesh
265
- if mesh.ndim != 1:
266
- raise NotImplementedError(
267
- "Muon requires a 1D mesh for distributed training yet."
268
- )
269
  elif mesh != p.device_mesh:
270
  raise ValueError("All parameters must be on the same mesh.")
271
 
272
  param_to_state[id(p)] = _muon_state()
273
- param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
 
274
 
275
- round_robin = (round_robin + 1) % mesh.mesh.numel()
276
 
277
  return param_to_state, ordered_params
278
 
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
372
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
  )
374
 
375
- chunk_size = params[0].device_mesh.mesh.numel()
376
 
377
  # Wait grad update
378
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
+ process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
58
  g = p.grad
 
59
 
60
  if rank == state.worker_rank:
61
+ num_ranks = dist.get_world_size(group=state.process_group)
62
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
63
  else:
64
  gather_list = None
65
 
 
68
  g.to_local(),
69
  dst=state.worker_rank,
70
  gather_list=gather_list,
71
+ group=state.process_group,
72
  )
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
 
106
  @torch.no_grad()
107
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
  u = state.computed_u
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
112
+ num_ranks = dist.get_world_size(group=state.process_group)
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
116
+ scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
117
  else:
118
  scatter_list = None
119
 
 
122
  u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
+ group=state.process_group,
126
  )
127
  if rank == state.worker_rank:
128
  # Clear u to free memory
 
130
  u = DTensor.from_local(
131
  u,
132
  placements=p.placements,
133
+ device_mesh=p.device_mesh,
134
  )
135
  p.data.mul_(1 - lr * weight_decay)
136
  p.data.add_(u, alpha=-lr)
 
236
  adjusted_lr = lr * adjusted_ratio
237
  return adjusted_lr
238
 
239
+ def get_shard_mesh(self, p, rank):
240
+ """
241
+ Get the shard mesh for a parameter p on the given rank.
242
+ """
243
+ assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
244
+
245
+ if p.placements == (Shard(dim=0),):
246
+ # Case for FSDP
247
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
+ elif p.placements == (Replicate(), Shard(dim=0)):
249
+ # Case for HSDP
250
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
251
+ if rank in shard_mesh:
252
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
253
+ else:
254
+ raise ValueError(f"Unsupported placements ({p.placements}).")
255
+
256
  def init_state_and_assign_params(self, params, group):
257
  param_to_state = {}
258
  param_to_flops = {}
 
277
 
278
  round_robin = 0
279
  mesh = None
280
+ shard_mesh = None
281
+ process_group = None
282
  for p in ordered_params:
283
  if mesh is None:
284
  mesh = p.device_mesh
285
+ shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
286
  elif mesh != p.device_mesh:
287
  raise ValueError("All parameters must be on the same mesh.")
288
 
289
  param_to_state[id(p)] = _muon_state()
290
+ param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
291
+ param_to_state[id(p)].process_group = process_group
292
 
293
+ round_robin = (round_robin + 1) % len(shard_mesh)
294
 
295
  return param_to_state, ordered_params
296
 
 
390
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
  )
392
 
393
+ chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
 
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
 
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_1f13dae_dirty
3
- ops = torch.ops._optimizer_1f13dae_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_1f13dae_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_2dc97a1_dirty
3
+ ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_2dc97a1_dirty::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0805952950efdbe79c378ca84ae62b77d2d11cd2ba680c8ffccfd79301489ac5
3
  size 1750000
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76910ba81e2c95c83207118725c4379db636346c4ccf05010e2ee00c41dff1ce
3
  size 1750000
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -50,15 +50,16 @@ class _muon_state:
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
 
54
 
55
  @torch.no_grad()
56
  def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
- mesh = g.device_mesh
59
 
60
  if rank == state.worker_rank:
61
- gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
 
62
  else:
63
  gather_list = None
64
 
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
67
  g.to_local(),
68
  dst=state.worker_rank,
69
  gather_list=gather_list,
70
- group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
  if state.gathered_grad is not None:
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
105
  @torch.no_grad()
106
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
- mesh = p.device_mesh
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
 
112
  if state.compute_event is None:
113
  raise RuntimeError("Compute event must be set before scatter.")
114
  comm_stream.wait_event(state.compute_event)
115
- scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
116
  else:
117
  scatter_list = None
118
 
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
121
  u,
122
  scatter_list=scatter_list,
123
  src=state.worker_rank,
124
- group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
  # Clear u to free memory
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
132
- device_mesh=mesh,
133
  )
134
  p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
235
  adjusted_lr = lr * adjusted_ratio
236
  return adjusted_lr
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def init_state_and_assign_params(self, params, group):
239
  param_to_state = {}
240
  param_to_flops = {}
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
259
 
260
  round_robin = 0
261
  mesh = None
 
 
262
  for p in ordered_params:
263
  if mesh is None:
264
  mesh = p.device_mesh
265
- if mesh.ndim != 1:
266
- raise NotImplementedError(
267
- "Muon requires a 1D mesh for distributed training yet."
268
- )
269
  elif mesh != p.device_mesh:
270
  raise ValueError("All parameters must be on the same mesh.")
271
 
272
  param_to_state[id(p)] = _muon_state()
273
- param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
 
274
 
275
- round_robin = (round_robin + 1) % mesh.mesh.numel()
276
 
277
  return param_to_state, ordered_params
278
 
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
372
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
  )
374
 
375
- chunk_size = params[0].device_mesh.mesh.numel()
376
 
377
  # Wait grad update
378
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
+ process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
58
  g = p.grad
 
59
 
60
  if rank == state.worker_rank:
61
+ num_ranks = dist.get_world_size(group=state.process_group)
62
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
63
  else:
64
  gather_list = None
65
 
 
68
  g.to_local(),
69
  dst=state.worker_rank,
70
  gather_list=gather_list,
71
+ group=state.process_group,
72
  )
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
 
106
  @torch.no_grad()
107
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
  u = state.computed_u
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
112
+ num_ranks = dist.get_world_size(group=state.process_group)
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
116
+ scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
117
  else:
118
  scatter_list = None
119
 
 
122
  u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
+ group=state.process_group,
126
  )
127
  if rank == state.worker_rank:
128
  # Clear u to free memory
 
130
  u = DTensor.from_local(
131
  u,
132
  placements=p.placements,
133
+ device_mesh=p.device_mesh,
134
  )
135
  p.data.mul_(1 - lr * weight_decay)
136
  p.data.add_(u, alpha=-lr)
 
236
  adjusted_lr = lr * adjusted_ratio
237
  return adjusted_lr
238
 
239
+ def get_shard_mesh(self, p, rank):
240
+ """
241
+ Get the shard mesh for a parameter p on the given rank.
242
+ """
243
+ assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
244
+
245
+ if p.placements == (Shard(dim=0),):
246
+ # Case for FSDP
247
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
+ elif p.placements == (Replicate(), Shard(dim=0)):
249
+ # Case for HSDP
250
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
251
+ if rank in shard_mesh:
252
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
253
+ else:
254
+ raise ValueError(f"Unsupported placements ({p.placements}).")
255
+
256
  def init_state_and_assign_params(self, params, group):
257
  param_to_state = {}
258
  param_to_flops = {}
 
277
 
278
  round_robin = 0
279
  mesh = None
280
+ shard_mesh = None
281
+ process_group = None
282
  for p in ordered_params:
283
  if mesh is None:
284
  mesh = p.device_mesh
285
+ shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
286
  elif mesh != p.device_mesh:
287
  raise ValueError("All parameters must be on the same mesh.")
288
 
289
  param_to_state[id(p)] = _muon_state()
290
+ param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
291
+ param_to_state[id(p)].process_group = process_group
292
 
293
+ round_robin = (round_robin + 1) % len(shard_mesh)
294
 
295
  return param_to_state, ordered_params
296
 
 
390
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
  )
392
 
393
+ chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
 
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
 
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_1f13dae_dirty
3
- ops = torch.ops._optimizer_1f13dae_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_1f13dae_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_2dc97a1_dirty
3
+ ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_2dc97a1_dirty::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:af91f4eec9fc14d66f3db4e120d4913a0e62102c76b9b8cd9c25d8af427be290
3
  size 1750088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd0a35a6f846a075a8f4561cfc66ef17c6358dd4a0062e63057b02625d9d6af7
3
  size 1750088
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -50,15 +50,16 @@ class _muon_state:
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
 
54
 
55
  @torch.no_grad()
56
  def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
- mesh = g.device_mesh
59
 
60
  if rank == state.worker_rank:
61
- gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
 
62
  else:
63
  gather_list = None
64
 
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
67
  g.to_local(),
68
  dst=state.worker_rank,
69
  gather_list=gather_list,
70
- group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
  if state.gathered_grad is not None:
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
105
  @torch.no_grad()
106
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
- mesh = p.device_mesh
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
 
112
  if state.compute_event is None:
113
  raise RuntimeError("Compute event must be set before scatter.")
114
  comm_stream.wait_event(state.compute_event)
115
- scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
116
  else:
117
  scatter_list = None
118
 
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
121
  u,
122
  scatter_list=scatter_list,
123
  src=state.worker_rank,
124
- group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
  # Clear u to free memory
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
132
- device_mesh=mesh,
133
  )
134
  p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
235
  adjusted_lr = lr * adjusted_ratio
236
  return adjusted_lr
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def init_state_and_assign_params(self, params, group):
239
  param_to_state = {}
240
  param_to_flops = {}
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
259
 
260
  round_robin = 0
261
  mesh = None
 
 
262
  for p in ordered_params:
263
  if mesh is None:
264
  mesh = p.device_mesh
265
- if mesh.ndim != 1:
266
- raise NotImplementedError(
267
- "Muon requires a 1D mesh for distributed training yet."
268
- )
269
  elif mesh != p.device_mesh:
270
  raise ValueError("All parameters must be on the same mesh.")
271
 
272
  param_to_state[id(p)] = _muon_state()
273
- param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
 
274
 
275
- round_robin = (round_robin + 1) % mesh.mesh.numel()
276
 
277
  return param_to_state, ordered_params
278
 
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
372
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
  )
374
 
375
- chunk_size = params[0].device_mesh.mesh.numel()
376
 
377
  # Wait grad update
378
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
+ process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
58
  g = p.grad
 
59
 
60
  if rank == state.worker_rank:
61
+ num_ranks = dist.get_world_size(group=state.process_group)
62
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
63
  else:
64
  gather_list = None
65
 
 
68
  g.to_local(),
69
  dst=state.worker_rank,
70
  gather_list=gather_list,
71
+ group=state.process_group,
72
  )
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
 
106
  @torch.no_grad()
107
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
  u = state.computed_u
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
112
+ num_ranks = dist.get_world_size(group=state.process_group)
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
116
+ scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
117
  else:
118
  scatter_list = None
119
 
 
122
  u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
+ group=state.process_group,
126
  )
127
  if rank == state.worker_rank:
128
  # Clear u to free memory
 
130
  u = DTensor.from_local(
131
  u,
132
  placements=p.placements,
133
+ device_mesh=p.device_mesh,
134
  )
135
  p.data.mul_(1 - lr * weight_decay)
136
  p.data.add_(u, alpha=-lr)
 
236
  adjusted_lr = lr * adjusted_ratio
237
  return adjusted_lr
238
 
239
+ def get_shard_mesh(self, p, rank):
240
+ """
241
+ Get the shard mesh for a parameter p on the given rank.
242
+ """
243
+ assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
244
+
245
+ if p.placements == (Shard(dim=0),):
246
+ # Case for FSDP
247
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
+ elif p.placements == (Replicate(), Shard(dim=0)):
249
+ # Case for HSDP
250
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
251
+ if rank in shard_mesh:
252
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
253
+ else:
254
+ raise ValueError(f"Unsupported placements ({p.placements}).")
255
+
256
  def init_state_and_assign_params(self, params, group):
257
  param_to_state = {}
258
  param_to_flops = {}
 
277
 
278
  round_robin = 0
279
  mesh = None
280
+ shard_mesh = None
281
+ process_group = None
282
  for p in ordered_params:
283
  if mesh is None:
284
  mesh = p.device_mesh
285
+ shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
286
  elif mesh != p.device_mesh:
287
  raise ValueError("All parameters must be on the same mesh.")
288
 
289
  param_to_state[id(p)] = _muon_state()
290
+ param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
291
+ param_to_state[id(p)].process_group = process_group
292
 
293
+ round_robin = (round_robin + 1) % len(shard_mesh)
294
 
295
  return param_to_state, ordered_params
296
 
 
390
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
  )
392
 
393
+ chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
torch-ext/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -50,15 +50,16 @@ class _muon_state:
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
 
54
 
55
  @torch.no_grad()
56
  def _gather(p, state, rank, comm_stream, none_grad):
57
  g = p.grad
58
- mesh = g.device_mesh
59
 
60
  if rank == state.worker_rank:
61
- gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
 
62
  else:
63
  gather_list = None
64
 
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
67
  g.to_local(),
68
  dst=state.worker_rank,
69
  gather_list=gather_list,
70
- group=mesh.get_group(),
71
  )
72
  if rank == state.worker_rank:
73
  if state.gathered_grad is not None:
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
105
  @torch.no_grad()
106
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
- mesh = p.device_mesh
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
 
112
  if state.compute_event is None:
113
  raise RuntimeError("Compute event must be set before scatter.")
114
  comm_stream.wait_event(state.compute_event)
115
- scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
116
  else:
117
  scatter_list = None
118
 
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
121
  u,
122
  scatter_list=scatter_list,
123
  src=state.worker_rank,
124
- group=mesh.get_group(),
125
  )
126
  if rank == state.worker_rank:
127
  # Clear u to free memory
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
129
  u = DTensor.from_local(
130
  u,
131
  placements=p.placements,
132
- device_mesh=mesh,
133
  )
134
  p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
235
  adjusted_lr = lr * adjusted_ratio
236
  return adjusted_lr
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def init_state_and_assign_params(self, params, group):
239
  param_to_state = {}
240
  param_to_flops = {}
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
259
 
260
  round_robin = 0
261
  mesh = None
 
 
262
  for p in ordered_params:
263
  if mesh is None:
264
  mesh = p.device_mesh
265
- if mesh.ndim != 1:
266
- raise NotImplementedError(
267
- "Muon requires a 1D mesh for distributed training yet."
268
- )
269
  elif mesh != p.device_mesh:
270
  raise ValueError("All parameters must be on the same mesh.")
271
 
272
  param_to_state[id(p)] = _muon_state()
273
- param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
 
274
 
275
- round_robin = (round_robin + 1) % mesh.mesh.numel()
276
 
277
  return param_to_state, ordered_params
278
 
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
372
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
  )
374
 
375
- chunk_size = params[0].device_mesh.mesh.numel()
376
 
377
  # Wait grad update
378
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
+ process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
58
  g = p.grad
 
59
 
60
  if rank == state.worker_rank:
61
+ num_ranks = dist.get_world_size(group=state.process_group)
62
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
63
  else:
64
  gather_list = None
65
 
 
68
  g.to_local(),
69
  dst=state.worker_rank,
70
  gather_list=gather_list,
71
+ group=state.process_group,
72
  )
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
 
106
  @torch.no_grad()
107
  def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
  u = state.computed_u
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
112
+ num_ranks = dist.get_world_size(group=state.process_group)
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
116
+ scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
117
  else:
118
  scatter_list = None
119
 
 
122
  u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
+ group=state.process_group,
126
  )
127
  if rank == state.worker_rank:
128
  # Clear u to free memory
 
130
  u = DTensor.from_local(
131
  u,
132
  placements=p.placements,
133
+ device_mesh=p.device_mesh,
134
  )
135
  p.data.mul_(1 - lr * weight_decay)
136
  p.data.add_(u, alpha=-lr)
 
236
  adjusted_lr = lr * adjusted_ratio
237
  return adjusted_lr
238
 
239
+ def get_shard_mesh(self, p, rank):
240
+ """
241
+ Get the shard mesh for a parameter p on the given rank.
242
+ """
243
+ assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
244
+
245
+ if p.placements == (Shard(dim=0),):
246
+ # Case for FSDP
247
+ return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
+ elif p.placements == (Replicate(), Shard(dim=0)):
249
+ # Case for HSDP
250
+ for i, shard_mesh in enumerate(p.device_mesh.mesh):
251
+ if rank in shard_mesh:
252
+ return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
253
+ else:
254
+ raise ValueError(f"Unsupported placements ({p.placements}).")
255
+
256
  def init_state_and_assign_params(self, params, group):
257
  param_to_state = {}
258
  param_to_flops = {}
 
277
 
278
  round_robin = 0
279
  mesh = None
280
+ shard_mesh = None
281
+ process_group = None
282
  for p in ordered_params:
283
  if mesh is None:
284
  mesh = p.device_mesh
285
+ shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
286
  elif mesh != p.device_mesh:
287
  raise ValueError("All parameters must be on the same mesh.")
288
 
289
  param_to_state[id(p)] = _muon_state()
290
+ param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
291
+ param_to_state[id(p)].process_group = process_group
292
 
293
+ round_robin = (round_robin + 1) % len(shard_mesh)
294
 
295
  return param_to_state, ordered_params
296
 
 
390
  p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
  )
392
 
393
+ chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())