File size: 7,497 Bytes
de071e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import copy


def permute_model(model, tmp_model, mlp_permutation, emb_permutation, n_blocks=32):
    permute_embedding_layer(model, tmp_model, emb_permutation)
    for i in range(n_blocks):
        permute_transformer_block(tmp_model, i, tmp_model, mlp_permutation, emb_permutation)
    permute_output_layer(tmp_model, tmp_model, emb_permutation)


def permute_transformer_block(model, i, tmp_model, mlp_permutation, emb_permutation):
    weights = model.state_dict()

    weights["model.layers." + str(i) + ".self_attn.q_proj.weight"] = weights[
        "model.layers." + str(i) + ".self_attn.q_proj.weight"
    ][:, emb_permutation]
    weights["model.layers." + str(i) + ".self_attn.k_proj.weight"] = weights[
        "model.layers." + str(i) + ".self_attn.k_proj.weight"
    ][:, emb_permutation]
    weights["model.layers." + str(i) + ".self_attn.v_proj.weight"] = weights[
        "model.layers." + str(i) + ".self_attn.v_proj.weight"
    ][:, emb_permutation]
    weights["model.layers." + str(i) + ".self_attn.o_proj.weight"] = weights[
        "model.layers." + str(i) + ".self_attn.o_proj.weight"
    ][emb_permutation]

    weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.gate_proj.weight"
    ][mlp_permutation]
    weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.up_proj.weight"
    ][mlp_permutation]
    weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.down_proj.weight"
    ][:, mlp_permutation]

    weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.gate_proj.weight"
    ][:, emb_permutation]
    weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.up_proj.weight"
    ][:, emb_permutation]
    weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.down_proj.weight"
    ][emb_permutation]

    tmp_model.load_state_dict(weights)


def permute_embedding_layer(model, tmp_model, emb_permutation):
    weights = model.state_dict()

    weights["model.embed_tokens.weight"] = weights["model.embed_tokens.weight"][:, emb_permutation]
    tmp_model.load_state_dict(weights)


def permute_output_layer(model, tmp_model, emb_permutation):
    weights = model.state_dict()

    weights["lm_head.weight"] = weights["lm_head.weight"][:, emb_permutation]
    tmp_model.load_state_dict(weights)


def permute_mlp_block(model, i, tmp_model, mlp_permutation):
    weights = model.state_dict()

    weights["model.layers." + str(i) + ".mlp.gate_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.gate_proj.weight"
    ][mlp_permutation]
    weights["model.layers." + str(i) + ".mlp.up_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.up_proj.weight"
    ][mlp_permutation]
    weights["model.layers." + str(i) + ".mlp.down_proj.weight"] = weights[
        "model.layers." + str(i) + ".mlp.down_proj.weight"
    ][:, mlp_permutation]

    tmp_model.load_state_dict(weights)


def avg_mlp_block(model0, model1, i, tmp_model, alpha=0.5):
    weights0 = model0.state_dict()
    weights1 = model1.state_dict()

    weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
    )
    weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
    )
    weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
    )

    tmp_model.load_state_dict(weights0)


def avg_transformer_block(model0, model1, i, tmp_model, alpha=0.5, attn=True):
    weights0 = model0.state_dict()
    weights1 = model1.state_dict()

    if attn is True:
        weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"] = (
            alpha * weights0["model.layers." + str(i) + ".self_attn.q_proj.weight"]
            + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.q_proj.weight"]
        )
        weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"] = (
            alpha * weights0["model.layers." + str(i) + ".self_attn.k_proj.weight"]
            + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.k_proj.weight"]
        )
        weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"] = (
            alpha * weights0["model.layers." + str(i) + ".self_attn.v_proj.weight"]
            + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.v_proj.weight"]
        )
        weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"] = (
            alpha * weights0["model.layers." + str(i) + ".self_attn.o_proj.weight"]
            + (1 - alpha) * weights1["model.layers." + str(i) + ".self_attn.o_proj.weight"]
        )

    weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".mlp.gate_proj.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.gate_proj.weight"]
    )
    weights0["model.layers." + str(i) + ".mlp.up_proj.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".mlp.up_proj.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.up_proj.weight"]
    )
    weights0["model.layers." + str(i) + ".mlp.down_proj.weight"] = (
        alpha * weights0["model.layers." + str(i) + ".mlp.down_proj.weight"]
        + (1 - alpha) * weights1["model.layers." + str(i) + ".mlp.down_proj.weight"]
    )

    tmp_model.load_state_dict(weights0)


def avg_embedding_layer(model0, model1, tmp_model, alpha=0.5):
    weights0 = model0.state_dict()
    weights1 = model1.state_dict()

    weights0["model.embed_tokens.weight"] = (
        alpha * weights0["model.embed_tokens.weight"]
        + (1 - alpha) * weights1["model.embed_tokens.weight"]
    )

    tmp_model.load_state_dict(weights0)


def avg_output_layer(model0, model1, tmp_model, alpha=0.5):
    weights0 = model0.state_dict()
    weights1 = model1.state_dict()

    weights0["lm_head.weight"] = (
        alpha * weights0["lm_head.weight"] + (1 - alpha) * weights1["lm_head.weight"]
    )
    weights0["model.norm.weight"] = (
        alpha * weights0["model.norm.weight"] + (1 - alpha) * weights1["model.norm.weight"]
    )

    tmp_model.load_state_dict(weights0)


def avg_model(model0, model1, tmp_model, alpha=0.5, n_blocks=32, attn=True, emb=True):
    model1 = copy.deepcopy(model1)

    if emb is True:
        avg_embedding_layer(model0, model1, tmp_model, alpha=alpha)
    else:
        tmp_model.load_state_dict(model0.state_dict())
    for i in range(n_blocks):
        avg_transformer_block(tmp_model, model1, i, tmp_model, alpha=alpha, attn=attn)
    if emb is True:
        avg_output_layer(tmp_model, model1, tmp_model, alpha=alpha)


def get_mlp_weights(model, i):
    return model.state_dict()["model.layers." + str(i) + ".intermediate.dense.weight"]


def get_emb_weights(model):
    return model.state_dict()["model.embed_tokens.weight"]