File size: 8,405 Bytes
3b609b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import torch
from torch.nn import CrossEntropyLoss
from peft.utils.integrations import gather_params_ctx
class CPTEmbedding(torch.nn.Module):
"""
CPTEmbedding is a custom embedding layer designed for Context-aware Prompt Tuning (CPT) in PEFT. It initializes
embeddings, applies prompt-specific projections, and computes loss using label masks.
"""
def __init__(self, config, word_embeddings):
"""
Initializes the CPTEmbedding module.
Args:
config (Namespace):
Configuration object containing model hyperparameters and CPT-specific settings.
word_embeddings (torch.nn.Embedding):
The base word embedding layer used to initialize CPT embeddings.
"""
super().__init__()
self.config = copy.deepcopy(config)
num_virtual_tokens = config.num_virtual_tokens
# Initialize embeddings with virtual token dimensions
self.embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim)
# Initialize embeddings using text-based prompt tuning, if configured
if not config.inference_mode:
assert config.num_virtual_tokens == len(config.cpt_token_ids)
init_token_ids = torch.LongTensor(config.cpt_token_ids).to(word_embeddings.weight.device)
with gather_params_ctx(word_embeddings.parameters()):
word_embedding_weights = word_embeddings(init_token_ids).detach().clone()
word_embedding_weights = word_embedding_weights.to(torch.float32)
self.embedding.weight = torch.nn.Parameter(word_embedding_weights)
# Initialize delta embedding with zero weights
self.delta_embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim)
self.delta_embedding.weight.data = torch.zeros_like(self.delta_embedding.weight).to(torch.float32)
# Apply hook for backward gradient updates
self.set_updated_tokens()
def forward(self, indices):
"""
Computes the prompt embeddings and applies delta adjustments.
Args:
indices (torch.Tensor):
Indices of the tokens to be embedded.
Returns:
torch.Tensor:
Sum of prompt embeddings and delta embeddings.
"""
with torch.no_grad():
prompt_embeddings = self.embedding(indices)
self.delta_embedding.weight.data = self.get_projection() # Apply epsilon-based projection
delta_prompt_embeddings = self.delta_embedding(indices)
return prompt_embeddings + delta_prompt_embeddings
def set_updated_tokens(self):
"""
Sets up a backward hook to selectively update token gradients based on the CPT token type mask.
"""
tensor_ICL_mask = torch.Tensor(self.config.cpt_tokens_type_mask).long()
mask_input_template = torch.remainder(tensor_ICL_mask, 4) == 1
mask_input = torch.remainder(tensor_ICL_mask, 4) == 2
mask_output_template = torch.remainder(tensor_ICL_mask, 4) == 3
mask = mask_input_template | mask_input | mask_output_template
mask = mask.view(-1, 1)
def backward_hook(grad):
grad = grad * mask.to(grad.device) # Apply mask to gradients
return grad
self.delta_embedding.weight.register_hook(backward_hook)
def get_epsilon(self):
cpt_tokens_type_mask = self.config.cpt_tokens_type_mask
MIN_VALUE = 1e-10
# Calculate normalized epsilon values for input, output, and format tokens
normalized_format_eps = self.config.opt_projection_format_epsilon * torch.sqrt(
torch.Tensor([self.config.token_dim / 2048])
)
normalized_input_eps = self.config.opt_projection_epsilon * torch.sqrt(
torch.Tensor([self.config.token_dim / 2048])
)
epsilon = torch.ones_like(torch.Tensor(cpt_tokens_type_mask)).to(torch.float32) * MIN_VALUE
cpt_tokens_type_mask = torch.Tensor(cpt_tokens_type_mask).long()
epsilon[(cpt_tokens_type_mask > 0) & (torch.remainder(cpt_tokens_type_mask, 4) == 1)] = normalized_format_eps
epsilon[(cpt_tokens_type_mask > 0) & (torch.remainder(cpt_tokens_type_mask, 4) == 3)] = normalized_format_eps
epsilon[(cpt_tokens_type_mask > 0) & (torch.remainder(cpt_tokens_type_mask, 4) == 2)] = normalized_input_eps
return epsilon
def get_projection(self):
"""
Applies epsilon-based projection to the delta embeddings to control their norm.
"""
# Apply projection to control delta embedding norm
with torch.no_grad():
new_embeddings_weights = self.delta_embedding.weight.clone().to(self.delta_embedding.weight.device)
token_norm = torch.norm(new_embeddings_weights, p=2, dim=1)
projection_mask = token_norm > 0
if torch.any(projection_mask):
epsilon = self.get_epsilon().to(self.delta_embedding.weight.device)
new_embeddings_weights[projection_mask] *= (
epsilon[projection_mask] / (token_norm[projection_mask].clamp(min=epsilon[projection_mask]))
).view(-1, 1)
return new_embeddings_weights
@staticmethod
def calculate_loss(base_model_output, labels, cpt_type_mask, config):
"""
Computes the loss for CPT models with optional exponential decay.
Args:
base_model_output (ModelOutput):
Output from the base model containing logits.
labels (torch.Tensor):
Ground-truth labels for the input tokens.
cpt_type_mask (torch.Tensor):
Token type mask used for filtering valid loss terms.
config (Namespace):
Configuration object containing loss-related hyperparameters.
Returns:
ModelOutput:
The base model output with computed loss.
"""
device = base_model_output.logits.device
lm_logits = base_model_output.logits
labels = labels.to(device)
# Shift logits and labels for token prediction
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_cpt_type_mask = cpt_type_mask[..., 1:].contiguous()
shift_labels_bool = (shift_labels.clone().detach() != -100).bool()
batch_size, seq_length, vocab_size = shift_logits.shape
# Compute cross-entropy loss
loss_fct = CrossEntropyLoss(reduction="none", ignore_index=-100)
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
loss = loss.view(batch_size, seq_length)
# Apply exponential decay weights to the loss
shift_labels_weights = shift_labels_bool.clone().detach().float()
for i in range(batch_size):
idx_labels = (shift_cpt_type_mask[i] > 0) & (shift_cpt_type_mask[i] % 4 == 0)
labels_ids = shift_cpt_type_mask[i][idx_labels].unique()
exponential_decay = torch.ones_like(shift_cpt_type_mask[i]).to(device=device).float()
decay_value = 1
for label_mask_idx in torch.flip(labels_ids, [0]):
exponential_decay[shift_cpt_type_mask[i] == label_mask_idx] = decay_value
decay_value *= config.opt_loss_decay_factor
if config.opt_weighted_loss_type == "decay":
shift_labels_weights[i] *= exponential_decay
# Compute the weighted mean loss
loss = (loss[shift_labels_bool] * shift_labels_weights[shift_labels_bool]).mean()
base_model_output.loss = loss
return base_model_output
|