# Copyright 2024-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import warnings from typing import Any, List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge class BoneLayer(BaseTunerLayer): # All names of layers that may contain (trainable) adapter weights adapter_layer_names = ("bone_block",) # All names of other parameters that may contain adapter-related parameters other_param_names = ("bone_r",) def __init__(self, base_layer: nn.Module, **kwargs) -> None: self.base_layer = base_layer self.bone_r = {} self.bone_block = nn.ParameterDict({}) # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] self.kwargs = kwargs base_layer = self.get_base_layer() if isinstance(base_layer, nn.Linear): self.in_features, self.out_features = base_layer.in_features, base_layer.out_features else: raise ValueError(f"Unsupported layer type {type(base_layer)}") def update_layer( self, adapter_name: str, r: int, init_weights: bool, **kwargs, ) -> None: """Internal function to create bone adapter Args: adapter_name (`str`): Name for the adapter to add. r (`int`): Rank for the added adapter. init_weights (`bool`): Whether to initialize weights. """ if r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") self.bone_r[adapter_name] = r # Determine shape of Bone weights base_layer = self.get_base_layer() if isinstance(base_layer, nn.Linear): self.bone_block[adapter_name] = nn.Parameter(torch.zeros(r, self.out_features), requires_grad=True) else: raise TypeError(f"Bone is not implemented for base layers of type {type(base_layer).__name__}") # Initialize weights if init_weights == "bat": if self.in_features % r != 0 or self.out_features % r != 0: raise ValueError("The weight matrix must be fully divisible into [r, r] blocks.") self.reset_bat_parameters(adapter_name, r) elif init_weights: self.reset_bone_parameters(adapter_name, r) else: self.reset_bone_parameters_random(adapter_name) # Move new weights to device self._move_adapter_to_device_of_base_layer(adapter_name) self.set_adapter(self.active_adapters) def reset_bone_parameters(self, adapter_name: str, r): self.bone_block[adapter_name] = nn.Parameter(torch.zeros(r, self.out_features), requires_grad=True) def reset_bat_parameters(self, adapter_name: str, r): self.bone_block[adapter_name] = nn.Parameter(torch.zeros(self.out_features // r, r, r), requires_grad=True) def reset_bone_parameters_random(self, adapter_name: str): nn.init.kaiming_uniform_(self.bone_block[adapter_name], a=math.sqrt(5)) def scale_layer(self, scale: float) -> None: if scale == 1: return for active_adapter in self.active_adapters: if active_adapter not in self.bone_block.keys(): continue warnings.warn("Scaling operation for Bone not supported! Automatically set scale to 1.") def unscale_layer(self, scale=None) -> None: for active_adapter in self.active_adapters: if active_adapter not in self.bone_block.keys(): continue warnings.warn("Unscaling operation for Bone not supported! Keeping scale at 1.") class BoneLinear(nn.Module, BoneLayer): """ Bone implemented in a dense layer. """ def __init__( self, base_layer, adapter_name: str, r: int = 0, init_weights: Union[bool, str] = True, **kwargs, ) -> None: super().__init__() BoneLayer.__init__(self, base_layer, **kwargs) self._active_adapter = adapter_name self.update_layer(adapter_name, r, init_weights, **kwargs) self.bone_fn = init_weights def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: """ Merge the active adapter weights into the base weights Args: safe_merge (`bool`, *optional*): If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs before merging the weights. This is useful if you want to check if the merge operation will produce NaNs. Defaults to `False`. adapter_names (`List[str]`, *optional*): The list of adapter names that should be merged. If `None`, all active adapters will be merged. Defaults to `None`. """ adapter_names = check_adapters_to_merge(self, adapter_names) if not adapter_names: # no adapter to merge return for active_adapter in adapter_names: if active_adapter in self.bone_block.keys(): base_layer = self.get_base_layer() if safe_merge: # Note that safe_merge will be slower than the normal merge # because of the copy operation. orig_weight = base_layer.weight.data.clone() if self.bone_fn == "bat": delta_weight = self.get_delta_weight(active_adapter, orig_weight) orig_weight += delta_weight else: delta_weight = self.get_delta_weight_bone(active_adapter, self.base_layer.weight.data) orig_weight = delta_weight if not torch.isfinite(orig_weight).all(): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) self.base_layer.weight.data = orig_weight else: if self.bone_fn == "bat": delta_weight = self.get_delta_weight(active_adapter, self.base_layer.weight.data) self.base_layer.weight.data += delta_weight else: delta_weight = self.get_delta_weight_bone(active_adapter, self.base_layer.weight.data) self.base_layer.weight.data = delta_weight self.merged_adapters.append(active_adapter) def unmerge(self) -> None: """ This method unmerges all merged adapter layers from the base weights. """ if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.bone_block.keys(): orig_weight = self.get_base_layer().weight.data.clone() if self.bone_fn == "bat": delta_weight = self.get_delta_weight(active_adapter, orig_weight, re=True) else: delta_weight = self.get_delta_weight_bone(active_adapter, orig_weight, re=True) self.get_base_layer().weight.data = delta_weight def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor: """ Compute the delta weight for the given adapter. Args: adapter (str): The name of the adapter for which the delta weight should be computed. """ device = self.bone_block[adapter].device dtype = self.bone_block[adapter].dtype # In case users wants to merge the adapter weights that are in # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to # (b)float16 because some CPUs have slow bf16/fp16 matmuls. cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) weight_bone = self.bone_block[adapter] if cast_to_fp32: weight_bone = weight_bone.float() r = weight_bone.size(-1) if re: o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3) one = torch.eye(weight_bone.size(-1)).to(weight_bone.device) inv_I_plus_b = torch.inverse(one + weight_bone) w = (o - weight_bone) @ inv_I_plus_b output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape) else: w = ( orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3) @ weight_bone + weight_bone ) output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape) if cast_to_fp32: output_tensor = output_tensor.to(dtype=dtype) # cast back the weights self.bone_block[adapter].data = weight_bone.to(dtype) return output_tensor def get_delta_weight_bone(self, adapter, orig_weight, re: bool = False) -> torch.Tensor: """ Compute the delta weight for the given adapter. Args: adapter (str): The name of the adapter for which the delta weight should be computed. """ device = self.bone_block[adapter].device dtype = self.bone_block[adapter].dtype # In case users wants to merge the adapter weights that are in # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to # (b)float16 because some CPUs have slow bf16/fp16 matmuls. cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) weight_bone = self.bone_block[adapter] if cast_to_fp32: weight_bone = weight_bone.float() in_features = orig_weight.size(-1) r = weight_bone.size(0) if in_features % r != 0: last_size = in_features % r n_block = in_features // r n_block_size = n_block * r if re: orig_weight[:, :n_block_size] = ( (orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) - weight_bone) .permute(2, 0, 1) .reshape(*orig_weight[:, :n_block_size].shape) ) orig_weight[:, n_block_size:] = ( orig_weight[:, n_block_size:] - (weight_bone.transpose(0, 1))[:, :last_size] ) else: orig_weight[:, :n_block_size] = ( (orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) + weight_bone) .permute(2, 0, 1) .reshape(*orig_weight[:, :n_block_size].shape) ) orig_weight[:, n_block_size:] = ( orig_weight[:, n_block_size:] + (weight_bone.transpose(0, 1))[:, :last_size] ) output_tensor = orig_weight else: if re: w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) - weight_bone output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape) else: w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) + weight_bone output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape) if cast_to_fp32: output_tensor = output_tensor.to(dtype=dtype) # cast back the weights self.bone_block[adapter].data = weight_bone.to(dtype) return output_tensor def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: previous_dtype = x.dtype if self.disable_adapters: if self.merged: self.unmerge() result = self.base_layer(x, *args, **kwargs) elif self.merged: result = self.base_layer(x, *args, **kwargs) else: if self.bone_fn == "bat": orig_weight = self.base_layer.weight.data.clone() for active_adapter in self.active_adapters: if active_adapter not in self.bone_block.keys(): continue delta_weight = self.get_delta_weight(active_adapter, orig_weight) orig_weight = orig_weight + delta_weight result = F.linear(input=x, weight=orig_weight, bias=self.base_layer.bias) else: result = self.base_layer(x, *args, **kwargs) for active_adapter in self.active_adapters: if active_adapter not in self.bone_block.keys(): continue bone = self.bone_block[active_adapter] r = bone.size(0) if x.size(-1) % r != 0: padding_size = (r - x.size(-1) % r) % r x = F.pad(x, (0, padding_size)) result = result + torch.sum(x.reshape(*x.shape[:-1], x.size(-1) // r, r), dim=-2) @ bone result = result.to(previous_dtype) return result def __repr__(self) -> str: rep = super().__repr__() return "bone." + rep