|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
|
|
import torch |
|
from transformers.pytorch_utils import Conv1D |
|
|
|
from peft.import_utils import is_bnb_4bit_available, is_bnb_available |
|
from peft.tuners.lora import LoraConfig, LoraModel |
|
from peft.tuners.tuners_utils import BaseTunerLayer |
|
from peft.utils import ( |
|
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, |
|
_freeze_adapter, |
|
_get_submodules, |
|
get_auto_gptq_quant_linear, |
|
get_quantization_config, |
|
) |
|
from peft.utils.integrations import gather_params_ctx |
|
|
|
from .gptq import SVDQuantLinear |
|
from .layer import AdaLoraLayer, RankAllocator, SVDLinear |
|
|
|
|
|
class AdaLoraModel(LoraModel): |
|
""" |
|
Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper: |
|
https://openreview.net/forum?id=lq62uWRJjiY |
|
|
|
Args: |
|
model ([`transformers.PreTrainedModel`]): The model to be adapted. |
|
config ([`AdaLoraConfig`]): The configuration of the AdaLora model. |
|
adapter_name (`str`): The name of the adapter, defaults to `"default"`. |
|
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): |
|
Create empty adapter weights on meta device. Useful to speed up the loading process. |
|
|
|
Returns: |
|
`torch.nn.Module`: The AdaLora model. |
|
|
|
Example:: |
|
|
|
>>> from transformers import AutoModelForSeq2SeqLM >>> from peft import LoraConfig, AdaLoraModel, AdaLoraConfig |
|
>>> config = AdaLoraConfig( |
|
peft_type="ADALORA", task_type="SEQ_2_SEQ_LM", init_r=12, lora_alpha=32, target_modules=["q", "v"], |
|
lora_dropout=0.01, |
|
) |
|
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> model = AdaLoraModel(model, config, "default") |
|
|
|
**Attributes**: |
|
- **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted. |
|
- **peft_config** ([`AdaLoraConfig`]): The configuration of the AdaLora model. |
|
""" |
|
|
|
|
|
|
|
def __init__(self, model, config, adapter_name): |
|
super().__init__(model, config, adapter_name) |
|
|
|
traininable_mode_counter = 0 |
|
for config in self.peft_config.values(): |
|
if not config.inference_mode: |
|
traininable_mode_counter += 1 |
|
|
|
if traininable_mode_counter > 1: |
|
raise ValueError( |
|
"AdaLoraModel supports only 1 trainable adapter. " |
|
"When using multiple adapters, set inference_mode to True for all adapters except the one you want to train." |
|
) |
|
|
|
if self.peft_config[adapter_name].inference_mode: |
|
_freeze_adapter(self.model, adapter_name) |
|
else: |
|
self.trainable_adapter_name = adapter_name |
|
self.rankallocator = RankAllocator(self.model, self.peft_config[adapter_name], self.trainable_adapter_name) |
|
|
|
def _check_new_adapter_config(self, config: LoraConfig) -> None: |
|
""" |
|
A helper method to check the config when a new adapter is being added. |
|
|
|
Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. |
|
|
|
""" |
|
super()._check_new_adapter_config(config) |
|
|
|
traininable_mode_counter = 0 |
|
for config_ in self.peft_config.values(): |
|
if not config_.inference_mode: |
|
traininable_mode_counter += 1 |
|
|
|
if traininable_mode_counter > 1: |
|
raise ValueError( |
|
f"{self.__class__.__name__} supports only 1 trainable adapter. " |
|
"When using multiple adapters, set inference_mode to True for all adapters except the one " |
|
"you want to train." |
|
) |
|
|
|
def _create_and_replace( |
|
self, |
|
lora_config, |
|
adapter_name, |
|
target, |
|
target_name, |
|
parent, |
|
current_key, |
|
): |
|
kwargs = { |
|
"r": lora_config.init_r, |
|
"lora_alpha": lora_config.lora_alpha, |
|
"lora_dropout": lora_config.lora_dropout, |
|
"fan_in_fan_out": lora_config.fan_in_fan_out, |
|
"init_lora_weights": lora_config.init_lora_weights, |
|
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), |
|
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), |
|
} |
|
if (kwargs["loaded_in_8bit"] or kwargs["loaded_in_4bit"]) and not is_bnb_available(): |
|
raise ImportError( |
|
"To use AdaLora with 8-bit quantization, please install the `bitsandbytes` package. " |
|
"You can install it with `pip install bitsandbytes`." |
|
) |
|
|
|
quantization_config = get_quantization_config(self.model, method="gptq") |
|
if quantization_config is not None: |
|
kwargs["gptq_quantization_config"] = quantization_config |
|
|
|
|
|
if not isinstance(target, AdaLoraLayer): |
|
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) |
|
if adapter_name not in self.active_adapters: |
|
|
|
new_module.requires_grad_(False) |
|
self._replace_module(parent, target_name, new_module, target) |
|
else: |
|
target.update_layer( |
|
adapter_name, |
|
lora_config.init_r, |
|
lora_config.lora_alpha, |
|
lora_config.lora_dropout, |
|
lora_config.init_lora_weights, |
|
) |
|
|
|
@staticmethod |
|
def _create_new_module(lora_config, adapter_name, target, **kwargs): |
|
|
|
if is_bnb_available(): |
|
import bitsandbytes as bnb |
|
|
|
from .bnb import SVDLinear8bitLt |
|
if is_bnb_4bit_available(): |
|
from .bnb import SVDLinear4bit |
|
|
|
gptq_quantization_config = kwargs.get("gptq_quantization_config", None) |
|
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) |
|
|
|
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) |
|
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) |
|
|
|
if isinstance(target, BaseTunerLayer): |
|
target_base_layer = target.get_base_layer() |
|
else: |
|
target_base_layer = target |
|
|
|
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): |
|
kwargs.update( |
|
{ |
|
"has_fp16_weights": target_base_layer.state.has_fp16_weights, |
|
"threshold": target_base_layer.state.threshold, |
|
"index": target_base_layer.index, |
|
} |
|
) |
|
new_module = SVDLinear8bitLt(target, adapter_name, **kwargs) |
|
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit): |
|
fourbit_kwargs = kwargs.copy() |
|
fourbit_kwargs.update( |
|
{ |
|
"compute_dtype": target_base_layer.compute_dtype, |
|
"compress_statistics": target_base_layer.weight.compress_statistics, |
|
"quant_type": target_base_layer.weight.quant_type, |
|
} |
|
) |
|
new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs) |
|
elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear): |
|
new_module = SVDQuantLinear(target, adapter_name, **kwargs) |
|
else: |
|
if isinstance(target_base_layer, torch.nn.Linear): |
|
if kwargs["fan_in_fan_out"]: |
|
warnings.warn( |
|
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " |
|
"Setting fan_in_fan_out to False." |
|
) |
|
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False |
|
elif isinstance(target_base_layer, Conv1D): |
|
if not kwargs["fan_in_fan_out"]: |
|
warnings.warn( |
|
"fan_in_fan_out is set to False but the target module is `Conv1D`. " |
|
"Setting fan_in_fan_out to True." |
|
) |
|
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True |
|
else: |
|
raise ValueError( |
|
f"Target module {target} is not supported. " |
|
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." |
|
) |
|
new_module = SVDLinear(target, adapter_name, **kwargs) |
|
|
|
return new_module |
|
|
|
@staticmethod |
|
def _prepare_adapter_config(peft_config, model_config): |
|
if peft_config.target_modules is None: |
|
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING: |
|
raise ValueError("Please specify `target_modules` in `peft_config`") |
|
peft_config.target_modules = TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING[ |
|
model_config["model_type"] |
|
] |
|
return peft_config |
|
|
|
def __getattr__(self, name: str): |
|
"""Forward missing attributes to the wrapped module.""" |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
if name == "model": |
|
raise |
|
return getattr(self.model, name) |
|
|
|
def forward(self, *args, **kwargs): |
|
outputs = self.model.forward(*args, **kwargs) |
|
|
|
if (getattr(outputs, "loss", None) is not None) and isinstance(outputs.loss, torch.Tensor): |
|
|
|
orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight |
|
|
|
if orth_reg_weight <= 0: |
|
raise ValueError("orth_reg_weight should be greater than 0. ") |
|
|
|
regu_loss = 0 |
|
num_param = 0 |
|
for n, p in self.model.named_parameters(): |
|
if ("lora_A" in n or "lora_B" in n) and self.trainable_adapter_name in n: |
|
if p.shape == torch.Size([0]): |
|
with gather_params_ctx(p, fwd_module=self): |
|
para_cov = p @ p.T if "lora_A" in n else p.T @ p |
|
else: |
|
para_cov = p @ p.T if "lora_A" in n else p.T @ p |
|
I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov)) |
|
I.requires_grad = False |
|
num_param += 1 |
|
regu_loss += torch.norm(para_cov - I, p="fro") |
|
if num_param > 0: |
|
regu_loss = regu_loss / num_param |
|
else: |
|
regu_loss = 0 |
|
outputs.loss += orth_reg_weight * regu_loss |
|
return outputs |
|
|
|
def resize_modules_by_rank_pattern(self, rank_pattern, adapter_name): |
|
lora_config = self.peft_config[adapter_name] |
|
for name, rank_idx in rank_pattern.items(): |
|
if isinstance(rank_idx, list): |
|
rank = sum(rank_idx) |
|
elif isinstance(rank_idx, torch.Tensor): |
|
rank_idx = rank_idx.view(-1) |
|
rank = rank_idx.sum().item() |
|
else: |
|
raise ValueError("Unexpected type of rank_idx") |
|
key = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1]) |
|
_, target, _ = _get_submodules(self.model, key) |
|
lora_E_weights = target.lora_E[adapter_name][rank_idx] |
|
lora_A_weights = target.lora_A[adapter_name][rank_idx] |
|
lora_B_weights = target.lora_B[adapter_name][:, rank_idx] |
|
ranknum = target.ranknum[adapter_name] |
|
target.update_layer( |
|
adapter_name, |
|
rank, |
|
lora_config.lora_alpha, |
|
lora_config.lora_dropout, |
|
lora_config.init_lora_weights, |
|
) |
|
with torch.no_grad(): |
|
if rank > 0: |
|
target.lora_E[adapter_name].copy_(lora_E_weights) |
|
target.lora_A[adapter_name].copy_(lora_A_weights) |
|
target.lora_B[adapter_name].copy_(lora_B_weights) |
|
|
|
target.ranknum[adapter_name].copy_(ranknum) |
|
|
|
def resize_state_dict_by_rank_pattern(self, rank_pattern, state_dict, adapter_name): |
|
for name, rank_idx in rank_pattern.items(): |
|
rank = sum(rank_idx) |
|
prefix = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1]) |
|
for layer in ["lora_E", "lora_A", "lora_B"]: |
|
key = f"base_model.model.{prefix}.{layer}.{adapter_name}" |
|
if layer != "lora_B": |
|
state_dict[key] = ( |
|
state_dict[key][rank_idx] if rank != state_dict[key].shape[0] else state_dict[key] |
|
) |
|
else: |
|
state_dict[key] = ( |
|
state_dict[key][:, rank_idx] if rank != state_dict[key].shape[1] else state_dict[key] |
|
) |
|
return state_dict |
|
|
|
def update_and_allocate(self, global_step): |
|
""" |
|
This method updates Adalora budget and mask. |
|
|
|
This should be called in every training step after `loss.backward()` and before `zero_grad()`. |
|
|
|
`tinit`, `tfinal` and `deltaT` are handled with in the method. |
|
|
|
Args: |
|
global_step (`int`): The current training step, it is used to calculate adalora budget. |
|
|
|
Example: |
|
|
|
```python |
|
>>> loss = model(**input).loss |
|
>>> loss.backward() |
|
>>> optimizer.step() |
|
>>> model.base_model.update_and_allocate(i_step) |
|
>>> optimizer.zero_grad() |
|
``` |
|
""" |
|
lora_config = self.peft_config[self.trainable_adapter_name] |
|
|
|
if global_step < lora_config.total_step - lora_config.tfinal: |
|
_, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step) |
|
if rank_pattern: |
|
lora_config.rank_pattern = rank_pattern |
|
|
|
elif global_step == lora_config.total_step - lora_config.tfinal: |
|
_, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step, force_mask=True) |
|
|
|
|
|
lora_config.rank_pattern = rank_pattern |
|
self.rankallocator.reset_ipt() |
|
|
|
|
|
elif global_step > lora_config.total_step - lora_config.tfinal: |
|
self.rankallocator.mask_using_rank_pattern(self.model, lora_config.rank_pattern) |
|
|
|
else: |
|
return None |
|
|
|
def add_weighted_adapter(self, *args, **kwargs): |
|
"""This method is not supported for AdaLoRA, use LoRA instead.""" |
|
raise TypeError(f"{self.__class__.__name__} does not support add_weighted_adapter method.") |
|
|