|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import collections |
|
import io |
|
import pickle |
|
from typing import Any |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
|
|
|
|
def broadcast_object( |
|
obj: Any, |
|
src_rank: int, |
|
group: object = dist.group.WORLD, |
|
device: torch.device = torch.device("cpu"), |
|
) -> Any: |
|
r""" |
|
Broadcasts an object to the given group. |
|
|
|
It will be sending the object if called from the source rank and receiving |
|
the object otherwise. |
|
|
|
Arguments: |
|
obj: object to broadcast; only used if called on the source rank. |
|
src_rank (int): source rank. |
|
group (``ProcessGroup``, optional): group used for the broadcast |
|
(default: ``dist.group.WORLD``). |
|
device (``torch.device``, optional): device to send from or receive |
|
to (default: ``torch.device("cpu")``). |
|
|
|
Returns: |
|
The broadcasted object. |
|
""" |
|
if dist.get_rank() == src_rank: |
|
|
|
buffer = io.BytesIO() |
|
torch.save(obj, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL) |
|
data = bytearray(buffer.getbuffer()) |
|
length_tensor = torch.LongTensor([len(data)]).to(device) |
|
data_send_tensor = torch.ByteTensor(data).to(device) |
|
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) |
|
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) |
|
else: |
|
|
|
length_tensor = torch.LongTensor([0]).to(device) |
|
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) |
|
data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=device) |
|
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) |
|
buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) |
|
obj = torch.load(buffer, map_location=device, weights_only=False) |
|
return obj |
|
|
|
|
|
def _recursive_copy_to_device( |
|
value: Any, |
|
non_blocking: bool, |
|
device: torch.device, |
|
) -> Any: |
|
r""" |
|
Recursively searches lists, tuples, dicts and copies tensors to device if possible. |
|
|
|
Non-tensor values are passed as-is in the result. |
|
|
|
.. note: These are all copies, so if there are two objects that reference |
|
the same object, then after this call, there will be two different objects |
|
referenced on the device. |
|
""" |
|
if isinstance(value, torch.Tensor): |
|
return value.to(device, non_blocking=non_blocking) |
|
|
|
if isinstance(value, (list, tuple)): |
|
values = [_recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for val in value] |
|
return values if isinstance(value, list) else tuple(values) |
|
|
|
if isinstance(value, collections.abc.Mapping): |
|
return { |
|
key: _recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for key, val in value.items() |
|
} |
|
|
|
return value |
|
|