File size: 10,935 Bytes
de071e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import copy
import torch
from scipy.stats import ortho_group


def permute_model(model, tmp_model, mlp_permutation, emb_permutation, n_blocks=32):
    permute_embedding_layer(model, tmp_model, emb_permutation)
    for i in range(n_blocks):
        permute_transformer_block(tmp_model, i, tmp_model, mlp_permutation, emb_permutation)
    permute_output_layer(tmp_model, tmp_model, emb_permutation)


def permute_transformer_block(model, i, tmp_model, mlp_permutation, emb_permutation):
    weights = model.state_dict()

    weights["model.layers." + str(i) + ".self_attn.q_proj.weight"] = weights[
        "model.layers." + str(i) + ".self_attn.q_proj.weight"
    ][:, emb_permutation]
    weights["model.layers." + str(i) + ".self_attn.k_proj.weight"] = weights[
        "model.layers." + str(i) + ".self_attn.k_proj.weight"
    ][:, emb_permutation]
    weights["model.layers." + str(i) + ".self_attn.v_proj.weight"] = weights[
        "model.layers." + str(i) + ".self_attn.v_proj.weight"
    ][:, emb_permutation]
    weights["model.layers." + str(i) + ".self_attn.o_proj.weight"] = weights[
        "model.layers." + str(i) + ".self_attn.o_proj.weight"
    ][emb_permutation]

    weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.gate_proj.weight"
    ][mlp_permutation]
    weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.up_proj.weight"
    ][mlp_permutation]
    weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.down_proj.weight"
    ][:, mlp_permutation]

    weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.gate_proj.weight"
    ][:, emb_permutation]
    weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.up_proj.weight"
    ][:, emb_permutation]
    weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.down_proj.weight"
    ][emb_permutation]

    weights["model.layers." + str(i) + ".input_layernorm.weight"] = weights[
        "model.layers." + str(i) + ".input_layernorm.weight"
    ][
        emb_permutation
    ]  # 1d
    weights["model.layers." + str(i) + ".post_attention_layernorm.weight"] = weights[
        "model.layers." + str(i) + ".post_attention_layernorm.weight"
    ][emb_permutation]

    tmp_model.load_state_dict(weights)


def permute_embedding_layer(model, tmp_model, emb_permutation):
    weights = model.state_dict()

    weights["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"][:, emb_permutation]
    tmp_model.load_state_dict(weights)


def permute_output_layer(model, tmp_model, emb_permutation):
    weights = model.state_dict()

    weights["lm_head.weight"] = weights["lm_head.weight"][:, emb_permutation]
    weights["model.norm.weight"] = weights["model.norm.weight"][emb_permutation]
    tmp_model.load_state_dict(weights)


def permute_mlp_block(model, i, tmp_model, mlp_permutation):
    weights = model.state_dict()

    weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.gate_proj.weight"
    ][mlp_permutation]
    weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.up_proj.weight"
    ][mlp_permutation]
    weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.down_proj.weight"
    ][:, mlp_permutation]

    tmp_model.load_state_dict(weights)


def avg_mlp_block(model0, model1, i, tmp_model, alpha=0.5):
    weights0 = model0.state_dict()
    weights1 = model1.state_dict()

    weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
    )
    weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
    )
    weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
    )

    tmp_model.load_state_dict(weights0)


def avg_transformer_block(model0, model1, i, tmp_model, alpha=0.5, attn=True):
    weights0 = model0.state_dict()
    weights1 = model1.state_dict()

    if attn is True:
        weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"] = (
            alpha * weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"]
            + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.q_proj.weight"]
        )
        weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"] = (
            alpha * weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"]
            + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.k_proj.weight"]
        )
        weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"] = (
            alpha * weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"]
            + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.v_proj.weight"]
        )
        weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"] = (
            alpha * weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"]
            + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.o_proj.weight"]
        )

    weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
    )
    weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
    )
    weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
    )

    weights0["model.layers." + str(i) + ".input_layernorm.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".input_layernorm.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".input_layernorm.weight"]
    )
    weights0["model.layers." + str(i) + ".post_attention_layernorm.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".post_attention_layernorm.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".post_attention_layernorm.weight"]
    )

    tmp_model.load_state_dict(weights0)


