File size: 4,033 Bytes
3b609b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Literal, Optional
from peft.config import PromptLearningConfig
from peft.utils import PeftType
@dataclass
class CPTConfig(PromptLearningConfig):
"""
CPT Configuration class extending PeftConfig for Context-aware Prompt Tuning (CPT).
This class introduces additional parameters required for CPT, such as:
- Token type masks
- Prompt tuning initialization
- Loss weighting
- Projection settings
For more details, see the paper: https://arxiv.org/abs/2410.17222
"""
# Token-related configurations
cpt_token_ids: Optional[list[int]] = field(
default=None, metadata={"help": "Tensor of token IDs used for CPT prompts."}
)
cpt_mask: Optional[list[int]] = field(default=None, metadata={"help": "Tensor mask applied to CPT tokens."})
cpt_tokens_type_mask: Optional[list[int]] = field(
default=None, metadata={"help": "Mask indicating the type of each CPT token."}
)
# Loss-related configurations
opt_weighted_loss_type: Optional[Literal["none", "decay"]] = field(
default="none", metadata={"help": "Type of weighted loss: 'none' or 'decay'."}
)
opt_loss_decay_factor: Optional[float] = field(
default=1.0, metadata={"help": "Factor for exponential decay in loss weighting."}
)
# Projection-related configurations
opt_projection_epsilon: Optional[float] = field(
default=0.1, metadata={"help": "Epsilon value for input projection."}
)
opt_projection_format_epsilon: Optional[float] = field(
default=0.1, metadata={"help": "Epsilon value for format projection."}
)
# Tokenizer configuration
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`"
},
)
# Neet to define CPT-specific static attributes
is_prompt_learning = True # Indicates that CPT is a prompt-learning method.
def __post_init__(self):
"""
Post-initialization hook to set additional attributes after the config is initialized.
"""
# CPT-specific static attributes
self.is_prompt_learning = True # Indicates that CPT is a prompt-learning method.
self.num_layers = None # Number of layers (optional, not always required).
self.token_dim = None # Dimension of token embeddings.
self.num_attention_heads = None # Number of attention heads (if applicable).
self.num_transformer_submodules = 1 # Number of transformer submodules used.
self.peft_type = PeftType.CPT # Specifies that the PEFT type is CPT.
self.task_type = "CAUSAL_LM" # Ensures task type is causal language modeling.
if self.cpt_token_ids is None:
self.cpt_token_ids = [0]
self.num_virtual_tokens = len(self.cpt_token_ids)
if self.cpt_mask is None:
self.cpt_mask = [1 for _ in self.cpt_token_ids]
if self.cpt_tokens_type_mask is None:
self.cpt_tokens_type_mask = [1 for _ in self.cpt_token_ids]
if not (
len(self.cpt_token_ids) == len(self.cpt_mask) == len(self.cpt_tokens_type_mask) == self.num_virtual_tokens
):
raise ValueError("cpt_token_ids, cpt_mask and cpt_tokens_type_mask must have the same length.")
|