import torch import torch.nn as nn class PartiallyFrozenEmbedding(nn.Module): """Split an existing `nn.Embedding` module that splits the embedding into: - A frozen embedding for indices [0..freeze_until_idx]. - A trainable embedding for indices [freeze_until_idx+1..vocab_size-1]. This should work with both Zero-2 and Zero-3 seamlessly """ def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int): """ :param original_embedding: An instance of nn.Embedding (the original embedding layer). :param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen. """ super().__init__() self.freeze_until_idx = freeze_until_idx self.original_vocab_size = original_embedding.num_embeddings self.embedding_dim = original_embedding.embedding_dim # Split the original embedding into frozen and trainable parts self.embedding_frozen = nn.Embedding( freeze_until_idx, self.embedding_dim, dtype=original_embedding.weight.dtype, device=original_embedding.weight.device, ) self.embedding_trainable = nn.Embedding( self.original_vocab_size - freeze_until_idx, self.embedding_dim, dtype=original_embedding.weight.dtype, device=original_embedding.weight.device, ) # Copy weights from the original embedding into the frozen and trainable parts with torch.no_grad(): self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx]) self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:]) # Freeze the frozen embedding self.embedding_frozen.weight.requires_grad = False def forward(self, input_ids: torch.Tensor) -> torch.Tensor: """ Forward pass for the split embedding wrapper. :param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1]. """ # Masks to separate frozen and trainable indices # (bsz, seq_len) mask_frozen = input_ids < self.freeze_until_idx mask_trainable = ~mask_frozen # Output tensor for embedding results batch_size, seq_len = input_ids.shape embeddings = torch.zeros( batch_size, seq_len, self.embedding_dim, device=input_ids.device, dtype=self.embedding_frozen.weight.dtype, ) # Handle frozen embedding if mask_frozen.any(): frozen_ids = input_ids[mask_frozen] frozen_emb = self.embedding_frozen(frozen_ids) embeddings[mask_frozen] = frozen_emb # Handle trainable embedding if mask_trainable.any(): # Adjust trainable IDs to the local index space of the trainable embedding trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx) trainable_emb = self.embedding_trainable(trainable_ids) embeddings[mask_trainable] = trainable_emb return embeddings def to_unsplit(self) -> nn.Embedding: unsplit_embedding = nn.Embedding( self.original_vocab_size, self.embedding_dim, dtype=self.embedding_frozen.weight.dtype, device=self.embedding_frozen.weight.device, ) with torch.no_grad(): unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight) unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight) return unsplit_embedding class PartiallyFrozenLinear(nn.Module): """A wrapper around nn.Linear to partially freeze part of the weight matrix.""" def __init__(self, original_linear: nn.Linear, freeze_until_idx: int): """ :param original_linear: The original nn.Linear layer. :param freeze_until_idx: The index up to which the rows of the weight matrix are frozen. """ super().__init__() assert original_linear.bias is None, "Currently only support linear module without bias" self.freeze_until_idx = freeze_until_idx self.input_dim = original_linear.in_features self.output_dim = original_linear.out_features # Create frozen and trainable linear layers self.linear_frozen = nn.Linear( self.input_dim, freeze_until_idx, bias=False, dtype=original_linear.weight.dtype, device=original_linear.weight.device, ) self.linear_trainable = nn.Linear( self.input_dim, self.output_dim - freeze_until_idx, bias=False, dtype=original_linear.weight.dtype, device=original_linear.weight.device, ) # Copy weights from the original linear layer with torch.no_grad(): self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx]) self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:]) # Freeze the frozen linear layer self.linear_frozen.weight.requires_grad = False def forward(self, input_tensor): # input_tensor: (bsz, seq_len, hidden_state_dim) frozen_output = self.linear_frozen(input_tensor) trainable_output = self.linear_trainable(input_tensor) return torch.cat((frozen_output, trainable_output), dim=-1) def to_unsplit(self) -> nn.Linear: unsplit_linear = nn.Linear( self.input_dim, self.output_dim, bias=False, dtype=self.linear_frozen.weight.dtype, device=self.linear_frozen.weight.device, ) # Copy weights from the frozen and trainable layers into the unsplit linear layer with torch.no_grad(): unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight) unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight) return unsplit_linear