Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
import math
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from specforge.distributed import get_tp_group, shard_tensor
class VocabParallelEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
if padding_idx is not None:
if padding_idx > 0:
assert (
padding_idx < self.num_embeddings
), "Padding_idx must be within num_embeddings"
elif padding_idx < 0:
assert (
padding_idx >= -self.num_embeddings
), "Padding_idx must be within num_embeddings"
padding_idx = self.num_embeddings + padding_idx
# tp-realted
self.tp_group = get_tp_group()
self.tp_rank = dist.get_rank(self.tp_group)
self.tp_size = dist.get_world_size(self.tp_group)
# deal with the case where the embedding is not divisible by the TP size
self.num_embeddings_per_shard = math.ceil(num_embeddings / self.tp_size)
self.padded_num_embeddings = (
self.num_embeddings_per_shard * self.tp_size - self.num_embeddings
)
self.vocab_start_index = self.tp_rank * self.num_embeddings_per_shard
self.vocab_end_index = min(
self.vocab_start_index + self.num_embeddings_per_shard,
self.num_embeddings,
)
if (
padding_idx is not None
and padding_idx >= self.vocab_start_index
and padding_idx < self.vocab_end_index
):
self.padding_idx = padding_idx - self.vocab_start_index
else:
self.padding_idx = None
self.weight = nn.Parameter(
torch.empty(
(self.num_embeddings_per_shard, self.embedding_dim), **factory_kwargs
),
requires_grad=True,
)
self.reset_parameters()
self._register_load_state_dict_pre_hook(self.shard_state_dict)
def shard_state_dict(self, state_dict, *args):
if "weight" in state_dict:
value = state_dict["weight"]
# pad this value if it is not divisible by the TP size
if value.shape[0] % self.tp_size != 0:
padding_size = self.tp_size - value.shape[0] % self.tp_size
value = F.pad(value, (0, 0, 0, padding_size))
state_dict["weight"] = shard_tensor(value, self.tp_group, 0)
def reset_parameters(self) -> None:
torch.nn.init.normal_(self.weight)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def generate_mask(self, input_):
# generate the mask for the vocab which is only owned by the current rank
mask = (input_ >= self.vocab_start_index) & (input_ < self.vocab_end_index)
return mask
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
mask = self.generate_mask(input_)
masked_input = input_ - self.vocab_start_index
masked_input[~mask] = 0
else:
masked_input = input_
output_parallel = F.embedding(
masked_input,
self.weight,
padding_idx=self.padding_idx,
max_norm=self.max_norm,
norm_type=self.norm_type,
scale_grad_by_freq=self.scale_grad_by_freq,
sparse=self.sparse,
)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel[~mask] = 0
# Reduce across all the model parallel GPUs.
dist.all_reduce(output_parallel, op=dist.ReduceOp.SUM, group=self.tp_group)
output = output_parallel
else:
output = output_parallel
return output