File size: 14,211 Bytes
3b609b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import Any, List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge


class BoneLayer(BaseTunerLayer):
    # All names of layers that may contain (trainable) adapter weights
    adapter_layer_names = ("bone_block",)
    # All names of other parameters that may contain adapter-related parameters
    other_param_names = ("bone_r",)

    def __init__(self, base_layer: nn.Module, **kwargs) -> None:
        self.base_layer = base_layer
        self.bone_r = {}
        self.bone_block = nn.ParameterDict({})
        # Mark the weight as unmerged
        self._disable_adapters = False
        self.merged_adapters = []
        self.kwargs = kwargs

        base_layer = self.get_base_layer()
        if isinstance(base_layer, nn.Linear):
            self.in_features, self.out_features = base_layer.in_features, base_layer.out_features
        else:
            raise ValueError(f"Unsupported layer type {type(base_layer)}")

    def update_layer(
        self,
        adapter_name: str,
        r: int,
        init_weights: bool,
        **kwargs,
    ) -> None:
        """Internal function to create bone adapter

        Args:
            adapter_name (`str`): Name for the adapter to add.
            r (`int`): Rank for the added adapter.
            init_weights (`bool`): Whether to initialize weights.
        """
        if r <= 0:
            raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

        self.bone_r[adapter_name] = r

        # Determine shape of Bone weights
        base_layer = self.get_base_layer()
        if isinstance(base_layer, nn.Linear):
            self.bone_block[adapter_name] = nn.Parameter(torch.zeros(r, self.out_features), requires_grad=True)

        else:
            raise TypeError(f"Bone is not implemented for base layers of type {type(base_layer).__name__}")

        # Initialize weights
        if init_weights == "bat":
            if self.in_features % r != 0 or self.out_features % r != 0:
                raise ValueError("The weight matrix must be fully divisible into [r, r] blocks.")
            self.reset_bat_parameters(adapter_name, r)
        elif init_weights:
            self.reset_bone_parameters(adapter_name, r)
        else:
            self.reset_bone_parameters_random(adapter_name)
        # Move new weights to device
        self._move_adapter_to_device_of_base_layer(adapter_name)
        self.set_adapter(self.active_adapters)

    def reset_bone_parameters(self, adapter_name: str, r):
        self.bone_block[adapter_name] = nn.Parameter(torch.zeros(r, self.out_features), requires_grad=True)

    def reset_bat_parameters(self, adapter_name: str, r):
        self.bone_block[adapter_name] = nn.Parameter(torch.zeros(self.out_features // r, r, r), requires_grad=True)

    def reset_bone_parameters_random(self, adapter_name: str):
        nn.init.kaiming_uniform_(self.bone_block[adapter_name], a=math.sqrt(5))

    def scale_layer(self, scale: float) -> None:
        if scale == 1:
            return

        for active_adapter in self.active_adapters:
            if active_adapter not in self.bone_block.keys():
                continue

            warnings.warn("Scaling operation for Bone not supported! Automatically set scale to 1.")

    def unscale_layer(self, scale=None) -> None:
        for active_adapter in self.active_adapters:
            if active_adapter not in self.bone_block.keys():
                continue

            warnings.warn("Unscaling operation for Bone not supported! Keeping scale at 1.")


class BoneLinear(nn.Module, BoneLayer):
    """
    Bone implemented in a dense layer.
    """

    def __init__(
        self,
        base_layer,
        adapter_name: str,
        r: int = 0,
        init_weights: Union[bool, str] = True,
        **kwargs,
    ) -> None:
        super().__init__()
        BoneLayer.__init__(self, base_layer, **kwargs)
        self._active_adapter = adapter_name
        self.update_layer(adapter_name, r, init_weights, **kwargs)
        self.bone_fn = init_weights

    def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
        """
        Merge the active adapter weights into the base weights

        Args:
            safe_merge (`bool`, *optional*):
                If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
                before merging the weights. This is useful if you want to check if the merge operation will produce
                NaNs. Defaults to `False`.
            adapter_names (`List[str]`, *optional*):
                The list of adapter names that should be merged. If `None`, all active adapters will be merged.
                Defaults to `None`.
        """
        adapter_names = check_adapters_to_merge(self, adapter_names)
        if not adapter_names:
            # no adapter to merge
            return

        for active_adapter in adapter_names:
            if active_adapter in self.bone_block.keys():
                base_layer = self.get_base_layer()
                if safe_merge:
                    # Note that safe_merge will be slower than the normal merge
                    # because of the copy operation.
                    orig_weight = base_layer.weight.data.clone()
                    if self.bone_fn == "bat":
                        delta_weight = self.get_delta_weight(active_adapter, orig_weight)
                        orig_weight += delta_weight
                    else:
                        delta_weight = self.get_delta_weight_bone(active_adapter, self.base_layer.weight.data)
                        orig_weight = delta_weight

                    if not torch.isfinite(orig_weight).all():
                        raise ValueError(
                            f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
                        )

                    self.base_layer.weight.data = orig_weight
                else:
                    if self.bone_fn == "bat":
                        delta_weight = self.get_delta_weight(active_adapter, self.base_layer.weight.data)
                        self.base_layer.weight.data += delta_weight
                    else:
                        delta_weight = self.get_delta_weight_bone(active_adapter, self.base_layer.weight.data)
                        self.base_layer.weight.data = delta_weight
                self.merged_adapters.append(active_adapter)

    def unmerge(self) -> None:
        """
        This method unmerges all merged adapter layers from the base weights.
        """
        if not self.merged:
            warnings.warn("Already unmerged. Nothing to do.")
            return
        while len(self.merged_adapters) > 0:
            active_adapter = self.merged_adapters.pop()
            if active_adapter in self.bone_block.keys():
                orig_weight = self.get_base_layer().weight.data.clone()
                if self.bone_fn == "bat":
                    delta_weight = self.get_delta_weight(active_adapter, orig_weight, re=True)
                else:
                    delta_weight = self.get_delta_weight_bone(active_adapter, orig_weight, re=True)

                self.get_base_layer().weight.data = delta_weight

    def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor:
        """
        Compute the delta weight for the given adapter.

        Args:
            adapter (str):
                The name of the adapter for which the delta weight should be computed.
        """
        device = self.bone_block[adapter].device
        dtype = self.bone_block[adapter].dtype
        # In case users wants to merge the adapter weights that are in
        # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
        # (b)float16 because some CPUs have slow bf16/fp16 matmuls.
        cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)

        weight_bone = self.bone_block[adapter]

        if cast_to_fp32:
            weight_bone = weight_bone.float()

        r = weight_bone.size(-1)
        if re:
            o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)
            one = torch.eye(weight_bone.size(-1)).to(weight_bone.device)
            inv_I_plus_b = torch.inverse(one + weight_bone)
            w = (o - weight_bone) @ inv_I_plus_b
            output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape)
        else:
            w = (
                orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)
                @ weight_bone
                + weight_bone
            )
            output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape)

        if cast_to_fp32:
            output_tensor = output_tensor.to(dtype=dtype)

            # cast back the weights
            self.bone_block[adapter].data = weight_bone.to(dtype)

        return output_tensor

    def get_delta_weight_bone(self, adapter, orig_weight, re: bool = False) -> torch.Tensor:
        """
        Compute the delta weight for the given adapter.

        Args:
            adapter (str):
                The name of the adapter for which the delta weight should be computed.
        """
        device = self.bone_block[adapter].device
        dtype = self.bone_block[adapter].dtype
        # In case users wants to merge the adapter weights that are in
        # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
        # (b)float16 because some CPUs have slow bf16/fp16 matmuls.
        cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)

        weight_bone = self.bone_block[adapter]

        if cast_to_fp32:
            weight_bone = weight_bone.float()

        in_features = orig_weight.size(-1)
        r = weight_bone.size(0)
        if in_features % r != 0:
            last_size = in_features % r
            n_block = in_features // r
            n_block_size = n_block * r

            if re:
                orig_weight[:, :n_block_size] = (
                    (orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) - weight_bone)
                    .permute(2, 0, 1)
                    .reshape(*orig_weight[:, :n_block_size].shape)
                )
                orig_weight[:, n_block_size:] = (
                    orig_weight[:, n_block_size:] - (weight_bone.transpose(0, 1))[:, :last_size]
                )
            else:
                orig_weight[:, :n_block_size] = (
                    (orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) + weight_bone)
                    .permute(2, 0, 1)
                    .reshape(*orig_weight[:, :n_block_size].shape)
                )
                orig_weight[:, n_block_size:] = (
                    orig_weight[:, n_block_size:] + (weight_bone.transpose(0, 1))[:, :last_size]
                )
            output_tensor = orig_weight

        else:
            if re:
                w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) - weight_bone
                output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape)
            else:
                w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) + weight_bone
                output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape)

        if cast_to_fp32:
            output_tensor = output_tensor.to(dtype=dtype)

            # cast back the weights
            self.bone_block[adapter].data = weight_bone.to(dtype)

        return output_tensor

    def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
        previous_dtype = x.dtype

        if self.disable_adapters:
            if self.merged:
                self.unmerge()
            result = self.base_layer(x, *args, **kwargs)
        elif self.merged:
            result = self.base_layer(x, *args, **kwargs)
        else:
            if self.bone_fn == "bat":
                orig_weight = self.base_layer.weight.data.clone()
                for active_adapter in self.active_adapters:
                    if active_adapter not in self.bone_block.keys():
                        continue
                    delta_weight = self.get_delta_weight(active_adapter, orig_weight)
                    orig_weight = orig_weight + delta_weight

                result = F.linear(input=x, weight=orig_weight, bias=self.base_layer.bias)
            else:
                result = self.base_layer(x, *args, **kwargs)
                for active_adapter in self.active_adapters:
                    if active_adapter not in self.bone_block.keys():
                        continue
                    bone = self.bone_block[active_adapter]
                    r = bone.size(0)
                    if x.size(-1) % r != 0:
                        padding_size = (r - x.size(-1) % r) % r
                        x = F.pad(x, (0, padding_size))
                    result = result + torch.sum(x.reshape(*x.shape[:-1], x.size(-1) // r, r), dim=-2) @ bone

        result = result.to(previous_dtype)
        return result

    def __repr__(self) -> str:
        rep = super().__repr__()
        return "bone." + rep