|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import Any, Optional, Set, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from peft.tuners.lycoris_utils import LycorisLayer |
|
|
|
|
|
class LoKrLayer(nn.Module, LycorisLayer): |
|
|
|
adapter_layer_names = ( |
|
"lokr_w1", |
|
"lokr_w1_a", |
|
"lokr_w1_b", |
|
"lokr_w2", |
|
"lokr_w2_a", |
|
"lokr_w2_b", |
|
"lokr_t2", |
|
) |
|
|
|
|
|
def __init__(self, base_layer: nn.Module) -> None: |
|
super().__init__() |
|
LycorisLayer.__init__(self, base_layer) |
|
|
|
|
|
self.lokr_w1 = nn.ParameterDict({}) |
|
self.lokr_w1_a = nn.ParameterDict({}) |
|
self.lokr_w1_b = nn.ParameterDict({}) |
|
self.lokr_w2 = nn.ParameterDict({}) |
|
self.lokr_w2_a = nn.ParameterDict({}) |
|
self.lokr_w2_b = nn.ParameterDict({}) |
|
self.lokr_t2 = nn.ParameterDict({}) |
|
|
|
@property |
|
def _available_adapters(self) -> Set[str]: |
|
return { |
|
*self.lokr_w1, |
|
*self.lokr_w1_a, |
|
*self.lokr_w1_b, |
|
*self.lokr_w2, |
|
*self.lokr_w2_a, |
|
*self.lokr_w2_b, |
|
*self.lokr_t2, |
|
} |
|
|
|
def create_adapter_parameters( |
|
self, |
|
adapter_name: str, |
|
r: int, |
|
shape, |
|
use_w1: bool, |
|
use_w2: bool, |
|
use_effective_conv2d: bool, |
|
): |
|
if use_w1: |
|
self.lokr_w1[adapter_name] = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) |
|
else: |
|
self.lokr_w1_a[adapter_name] = nn.Parameter(torch.empty(shape[0][0], r)) |
|
self.lokr_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][0])) |
|
|
|
if len(shape) == 4: |
|
|
|
if use_w2: |
|
self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *shape[2:])) |
|
elif use_effective_conv2d: |
|
self.lokr_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3])) |
|
self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0][1])) |
|
self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1])) |
|
else: |
|
self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r)) |
|
self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1] * shape[2] * shape[3])) |
|
else: |
|
|
|
if use_w2: |
|
self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1])) |
|
else: |
|
self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r)) |
|
self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1])) |
|
|
|
def reset_adapter_parameters(self, adapter_name: str): |
|
if adapter_name in self.lokr_w1: |
|
nn.init.zeros_(self.lokr_w1[adapter_name]) |
|
else: |
|
nn.init.zeros_(self.lokr_w1_a[adapter_name]) |
|
nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5)) |
|
|
|
if adapter_name in self.lokr_w2: |
|
nn.init.kaiming_uniform_(self.lokr_w2[adapter_name], a=math.sqrt(5)) |
|
else: |
|
nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5)) |
|
nn.init.kaiming_uniform_(self.lokr_w2_b[adapter_name], a=math.sqrt(5)) |
|
|
|
if adapter_name in self.lokr_t2: |
|
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5)) |
|
|
|
def reset_adapter_parameters_random(self, adapter_name: str): |
|
if adapter_name in self.lokr_w1: |
|
nn.init.kaiming_uniform_(self.lokr_w1[adapter_name], a=math.sqrt(5)) |
|
else: |
|
nn.init.kaiming_uniform_(self.lokr_w1_a[adapter_name], a=math.sqrt(5)) |
|
nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5)) |
|
|
|
if adapter_name in self.lokr_w2: |
|
nn.init.kaiming_uniform_(self.lokr_w2[adapter_name], a=math.sqrt(5)) |
|
else: |
|
nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5)) |
|
nn.init.kaiming_uniform_(self.lokr_w2_b[adapter_name], a=math.sqrt(5)) |
|
|
|
if adapter_name in self.lokr_t2: |
|
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5)) |
|
|
|
|
|
def reset_adapter_parameters_lycoris_way(self, adapter_name): |
|
if adapter_name in self.lokr_w1: |
|
nn.init.kaiming_uniform_(self.lokr_w1[adapter_name], a=math.sqrt(5)) |
|
else: |
|
nn.init.kaiming_uniform_(self.lokr_w1_a[adapter_name], a=math.sqrt(5)) |
|
nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5)) |
|
|
|
if adapter_name in self.lokr_w2: |
|
nn.init.zeros_(self.lokr_w2[adapter_name]) |
|
else: |
|
nn.init.zeros_(self.lokr_w2_b[adapter_name]) |
|
nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5)) |
|
|
|
if adapter_name in self.lokr_t2: |
|
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5)) |
|
|
|
def update_layer( |
|
self, |
|
adapter_name: str, |
|
r: int, |
|
alpha: float, |
|
rank_dropout: float, |
|
module_dropout: float, |
|
init_weights: bool, |
|
use_effective_conv2d: bool, |
|
decompose_both: bool, |
|
decompose_factor: int, |
|
**kwargs, |
|
) -> None: |
|
"""Internal function to create lokr adapter |
|
|
|
Args: |
|
adapter_name (`str`): Name for the adapter to add. |
|
r (`int`): Rank for the added adapter. |
|
alpha (`float`): Alpha for the added adapter. |
|
rank_dropout (`float`): The dropout probability for rank dimension during training |
|
module_dropout (`float`): The dropout probability for disabling adapter during training. |
|
init_weights (`bool`): Whether to initialize adapter weights. |
|
use_effective_conv2d (`bool`): Use parameter effective decomposition for Conv2d with ksize > 1. |
|
decompose_both (`bool`): Perform rank decomposition of left kronecker product matrix. |
|
decompose_factor (`int`): Kronecker product decomposition factor. |
|
""" |
|
if r <= 0: |
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") |
|
|
|
self.r[adapter_name] = r |
|
self.alpha[adapter_name] = alpha |
|
self.scaling[adapter_name] = alpha / r |
|
self.rank_dropout[adapter_name] = rank_dropout |
|
self.module_dropout[adapter_name] = module_dropout |
|
self.rank_dropout_scale[adapter_name] = kwargs["rank_dropout_scale"] |
|
base_layer = self.get_base_layer() |
|
|
|
|
|
if isinstance(base_layer, nn.Linear): |
|
in_dim, out_dim = base_layer.in_features, base_layer.out_features |
|
|
|
in_m, in_n = factorization(in_dim, decompose_factor) |
|
out_l, out_k = factorization(out_dim, decompose_factor) |
|
shape = ((out_l, out_k), (in_m, in_n)) |
|
|
|
use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) |
|
use_w2 = not (r < max(shape[0][1], shape[1][1]) / 2) |
|
use_effective_conv2d = False |
|
elif isinstance(base_layer, nn.Conv2d): |
|
in_dim, out_dim = base_layer.in_channels, base_layer.out_channels |
|
k_size = base_layer.kernel_size |
|
|
|
in_m, in_n = factorization(in_dim, decompose_factor) |
|
out_l, out_k = factorization(out_dim, decompose_factor) |
|
shape = ((out_l, out_k), (in_m, in_n), *k_size) |
|
|
|
use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) |
|
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2 |
|
use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1) |
|
else: |
|
raise TypeError(f"LoKr is not implemented for base layers of type {type(base_layer).__name__}") |
|
|
|
|
|
self.create_adapter_parameters(adapter_name, r, shape, use_w1, use_w2, use_effective_conv2d) |
|
|
|
|
|
if init_weights: |
|
if init_weights == "lycoris": |
|
self.reset_adapter_parameters_lycoris_way(adapter_name) |
|
else: |
|
self.reset_adapter_parameters(adapter_name) |
|
else: |
|
self.reset_adapter_parameters_random(adapter_name) |
|
|
|
|
|
self._move_adapter_to_device_of_base_layer(adapter_name) |
|
self.set_adapter(self.active_adapters) |
|
|
|
def get_delta_weight(self, adapter_name: str) -> torch.Tensor: |
|
|
|
if adapter_name in self.lokr_w1: |
|
w1 = self.lokr_w1[adapter_name] |
|
else: |
|
w1 = self.lokr_w1_a[adapter_name] @ self.lokr_w1_b[adapter_name] |
|
|
|
if adapter_name in self.lokr_w2: |
|
w2 = self.lokr_w2[adapter_name] |
|
elif adapter_name in self.lokr_t2: |
|
w2 = make_weight_cp(self.lokr_t2[adapter_name], self.lokr_w2_a[adapter_name], self.lokr_w2_b[adapter_name]) |
|
else: |
|
w2 = self.lokr_w2_a[adapter_name] @ self.lokr_w2_b[adapter_name] |
|
|
|
|
|
weight = make_kron(w1, w2, self.scaling[adapter_name]) |
|
weight = weight.reshape(self.get_base_layer().weight.shape) |
|
|
|
|
|
rank_dropout = self.rank_dropout[adapter_name] |
|
if self.training and rank_dropout: |
|
drop = (torch.rand(weight.size(0)) > rank_dropout).float() |
|
drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device) |
|
if self.rank_dropout_scale[adapter_name]: |
|
drop /= drop.mean() |
|
weight *= drop |
|
|
|
return weight |
|
|
|
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
|
previous_dtype = x.dtype |
|
|
|
if self.disable_adapters: |
|
if self.merged: |
|
self.unmerge() |
|
result = self.base_layer(x, *args, **kwargs) |
|
elif self.merged: |
|
result = self.base_layer(x, *args, **kwargs) |
|
else: |
|
result = self.base_layer(x, *args, **kwargs) |
|
|
|
|
|
for active_adapter in self.active_adapters: |
|
if active_adapter not in self._available_adapters: |
|
continue |
|
|
|
module_dropout = self.module_dropout[active_adapter] |
|
|
|
|
|
if (not self.training) or (self.training and torch.rand(1) > module_dropout): |
|
result = result + self._get_delta_activations(active_adapter, x, *args, **kwargs) |
|
|
|
result = result.to(previous_dtype) |
|
return result |
|
|
|
|
|
class Linear(LoKrLayer): |
|
"""LoKr implemented in Linear layer""" |
|
|
|
def __init__( |
|
self, |
|
base_layer: nn.Module, |
|
device: Optional[Union[str, torch.device]] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
adapter_name: str = "default", |
|
r: int = 0, |
|
alpha: float = 0.0, |
|
rank_dropout: float = 0.0, |
|
module_dropout: float = 0.0, |
|
init_weights: bool = True, |
|
**kwargs, |
|
): |
|
super().__init__(base_layer) |
|
|
|
|
|
self._active_adapter = adapter_name |
|
self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs) |
|
|
|
def _get_delta_activations( |
|
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any |
|
) -> torch.Tensor: |
|
delta_weight = self.get_delta_weight(adapter_name) |
|
|
|
return F.linear(input, delta_weight) |
|
|
|
def __repr__(self) -> str: |
|
rep = super().__repr__() |
|
return "lokr." + rep |
|
|
|
|
|
class Conv2d(LoKrLayer): |
|
"""LoKr implemented in Conv2d layer""" |
|
|
|
def __init__( |
|
self, |
|
base_layer: nn.Module, |
|
device: Optional[Union[str, torch.device]] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
adapter_name: str = "default", |
|
r: int = 0, |
|
alpha: float = 0.0, |
|
rank_dropout: float = 0.0, |
|
module_dropout: float = 0.0, |
|
use_effective_conv2d: bool = False, |
|
init_weights: bool = True, |
|
**kwargs, |
|
): |
|
super().__init__(base_layer) |
|
|
|
|
|
self._active_adapter = adapter_name |
|
self.update_layer( |
|
adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs |
|
) |
|
|
|
def _get_delta_activations( |
|
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any |
|
) -> torch.Tensor: |
|
delta_weight = self.get_delta_weight(adapter_name) |
|
|
|
base_layer = self.get_base_layer() |
|
return F.conv2d( |
|
input, |
|
delta_weight, |
|
stride=base_layer.stride, |
|
padding=base_layer.padding, |
|
dilation=base_layer.dilation, |
|
groups=base_layer.groups, |
|
) |
|
|
|
def __repr__(self) -> str: |
|
rep = super().__repr__() |
|
return "lokr." + rep |
|
|
|
|
|
|
|
|
|
|
|
def factorization(dimension: int, factor: int = -1) -> Tuple[int, int]: |
|
"""Factorizes the provided number into the product of two numbers |
|
|
|
Args: |
|
dimension (`int`): The number that needs to be factorized. |
|
factor (`int`, optional): |
|
Factorization divider. The algorithm will try to output two numbers, one of each will be as close to the |
|
factor as possible. If -1 is provided, the decomposition algorithm would try to search dividers near the |
|
square root of the dimension. Defaults to -1. |
|
|
|
Returns: |
|
Tuple[`int`, `int`]: A tuple of two numbers, whose product is equal to the provided number. The first number is |
|
always less than or equal to the second. |
|
|
|
Example: |
|
```py |
|
>>> factorization(256, factor=-1) |
|
(16, 16) |
|
|
|
>>> factorization(128, factor=-1) |
|
(8, 16) |
|
|
|
>>> factorization(127, factor=-1) |
|
(1, 127) |
|
|
|
>>> factorization(128, factor=4) |
|
(4, 32) |
|
``` |
|
""" |
|
|
|
if factor > 0 and (dimension % factor) == 0: |
|
m = factor |
|
n = dimension // factor |
|
return m, n |
|
if factor == -1: |
|
factor = dimension |
|
m, n = 1, dimension |
|
length = m + n |
|
while m < n: |
|
new_m = m + 1 |
|
while dimension % new_m != 0: |
|
new_m += 1 |
|
new_n = dimension // new_m |
|
if new_m + new_n > length or new_m > factor: |
|
break |
|
else: |
|
m, n = new_m, new_n |
|
if m > n: |
|
n, m = m, n |
|
return m, n |
|
|
|
|
|
def make_weight_cp(t, wa, wb): |
|
rebuild2 = torch.einsum("i j k l, i p, j r -> p r k l", t, wa, wb) |
|
return rebuild2 |
|
|
|
|
|
def make_kron(w1, w2, scale=1.0): |
|
if len(w2.shape) == 4: |
|
w1 = w1.unsqueeze(2).unsqueeze(2) |
|
w2 = w2.contiguous() |
|
rebuild = torch.kron(w1, w2) |
|
|
|
return rebuild * scale |
|
|