def avg_embedding_layer(model0, model1, tmp_model, alpha=0.5):
    weights0 = model0.state_dict()
    weights1 = model1.state_dict()

    weights0["model.embed_tokens.weight"] = (
        alpha * weights0["model.embed_tokens.weight"]
        + (1 - alpha) * weights1["model.embed_tokens.weight"]
    )

    tmp_model.load_state_dict(weights0)


def avg_output_layer(model0, model1, tmp_model, alpha=0.5):
    weights0 = model0.state_dict()
    weights1 = model1.state_dict()

    weights0["lm_head.weight"] = (
        alpha * weights0["lm_head.weight"] + (1 - alpha) * weights1["lm_head.weight"]
    )
    weights0["model.norm.weight"] = (
        alpha * weights0["model.norm.weight"] + (1 - alpha) * weights1["model.norm.weight"]
    )

    tmp_model.load_state_dict(weights0)


def avg_model(model0, model1, tmp_model, alpha=0.5, n_blocks=32, attn=True, emb=True):
    model1 = copy.deepcopy(model1)

    if emb is True:
        avg_embedding_layer(model0, model1, tmp_model, alpha=alpha)
    else:
        tmp_model.load_state_dict(model0.state_dict())
    for i in range(n_blocks):
        avg_transformer_block(tmp_model, model1, i, tmp_model, alpha=alpha, attn=attn)
    if emb is True:
        avg_output_layer(tmp_model, model1, tmp_model, alpha=alpha)


def get_mlp_weights(model, i):
    return model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"]


def get_emb_weights(model):
    return model.state_dict()["model.embed_tokens.weight"]


def rotate_model(model, num_layers=32, hidden_dim=4096):

    model.to("cuda")

    rotation = ortho_group.rvs(dim=hidden_dim)
    rotation = torch.tensor(rotation, dtype=torch.bfloat16).to("cuda")

    weights = model.state_dict()
    weights_rotated = model.state_dict()

    weights_rotated["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"] @ rotation

    for i in range(num_layers):

        weights_rotated[f"model.layers.{i}.input_layernorm.weight"] = torch.ones(hidden_dim)
        weights_rotated[f"model.layers.{i}.post_attention_layernorm.weight"] = torch.ones(
            hidden_dim
        )

        weights_rotated[f"model.layers.{i}.self_attn.q_proj.weight"] = (
            weights[f"model.layers.{i}.self_attn.q_proj.weight"]
            @ torch.diag(weights[f"model.layers.{i}.input_layernorm.weight"])
            @ rotation
        )
        weights_rotated[f"model.layers.{i}.self_attn.k_proj.weight"] = (
            weights[f"model.layers.{i}.self_attn.k_proj.weight"]
            @ torch.diag(weights[f"model.layers.{i}.input_layernorm.weight"])
            @ rotation
        )
        weights_rotated[f"model.layers.{i}.self_attn.v_proj.weight"] = (
            weights[f"model.layers.{i}.self_attn.v_proj.weight"]
            @ torch.diag(weights[f"model.layers.{i}.input_layernorm.weight"])
            @ rotation
        )
        weights_rotated[f"model.layers.{i}.self_attn.o_proj.weight"] = (
            rotation.T @ weights[f"model.layers.{i}.self_attn.o_proj.weight"]
        )

        weights_rotated[f"model.layers.{i}.mlp.gate_proj.weight"] = (
            weights[f"model.layers.{i}.mlp.gate_proj.weight"]
            @ torch.diag(weights[f"model.layers.{i}.post_attention_layernorm.weight"])
            @ rotation
        )
        weights_rotated[f"model.layers.{i}.mlp.up_proj.weight"] = (
            weights[f"model.layers.{i}.mlp.up_proj.weight"]
            @ torch.diag(weights[f"model.layers.{i}.post_attention_layernorm.weight"])
            @ rotation
        )
        weights_rotated[f"model.layers.{i}.mlp.down_proj.weight"] = (
            rotation.T @ weights[f"model.layers.{i}.mlp.down_proj.weight"]
        )

    weights_rotated["model.norm.weight"] = torch.ones(hidden_dim)
    weights_rotated["lm_head.weight"] = (
        weights["lm_head.weight"] @ torch.diag(weights["model.norm.weight"]) @ rotation
    )

    model.load_state_dict(weights_rotated)