Spaces:
Runtime error
Runtime error
File size: 7,497 Bytes
de071e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import copy
def permute_model(model, tmp_model, mlp_permutation, emb_permutation, n_blocks=32):
permute_embedding_layer(model, tmp_model, emb_permutation)
for i in range(n_blocks):
permute_transformer_block(tmp_model, i, tmp_model, mlp_permutation, emb_permutation)
permute_output_layer(tmp_model, tmp_model, emb_permutation)
def permute_transformer_block(model, i, tmp_model, mlp_permutation, emb_permutation):
weights = model.state_dict()
weights["model.layers." + str(i) + ".self_attn.q_proj.weight"] = weights[
"model.layers." + str(i) + ".self_attn.q_proj.weight"
][:, emb_permutation]
weights["model.layers." + str(i) + ".self_attn.k_proj.weight"] = weights[
"model.layers." + str(i) + ".self_attn.k_proj.weight"
][:, emb_permutation]
weights["model.layers." + str(i) + ".self_attn.v_proj.weight"] = weights[
"model.layers." + str(i) + ".self_attn.v_proj.weight"
][:, emb_permutation]
weights["model.layers." + str(i) + ".self_attn.o_proj.weight"] = weights[
"model.layers." + str(i) + ".self_attn.o_proj.weight"
][emb_permutation]
weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
"model.layers." + str(i) + ".mlp.gate_proj.weight"
][mlp_permutation]
weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
"model.layers." + str(i) + ".mlp.up_proj.weight"
][mlp_permutation]
weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
"model.layers." + str(i) + ".mlp.down_proj.weight"
][:, mlp_permutation]
weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
"model.layers." + str(i) + ".mlp.gate_proj.weight"
][:, emb_permutation]
weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
"model.layers." + str(i) + ".mlp.up_proj.weight"
][:, emb_permutation]
weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
"model.layers." + str(i) + ".mlp.down_proj.weight"
][emb_permutation]
tmp_model.load_state_dict(weights)
def permute_embedding_layer(model, tmp_model, emb_permutation):
weights = model.state_dict()
weights["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"][:, emb_permutation]
tmp_model.load_state_dict(weights)
def permute_output_layer(model, tmp_model, emb_permutation):
weights = model.state_dict()
weights["lm_head.weight"] = weights["lm_head.weight"][:, emb_permutation]
tmp_model.load_state_dict(weights)
def permute_mlp_block(model, i, tmp_model, mlp_permutation):
weights = model.state_dict()
weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
"model.layers." + str(i) + ".mlp.gate_proj.weight"
][mlp_permutation]
weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
"model.layers." + str(i) + ".mlp.up_proj.weight"
][mlp_permutation]
weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
"model.layers." + str(i) + ".mlp.down_proj.weight"
][:, mlp_permutation]
tmp_model.load_state_dict(weights)
def avg_mlp_block(model0, model1, i, tmp_model, alpha=0.5):
weights0 = model0.state_dict()
weights1 = model1.state_dict()
weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
)
weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
)
weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
)
tmp_model.load_state_dict(weights0)
def avg_transformer_block(model0, model1, i, tmp_model, alpha=0.5, attn=True):
weights0 = model0.state_dict()
weights1 = model1.state_dict()
if attn is True:
weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"] = (
alpha * weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"]
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.q_proj.weight"]
)
weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"] = (
alpha * weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"]
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.k_proj.weight"]
)
weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"] = (
alpha * weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"]
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.v_proj.weight"]
)
weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"] = (
alpha * weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"]
+ (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.o_proj.weight"]
)
weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
)
weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
)
weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
+ (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
)
tmp_model.load_state_dict(weights0)
def avg_embedding_layer(model0, model1, tmp_model, alpha=0.5):
weights0 = model0.state_dict()
weights1 = model1.state_dict()
weights0["model.embed_tokens.weight"] = (
alpha * weights0["model.embed_tokens.weight"]
+ (1 - alpha) * weights1["model.embed_tokens.weight"]
)
tmp_model.load_state_dict(weights0)
def avg_output_layer(model0, model1, tmp_model, alpha=0.5):
weights0 = model0.state_dict()
weights1 = model1.state_dict()
weights0["lm_head.weight"] = (
alpha * weights0["lm_head.weight"] + (1 - alpha) * weights1["lm_head.weight"]
)
weights0["model.norm.weight"] = (
alpha * weights0["model.norm.weight"] + (1 - alpha) * weights1["model.norm.weight"]
)
tmp_model.load_state_dict(weights0)
def avg_model(model0, model1, tmp_model, alpha=0.5, n_blocks=32, attn=True, emb=True):
model1 = copy.deepcopy(model1)
if emb is True:
avg_embedding_layer(model0, model1, tmp_model, alpha=alpha)
else:
tmp_model.load_state_dict(model0.state_dict())
for i in range(n_blocks):
avg_transformer_block(tmp_model, model1, i, tmp_model, alpha=alpha, attn=attn)
if emb is True:
avg_output_layer(tmp_model, model1, tmp_model, alpha=alpha)
def get_mlp_weights(model, i):
return model.state_dict()["model.layers." + str(i) + ".intermediate.dense.weight"]
def get_emb_weights(model):
return model.state_dict()["model.embed_tokens.weight"]
|