| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from specforge.distributed import get_tp_group, shard_tensor |
|
|
|
|
| class VocabParallelEmbedding(nn.Module): |
|
|
| def __init__( |
| self, |
| num_embeddings: int, |
| embedding_dim: int, |
| padding_idx: Optional[int] = None, |
| max_norm: Optional[float] = None, |
| norm_type: float = 2.0, |
| scale_grad_by_freq: bool = False, |
| sparse: bool = False, |
| device=None, |
| dtype=None, |
| ): |
| factory_kwargs = {"device": device, "dtype": dtype} |
| super().__init__() |
|
|
| self.num_embeddings = num_embeddings |
| self.embedding_dim = embedding_dim |
| self.max_norm = max_norm |
| self.norm_type = norm_type |
| self.scale_grad_by_freq = scale_grad_by_freq |
| self.sparse = sparse |
|
|
| if padding_idx is not None: |
| if padding_idx > 0: |
| assert ( |
| padding_idx < self.num_embeddings |
| ), "Padding_idx must be within num_embeddings" |
| elif padding_idx < 0: |
| assert ( |
| padding_idx >= -self.num_embeddings |
| ), "Padding_idx must be within num_embeddings" |
| padding_idx = self.num_embeddings + padding_idx |
|
|
| |
| self.tp_group = get_tp_group() |
| self.tp_rank = dist.get_rank(self.tp_group) |
| self.tp_size = dist.get_world_size(self.tp_group) |
|
|
| |
| self.num_embeddings_per_shard = math.ceil(num_embeddings / self.tp_size) |
| self.padded_num_embeddings = ( |
| self.num_embeddings_per_shard * self.tp_size - self.num_embeddings |
| ) |
| self.vocab_start_index = self.tp_rank * self.num_embeddings_per_shard |
| self.vocab_end_index = min( |
| self.vocab_start_index + self.num_embeddings_per_shard, |
| self.num_embeddings, |
| ) |
|
|
| if ( |
| padding_idx is not None |
| and padding_idx >= self.vocab_start_index |
| and padding_idx < self.vocab_end_index |
| ): |
| self.padding_idx = padding_idx - self.vocab_start_index |
| else: |
| self.padding_idx = None |
|
|
| self.weight = nn.Parameter( |
| torch.empty( |
| (self.num_embeddings_per_shard, self.embedding_dim), **factory_kwargs |
| ), |
| requires_grad=True, |
| ) |
| self.reset_parameters() |
| self._register_load_state_dict_pre_hook(self.shard_state_dict) |
|
|
| def shard_state_dict(self, state_dict, *args): |
| if "weight" in state_dict: |
| value = state_dict["weight"] |
|
|
| |
| if value.shape[0] % self.tp_size != 0: |
| padding_size = self.tp_size - value.shape[0] % self.tp_size |
| value = F.pad(value, (0, 0, 0, padding_size)) |
| state_dict["weight"] = shard_tensor(value, self.tp_group, 0) |
|
|
| def reset_parameters(self) -> None: |
| torch.nn.init.normal_(self.weight) |
| self._fill_padding_idx_with_zero() |
|
|
| def _fill_padding_idx_with_zero(self) -> None: |
| if self.padding_idx is not None: |
| with torch.no_grad(): |
| self.weight[self.padding_idx].fill_(0) |
|
|
| def generate_mask(self, input_): |
| |
| mask = (input_ >= self.vocab_start_index) & (input_ < self.vocab_end_index) |
| return mask |
|
|
| def forward(self, input_): |
| if self.tp_size > 1: |
| |
| mask = self.generate_mask(input_) |
| masked_input = input_ - self.vocab_start_index |
| masked_input[~mask] = 0 |
| else: |
| masked_input = input_ |
|
|
| output_parallel = F.embedding( |
| masked_input, |
| self.weight, |
| padding_idx=self.padding_idx, |
| max_norm=self.max_norm, |
| norm_type=self.norm_type, |
| scale_grad_by_freq=self.scale_grad_by_freq, |
| sparse=self.sparse, |
| ) |
|
|
| |
| if self.tp_size > 1: |
| output_parallel[~mask] = 0 |
| |
| dist.all_reduce(output_parallel, op=dist.ReduceOp.SUM, group=self.tp_group) |
| output = output_parallel |
| else: |
| output = output_parallel |
| return output |
|
|