SVFT_PEFT / SVFT-main /svft /svft_layers.py
sakshi7502's picture
Upload 64 files
6376749 verified
import time
import math
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor, transpose
def create_orthonormal_matrix(A):
# returns an orthonormal matrix (square) of size (min(A.shape), min(A.shape))
Q, R = torch.qr(A)
return Q
def get_target_modules_list(model, target_modules):
target_names = []
for n, _ in model.named_modules():
if any(t in n for t in target_modules):
target_names.append(n)
return target_names
def replace_svft_with_fused_linear(model, target_modules_list):
print("Replacing SVFT layers with new Linear layers")
# filter out svft layer
target_modules_list = [l for l in target_modules_list if "svft_layer" not in l]
for target_path in tqdm(reversed(target_modules_list), total=len(target_modules_list)):
parent_path = target_path[: target_path.rfind(".")] if "." in target_path else ""
target_name = target_path.split(".")[-1]
parent = model.get_submodule(parent_path) if parent_path else model
target = model.get_submodule(target_path)
in_dim = target.svft_layer.v.shape[1]
out_dim = target.svft_layer.u.shape[0]
if target.bias is None:
lin = torch.nn.Linear(in_dim, out_dim, bias=False)
else:
lin = torch.nn.Linear(in_dim, out_dim, bias=True)
lin.bias.data = target.bias.data
lin.weight.data = target.merge_and_unload()
parent.__setattr__(target_name, lin)
def create_and_replace_modules(model, target_modules_list, create_fn):
print("Replacing Linear layers with SVFT layers")
for target_path in tqdm(reversed(target_modules_list), total=len(target_modules_list)):
parent_path = target_path[: target_path.rfind(".")] if "." in target_path else ""
target_name = target_path.split(".")[-1]
parent = model.get_submodule(parent_path) if parent_path else model
target = model.get_submodule(target_path)
parent.__setattr__(target_name, create_fn(target))
class SVFTLayer(nn.Module):
def __init__(self, u, s, v, off_diag, pattern="banded", rank=None, fill_orthonormal=False):
"""
@inputs:
u: torch.Tensor. Left singular vectors of pre-trained weight matrix
s: torch.Tensor. Singular values of pre-trained weight matrix
v: torch.Tensor. Right singular vectors of pre-trained weight matrix
off_diag: int. Total off-diagonals to be used to populate matrix M (as referred in main paper)
pattern: str. Choices: "banded", "random", "top_k". Using "banded" with off_diag=1 simulates SVFT-plain
rank: int. Constraints how many singular vectors and values to use.
fill_orthonormal: bool. To determine if random orthonormal basis should be used
"""
super().__init__()
self.off_diag = off_diag
rank = s.shape[0] if rank is None else min(s.shape[0], rank)
self.n = rank
diff_rank = s.shape[0] - rank
if fill_orthonormal:
Q_u = torch.randn_like(u).to(s.device)
torch.nn.init.orthogonal_(Q_u)
Q_v = torch.randn_like(v).to(s.device)
torch.nn.init.orthogonal_(Q_v)
u = torch.cat([u[:, :rank], Q_u[:, :diff_rank]], dim=1)
v = torch.cat([v[:rank, :], Q_v[:diff_rank, :]], dim=0)
s = torch.cat([s[:rank], torch.zeros(diff_rank).to(s.device)], dim=0)
self.n = s.shape[0]
else:
s = s[:rank]
u = u[:, :rank]
v = v[:rank, :]
self.u = nn.Parameter(u.clone().detach().contiguous(), requires_grad=False)
s_pre = s.cpu().detach().clone().contiguous()
self.s_pre_edge_index = torch.sparse.spdiags(s_pre, torch.LongTensor([0]), (self.n, self.n)).coalesce().indices()
self.s_pre = nn.Parameter(s_pre, requires_grad=False)
if pattern=="banded":
diags = 2*self.off_diag + 1
offsets_positive = torch.arange(0, self.off_diag+1)
offsets_negative = torch.arange(-1, -self.off_diag-1, -1)
self.offsets = torch.cat([offsets_positive, offsets_negative])
self.s_edge_index = torch.sparse.spdiags(torch.randn([diags, self.n]), self.offsets, (self.n, self.n)).coalesce().indices()
self.s = torch.nn.Parameter(torch.zeros(self.s_edge_index.shape[1]), requires_grad=True)
elif pattern=="random":
print("Random pattern")
k = self.n*(2*self.off_diag+1) - self.off_diag*(self.off_diag+1)
rows = torch.randint(0, self.n, (k,))
cols = torch.randint(0, self.n, (k,))
self.s_edge_index = torch.stack([rows, cols])
self.s = torch.nn.Parameter(torch.zeros(k), requires_grad=True)
elif pattern=="top_k":
if u.shape == v.shape:
coeffs = u@v.T
else:
coeffs = u if u.shape[0]==u.shape[1] else v
k = self.n*(2*self.off_diag+1) - self.off_diag*(self.off_diag+1)
# Flatten the tensor to 1D
flattened_tensor = coeffs.contiguous().view(-1)
_, top_indices_flat = torch.topk(flattened_tensor, k)
num_rows, num_cols = coeffs.size()
rows = top_indices_flat // num_cols
cols = top_indices_flat % num_cols
self.s_edge_index = torch.stack([rows, cols])
self.s = torch.nn.Parameter(torch.zeros(k), requires_grad=True)
torch.nn.init.kaiming_normal_(self.s[None, :])
self.s.squeeze()
self.register_buffer('s_pre_row', self.s_pre_edge_index[0])
self.register_buffer('s_pre_col', self.s_pre_edge_index[1])
self.register_buffer('s_row', self.s_edge_index[0])
self.register_buffer('s_col', self.s_edge_index[1])
self.gate = nn.Parameter(torch.tensor([0.], dtype=torch.float32), requires_grad=True)
self.v = nn.Parameter(v.clone().detach().contiguous(), requires_grad=False)
def forward(self, x):
x = x @ self.get_weights()
return x
def get_weights(self):
s = SparseTensor(row=self.s_row, col=self.s_col, value=self.s*F.sigmoid(self.gate))
s_pre = SparseTensor(row=self.s_pre_row, col=self.s_pre_col, value=self.s_pre)
del_s = s_pre + s
weight = (del_s @ self.v).T
weight = weight @ self.u.T
return weight
def merge_and_unload(self):
return self.get_weights().T.contiguous()
class LinearWithSVFT(nn.Module):
def __init__(self, linear, off_diag, pattern="banded", rank=None, fill_orthonormal=False):
"""
@inputs:
linear: torch.Tensor. Linear Layer that has to adapted
off_diag: int. total number off diagonals to be used if pattern is 'banded'
for remaining patterns, equivalent number of learnable parameters are learnt
rank: SVD rank
fill_orthonormal: bool. To determine if random orthonormal basis should be used
"""
super().__init__()
self.bias = linear.bias
# since linear.weight is on GPU, computing SVD will be significantly faster
svd = torch.linalg.svd(linear.weight, full_matrices=False)
self.svft_layer = SVFTLayer(svd[0],
svd[1],
svd[2],
off_diag=off_diag,
pattern=pattern,
rank=rank,
fill_orthonormal=fill_orthonormal)
def forward(self, x):
if self.bias is not None:
return self.svft_layer(x) + self.bias
else:
return self.svft_layer(x)
def merge_and_unload(self):
return self.svft_layer.merge_and_unload()
def freeze_model(model, exclude_list = None):
''' Freeze all parameters of the model '''
if exclude_list is None:
exclude_list = []
for n, p in model.named_parameters():
if not any(e in n for e in exclude_list):
p.requires_grad = False