Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,186 Bytes
07f1f64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import torch
import torch.nn as nn
class PartiallyFrozenEmbedding(nn.Module):
"""Split an existing `nn.Embedding` module that splits the embedding into:
- A frozen embedding for indices [0..freeze_until_idx].
- A trainable embedding for indices [freeze_until_idx+1..vocab_size-1].
This should work with both Zero-2 and Zero-3 seamlessly
"""
def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int):
"""
:param original_embedding: An instance of nn.Embedding (the original embedding layer).
:param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen.
"""
super().__init__()
self.freeze_until_idx = freeze_until_idx
self.original_vocab_size = original_embedding.num_embeddings
self.embedding_dim = original_embedding.embedding_dim
# Split the original embedding into frozen and trainable parts
self.embedding_frozen = nn.Embedding(
freeze_until_idx,
self.embedding_dim,
dtype=original_embedding.weight.dtype,
device=original_embedding.weight.device,
)
self.embedding_trainable = nn.Embedding(
self.original_vocab_size - freeze_until_idx,
self.embedding_dim,
dtype=original_embedding.weight.dtype,
device=original_embedding.weight.device,
)
# Copy weights from the original embedding into the frozen and trainable parts
with torch.no_grad():
self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx])
self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:])
# Freeze the frozen embedding
self.embedding_frozen.weight.requires_grad = False
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the split embedding wrapper.
:param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1].
"""
# Masks to separate frozen and trainable indices
# (bsz, seq_len)
mask_frozen = input_ids < self.freeze_until_idx
mask_trainable = ~mask_frozen
# Output tensor for embedding results
batch_size, seq_len = input_ids.shape
embeddings = torch.zeros(
batch_size,
seq_len,
self.embedding_dim,
device=input_ids.device,
dtype=self.embedding_frozen.weight.dtype,
)
# Handle frozen embedding
if mask_frozen.any():
frozen_ids = input_ids[mask_frozen]
frozen_emb = self.embedding_frozen(frozen_ids)
embeddings[mask_frozen] = frozen_emb
# Handle trainable embedding
if mask_trainable.any():
# Adjust trainable IDs to the local index space of the trainable embedding
trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx)
trainable_emb = self.embedding_trainable(trainable_ids)
embeddings[mask_trainable] = trainable_emb
return embeddings
def to_unsplit(self) -> nn.Embedding:
unsplit_embedding = nn.Embedding(
self.original_vocab_size,
self.embedding_dim,
dtype=self.embedding_frozen.weight.dtype,
device=self.embedding_frozen.weight.device,
)
with torch.no_grad():
unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight)
unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight)
return unsplit_embedding
class PartiallyFrozenLinear(nn.Module):
"""A wrapper around nn.Linear to partially freeze part of the weight matrix."""
def __init__(self, original_linear: nn.Linear, freeze_until_idx: int):
"""
:param original_linear: The original nn.Linear layer.
:param freeze_until_idx: The index up to which the rows of the weight matrix are frozen.
"""
super().__init__()
assert original_linear.bias is None, "Currently only support linear module without bias"
self.freeze_until_idx = freeze_until_idx
self.input_dim = original_linear.in_features
self.output_dim = original_linear.out_features
# Create frozen and trainable linear layers
self.linear_frozen = nn.Linear(
self.input_dim,
freeze_until_idx,
bias=False,
dtype=original_linear.weight.dtype,
device=original_linear.weight.device,
)
self.linear_trainable = nn.Linear(
self.input_dim,
self.output_dim - freeze_until_idx,
bias=False,
dtype=original_linear.weight.dtype,
device=original_linear.weight.device,
)
# Copy weights from the original linear layer
with torch.no_grad():
self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx])
self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:])
# Freeze the frozen linear layer
self.linear_frozen.weight.requires_grad = False
def forward(self, input_tensor):
# input_tensor: (bsz, seq_len, hidden_state_dim)
frozen_output = self.linear_frozen(input_tensor)
trainable_output = self.linear_trainable(input_tensor)
return torch.cat((frozen_output, trainable_output), dim=-1)
def to_unsplit(self) -> nn.Linear:
unsplit_linear = nn.Linear(
self.input_dim,
self.output_dim,
bias=False,
dtype=self.linear_frozen.weight.dtype,
device=self.linear_frozen.weight.device,
)
# Copy weights from the frozen and trainable layers into the unsplit linear layer
with torch.no_grad():
unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight)
unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight)
return unsplit_linear
|