|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
from typing import Literal, Optional |
|
|
|
from peft.config import PromptLearningConfig |
|
from peft.utils import PeftType |
|
|
|
|
|
@dataclass |
|
class CPTConfig(PromptLearningConfig): |
|
""" |
|
CPT Configuration class extending PeftConfig for Context-aware Prompt Tuning (CPT). |
|
|
|
This class introduces additional parameters required for CPT, such as: |
|
- Token type masks |
|
- Prompt tuning initialization |
|
- Loss weighting |
|
- Projection settings |
|
|
|
For more details, see the paper: https://arxiv.org/abs/2410.17222 |
|
""" |
|
|
|
|
|
cpt_token_ids: Optional[list[int]] = field( |
|
default=None, metadata={"help": "Tensor of token IDs used for CPT prompts."} |
|
) |
|
cpt_mask: Optional[list[int]] = field(default=None, metadata={"help": "Tensor mask applied to CPT tokens."}) |
|
cpt_tokens_type_mask: Optional[list[int]] = field( |
|
default=None, metadata={"help": "Mask indicating the type of each CPT token."} |
|
) |
|
|
|
|
|
opt_weighted_loss_type: Optional[Literal["none", "decay"]] = field( |
|
default="none", metadata={"help": "Type of weighted loss: 'none' or 'decay'."} |
|
) |
|
opt_loss_decay_factor: Optional[float] = field( |
|
default=1.0, metadata={"help": "Factor for exponential decay in loss weighting."} |
|
) |
|
|
|
|
|
opt_projection_epsilon: Optional[float] = field( |
|
default=0.1, metadata={"help": "Epsilon value for input projection."} |
|
) |
|
opt_projection_format_epsilon: Optional[float] = field( |
|
default=0.1, metadata={"help": "Epsilon value for format projection."} |
|
) |
|
|
|
|
|
tokenizer_name_or_path: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" |
|
}, |
|
) |
|
|
|
is_prompt_learning = True |
|
|
|
def __post_init__(self): |
|
""" |
|
Post-initialization hook to set additional attributes after the config is initialized. |
|
""" |
|
|
|
self.is_prompt_learning = True |
|
self.num_layers = None |
|
self.token_dim = None |
|
self.num_attention_heads = None |
|
self.num_transformer_submodules = 1 |
|
self.peft_type = PeftType.CPT |
|
self.task_type = "CAUSAL_LM" |
|
|
|
if self.cpt_token_ids is None: |
|
self.cpt_token_ids = [0] |
|
|
|
self.num_virtual_tokens = len(self.cpt_token_ids) |
|
|
|
if self.cpt_mask is None: |
|
self.cpt_mask = [1 for _ in self.cpt_token_ids] |
|
|
|
if self.cpt_tokens_type_mask is None: |
|
self.cpt_tokens_type_mask = [1 for _ in self.cpt_token_ids] |
|
|
|
if not ( |
|
len(self.cpt_token_ids) == len(self.cpt_mask) == len(self.cpt_tokens_type_mask) == self.num_virtual_tokens |
|
): |
|
raise ValueError("cpt_token_ids, cpt_mask and cpt_tokens_type_mask must have the same length.") |
|
|