Spaces:
Sleeping
Sleeping
File size: 8,288 Bytes
6376749 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import time
import math
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor, transpose
def create_orthonormal_matrix(A):
# returns an orthonormal matrix (square) of size (min(A.shape), min(A.shape))
Q, R = torch.qr(A)
return Q
def get_target_modules_list(model, target_modules):
target_names = []
for n, _ in model.named_modules():
if any(t in n for t in target_modules):
target_names.append(n)
return target_names
def replace_svft_with_fused_linear(model, target_modules_list):
print("Replacing SVFT layers with new Linear layers")
# filter out svft layer
target_modules_list = [l for l in target_modules_list if "svft_layer" not in l]
for target_path in tqdm(reversed(target_modules_list), total=len(target_modules_list)):
parent_path = target_path[: target_path.rfind(".")] if "." in target_path else ""
target_name = target_path.split(".")[-1]
parent = model.get_submodule(parent_path) if parent_path else model
target = model.get_submodule(target_path)
in_dim = target.svft_layer.v.shape[1]
out_dim = target.svft_layer.u.shape[0]
if target.bias is None:
lin = torch.nn.Linear(in_dim, out_dim, bias=False)
else:
lin = torch.nn.Linear(in_dim, out_dim, bias=True)
lin.bias.data = target.bias.data
lin.weight.data = target.merge_and_unload()
parent.__setattr__(target_name, lin)
def create_and_replace_modules(model, target_modules_list, create_fn):
print("Replacing Linear layers with SVFT layers")
for target_path in tqdm(reversed(target_modules_list), total=len(target_modules_list)):
parent_path = target_path[: target_path.rfind(".")] if "." in target_path else ""
target_name = target_path.split(".")[-1]
parent = model.get_submodule(parent_path) if parent_path else model
target = model.get_submodule(target_path)
parent.__setattr__(target_name, create_fn(target))
class SVFTLayer(nn.Module):
def __init__(self, u, s, v, off_diag, pattern="banded", rank=None, fill_orthonormal=False):
"""
@inputs:
u: torch.Tensor. Left singular vectors of pre-trained weight matrix
s: torch.Tensor. Singular values of pre-trained weight matrix
v: torch.Tensor. Right singular vectors of pre-trained weight matrix
off_diag: int. Total off-diagonals to be used to populate matrix M (as referred in main paper)
pattern: str. Choices: "banded", "random", "top_k". Using "banded" with off_diag=1 simulates SVFT-plain
rank: int. Constraints how many singular vectors and values to use.
fill_orthonormal: bool. To determine if random orthonormal basis should be used
"""
super().__init__()
self.off_diag = off_diag
rank = s.shape[0] if rank is None else min(s.shape[0], rank)
self.n = rank
diff_rank = s.shape[0] - rank
if fill_orthonormal:
Q_u = torch.randn_like(u).to(s.device)
torch.nn.init.orthogonal_(Q_u)
Q_v = torch.randn_like(v).to(s.device)
torch.nn.init.orthogonal_(Q_v)
u = torch.cat([u[:, :rank], Q_u[:, :diff_rank]], dim=1)
v = torch.cat([v[:rank, :], Q_v[:diff_rank, :]], dim=0)
s = torch.cat([s[:rank], torch.zeros(diff_rank).to(s.device)], dim=0)
self.n = s.shape[0]
else:
s = s[:rank]
u = u[:, :rank]
v = v[:rank, :]
self.u = nn.Parameter(u.clone().detach().contiguous(), requires_grad=False)
s_pre = s.cpu().detach().clone().contiguous()
self.s_pre_edge_index = torch.sparse.spdiags(s_pre, torch.LongTensor([0]), (self.n, self.n)).coalesce().indices()
self.s_pre = nn.Parameter(s_pre, requires_grad=False)
if pattern=="banded":
diags = 2*self.off_diag + 1
offsets_positive = torch.arange(0, self.off_diag+1)
offsets_negative = torch.arange(-1, -self.off_diag-1, -1)
self.offsets = torch.cat([offsets_positive, offsets_negative])
self.s_edge_index = torch.sparse.spdiags(torch.randn([diags, self.n]), self.offsets, (self.n, self.n)).coalesce().indices()
self.s = torch.nn.Parameter(torch.zeros(self.s_edge_index.shape[1]), requires_grad=True)
elif pattern=="random":
print("Random pattern")
k = self.n*(2*self.off_diag+1) - self.off_diag*(self.off_diag+1)
rows = torch.randint(0, self.n, (k,))
cols = torch.randint(0, self.n, (k,))
self.s_edge_index = torch.stack([rows, cols])
self.s = torch.nn.Parameter(torch.zeros(k), requires_grad=True)
elif pattern=="top_k":
if u.shape == v.shape:
coeffs = u@v.T
else:
coeffs = u if u.shape[0]==u.shape[1] else v
k = self.n*(2*self.off_diag+1) - self.off_diag*(self.off_diag+1)
# Flatten the tensor to 1D
flattened_tensor = coeffs.contiguous().view(-1)
_, top_indices_flat = torch.topk(flattened_tensor, k)
num_rows, num_cols = coeffs.size()
rows = top_indices_flat // num_cols
cols = top_indices_flat % num_cols
self.s_edge_index = torch.stack([rows, cols])
self.s = torch.nn.Parameter(torch.zeros(k), requires_grad=True)
torch.nn.init.kaiming_normal_(self.s[None, :])
self.s.squeeze()
self.register_buffer('s_pre_row', self.s_pre_edge_index[0])
self.register_buffer('s_pre_col', self.s_pre_edge_index[1])
self.register_buffer('s_row', self.s_edge_index[0])
self.register_buffer('s_col', self.s_edge_index[1])
self.gate = nn.Parameter(torch.tensor([0.], dtype=torch.float32), requires_grad=True)
self.v = nn.Parameter(v.clone().detach().contiguous(), requires_grad=False)
def forward(self, x):
x = x @ self.get_weights()
return x
def get_weights(self):
s = SparseTensor(row=self.s_row, col=self.s_col, value=self.s*F.sigmoid(self.gate))
s_pre = SparseTensor(row=self.s_pre_row, col=self.s_pre_col, value=self.s_pre)
del_s = s_pre + s
weight = (del_s @ self.v).T
weight = weight @ self.u.T
return weight
def merge_and_unload(self):
return self.get_weights().T.contiguous()
class LinearWithSVFT(nn.Module):
def __init__(self, linear, off_diag, pattern="banded", rank=None, fill_orthonormal=False):
"""
@inputs:
linear: torch.Tensor. Linear Layer that has to adapted
off_diag: int. total number off diagonals to be used if pattern is 'banded'
for remaining patterns, equivalent number of learnable parameters are learnt
rank: SVD rank
fill_orthonormal: bool. To determine if random orthonormal basis should be used
"""
super().__init__()
self.bias = linear.bias
# since linear.weight is on GPU, computing SVD will be significantly faster
svd = torch.linalg.svd(linear.weight, full_matrices=False)
self.svft_layer = SVFTLayer(svd[0],
svd[1],
svd[2],
off_diag=off_diag,
pattern=pattern,
rank=rank,
fill_orthonormal=fill_orthonormal)
def forward(self, x):
if self.bias is not None:
return self.svft_layer(x) + self.bias
else:
return self.svft_layer(x)
def merge_and_unload(self):
return self.svft_layer.merge_and_unload()
def freeze_model(model, exclude_list = None):
''' Freeze all parameters of the model '''
if exclude_list is None:
exclude_list = []
for n, p in model.named_parameters():
if not any(e in n for e in exclude_list):
p.requires_grad = False |