SmartDazi commited on
Commit
a5efb83
·
verified ·
1 Parent(s): 86930ad

Upload folder using huggingface_hub

Browse files
Modelfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ollama modelfile auto-generated by llamafactory
2
+
3
+ FROM .
4
+
5
+ TEMPLATE """<s>{{ if .System }}<|im_start|>system
6
+ {{ .System }}<|im_end|>
7
+ {{ end }}{{ range .Messages }}{{ if eq .Role "user" }}<|im_start|>user
8
+ {{ .Content }}<|im_end|>
9
+ <|im_start|>assistant
10
+ {{ else if eq .Role "assistant" }}{{ .Content }}<|im_end|>
11
+ {{ end }}{{ end }}"""
12
+
13
+ PARAMETER stop "<|im_end|>"
14
+ PARAMETER num_ctx 4096
added_tokens.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|execute_end|>": 73444,
3
+ "<|execute_start|>": 73443,
4
+ "<|fim_middle|>": 73446,
5
+ "<|fim_prefix|>": 73445,
6
+ "<|fim_suffix|>": 73447,
7
+ "<|im_end|>": 73440,
8
+ "<|im_start|>": 73441,
9
+ "<|tool_call|>": 73442
10
+ }
compressed_attention.py ADDED
@@ -0,0 +1,1401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Any, Tuple, Union
16
+ from collections import Counter
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+ import warnings
21
+ from torch import nn
22
+ def is_hopper_gpu():
23
+ if torch.cuda.is_available():
24
+ device_capability = torch.cuda.get_device_capability()
25
+ major, minor = device_capability
26
+ return major == 9
27
+ return False
28
+ def get_compressed_seqlens(
29
+ cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int
30
+ ):
31
+ # compute seqlens after compression
32
+ seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
33
+ y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
34
+ # corner case, if sequence_length < kernel_size, no compression for this sequence
35
+ y_seqlens[seqlens < kernel_size] = 0
36
+ y_cu_seqlens = torch.zeros(
37
+ y_seqlens.shape[0] + 1, dtype=torch.int32, device=cu_seqlens.device
38
+ )
39
+ y_cu_seqlens[1:] = torch.cumsum(y_seqlens, dim=0)
40
+ return y_seqlens, y_cu_seqlens
41
+
42
+
43
+ def get_num_warps_stages(head_dim, block_size, is_hopper_gpu):
44
+ """
45
+ Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton.
46
+
47
+ Args:
48
+ head_dim (int): Size of the head dimension.
49
+ block_size (int): Size of the block in the attention matrix.
50
+ is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU.
51
+
52
+ Returns:
53
+ tuple: (num_warps, num_stages) recommended values.
54
+ """
55
+ # Determine if head_dim and block_size exceed 64
56
+ head_large = head_dim > 64
57
+ block_large = block_size > 64
58
+
59
+ if is_hopper_gpu:
60
+ # Hopper GPU recommendations
61
+ if head_large and block_large:
62
+ num_warps = 8
63
+ num_stages = 3
64
+ elif head_large or block_large:
65
+ num_warps = 4
66
+ num_stages = 3
67
+ else:
68
+ num_warps = 2
69
+ num_stages = 2
70
+ else:
71
+ # Ampere GPU recommendations
72
+ if head_large and block_large:
73
+ num_warps = 8
74
+ num_stages = 3
75
+ elif head_large or block_large:
76
+ num_warps = 8
77
+ num_stages = 3
78
+ else:
79
+ num_warps = 2
80
+ num_stages = 2
81
+ return num_warps, num_stages
82
+
83
+
84
+ IS_HOPPER_GPU = is_hopper_gpu()
85
+
86
+
87
+ @triton.jit
88
+ def forward_kernel(
89
+ q_ptr, # Q: n x h x d
90
+ k_ptr, # K: n x h x d
91
+ v_ptr, # V: n x h x d
92
+ o_ptr, # O: n x h x d
93
+ lse_ptr, # LSE: h x n
94
+ # size and stride at compresstion
95
+ kernel_size,
96
+ kernel_stride,
97
+ # seqlens
98
+ cu_seqlens_q,
99
+ cu_seqlens_k,
100
+ # shape
101
+ NUM_KV_HEADS,
102
+ NUM_SHARE_Q_HEADS,
103
+ HEAD_DIM,
104
+ # sm_scale
105
+ sm_scale,
106
+ # stride
107
+ stride_qn,
108
+ stride_qh,
109
+ stride_qd,
110
+ stride_kn,
111
+ stride_kh,
112
+ stride_kd,
113
+ stride_vn,
114
+ stride_vh,
115
+ stride_vd,
116
+ stride_on,
117
+ stride_oh,
118
+ stride_od,
119
+ stride_lh,
120
+ stride_ln,
121
+ # META parameters
122
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
123
+ BLOCK_SIZE_K: tl.constexpr, # k block size
124
+ BLOCK_SIZE_D: tl.constexpr,
125
+ ):
126
+ qk_scale = sm_scale * 1.44269504
127
+ # get batch id and head id
128
+ pid_b = tl.program_id(0)
129
+ pid_h = tl.program_id(1)
130
+ pid_q = tl.program_id(2)
131
+ pid_kh = pid_h // NUM_SHARE_Q_HEADS
132
+ # get q k start and len after rmpad
133
+ q_start = tl.load(cu_seqlens_q + pid_b)
134
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
135
+ k_start = tl.load(cu_seqlens_k + pid_b)
136
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
137
+ # skip first kernel_size query block, because they do no attend to any keys
138
+ q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
139
+ if q_start_in_seq >= q_len:
140
+ return
141
+ # init qkv pointer
142
+ q_ptrs = tl.make_block_ptr(
143
+ base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
144
+ shape=(q_len, HEAD_DIM),
145
+ strides=(stride_qn, stride_qd),
146
+ offsets=(q_start_in_seq, 0),
147
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
148
+ order=(1, 0),
149
+ )
150
+ k_ptrs = tl.make_block_ptr(
151
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
152
+ shape=(HEAD_DIM, k_len),
153
+ strides=(stride_kd, stride_kn),
154
+ offsets=(0, 0),
155
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
156
+ order=(0, 1),
157
+ )
158
+ v_ptrs = tl.make_block_ptr(
159
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
160
+ shape=(k_len, HEAD_DIM),
161
+ strides=(stride_vn, stride_vd),
162
+ offsets=(0, 0),
163
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
164
+ order=(1, 0),
165
+ )
166
+ # load q
167
+ q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
168
+ # init statistics
169
+ off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
170
+ off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
171
+ m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
172
+ lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
173
+ acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32)
174
+ # attention
175
+ lo = 0
176
+ hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
177
+ for i in range(lo, hi, BLOCK_SIZE_K):
178
+ i = tl.multiple_of(i, BLOCK_SIZE_K)
179
+ # load k
180
+ k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero")
181
+ # compute qk
182
+ qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
183
+ qk += tl.where(
184
+ off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")
185
+ )
186
+ qk += tl.dot(q, k) * qk_scale
187
+ # compute m_ij and l_ij
188
+ m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
189
+ p = tl.exp2(qk - m_ij[:, None])
190
+ l_ij = tl.sum(p, axis=1)
191
+ # scale acc_o
192
+ acc_o_scale = tl.exp2(m_i - m_ij)
193
+ acc_o = acc_o * acc_o_scale[:, None]
194
+ # load v and update acc_o
195
+ v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
196
+ p = p.to(v.dtype)
197
+ acc_o += tl.dot(p, v)
198
+ # update statistics
199
+ m_i = m_ij
200
+ lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
201
+ # update ptrs
202
+ k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))
203
+ v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))
204
+ # final scale
205
+ acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
206
+ # save output
207
+ o_ptrs = tl.make_block_ptr(
208
+ base=o_ptr + q_start * stride_on + pid_h * stride_oh,
209
+ shape=(q_len, HEAD_DIM),
210
+ strides=(stride_on, stride_od),
211
+ offsets=(q_start_in_seq, 0),
212
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
213
+ order=(1, 0),
214
+ )
215
+ tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
216
+ # save lse
217
+ l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln
218
+ tl.store(l_ptrs, lse_i, mask=off_q < q_len)
219
+
220
+
221
+ @triton.jit
222
+ def backward_sum_o_do(
223
+ o_ptr, # O: n x h x d
224
+ do_ptr, # dO: n x h x d
225
+ delta_ptr, # D: h x n
226
+ o_len,
227
+ HEAD_DIM,
228
+ stride_on,
229
+ stride_oh,
230
+ stride_od,
231
+ stride_don,
232
+ stride_doh,
233
+ stride_dod,
234
+ stride_dh,
235
+ stride_dn,
236
+ BLOCK_SIZE_O: tl.constexpr,
237
+ BLOCK_SIZE_D: tl.constexpr,
238
+ ):
239
+ pid_n = tl.program_id(0)
240
+ pid_h = tl.program_id(1)
241
+ off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
242
+ off_d = tl.arange(0, BLOCK_SIZE_D)
243
+ o = tl.load(
244
+ o_ptr
245
+ + off_n[:, None] * stride_on
246
+ + pid_h * stride_oh
247
+ + off_d[None, :] * stride_od,
248
+ mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
249
+ other=0,
250
+ ).to(tl.float32)
251
+ do = tl.load(
252
+ do_ptr
253
+ + off_n[:, None] * stride_don
254
+ + pid_h * stride_doh
255
+ + off_d[None, :] * stride_dod,
256
+ mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
257
+ other=0,
258
+ ).to(tl.float32)
259
+ delta = tl.sum(o * do, axis=1)
260
+ tl.store(
261
+ delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len
262
+ )
263
+
264
+
265
+ @triton.jit
266
+ def backward_dkdv(
267
+ q_ptr, # Q: n x qh x d
268
+ k_ptr, # K: n x kh x d
269
+ v_ptr, # V: n x kh x d
270
+ lse_ptr, # LSE: qh x n
271
+ d_ptr, # Delta: qh x n
272
+ do_ptr,
273
+ dk_ptr, # DK: sh x n x kh x d
274
+ dv_ptr, # DV: sh x n x kh x d
275
+ kernel_size,
276
+ kernel_stride,
277
+ # seqlens
278
+ cu_seqlens_q,
279
+ cu_seqlens_k,
280
+ # shape
281
+ NUM_KV_HEADS,
282
+ NUM_SHARE_Q_HEADS,
283
+ HEAD_DIM,
284
+ # sm_scale
285
+ sm_scale,
286
+ # stride
287
+ stride_qn,
288
+ stride_qh,
289
+ stride_qd,
290
+ stride_kn,
291
+ stride_kh,
292
+ stride_kd,
293
+ stride_vn,
294
+ stride_vh,
295
+ stride_vd,
296
+ stride_lh,
297
+ stride_ln,
298
+ stride_dh,
299
+ stride_dn,
300
+ stride_don,
301
+ stride_doh,
302
+ stride_dod,
303
+ stride_dks,
304
+ stride_dkn,
305
+ stride_dkh,
306
+ stride_dkd,
307
+ stride_dvs,
308
+ stride_dvn,
309
+ stride_dvh,
310
+ stride_dvd,
311
+ # META parameters
312
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
313
+ BLOCK_SIZE_K: tl.constexpr, # k block size
314
+ BLOCK_SIZE_D: tl.constexpr,
315
+ ):
316
+ qk_scale = sm_scale * 1.44269504
317
+ # get batch id and head id
318
+ pid_b = tl.program_id(0)
319
+ pid_h = tl.program_id(1)
320
+ pid_kh = pid_h // NUM_SHARE_Q_HEADS
321
+ pid_sh = pid_h % NUM_SHARE_Q_HEADS
322
+ pid_k = tl.program_id(2)
323
+ # get q k start and len after rmpad
324
+ q_start = tl.load(cu_seqlens_q + pid_b)
325
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
326
+ k_start = tl.load(cu_seqlens_k + pid_b)
327
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
328
+ if BLOCK_SIZE_K * pid_k >= k_len:
329
+ return
330
+ # init pointers
331
+ k_ptrs = tl.make_block_ptr(
332
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
333
+ shape=(k_len, HEAD_DIM),
334
+ strides=(stride_kn, stride_kd),
335
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
336
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
337
+ order=(1, 0),
338
+ )
339
+ dk_ptrs = tl.make_block_ptr(
340
+ base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
341
+ shape=(k_len, HEAD_DIM),
342
+ strides=(stride_dkn, stride_dkd),
343
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
344
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
345
+ order=(1, 0),
346
+ )
347
+ v_ptrs = tl.make_block_ptr(
348
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
349
+ shape=(k_len, HEAD_DIM),
350
+ strides=(stride_vn, stride_vd),
351
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
352
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
353
+ order=(1, 0),
354
+ )
355
+ dv_ptrs = tl.make_block_ptr(
356
+ base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
357
+ shape=(k_len, HEAD_DIM),
358
+ strides=(stride_dvn, stride_dvd),
359
+ offsets=(pid_k * BLOCK_SIZE_K, 0),
360
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
361
+ order=(1, 0),
362
+ )
363
+ # offsets
364
+ off_q = tl.arange(0, BLOCK_SIZE_Q)
365
+ off_k = (
366
+ pid_k * BLOCK_SIZE_K * kernel_stride
367
+ + tl.arange(0, BLOCK_SIZE_K) * kernel_stride
368
+ + kernel_size
369
+ - 1
370
+ )
371
+ # load k v and keep in SRAM
372
+ k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
373
+ v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
374
+ # init dk dv
375
+ dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
376
+ dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
377
+ q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1
378
+ q_ptrs = tl.make_block_ptr(
379
+ base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
380
+ shape=(HEAD_DIM, q_len),
381
+ strides=(stride_qd, stride_qn),
382
+ offsets=(0, q_lo),
383
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
384
+ order=(0, 1),
385
+ )
386
+ do_ptrs = tl.make_block_ptr(
387
+ base=do_ptr + q_start * stride_don + pid_h * stride_doh,
388
+ shape=(HEAD_DIM, q_len),
389
+ strides=(stride_dod, stride_don),
390
+ offsets=(0, q_lo),
391
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
392
+ order=(0, 1),
393
+ )
394
+ d_ptrs = tl.make_block_ptr(
395
+ base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
396
+ shape=(1, q_len),
397
+ strides=(0, stride_dn),
398
+ offsets=(0, q_lo),
399
+ block_shape=(1, BLOCK_SIZE_Q),
400
+ order=(1, 0),
401
+ )
402
+ lse_ptrs = tl.make_block_ptr(
403
+ base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
404
+ shape=(1, q_len),
405
+ strides=(0, stride_ln),
406
+ offsets=(0, q_lo),
407
+ block_shape=(1, BLOCK_SIZE_Q),
408
+ order=(0, 1),
409
+ )
410
+ # loop for q blocks
411
+ for i in range(q_lo, q_len, BLOCK_SIZE_Q):
412
+ # load
413
+ q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
414
+ do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
415
+ lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
416
+ d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
417
+ # compute qk
418
+ # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
419
+ qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf"))
420
+ qk += tl.dot(k, q) * qk_scale
421
+ # compute p, ds
422
+ # [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
423
+ p = tl.exp2(qk - lse)
424
+ # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
425
+ dp = tl.dot(v, do)
426
+ ds = sm_scale * p * (dp - d)
427
+ # cast dtype
428
+ p = p.to(do.dtype)
429
+ ds = ds.to(q.dtype)
430
+ # update dk and dv
431
+ # [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM]
432
+ dk += tl.dot(ds, tl.trans(q))
433
+ dv += tl.dot(p, tl.trans(do))
434
+ # increment pointers
435
+ q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q))
436
+ do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q))
437
+ lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q))
438
+ d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q))
439
+ # save dk dv
440
+ tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
441
+ tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
442
+
443
+
444
+ @triton.jit
445
+ def backward_dq(
446
+ q_ptr, # Q: n x qh x d
447
+ k_ptr, # K: n x kh x d
448
+ v_ptr, # V: n x kh x d
449
+ lse_ptr, # LSE: qh x n
450
+ d_ptr, # Delta: qh x n
451
+ do_ptr,
452
+ dq_ptr,
453
+ kernel_size,
454
+ kernel_stride,
455
+ # seqlens
456
+ cu_seqlens_q,
457
+ cu_seqlens_k,
458
+ # shape
459
+ NUM_KV_HEADS,
460
+ NUM_SHARE_Q_HEADS,
461
+ HEAD_DIM,
462
+ # sm_scale
463
+ sm_scale,
464
+ # stride
465
+ stride_qn,
466
+ stride_qh,
467
+ stride_qd,
468
+ stride_kn,
469
+ stride_kh,
470
+ stride_kd,
471
+ stride_vn,
472
+ stride_vh,
473
+ stride_vd,
474
+ stride_lh,
475
+ stride_ln,
476
+ stride_dh,
477
+ stride_dn,
478
+ stride_don,
479
+ stride_doh,
480
+ stride_dod,
481
+ stride_dqn,
482
+ stride_dqh,
483
+ stride_dqd,
484
+ # META parameters
485
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
486
+ BLOCK_SIZE_K: tl.constexpr, # k block size
487
+ BLOCK_SIZE_D: tl.constexpr,
488
+ ):
489
+ qk_scale = sm_scale * 1.44269504
490
+ # get batch id and head id
491
+ pid_b = tl.program_id(0)
492
+ pid_h = tl.program_id(1)
493
+ pid_q = tl.program_id(2)
494
+ pid_kh = pid_h // NUM_SHARE_Q_HEADS
495
+ # get q k start and len after rmpad
496
+ q_start = tl.load(cu_seqlens_q + pid_b)
497
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
498
+ k_start = tl.load(cu_seqlens_k + pid_b)
499
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
500
+ # skip first kernel_size query block, because they do no attend to any keys
501
+ q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
502
+ if q_start_in_seq >= q_len:
503
+ return
504
+ # init pointers
505
+ q_ptrs = tl.make_block_ptr(
506
+ base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
507
+ shape=(q_len, HEAD_DIM),
508
+ strides=(stride_qn, stride_qd),
509
+ offsets=(q_start_in_seq, 0),
510
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
511
+ order=(1, 0),
512
+ )
513
+ dq_ptrs = tl.make_block_ptr(
514
+ base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh,
515
+ shape=(q_len, HEAD_DIM),
516
+ strides=(stride_dqn, stride_dqd),
517
+ offsets=(q_start_in_seq, 0),
518
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
519
+ order=(1, 0),
520
+ )
521
+ k_ptrs = tl.make_block_ptr(
522
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
523
+ shape=(k_len, HEAD_DIM),
524
+ strides=(stride_kn, stride_kd),
525
+ offsets=(0, 0),
526
+ block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
527
+ order=(1, 0),
528
+ )
529
+ v_ptrs = tl.make_block_ptr(
530
+ base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
531
+ shape=(HEAD_DIM, k_len),
532
+ strides=(stride_vd, stride_vn),
533
+ offsets=(0, 0),
534
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
535
+ order=(0, 1),
536
+ )
537
+ do_ptrs = tl.make_block_ptr(
538
+ base=do_ptr + q_start * stride_don + pid_h * stride_doh,
539
+ shape=(q_len, HEAD_DIM),
540
+ strides=(stride_don, stride_dod),
541
+ offsets=(q_start_in_seq, 0),
542
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
543
+ order=(1, 0),
544
+ )
545
+ d_ptrs = tl.make_block_ptr(
546
+ base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
547
+ shape=(q_len, 1),
548
+ strides=(stride_dn, stride_dh),
549
+ offsets=(q_start_in_seq, 0),
550
+ block_shape=(BLOCK_SIZE_Q, 1),
551
+ order=(0, 1),
552
+ )
553
+ lse_ptrs = tl.make_block_ptr(
554
+ base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
555
+ shape=(q_len, 1),
556
+ strides=(stride_ln, stride_lh),
557
+ offsets=(q_start_in_seq, 0),
558
+ block_shape=(BLOCK_SIZE_Q, 1),
559
+ order=(0, 1),
560
+ )
561
+ # offsets
562
+ off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
563
+ off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
564
+ # load q, do, lse, delta, and keep in SRAM
565
+ q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
566
+ do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
567
+ lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
568
+ d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
569
+ # init dq
570
+ dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32)
571
+ lo = 0
572
+ hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
573
+ for i in range(lo, hi, BLOCK_SIZE_K):
574
+ # load
575
+ k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
576
+ v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
577
+ # compute qk
578
+ qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
579
+ qk += tl.where(
580
+ off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")
581
+ )
582
+ qk += tl.dot(q, tl.trans(k)) * qk_scale
583
+ # compute p, ds
584
+ p = tl.exp2(qk - lse)
585
+ dp = tl.dot(do, v)
586
+ ds = sm_scale * p * (dp - d)
587
+ # cast dtype
588
+ ds = ds.to(q.dtype)
589
+ # update dq
590
+ dq += tl.dot(ds, k)
591
+ # increment pointers
592
+ k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0))
593
+ v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K))
594
+ # save dq
595
+ tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))
596
+
597
+
598
+ def _compressed_attention_fwd(
599
+ q: torch.Tensor,
600
+ k: torch.Tensor,
601
+ v: torch.Tensor,
602
+ kernel_size: int,
603
+ kernel_stride: int,
604
+ cu_seqlens_q: torch.Tensor,
605
+ cu_seqlens_k: torch.Tensor,
606
+ max_seqlen_q: torch.Tensor,
607
+ max_seqlen_k: torch.Tensor,
608
+ sm_scale: float,
609
+ ):
610
+ # dtype check
611
+ assert k.dtype == q.dtype and v.dtype == q.dtype
612
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
613
+ # shape
614
+ q_len, num_q_heads, head_dim = q.shape
615
+ k_len, num_k_heads, head_dim = k.shape
616
+ v_len, num_v_heads, head_dim = v.shape
617
+ batch_size = cu_seqlens_q.shape[0] - 1
618
+ assert k_len == v_len and q_len > k_len
619
+ # gqa
620
+ assert num_k_heads == num_v_heads
621
+ assert num_q_heads % num_k_heads == 0
622
+ num_share_q_heads = num_q_heads // num_k_heads
623
+ # output tensor
624
+ o = torch.zeros_like(q)
625
+ lse = torch.full(
626
+ (num_q_heads, q_len),
627
+ fill_value=-torch.inf,
628
+ dtype=torch.float32,
629
+ device=q.device,
630
+ )
631
+ # launch kernel
632
+ grid = lambda META: (
633
+ batch_size,
634
+ num_q_heads,
635
+ triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
636
+ )
637
+ BLOCK_SIZE_Q = 128
638
+ BLOCK_SIZE_K = 128
639
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
640
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
641
+ forward_kernel[grid](
642
+ q,
643
+ k,
644
+ v,
645
+ o,
646
+ lse,
647
+ kernel_size,
648
+ kernel_stride,
649
+ cu_seqlens_q,
650
+ cu_seqlens_k,
651
+ num_k_heads,
652
+ num_share_q_heads,
653
+ head_dim,
654
+ sm_scale,
655
+ q.stride(0),
656
+ q.stride(1),
657
+ q.stride(2),
658
+ k.stride(0),
659
+ k.stride(1),
660
+ k.stride(2),
661
+ v.stride(0),
662
+ v.stride(1),
663
+ v.stride(2),
664
+ o.stride(0),
665
+ o.stride(1),
666
+ o.stride(2),
667
+ lse.stride(0),
668
+ lse.stride(1),
669
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
670
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
671
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
672
+ num_warps=num_warps,
673
+ num_stages=num_stages,
674
+ )
675
+ return o, lse
676
+
677
+
678
+ def _compressed_attention_bwd(
679
+ o: torch.Tensor,
680
+ do: torch.Tensor,
681
+ lse: torch.Tensor,
682
+ q: torch.Tensor,
683
+ k: torch.Tensor,
684
+ v: torch.Tensor,
685
+ kernel_size: int,
686
+ kernel_stride: int,
687
+ cu_seqlens_q: torch.Tensor,
688
+ cu_seqlens_k: torch.Tensor,
689
+ max_seqlen_q: torch.Tensor,
690
+ max_seqlen_k: torch.Tensor,
691
+ sm_scale: float,
692
+ ):
693
+ q_len, num_q_heads, head_dim = q.shape
694
+ k_len, num_k_heads, head_dim = k.shape
695
+ v_len, num_v_heads, head_dim = v.shape
696
+ o_len, num_o_heads, head_dim = o.shape
697
+ num_share_q_heads = num_q_heads // num_k_heads
698
+ # compute D
699
+ delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)
700
+ grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads)
701
+ BLOCK_SIZE_O = 256
702
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
703
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)
704
+ backward_sum_o_do[grid](
705
+ o,
706
+ do,
707
+ delta,
708
+ o_len,
709
+ head_dim,
710
+ o.stride(0),
711
+ o.stride(1),
712
+ o.stride(2),
713
+ do.stride(0),
714
+ do.stride(1),
715
+ do.stride(2),
716
+ delta.stride(0),
717
+ delta.stride(1),
718
+ BLOCK_SIZE_O=BLOCK_SIZE_O,
719
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
720
+ num_warps=num_warps,
721
+ num_stages=num_stages,
722
+ )
723
+ # compute dk dv
724
+ dk = torch.zeros(
725
+ num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
726
+ )
727
+ dv = torch.zeros(
728
+ num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
729
+ )
730
+ batch_size = cu_seqlens_q.shape[0] - 1
731
+ grid = lambda META: (
732
+ batch_size,
733
+ num_q_heads,
734
+ triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
735
+ )
736
+ BLOCK_SIZE_Q = 64
737
+ BLOCK_SIZE_K = 128
738
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
739
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)
740
+ backward_dkdv[grid](
741
+ q,
742
+ k,
743
+ v,
744
+ lse,
745
+ delta,
746
+ do,
747
+ dk,
748
+ dv,
749
+ kernel_size,
750
+ kernel_stride,
751
+ cu_seqlens_q,
752
+ cu_seqlens_k,
753
+ num_k_heads,
754
+ num_share_q_heads,
755
+ head_dim,
756
+ sm_scale,
757
+ q.stride(0),
758
+ q.stride(1),
759
+ q.stride(2),
760
+ k.stride(0),
761
+ k.stride(1),
762
+ k.stride(2),
763
+ v.stride(0),
764
+ v.stride(1),
765
+ v.stride(2),
766
+ lse.stride(0),
767
+ lse.stride(1),
768
+ delta.stride(0),
769
+ delta.stride(1),
770
+ do.stride(0),
771
+ do.stride(1),
772
+ do.stride(2),
773
+ dk.stride(0),
774
+ dk.stride(1),
775
+ dk.stride(2),
776
+ dk.stride(3),
777
+ dv.stride(0),
778
+ dv.stride(1),
779
+ dv.stride(2),
780
+ dv.stride(3),
781
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
782
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
783
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
784
+ num_warps=num_warps,
785
+ num_stages=num_stages,
786
+ )
787
+ dk = dk.sum(0)
788
+ dv = dv.sum(0)
789
+ # compute dq
790
+ dq = torch.zeros_like(q)
791
+ grid = lambda META: (
792
+ batch_size,
793
+ num_q_heads,
794
+ triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
795
+ )
796
+ BLOCK_SIZE_Q = 128
797
+ BLOCK_SIZE_K = 64
798
+ num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
799
+ backward_dq[grid](
800
+ q,
801
+ k,
802
+ v,
803
+ lse,
804
+ delta,
805
+ do,
806
+ dq,
807
+ kernel_size,
808
+ kernel_stride,
809
+ cu_seqlens_q,
810
+ cu_seqlens_k,
811
+ num_k_heads,
812
+ num_share_q_heads,
813
+ head_dim,
814
+ sm_scale,
815
+ q.stride(0),
816
+ q.stride(1),
817
+ q.stride(2),
818
+ k.stride(0),
819
+ k.stride(1),
820
+ k.stride(2),
821
+ v.stride(0),
822
+ v.stride(1),
823
+ v.stride(2),
824
+ lse.stride(0),
825
+ lse.stride(1),
826
+ delta.stride(0),
827
+ delta.stride(1),
828
+ do.stride(0),
829
+ do.stride(1),
830
+ do.stride(2),
831
+ dq.stride(0),
832
+ dq.stride(1),
833
+ dq.stride(2),
834
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
835
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
836
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
837
+ num_warps=num_warps,
838
+ num_stages=num_stages,
839
+ )
840
+ return dq, dk, dv
841
+
842
+
843
+ class CompressedAttention(torch.autograd.Function):
844
+ @staticmethod
845
+ def forward(
846
+ ctx,
847
+ q: torch.Tensor,
848
+ k: torch.Tensor,
849
+ v: torch.Tensor,
850
+ kernel_size: int,
851
+ kernel_stride: int,
852
+ cu_seqlens_q: torch.Tensor,
853
+ cu_seqlens_k: torch.Tensor,
854
+ max_seqlen_q: torch.Tensor,
855
+ max_seqlen_k: torch.Tensor,
856
+ sm_scale=None,
857
+ ):
858
+ # dtype check
859
+ assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
860
+ assert q.dtype == k.dtype and k.dtype == v.dtype
861
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
862
+ # softmax scale
863
+ if sm_scale is None:
864
+ sm_scale = 1 / math.sqrt(q.shape[-1])
865
+ o, lse = _compressed_attention_fwd(
866
+ q,
867
+ k,
868
+ v,
869
+ kernel_size,
870
+ kernel_stride,
871
+ cu_seqlens_q,
872
+ cu_seqlens_k,
873
+ max_seqlen_q,
874
+ max_seqlen_k,
875
+ sm_scale,
876
+ )
877
+ ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k)
878
+ ctx.sm_scale = sm_scale
879
+ ctx.max_seqlen_q = max_seqlen_q
880
+ ctx.max_seqlen_k = max_seqlen_k
881
+ ctx.kernel_size = kernel_size
882
+ ctx.kernel_stride = kernel_stride
883
+ return o, lse
884
+
885
+ @staticmethod
886
+ def backward(ctx, do: torch.Tensor, *args) -> Any:
887
+ q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
888
+ max_seqlen_q = ctx.max_seqlen_q
889
+ max_seqlen_k = ctx.max_seqlen_k
890
+ sm_scale = ctx.sm_scale
891
+ kernel_size = ctx.kernel_size
892
+ kernel_stride = ctx.kernel_stride
893
+ dq, dk, dv = _compressed_attention_bwd(
894
+ o,
895
+ do,
896
+ lse,
897
+ q,
898
+ k,
899
+ v,
900
+ kernel_size,
901
+ kernel_stride,
902
+ cu_seqlens_q,
903
+ cu_seqlens_k,
904
+ max_seqlen_q,
905
+ max_seqlen_k,
906
+ sm_scale,
907
+ )
908
+ return dq, dk, dv, None, None, None, None, None, None, None
909
+
910
+
911
+ @triton.jit
912
+ def score_kernel(
913
+ q_ptr,
914
+ k_ptr,
915
+ lse_ptr,
916
+ s_ptr,
917
+ kernel_size,
918
+ kernel_stride,
919
+ # seqlens
920
+ cu_seqlens_q,
921
+ cu_seqlens_k,
922
+ # shape
923
+ NUM_KV_HEADS,
924
+ NUM_SHARE_Q_HEADS,
925
+ HEAD_DIM,
926
+ # sm_scale
927
+ sm_scale,
928
+ # stride
929
+ stride_qn,
930
+ stride_qh,
931
+ stride_qd,
932
+ stride_kn,
933
+ stride_kh,
934
+ stride_kd,
935
+ stride_lh,
936
+ stride_ln,
937
+ stride_sh,
938
+ stride_sq,
939
+ stride_sk,
940
+ # META parameters
941
+ BLOCK_SIZE_Q: tl.constexpr, # q block size
942
+ BLOCK_SIZE_K: tl.constexpr, # k block size
943
+ BLOCK_SIZE_D: tl.constexpr,
944
+ ):
945
+ qk_scale = sm_scale * 1.44269504
946
+ # get batch id and head id
947
+ pid_bkh = tl.program_id(0)
948
+ pid_b = pid_bkh // NUM_KV_HEADS
949
+ pid_kh = pid_bkh % NUM_KV_HEADS
950
+ pid_q = tl.program_id(1)
951
+ pid_k = tl.program_id(2)
952
+ # get q k start and len after rmpad
953
+ q_start = tl.load(cu_seqlens_q + pid_b)
954
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
955
+ k_start = tl.load(cu_seqlens_k + pid_b)
956
+ k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
957
+ if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len:
958
+ return
959
+ # init k pointer and load k
960
+ k_ptrs = tl.make_block_ptr(
961
+ base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
962
+ shape=(HEAD_DIM, k_len),
963
+ strides=(stride_kd, stride_kn),
964
+ offsets=(0, pid_k * BLOCK_SIZE_K),
965
+ block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
966
+ order=(0, 1),
967
+ )
968
+ k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
969
+ # offsets
970
+ off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q
971
+ off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
972
+ causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :]
973
+ # init score
974
+ s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
975
+ # loop over gqa heads
976
+ for h in range(NUM_SHARE_Q_HEADS):
977
+ pid_h = pid_kh * NUM_SHARE_Q_HEADS + h
978
+ q_ptrs = tl.make_block_ptr(
979
+ base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
980
+ shape=(q_len, HEAD_DIM),
981
+ strides=(stride_qn, stride_qd),
982
+ offsets=(pid_q * BLOCK_SIZE_Q, 0),
983
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
984
+ order=(1, 0),
985
+ )
986
+ lse_ptrs = tl.make_block_ptr(
987
+ base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
988
+ shape=(q_len, 1),
989
+ strides=(stride_ln, stride_lh),
990
+ offsets=(pid_q * BLOCK_SIZE_Q, 0),
991
+ block_shape=(BLOCK_SIZE_Q, 1),
992
+ order=(0, 1),
993
+ )
994
+ # load q and lse
995
+ q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
996
+ lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
997
+ # compute qk
998
+ qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
999
+ qk += tl.dot(q, k) * qk_scale
1000
+ # compute score
1001
+ s += tl.where(causal_mask, tl.exp2(qk - lse), 0)
1002
+ # save output
1003
+ s_ptrs = tl.make_block_ptr(
1004
+ base=s_ptr + pid_kh * stride_sh + q_start * stride_sq,
1005
+ shape=(q_len, k_len),
1006
+ strides=(stride_sq, stride_sk),
1007
+ offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K),
1008
+ block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K),
1009
+ order=(1, 0),
1010
+ )
1011
+ tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1))
1012
+
1013
+
1014
+ def _get_attention_score(
1015
+ q: torch.Tensor, # [total_query_len, num_q_heads, head_dim]
1016
+ k: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
1017
+ lse: torch.Tensor, # [num_q_heads, total_query_len]
1018
+ kernel_size: int,
1019
+ kernel_stride: int,
1020
+ cu_seqlens_q: torch.Tensor,
1021
+ cu_seqlens_k: torch.Tensor,
1022
+ max_seqlen_q: int,
1023
+ max_seqlen_k: int,
1024
+ sm_scale: float,
1025
+ ) -> torch.Tensor:
1026
+ # dtype check
1027
+ assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
1028
+ assert q.dtype == k.dtype
1029
+ assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
1030
+ assert (
1031
+ lse.dtype == torch.float32
1032
+ ) # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale)))
1033
+ # shape
1034
+ q_len, num_q_heads, head_dim = q.shape
1035
+ k_len, num_k_heads, head_dim = k.shape
1036
+ batch_size = cu_seqlens_q.shape[0] - 1
1037
+ assert q_len > k_len
1038
+ if sm_scale is None:
1039
+ sm_scale = 1 / math.sqrt(head_dim)
1040
+ # gqa
1041
+ assert num_q_heads % num_k_heads == 0
1042
+ num_share_q_heads = num_q_heads // num_k_heads
1043
+ # init score
1044
+ score = torch.zeros(
1045
+ num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device
1046
+ )
1047
+ # launch kernel
1048
+ grid = lambda META: (
1049
+ batch_size * num_k_heads,
1050
+ triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
1051
+ triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
1052
+ )
1053
+ BLOCK_SIZE_Q = 128
1054
+ BLOCK_SIZE_K = 128
1055
+ BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
1056
+ score_kernel[grid](
1057
+ q,
1058
+ k,
1059
+ lse,
1060
+ score,
1061
+ kernel_size,
1062
+ kernel_stride,
1063
+ cu_seqlens_q,
1064
+ cu_seqlens_k,
1065
+ num_k_heads,
1066
+ num_share_q_heads,
1067
+ head_dim,
1068
+ sm_scale,
1069
+ q.stride(0),
1070
+ q.stride(1),
1071
+ q.stride(2),
1072
+ k.stride(0),
1073
+ k.stride(1),
1074
+ k.stride(2),
1075
+ lse.stride(0),
1076
+ lse.stride(1),
1077
+ score.stride(0),
1078
+ score.stride(1),
1079
+ score.stride(2),
1080
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1081
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
1082
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
1083
+ num_warps=8,
1084
+ num_stages=3,
1085
+ )
1086
+ return score
1087
+
1088
+
1089
+ @triton.jit
1090
+ def _transform_score_kernel(
1091
+ s_ptr, # score, shape: [num_heads, q_len, k_len]
1092
+ bs_ptr, # block wise score: [num_heads, q_len, num_k_block]
1093
+ offs,
1094
+ cu_seqlens_q,
1095
+ # shape
1096
+ num_heads,
1097
+ num_offs,
1098
+ max_k_len,
1099
+ max_blocks,
1100
+ pad_len,
1101
+ # kernel & block size
1102
+ block_size,
1103
+ block_stride, # block_size // kernel_stride
1104
+ init_blocks,
1105
+ local_blocks,
1106
+ # stride
1107
+ stride_sh,
1108
+ stride_sq,
1109
+ stride_sk,
1110
+ stride_bsh,
1111
+ stride_bsq,
1112
+ stride_bsk,
1113
+ BLOCK_SIZE_Q: tl.constexpr,
1114
+ BLOCK_SIZE_K: tl.constexpr,
1115
+ BLOCK_SIZE_O: tl.constexpr,
1116
+ ):
1117
+ pid_bh = tl.program_id(0)
1118
+ pid_b = pid_bh // num_heads
1119
+ pid_h = pid_bh % num_heads
1120
+ pid_q = tl.program_id(1)
1121
+ pid_k = tl.program_id(2)
1122
+ q_start = tl.load(cu_seqlens_q + pid_b)
1123
+ q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
1124
+ k_start = pid_k * BLOCK_SIZE_K
1125
+ if pid_q * BLOCK_SIZE_Q >= q_len:
1126
+ return
1127
+ # load weight
1128
+ off_o = tl.arange(0, BLOCK_SIZE_O)
1129
+ w = tl.load(offs + off_o, mask=off_o < num_offs, other=0)
1130
+ # load score
1131
+ off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)
1132
+ off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len
1133
+ off_k = off_k[None, :] + off_o[:, None]
1134
+ s_ptrs = (
1135
+ s_ptr
1136
+ + q_start * stride_sq
1137
+ + pid_h * stride_sh
1138
+ + off_q[:, None, None] * stride_sq
1139
+ + off_k[None, :, :] * stride_sk
1140
+ )
1141
+ # weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK]
1142
+ s = tl.load(
1143
+ s_ptrs,
1144
+ mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len),
1145
+ other=0,
1146
+ )
1147
+ s = s * w[None, :, None]
1148
+ s = tl.max(s, axis=1)
1149
+ # init mask and local mask
1150
+ off_bq = off_q // block_size
1151
+ off_bk = tl.arange(0, BLOCK_SIZE_K)
1152
+
1153
+ s = tl.where(
1154
+ # For local blocks: set to negative infinity (exclude from topk)
1155
+ (off_bq[:, None] >= (off_bk + k_start)[None, :]) & (off_bq[:, None] < (off_bk + k_start)[None, :] + local_blocks),
1156
+ float("-inf"),
1157
+ s,
1158
+ )
1159
+
1160
+ # Keep the original conditions for init_blocks and query location as infinity
1161
+ s = tl.where(
1162
+ (off_bk[None, :] < init_blocks - k_start)
1163
+ # Force blocks where the query is located to have infinite score (always include in topk)
1164
+ | (off_bq[:, None] == (off_bk + k_start)[None, :]),
1165
+ float("inf"),
1166
+ s,
1167
+ )
1168
+ # store block wise score
1169
+ bs_ptrs = (
1170
+ bs_ptr
1171
+ + q_start * stride_bsq
1172
+ + k_start * stride_bsk
1173
+ + pid_h * stride_bsh
1174
+ + off_q[:, None] * stride_bsq
1175
+ + off_bk[None, :] * stride_bsk
1176
+ )
1177
+ tl.store(
1178
+ bs_ptrs,
1179
+ s,
1180
+ mask=(off_q < q_len)[:, None] & (off_bk < max_blocks - k_start)[None, :],
1181
+ )
1182
+
1183
+
1184
+ def transform_score(
1185
+ score: torch.Tensor,
1186
+ kernel_size: int,
1187
+ kernel_stride: int,
1188
+ block_size: int,
1189
+ cu_seqlens_q: torch.Tensor,
1190
+ cu_seqlens_k: torch.Tensor,
1191
+ max_seqlen_q: int,
1192
+ max_seqlen_k: int,
1193
+ init_blocks: int = 1,
1194
+ local_blocks: int = 2,
1195
+ ) -> torch.Tensor:
1196
+ num_k_heads, total_query_len, max_key_len = score.shape
1197
+ batch_size = cu_seqlens_q.shape[0] - 1
1198
+ pad_len = kernel_size // kernel_stride - 1
1199
+ max_blocks = math.ceil(max_seqlen_q / block_size)
1200
+ block_score = torch.zeros(
1201
+ num_k_heads,
1202
+ total_query_len,
1203
+ max_blocks,
1204
+ dtype=torch.float32,
1205
+ device=score.device,
1206
+ )
1207
+ offs = (
1208
+ torch.arange(kernel_size // kernel_stride, device=score.device)[:, None]
1209
+ + torch.arange(block_size // kernel_stride, device=score.device)[None, :]
1210
+ ).view(-1)
1211
+ offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max())
1212
+ num_offs = int(offs.shape[0])
1213
+ BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks))
1214
+ BLOCK_SIZE_O = triton.next_power_of_2(num_offs)
1215
+ BLOCK_SIZE_Q = 8
1216
+ grid = (
1217
+ num_k_heads * batch_size,
1218
+ triton.cdiv(total_query_len, BLOCK_SIZE_Q),
1219
+ triton.cdiv(max_blocks, BLOCK_SIZE_K),
1220
+ )
1221
+ _transform_score_kernel[grid](
1222
+ score,
1223
+ block_score,
1224
+ torch.ones_like(offs, dtype=offs.dtype,device=offs.device), #! 为了max 就不用wieght了
1225
+ cu_seqlens_q,
1226
+ num_k_heads,
1227
+ offs.shape[0],
1228
+ max_key_len,
1229
+ max_blocks,
1230
+ pad_len,
1231
+ block_size,
1232
+ block_size // kernel_stride,
1233
+ init_blocks,
1234
+ local_blocks,
1235
+ score.stride(0),
1236
+ score.stride(1),
1237
+ score.stride(2),
1238
+ block_score.stride(0),
1239
+ block_score.stride(1),
1240
+ block_score.stride(2),
1241
+ BLOCK_SIZE_Q=BLOCK_SIZE_Q,
1242
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
1243
+ BLOCK_SIZE_O=BLOCK_SIZE_O,
1244
+ num_warps=8,
1245
+ num_stages=3,
1246
+ )
1247
+ return block_score
1248
+
1249
+
1250
+ def compressed_attention(
1251
+ q: torch.Tensor,
1252
+ k: torch.Tensor,
1253
+ v: torch.Tensor,
1254
+ kernel_size: int,
1255
+ kernel_stride: int,
1256
+ block_size: int,
1257
+ topk: int,
1258
+ cu_seqlens_q: torch.Tensor,
1259
+ cu_seqlens_k: torch.Tensor,
1260
+ max_seqlen_q: int,
1261
+ max_seqlen_k: int,
1262
+ sm_scale: float = None,
1263
+ init_blocks: int = 1,
1264
+ local_blocks: int = 2,
1265
+ parallel_topk_compute: Union[str, bool] = "auto",
1266
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1267
+ """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
1268
+
1269
+ Args:
1270
+ q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
1271
+ k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
1272
+ v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
1273
+ kernel_size (int): kernel size in compress_key_value
1274
+ kernel_stride (int): stride of compress_key_value
1275
+ block_size (int): key value block size for topk sparse attention.
1276
+ topk (int): number of blocks for each query.
1277
+ cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
1278
+ cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
1279
+ max_seqlen_q (int): max q len of the batch.
1280
+ max_seqlen_k (int): max k len of the batch.
1281
+ sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
1282
+ init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
1283
+ local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
1284
+ parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug.
1285
+ We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise.
1286
+
1287
+ Returns:
1288
+ Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
1289
+ """
1290
+ if max_seqlen_q is None:
1291
+ max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item()
1292
+ if max_seqlen_k is None:
1293
+ max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()
1294
+ attn_output, lse = CompressedAttention.apply(
1295
+ q,
1296
+ k,
1297
+ v,
1298
+ kernel_size,
1299
+ kernel_stride,
1300
+ cu_seqlens_q,
1301
+ cu_seqlens_k,
1302
+ max_seqlen_q,
1303
+ max_seqlen_k,
1304
+ sm_scale,
1305
+ )
1306
+
1307
+ # do not select topk index
1308
+ if topk <= 0:
1309
+ warnings.warn("topk <= 0, returned topk_idx will be None")
1310
+ return attn_output, None
1311
+
1312
+ assert topk >= init_blocks #+ local_blocks
1313
+ with torch.no_grad():
1314
+ num_k_heads, num_q_heads = k.shape[1], q.shape[1]
1315
+ num_shared_q_heads = num_q_heads // num_k_heads
1316
+ batch_size = cu_seqlens_q.shape[0] - 1
1317
+ q_idx = torch.cat(
1318
+ [
1319
+ torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device)
1320
+ for i in range(batch_size)
1321
+ ],
1322
+ dim=0,
1323
+ )
1324
+ q_idx = q_idx // block_size
1325
+ # whether to use parallel version
1326
+ if parallel_topk_compute == "auto":
1327
+ parallel_topk_compute = cu_seqlens_q[-1] <= 32768
1328
+ # parallel version
1329
+ if parallel_topk_compute:
1330
+ # recompute score
1331
+ score = _get_attention_score(
1332
+ q,
1333
+ k,
1334
+ lse,
1335
+ kernel_size,
1336
+ kernel_stride,
1337
+ cu_seqlens_q,
1338
+ cu_seqlens_k,
1339
+ max_seqlen_q,
1340
+ max_seqlen_k,
1341
+ sm_scale,
1342
+ )
1343
+ # transform score to block-wise score
1344
+ score = transform_score(
1345
+ score,
1346
+ kernel_size,
1347
+ kernel_stride,
1348
+ block_size,
1349
+ cu_seqlens_q,
1350
+ cu_seqlens_k,
1351
+ max_seqlen_q,
1352
+ max_seqlen_k,
1353
+ init_blocks,
1354
+ local_blocks,
1355
+ )
1356
+ # get topk
1357
+ topk = min(topk, score.shape[-1])
1358
+ topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
1359
+ # print(cu_seqlens_q)
1360
+ # breakpoint()
1361
+ topk_idx[topk_idx >= q_idx[None, :, None]] = -1
1362
+ topk_idx = topk_idx.to(torch.int32)
1363
+ # non parallel version, avoid some current bugs when sequence length is too long
1364
+ # FIXME: need to fix later
1365
+ else:
1366
+ topk_idx_list = []
1367
+ for h in range(num_k_heads):
1368
+ # recompute score
1369
+ score = _get_attention_score(
1370
+ q[:, h * num_shared_q_heads : (h + 1) * num_shared_q_heads],
1371
+ k[:, h : h + 1],
1372
+ lse[h * num_shared_q_heads : (h + 1) * num_shared_q_heads],
1373
+ kernel_size,
1374
+ kernel_stride,
1375
+ cu_seqlens_q,
1376
+ cu_seqlens_k,
1377
+ max_seqlen_q,
1378
+ max_seqlen_k,
1379
+ sm_scale,
1380
+ )
1381
+ # transform score to block-wise score
1382
+ score = transform_score(
1383
+ score,
1384
+ kernel_size,
1385
+ kernel_stride,
1386
+ block_size,
1387
+ cu_seqlens_q,
1388
+ cu_seqlens_k,
1389
+ max_seqlen_q,
1390
+ max_seqlen_k,
1391
+ init_blocks,
1392
+ local_blocks,
1393
+ )
1394
+ # get topk
1395
+ topk = min(topk, score.shape[-1])
1396
+ topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
1397
+ topk_idx[topk_idx >= q_idx[None, :, None]] = -1
1398
+ topk_idx = topk_idx.to(torch.int32)
1399
+ topk_idx_list.append(topk_idx)
1400
+ topk_idx = torch.cat(topk_idx_list, dim=0)
1401
+ return attn_output, topk_idx
config.json ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/share_data/data7/fanshengda/mcp-agent/minicpm4_sft/mcp_summary/checkpoint-25000",
3
+ "architectures": [
4
+ "MiniCPMForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_minicpm.MiniCPMConfig",
10
+ "AutoModel": "modeling_minicpm.MiniCPMForCausalLM",
11
+ "AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM",
12
+ "AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM",
13
+ "AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification"
14
+ },
15
+ "bos_token_id": 1,
16
+ "dim_model_base": 256,
17
+ "eos_token_id": [
18
+ 2,
19
+ 73440
20
+ ],
21
+ "hidden_act": "silu",
22
+ "hidden_size": 4096,
23
+ "initializer_range": 0.1,
24
+ "intermediate_size": 16384,
25
+ "max_position_embeddings": 32768,
26
+ "model_type": "minicpm",
27
+ "num_attention_heads": 32,
28
+ "num_hidden_layers": 32,
29
+ "num_key_value_heads": 2,
30
+ "pad_token_id": 2,
31
+ "pretraining_tp": 1,
32
+ "rms_norm_eps": 1e-06,
33
+ "rope_scaling": {
34
+ "long_factor": [
35
+ 0.9977997200264581,
36
+ 1.014658295992452,
37
+ 1.0349680404997148,
38
+ 1.059429246056193,
39
+ 1.0888815016813513,
40
+ 1.1243301355211495,
41
+ 1.166977103606075,
42
+ 1.2182568066927284,
43
+ 1.2798772354275727,
44
+ 1.3538666751582975,
45
+ 1.4426259039919596,
46
+ 1.5489853358570191,
47
+ 1.6762658237220625,
48
+ 1.8283407612492941,
49
+ 2.0096956085876183,
50
+ 2.225478927469756,
51
+ 2.481536379650452,
52
+ 2.784415934557119,
53
+ 3.1413289096347365,
54
+ 3.560047844772632,
55
+ 4.048719380066383,
56
+ 4.615569542115128,
57
+ 5.2684819496549835,
58
+ 6.014438591970396,
59
+ 6.858830049237097,
60
+ 7.804668263503327,
61
+ 8.851768731513417,
62
+ 9.99600492938444,
63
+ 11.228766118181639,
64
+ 12.536757560834843,
65
+ 13.902257701387796,
66
+ 15.303885189125953,
67
+ 16.717837610115794,
68
+ 18.119465097853947,
69
+ 19.484965238406907,
70
+ 20.792956681060105,
71
+ 22.02571786985731,
72
+ 23.16995406772833,
73
+ 24.217054535738416,
74
+ 25.16289275000465,
75
+ 26.007284207271347,
76
+ 26.753240849586767,
77
+ 27.40615325712662,
78
+ 27.973003419175363,
79
+ 28.461674954469114,
80
+ 28.880393889607006,
81
+ 29.237306864684626,
82
+ 29.540186419591297,
83
+ 29.79624387177199,
84
+ 30.01202719065413,
85
+ 30.193382037992453,
86
+ 30.34545697551969,
87
+ 30.47273746338473,
88
+ 30.579096895249787,
89
+ 30.66785612408345,
90
+ 30.741845563814174,
91
+ 30.80346599254902,
92
+ 30.85474569563567,
93
+ 30.897392663720595,
94
+ 30.932841297560394,
95
+ 30.962293553185553,
96
+ 30.986754758742034,
97
+ 31.007064503249293,
98
+ 31.02392307921529
99
+ ],
100
+ "original_max_position_embeddings": 32768,
101
+ "rope_type": "longrope",
102
+ "short_factor": [
103
+ 0.9977997200264581,
104
+ 1.014658295992452,
105
+ 1.0349680404997148,
106
+ 1.059429246056193,
107
+ 1.0888815016813513,
108
+ 1.1243301355211495,
109
+ 1.166977103606075,
110
+ 1.2182568066927284,
111
+ 1.2798772354275727,
112
+ 1.3538666751582975,
113
+ 1.4426259039919596,
114
+ 1.5489853358570191,
115
+ 1.6762658237220625,
116
+ 1.8283407612492941,
117
+ 2.0096956085876183,
118
+ 2.225478927469756,
119
+ 2.481536379650452,
120
+ 2.784415934557119,
121
+ 3.1413289096347365,
122
+ 3.560047844772632,
123
+ 4.048719380066383,
124
+ 4.615569542115128,
125
+ 5.2684819496549835,
126
+ 6.014438591970396,
127
+ 6.858830049237097,
128
+ 7.804668263503327,
129
+ 8.851768731513417,
130
+ 9.99600492938444,
131
+ 11.228766118181639,
132
+ 12.536757560834843,
133
+ 13.902257701387796,
134
+ 15.303885189125953,
135
+ 16.717837610115794,
136
+ 18.119465097853947,
137
+ 19.484965238406907,
138
+ 20.792956681060105,
139
+ 22.02571786985731,
140
+ 23.16995406772833,
141
+ 24.217054535738416,
142
+ 25.16289275000465,
143
+ 26.007284207271347,
144
+ 26.753240849586767,
145
+ 27.40615325712662,
146
+ 27.973003419175363,
147
+ 28.461674954469114,
148
+ 28.880393889607006,
149
+ 29.237306864684626,
150
+ 29.540186419591297,
151
+ 29.79624387177199,
152
+ 30.01202719065413,
153
+ 30.193382037992453,
154
+ 30.34545697551969,
155
+ 30.47273746338473,
156
+ 30.579096895249787,
157
+ 30.66785612408345,
158
+ 30.741845563814174,
159
+ 30.80346599254902,
160
+ 30.85474569563567,
161
+ 30.897392663720595,
162
+ 30.932841297560394,
163
+ 30.962293553185553,
164
+ 30.986754758742034,
165
+ 31.007064503249293,
166
+ 31.02392307921529
167
+ ]
168
+ },
169
+ "rope_theta": 10000.0,
170
+ "scale_depth": 1.4,
171
+ "scale_emb": 12,
172
+ "tie_word_embeddings": false,
173
+ "torch_dtype": "bfloat16",
174
+ "transformers_version": "4.49.0",
175
+ "use_cache": true,
176
+ "vocab_size": 73448
177
+ }
configuration_minicpm.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ MiniCPM model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ MINICPM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
29
+
30
+
31
+ class MiniCPMConfig(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM
34
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
35
+ defaults will yield a similar configuration to that of the MiniCPM-7B.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 32000):
43
+ Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`MiniCPMModel`]
45
+ hidden_size (`int`, *optional*, defaults to 4096):
46
+ Dimension of the hidden representations.
47
+ intermediate_size (`int`, *optional*, defaults to 11008):
48
+ Dimension of the MLP representations.
49
+ num_hidden_layers (`int`, *optional*, defaults to 32):
50
+ Number of hidden layers in the Transformer decoder.
51
+ num_attention_heads (`int`, *optional*, defaults to 32):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ num_key_value_heads (`int`, *optional*):
54
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
55
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
56
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
57
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
58
+ by meanpooling all the original heads within that group. For more details checkout [this
59
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
60
+ `num_attention_heads`.
61
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62
+ The non-linear activation function (function or string) in the decoder.
63
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
64
+ The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens,
65
+ MiniCPM 2 up to 4096, CodeMiniCPM up to 16384.
66
+ initializer_range (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
69
+ The epsilon used by the rms normalization layers.
70
+ use_cache (`bool`, *optional*, defaults to `True`):
71
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
72
+ relevant if `config.is_decoder=True`.
73
+ pad_token_id (`int`, *optional*):
74
+ Padding token id.
75
+ bos_token_id (`int`, *optional*, defaults to 1):
76
+ Beginning of stream token id.
77
+ eos_token_id (`int`, *optional*, defaults to 2):
78
+ End of stream token id.
79
+ pretraining_tp (`int`, *optional*, defaults to 1):
80
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
81
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
82
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
83
+ issue](https://github.com/pytorch/pytorch/issues/76232).
84
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
85
+ Whether to tie weight embeddings
86
+ rope_theta (`float`, *optional*, defaults to 10000.0):
87
+ The base period of the RoPE embeddings.
88
+ rope_scaling (`Dict`, *optional*):
89
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
90
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
91
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
92
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
93
+ these scaling strategies behave:
94
+ https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
95
+ experimental feature, subject to breaking API changes in future versions.
96
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
97
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
98
+ attention_dropout (`float`, *optional*, defaults to 0.0):
99
+ The dropout ratio for the attention probabilities.
100
+
101
+ ```python
102
+ >>> from transformers import MiniCPMModel, MiniCPMConfig
103
+
104
+ >>> # Initializing a MiniCPM minicpm-7b style configuration
105
+ >>> configuration = MiniCPMConfig()
106
+
107
+ >>> # Initializing a model from the minicpm-7b style configuration
108
+ >>> model = MiniCPMModel(configuration)
109
+
110
+ >>> # Accessing the model configuration
111
+ >>> configuration = model.config
112
+ ```"""
113
+
114
+ model_type = "minicpm"
115
+ keys_to_ignore_at_inference = ["past_key_values"]
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_size=32000,
120
+ hidden_size=4096,
121
+ intermediate_size=11008,
122
+ num_hidden_layers=32,
123
+ num_attention_heads=32,
124
+ num_key_value_heads=None,
125
+ hidden_act="silu",
126
+ max_position_embeddings=2048,
127
+ initializer_range=0.02,
128
+ rms_norm_eps=1e-6,
129
+ use_cache=True,
130
+ pad_token_id=None,
131
+ bos_token_id=1,
132
+ eos_token_id=2,
133
+ pretraining_tp=1,
134
+ tie_word_embeddings=True,
135
+ rope_theta=10000.0,
136
+ rope_scaling=None,
137
+ attention_bias=False,
138
+ attention_dropout=0.0,
139
+ scale_emb=1,
140
+ dim_model_base=1,
141
+ scale_depth=1,
142
+ **kwargs,
143
+ ):
144
+ self.vocab_size = vocab_size
145
+ self.max_position_embeddings = max_position_embeddings
146
+ self.hidden_size = hidden_size
147
+ self.intermediate_size = intermediate_size
148
+ self.num_hidden_layers = num_hidden_layers
149
+ self.num_attention_heads = num_attention_heads
150
+
151
+ # for backward compatibility
152
+ if num_key_value_heads is None:
153
+ num_key_value_heads = num_attention_heads
154
+
155
+ self.num_key_value_heads = num_key_value_heads
156
+ self.hidden_act = hidden_act
157
+ self.initializer_range = initializer_range
158
+ self.rms_norm_eps = rms_norm_eps
159
+ self.pretraining_tp = pretraining_tp
160
+ self.use_cache = use_cache
161
+ self.rope_theta = rope_theta
162
+ self.rope_scaling = rope_scaling
163
+ # self._rope_scaling_validation()
164
+ self.attention_bias = attention_bias
165
+ self.attention_dropout = attention_dropout
166
+ self.scale_emb = scale_emb
167
+ self.dim_model_base = dim_model_base
168
+ self.scale_depth = scale_depth
169
+
170
+ super().__init__(
171
+ pad_token_id=pad_token_id,
172
+ bos_token_id=bos_token_id,
173
+ eos_token_id=eos_token_id,
174
+ tie_word_embeddings=tie_word_embeddings,
175
+ **kwargs,
176
+ )
177
+ try:
178
+ import flash_attn
179
+ self._attn_implementation = "flash_attention_2"
180
+ except:
181
+ pass
182
+
183
+ def _rope_scaling_validation(self):
184
+ """
185
+ Validate the `rope_scaling` configuration.
186
+ """
187
+ if self.rope_scaling is None:
188
+ return
189
+
190
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
191
+ raise ValueError(
192
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
193
+ f"got {self.rope_scaling}"
194
+ )
195
+ rope_scaling_type = self.rope_scaling.get("type", None)
196
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
197
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
198
+ raise ValueError(
199
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
200
+ )
201
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
202
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 2,
6
+ 73440
7
+ ],
8
+ "pad_token_id": 2,
9
+ "temperature": 0.8,
10
+ "top_p": 0.8,
11
+ "transformers_version": "4.49.0"
12
+ }
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8889121190f1249d223ba24a5cbf50c959a6c4cc2456d5cf3448e5463ddce882
3
+ size 3990806216
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e513ad5fb6f8180cceb278d5ac2ad3d677694cbf84b725187e047774677c50a1
3
+ size 3926008088
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:059c3fa22026b886ad5169ed954f7e336320b605ab8973a6ea16308f7a61a597
3
+ size 3926008120
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:548dc62dd17b6ae7b1d328a35a8d7ecb89a20e5dcbfc83d20c060d1cd45ddeba
3
+ size 3926033016
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbc464a995b4d5ad7188c965862202621af0b5c25e67d7078c220dc4a1e46d06
3
+ size 601686144
model.safetensors.index.json ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 16370507776
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00005-of-00005.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00005.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00005.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
13
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
14
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
15
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
16
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
17
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00005.safetensors",
18
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
19
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
20
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
21
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
22
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
23
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
24
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
25
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
26
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00005.safetensors",
27
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
28
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
29
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
30
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
31
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
32
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
33
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
34
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
35
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00005.safetensors",
36
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
37
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
38
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
39
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
40
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
41
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
42
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
43
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
44
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00005.safetensors",
45
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
46
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
47
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
48
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
49
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
50
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
51
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
52
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
53
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00005.safetensors",
54
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
55
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
56
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
57
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
58
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
59
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
60
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
61
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
62
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00005.safetensors",
63
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
64
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
65
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
66
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
67
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
68
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
69
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
70
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
71
+ "model.layers.15.input_layernorm.weight": "model-00003-of-00005.safetensors",
72
+ "model.layers.15.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
73
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
74
+ "model.layers.15.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
75
+ "model.layers.15.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
76
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
77
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
78
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
79
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
80
+ "model.layers.16.input_layernorm.weight": "model-00003-of-00005.safetensors",
81
+ "model.layers.16.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
82
+ "model.layers.16.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
83
+ "model.layers.16.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
84
+ "model.layers.16.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
85
+ "model.layers.16.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
86
+ "model.layers.16.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
87
+ "model.layers.16.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
88
+ "model.layers.16.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
89
+ "model.layers.17.input_layernorm.weight": "model-00003-of-00005.safetensors",
90
+ "model.layers.17.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
91
+ "model.layers.17.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
92
+ "model.layers.17.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
93
+ "model.layers.17.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
94
+ "model.layers.17.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
95
+ "model.layers.17.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
96
+ "model.layers.17.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
97
+ "model.layers.17.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
98
+ "model.layers.18.input_layernorm.weight": "model-00003-of-00005.safetensors",
99
+ "model.layers.18.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
100
+ "model.layers.18.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
101
+ "model.layers.18.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
102
+ "model.layers.18.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
103
+ "model.layers.18.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
104
+ "model.layers.18.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
105
+ "model.layers.18.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
106
+ "model.layers.18.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
107
+ "model.layers.19.input_layernorm.weight": "model-00003-of-00005.safetensors",
108
+ "model.layers.19.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
109
+ "model.layers.19.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
110
+ "model.layers.19.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
111
+ "model.layers.19.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
112
+ "model.layers.19.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
113
+ "model.layers.19.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
114
+ "model.layers.19.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
115
+ "model.layers.19.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
116
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00005.safetensors",
117
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
118
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
119
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
120
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
121
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
122
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
123
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
124
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
125
+ "model.layers.20.input_layernorm.weight": "model-00003-of-00005.safetensors",
126
+ "model.layers.20.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
127
+ "model.layers.20.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
128
+ "model.layers.20.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
129
+ "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
130
+ "model.layers.20.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
131
+ "model.layers.20.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
132
+ "model.layers.20.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
133
+ "model.layers.20.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
134
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00005.safetensors",
135
+ "model.layers.21.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
136
+ "model.layers.21.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
137
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
138
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
139
+ "model.layers.21.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
140
+ "model.layers.21.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
141
+ "model.layers.21.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
142
+ "model.layers.21.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
143
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00005.safetensors",
144
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
145
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
146
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
147
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
148
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
149
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
150
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
151
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
152
+ "model.layers.23.input_layernorm.weight": "model-00004-of-00005.safetensors",
153
+ "model.layers.23.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
154
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
155
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
156
+ "model.layers.23.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
157
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
158
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
159
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
160
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
161
+ "model.layers.24.input_layernorm.weight": "model-00004-of-00005.safetensors",
162
+ "model.layers.24.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
163
+ "model.layers.24.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
164
+ "model.layers.24.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
165
+ "model.layers.24.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
166
+ "model.layers.24.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
167
+ "model.layers.24.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
168
+ "model.layers.24.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
169
+ "model.layers.24.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
170
+ "model.layers.25.input_layernorm.weight": "model-00004-of-00005.safetensors",
171
+ "model.layers.25.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
172
+ "model.layers.25.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
173
+ "model.layers.25.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
174
+ "model.layers.25.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
175
+ "model.layers.25.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
176
+ "model.layers.25.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
177
+ "model.layers.25.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
178
+ "model.layers.25.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
179
+ "model.layers.26.input_layernorm.weight": "model-00004-of-00005.safetensors",
180
+ "model.layers.26.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
181
+ "model.layers.26.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
182
+ "model.layers.26.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
183
+ "model.layers.26.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
184
+ "model.layers.26.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
185
+ "model.layers.26.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
186
+ "model.layers.26.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
187
+ "model.layers.26.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
188
+ "model.layers.27.input_layernorm.weight": "model-00004-of-00005.safetensors",
189
+ "model.layers.27.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
190
+ "model.layers.27.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
191
+ "model.layers.27.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
192
+ "model.layers.27.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
193
+ "model.layers.27.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
194
+ "model.layers.27.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
195
+ "model.layers.27.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
196
+ "model.layers.27.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
197
+ "model.layers.28.input_layernorm.weight": "model-00004-of-00005.safetensors",
198
+ "model.layers.28.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
199
+ "model.layers.28.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
200
+ "model.layers.28.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
201
+ "model.layers.28.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
202
+ "model.layers.28.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
203
+ "model.layers.28.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
204
+ "model.layers.28.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
205
+ "model.layers.28.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
206
+ "model.layers.29.input_layernorm.weight": "model-00004-of-00005.safetensors",
207
+ "model.layers.29.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
208
+ "model.layers.29.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
209
+ "model.layers.29.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
210
+ "model.layers.29.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
211
+ "model.layers.29.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
212
+ "model.layers.29.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
213
+ "model.layers.29.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
214
+ "model.layers.29.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
215
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00005.safetensors",
216
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
217
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
218
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
219
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
220
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
221
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
222
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
223
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
224
+ "model.layers.30.input_layernorm.weight": "model-00004-of-00005.safetensors",
225
+ "model.layers.30.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
226
+ "model.layers.30.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
227
+ "model.layers.30.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
228
+ "model.layers.30.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
229
+ "model.layers.30.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
230
+ "model.layers.30.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
231
+ "model.layers.30.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
232
+ "model.layers.30.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
233
+ "model.layers.31.input_layernorm.weight": "model-00004-of-00005.safetensors",
234
+ "model.layers.31.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
235
+ "model.layers.31.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
236
+ "model.layers.31.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
237
+ "model.layers.31.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
238
+ "model.layers.31.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
239
+ "model.layers.31.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
240
+ "model.layers.31.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
241
+ "model.layers.31.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
242
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00005.safetensors",
243
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
244
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
245
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
246
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
247
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
248
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
249
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
250
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
251
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00005.safetensors",
252
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
253
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
254
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
255
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
256
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
257
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
258
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
259
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
260
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00005.safetensors",
261
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
262
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
263
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
264
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
265
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
266
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
267
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
268
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
269
+ "model.layers.7.input_layernorm.weight": "model-00002-of-00005.safetensors",
270
+ "model.layers.7.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
271
+ "model.layers.7.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
272
+ "model.layers.7.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
273
+ "model.layers.7.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
274
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
275
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
276
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
277
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
278
+ "model.layers.8.input_layernorm.weight": "model-00002-of-00005.safetensors",
279
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
280
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
281
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
282
+ "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
283
+ "model.layers.8.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
284
+ "model.layers.8.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
285
+ "model.layers.8.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
286
+ "model.layers.8.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
287
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00005.safetensors",
288
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
289
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
290
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
291
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
292
+ "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
293
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
294
+ "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
295
+ "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
296
+ "model.norm.weight": "model-00004-of-00005.safetensors"
297
+ }
298
+ }
modeling_minicpm.py ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_end|>",
4
+ "<|im_start|>",
5
+ "<|tool_call|>",
6
+ "<|execute_start|>",
7
+ "<|execute_end|>",
8
+ "<|fim_prefix|>",
9
+ "<|fim_middle|>",
10
+ "<|fim_suffix|>"
11
+ ],
12
+ "bos_token": {
13
+ "content": "<s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "eos_token": {
20
+ "content": "<|im_end|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "pad_token": {
27
+ "content": "<|im_end|>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ },
33
+ "unk_token": {
34
+ "content": "<unk>",
35
+ "lstrip": false,
36
+ "normalized": false,
37
+ "rstrip": false,
38
+ "single_word": false
39
+ }
40
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb74d51116831c3bf65db812c553f94ab0c88dcf97a5bbb37e3504f6d359c530
3
+ size 1181204
tokenizer_config.json ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "73440": {
31
+ "content": "<|im_end|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "73441": {
39
+ "content": "<|im_start|>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "73442": {
47
+ "content": "<|tool_call|>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "73443": {
55
+ "content": "<|execute_start|>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": true
61
+ },
62
+ "73444": {
63
+ "content": "<|execute_end|>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": true
69
+ },
70
+ "73445": {
71
+ "content": "<|fim_prefix|>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": true
77
+ },
78
+ "73446": {
79
+ "content": "<|fim_middle|>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": true
85
+ },
86
+ "73447": {
87
+ "content": "<|fim_suffix|>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": true
93
+ }
94
+ },
95
+ "additional_special_tokens": [
96
+ "<|im_end|>",
97
+ "<|im_start|>",
98
+ "<|tool_call|>",
99
+ "<|execute_start|>",
100
+ "<|execute_end|>",
101
+ "<|fim_prefix|>",
102
+ "<|fim_middle|>",
103
+ "<|fim_suffix|>"
104
+ ],
105
+ "bos_token": "<s>",
106
+ "chat_template": "{%- macro json_to_python_type(param_name, json_spec) %}\n{%- set basic_type_map = {\n 'string': 'str',\n 'number': 'float',\n 'integer': 'int',\n 'boolean': 'bool',\n 'null': 'None'\n} %}\n\n{%- if json_spec.enum %}\n {{- param_name|title }}\n{%- elif basic_type_map[json_spec.type] is defined %}\n {{- basic_type_map[json_spec.type] }}\n{%- elif json_spec.type == 'array' %}\n {{- 'List[' + json_to_python_type(param_name, json_spec['items']) + ']' }}\n{%- elif json_spec.type == 'object' %}\n {{- 'Dict[str, ' + json_to_python_type(param_name, json_spec.additionalProperties if json_spec.additionalProperties else 'Any') + ']' if not json_spec.properties else param_name|title }}\n{%- elif json_spec.type is iterable %}\n {{- 'Union[' }}\n {%- for t in json_spec.type %}\n {{- json_to_python_type(param_name, {'type': t}) }}\n {{- ', ' if not loop.last }}\n {%- endfor %}\n {{- ']' }}\n{%- else %}\n {{- 'Any' }}\n{%- endif %}\n{%- endmacro %}\n\n{%- macro object_to_fields(json_spec, field_indent) %}\n {%- set o_ns = namespace(f = caller()) %}\n {%- for param_name, param_fields in json_spec.properties|items %}\n {%- if param_fields.enum %}\n {{- '\\n\\nclass ' + param_name|title + '(Enum):\\n' }}\n {%- for enum_option in param_fields.enum %}\n {{- ' enum_' + loop.index0|string + ' = ' + enum_option|tojson + '\\n' }}\n {%- endfor %}\n {%- elif param_fields.type == 'object' and param_fields.properties %}\n {%- call object_to_fields(param_fields, ' ') %}\n {{- '\\n\\nclass ' + param_name|title + '(BaseModel):\\n' }}\n {%- endcall %}\n {%- elif param_fields.type == 'array' and param_fields['items'] and param_fields['items'].type == 'object' and param_fields['items'].properties %}\n {%- call object_to_fields(param_fields['items'], ' ') %}\n {{- '\\n\\nclass ' + param_name|title + '(BaseModel):\\n' }}\n {%- endcall %}\n {%- endif %}\n {%- set param_default = param_fields.default|tojson if param_fields.default is string else param_fields.default|string if param_fields.default is defined else 'None' %}\n {%- set o_ns.f = o_ns.f + field_indent + param_name + ': ' %}\n {%- set o_ns.f = o_ns.f + ('Optional[' + json_to_python_type(param_name, param_fields) + ']' if param_name not in json_spec.required else json_to_python_type(param_name, param_fields)) %}\n {%- if not param_fields.title and not param_fields.description and not param_fields.pattern %}\n {%- set o_ns.f = o_ns.f + (' = ' + param_default if param_name not in json_spec.required else '') %}\n {%- else %}\n {%- set o_ns.f = o_ns.f + (' = Field(...' if param_name in json_spec.required else ' = Field(' + param_default) %}\n {%- set o_ns.f = o_ns.f + (', description=' + param_fields.description|tojson if param_fields.description else '') %}\n {%- set o_ns.f = o_ns.f + (', regex=' + param_fields.pattern|tojson if param_fields.pattern else '') %}\n {%- set o_ns.f = o_ns.f + (', title=' + param_fields.title|tojson if param_fields.title else '') %}\n {%- set o_ns.f = o_ns.f + ')' %}\n {%- endif %}\n {%- set o_ns.f = o_ns.f + '\\n' %}\n {%- endfor %}\n {{- o_ns.f }}\n{%- endmacro %}\n\n{%- macro tool_parser(tools) %}\n{%- for tool in tools %}\n {%- if tool.type is not defined or tool.type == 'function' %}\n {%- if tool.function is defined %}\n {%- set tool = tool.function %}\n {%- endif %}\n {%- set tool_params = tool.parameters if tool.parameters is defined else none %}\n {%- call object_to_fields(tool_params, ' ') %}\n {{- '\\n\\ndef ' + tool.name + '(' }}\n {%- if tool_params %}\n {%- for param_name, param_fields in tool_params.properties|items %}\n {%- set param_default = param_fields.default|tojson if param_fields.default is string else param_fields.default|string if param_fields.default is defined else 'None' %}\n {{- ', ' if loop.index0 != 0 }}\n {{- param_name }}\n {{- '=' + param_default if param_name not in tool_params.required }}\n {%- endfor %}\n {%- endif %}\n {{- '):\\n \"\"\"' }}\n {{- tool.description }}\n {{- '\\n\\n Args:\\n' if tool_params else '\\n' }}\n {%- endcall %}\n {{- ' \"\"\"\\n' }}\n {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n\n{%- if messages[0]['role'] == 'system' %}\n {%- set loop_messages = messages[1:] %}\n {%- set system_message = messages[0]['content'] %}\n{%- else %}\n {%- set loop_messages = messages %}\n {%- set system_message = '' %}\n{%- endif %}\n{{- '<|im_start|>system\\n' + system_message if system_message or tools }}\n{%- if tools %}\n {{- '\\n# Functions\\nHere is a list of functions that you can invoke:\\n```python\\nfrom enum import Enum\\nfrom typing import List, Dict, Optional\\nfrom pydantic import BaseModel, Field\\n\\n' }}\n {{- tool_parser(tools) }}\n {{- \"\\n```\\n\\n# Function Call Rule and Output Format\\n- If the user's question can be answered without calling any function, please answer the user's question directly. In this situation, you should return your thought and answer the user's question directly.\\n- If the user cannot be answered without calling any function, and the user does not provide enough information to call functions, please ask the user for more information. In this situation, you should return your thought and ask the user for more information.\\n- If the user's question cannot be answered without calling any function, and the user has provided enough information to call functions to solve it, you should call the functions. In this situation, the assistant should return your thought and call the functions.\\n- Use default parameters unless the user has specified otherwise.\\n- You should answer in the following format:\\n\\n<|thought_start|>\\n{explain why the user's question can be answered without calling a function or why you should ask the user for more information or why you should call one or more functions and your plan to solve the user's question.}\\n<|thought_end|>\\n<|tool_call_start|>\\n```python\\nfunc1(params_name=params_value, params_name2=params_value2...)\\nfunc2(params)\\n```\\n<|tool_call_end|>\\n{answer the user's question directly or ask the user for more information}\" }}\n{%- endif %}\n{{- '<|im_end|>\\n' if system_message or tools }}\n{%- for message in loop_messages %}\n {%- set content = message.content %}\n {%- if message.role == 'assistant' and message.tool_calls %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {{- '<|thought_start|>\\n' + message.thought + '\\n<|thought_end|>\\n' if message.thought }}\n {{- '<|tool_call_start|>\\n```python\\n' }}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- tool_call.name + '(' }}\n {%- if tool_call.arguments is defined and tool_call.arguments|length > 0 %}\n {%- for param_name, param_value in tool_call.arguments|items %}\n {{- param_name + '=' + param_value|tojson }}\n {{- ',' if not loop.last }}\n {%- endfor %}\n {%- endif %}\n {{- ')\\n' }}\n {%- endfor %}\n {{- '```\\n<|tool_call_end|>\\n' }}\n {{- content if content and not content.startswith('<|tool_call_start|>') }}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == 'assistant' and message.thought %}\n {{- '<|im_start|>' + message.role + '\\n' + '<|thought_start|>\\n' + message.thought + '\\n<|thought_end|>\\n' + content + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endfor %}\n\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}",
107
+ "clean_up_tokenization_spaces": false,
108
+ "eos_token": "<|im_end|>",
109
+ "extra_special_tokens": {},
110
+ "legacy": true,
111
+ "model_max_length": 1000000000000000019884624838656,
112
+ "pad_token": "<|im_end|>",
113
+ "padding_side": "left",
114
+ "sp_model_kwargs": {},
115
+ "spaces_between_special_tokens": false,
116
+ "split_special_tokens": false,
117
+ "tokenizer_class": "LlamaTokenizer",
118
+ "unk_token": "<unk>",
119
+ "use_default_system_prompt": false
120
+ }