Upload folder using huggingface_hub
Browse files- Modelfile +14 -0
- added_tokens.json +10 -0
- compressed_attention.py +1401 -0
- config.json +177 -0
- configuration_minicpm.py +202 -0
- generation_config.json +12 -0
- model-00001-of-00005.safetensors +3 -0
- model-00002-of-00005.safetensors +3 -0
- model-00003-of-00005.safetensors +3 -0
- model-00004-of-00005.safetensors +3 -0
- model-00005-of-00005.safetensors +3 -0
- model.safetensors.index.json +298 -0
- modeling_minicpm.py +0 -0
- special_tokens_map.json +40 -0
- tokenizer.json +0 -0
- tokenizer.model +3 -0
- tokenizer_config.json +120 -0
Modelfile
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ollama modelfile auto-generated by llamafactory
|
2 |
+
|
3 |
+
FROM .
|
4 |
+
|
5 |
+
TEMPLATE """<s>{{ if .System }}<|im_start|>system
|
6 |
+
{{ .System }}<|im_end|>
|
7 |
+
{{ end }}{{ range .Messages }}{{ if eq .Role "user" }}<|im_start|>user
|
8 |
+
{{ .Content }}<|im_end|>
|
9 |
+
<|im_start|>assistant
|
10 |
+
{{ else if eq .Role "assistant" }}{{ .Content }}<|im_end|>
|
11 |
+
{{ end }}{{ end }}"""
|
12 |
+
|
13 |
+
PARAMETER stop "<|im_end|>"
|
14 |
+
PARAMETER num_ctx 4096
|
added_tokens.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<|execute_end|>": 73444,
|
3 |
+
"<|execute_start|>": 73443,
|
4 |
+
"<|fim_middle|>": 73446,
|
5 |
+
"<|fim_prefix|>": 73445,
|
6 |
+
"<|fim_suffix|>": 73447,
|
7 |
+
"<|im_end|>": 73440,
|
8 |
+
"<|im_start|>": 73441,
|
9 |
+
"<|tool_call|>": 73442
|
10 |
+
}
|
compressed_attention.py
ADDED
@@ -0,0 +1,1401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
from typing import Any, Tuple, Union
|
16 |
+
from collections import Counter
|
17 |
+
import torch
|
18 |
+
import triton
|
19 |
+
import triton.language as tl
|
20 |
+
import warnings
|
21 |
+
from torch import nn
|
22 |
+
def is_hopper_gpu():
|
23 |
+
if torch.cuda.is_available():
|
24 |
+
device_capability = torch.cuda.get_device_capability()
|
25 |
+
major, minor = device_capability
|
26 |
+
return major == 9
|
27 |
+
return False
|
28 |
+
def get_compressed_seqlens(
|
29 |
+
cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int
|
30 |
+
):
|
31 |
+
# compute seqlens after compression
|
32 |
+
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
33 |
+
y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
|
34 |
+
# corner case, if sequence_length < kernel_size, no compression for this sequence
|
35 |
+
y_seqlens[seqlens < kernel_size] = 0
|
36 |
+
y_cu_seqlens = torch.zeros(
|
37 |
+
y_seqlens.shape[0] + 1, dtype=torch.int32, device=cu_seqlens.device
|
38 |
+
)
|
39 |
+
y_cu_seqlens[1:] = torch.cumsum(y_seqlens, dim=0)
|
40 |
+
return y_seqlens, y_cu_seqlens
|
41 |
+
|
42 |
+
|
43 |
+
def get_num_warps_stages(head_dim, block_size, is_hopper_gpu):
|
44 |
+
"""
|
45 |
+
Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
head_dim (int): Size of the head dimension.
|
49 |
+
block_size (int): Size of the block in the attention matrix.
|
50 |
+
is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
tuple: (num_warps, num_stages) recommended values.
|
54 |
+
"""
|
55 |
+
# Determine if head_dim and block_size exceed 64
|
56 |
+
head_large = head_dim > 64
|
57 |
+
block_large = block_size > 64
|
58 |
+
|
59 |
+
if is_hopper_gpu:
|
60 |
+
# Hopper GPU recommendations
|
61 |
+
if head_large and block_large:
|
62 |
+
num_warps = 8
|
63 |
+
num_stages = 3
|
64 |
+
elif head_large or block_large:
|
65 |
+
num_warps = 4
|
66 |
+
num_stages = 3
|
67 |
+
else:
|
68 |
+
num_warps = 2
|
69 |
+
num_stages = 2
|
70 |
+
else:
|
71 |
+
# Ampere GPU recommendations
|
72 |
+
if head_large and block_large:
|
73 |
+
num_warps = 8
|
74 |
+
num_stages = 3
|
75 |
+
elif head_large or block_large:
|
76 |
+
num_warps = 8
|
77 |
+
num_stages = 3
|
78 |
+
else:
|
79 |
+
num_warps = 2
|
80 |
+
num_stages = 2
|
81 |
+
return num_warps, num_stages
|
82 |
+
|
83 |
+
|
84 |
+
IS_HOPPER_GPU = is_hopper_gpu()
|
85 |
+
|
86 |
+
|
87 |
+
@triton.jit
|
88 |
+
def forward_kernel(
|
89 |
+
q_ptr, # Q: n x h x d
|
90 |
+
k_ptr, # K: n x h x d
|
91 |
+
v_ptr, # V: n x h x d
|
92 |
+
o_ptr, # O: n x h x d
|
93 |
+
lse_ptr, # LSE: h x n
|
94 |
+
# size and stride at compresstion
|
95 |
+
kernel_size,
|
96 |
+
kernel_stride,
|
97 |
+
# seqlens
|
98 |
+
cu_seqlens_q,
|
99 |
+
cu_seqlens_k,
|
100 |
+
# shape
|
101 |
+
NUM_KV_HEADS,
|
102 |
+
NUM_SHARE_Q_HEADS,
|
103 |
+
HEAD_DIM,
|
104 |
+
# sm_scale
|
105 |
+
sm_scale,
|
106 |
+
# stride
|
107 |
+
stride_qn,
|
108 |
+
stride_qh,
|
109 |
+
stride_qd,
|
110 |
+
stride_kn,
|
111 |
+
stride_kh,
|
112 |
+
stride_kd,
|
113 |
+
stride_vn,
|
114 |
+
stride_vh,
|
115 |
+
stride_vd,
|
116 |
+
stride_on,
|
117 |
+
stride_oh,
|
118 |
+
stride_od,
|
119 |
+
stride_lh,
|
120 |
+
stride_ln,
|
121 |
+
# META parameters
|
122 |
+
BLOCK_SIZE_Q: tl.constexpr, # q block size
|
123 |
+
BLOCK_SIZE_K: tl.constexpr, # k block size
|
124 |
+
BLOCK_SIZE_D: tl.constexpr,
|
125 |
+
):
|
126 |
+
qk_scale = sm_scale * 1.44269504
|
127 |
+
# get batch id and head id
|
128 |
+
pid_b = tl.program_id(0)
|
129 |
+
pid_h = tl.program_id(1)
|
130 |
+
pid_q = tl.program_id(2)
|
131 |
+
pid_kh = pid_h // NUM_SHARE_Q_HEADS
|
132 |
+
# get q k start and len after rmpad
|
133 |
+
q_start = tl.load(cu_seqlens_q + pid_b)
|
134 |
+
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
|
135 |
+
k_start = tl.load(cu_seqlens_k + pid_b)
|
136 |
+
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
|
137 |
+
# skip first kernel_size query block, because they do no attend to any keys
|
138 |
+
q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
|
139 |
+
if q_start_in_seq >= q_len:
|
140 |
+
return
|
141 |
+
# init qkv pointer
|
142 |
+
q_ptrs = tl.make_block_ptr(
|
143 |
+
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
|
144 |
+
shape=(q_len, HEAD_DIM),
|
145 |
+
strides=(stride_qn, stride_qd),
|
146 |
+
offsets=(q_start_in_seq, 0),
|
147 |
+
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
|
148 |
+
order=(1, 0),
|
149 |
+
)
|
150 |
+
k_ptrs = tl.make_block_ptr(
|
151 |
+
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
|
152 |
+
shape=(HEAD_DIM, k_len),
|
153 |
+
strides=(stride_kd, stride_kn),
|
154 |
+
offsets=(0, 0),
|
155 |
+
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
|
156 |
+
order=(0, 1),
|
157 |
+
)
|
158 |
+
v_ptrs = tl.make_block_ptr(
|
159 |
+
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
|
160 |
+
shape=(k_len, HEAD_DIM),
|
161 |
+
strides=(stride_vn, stride_vd),
|
162 |
+
offsets=(0, 0),
|
163 |
+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
|
164 |
+
order=(1, 0),
|
165 |
+
)
|
166 |
+
# load q
|
167 |
+
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
|
168 |
+
# init statistics
|
169 |
+
off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
|
170 |
+
off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
|
171 |
+
m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
|
172 |
+
lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
|
173 |
+
acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32)
|
174 |
+
# attention
|
175 |
+
lo = 0
|
176 |
+
hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
|
177 |
+
for i in range(lo, hi, BLOCK_SIZE_K):
|
178 |
+
i = tl.multiple_of(i, BLOCK_SIZE_K)
|
179 |
+
# load k
|
180 |
+
k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero")
|
181 |
+
# compute qk
|
182 |
+
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
|
183 |
+
qk += tl.where(
|
184 |
+
off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")
|
185 |
+
)
|
186 |
+
qk += tl.dot(q, k) * qk_scale
|
187 |
+
# compute m_ij and l_ij
|
188 |
+
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
189 |
+
p = tl.exp2(qk - m_ij[:, None])
|
190 |
+
l_ij = tl.sum(p, axis=1)
|
191 |
+
# scale acc_o
|
192 |
+
acc_o_scale = tl.exp2(m_i - m_ij)
|
193 |
+
acc_o = acc_o * acc_o_scale[:, None]
|
194 |
+
# load v and update acc_o
|
195 |
+
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
|
196 |
+
p = p.to(v.dtype)
|
197 |
+
acc_o += tl.dot(p, v)
|
198 |
+
# update statistics
|
199 |
+
m_i = m_ij
|
200 |
+
lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
|
201 |
+
# update ptrs
|
202 |
+
k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))
|
203 |
+
v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))
|
204 |
+
# final scale
|
205 |
+
acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
|
206 |
+
# save output
|
207 |
+
o_ptrs = tl.make_block_ptr(
|
208 |
+
base=o_ptr + q_start * stride_on + pid_h * stride_oh,
|
209 |
+
shape=(q_len, HEAD_DIM),
|
210 |
+
strides=(stride_on, stride_od),
|
211 |
+
offsets=(q_start_in_seq, 0),
|
212 |
+
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
|
213 |
+
order=(1, 0),
|
214 |
+
)
|
215 |
+
tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
|
216 |
+
# save lse
|
217 |
+
l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln
|
218 |
+
tl.store(l_ptrs, lse_i, mask=off_q < q_len)
|
219 |
+
|
220 |
+
|
221 |
+
@triton.jit
|
222 |
+
def backward_sum_o_do(
|
223 |
+
o_ptr, # O: n x h x d
|
224 |
+
do_ptr, # dO: n x h x d
|
225 |
+
delta_ptr, # D: h x n
|
226 |
+
o_len,
|
227 |
+
HEAD_DIM,
|
228 |
+
stride_on,
|
229 |
+
stride_oh,
|
230 |
+
stride_od,
|
231 |
+
stride_don,
|
232 |
+
stride_doh,
|
233 |
+
stride_dod,
|
234 |
+
stride_dh,
|
235 |
+
stride_dn,
|
236 |
+
BLOCK_SIZE_O: tl.constexpr,
|
237 |
+
BLOCK_SIZE_D: tl.constexpr,
|
238 |
+
):
|
239 |
+
pid_n = tl.program_id(0)
|
240 |
+
pid_h = tl.program_id(1)
|
241 |
+
off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
|
242 |
+
off_d = tl.arange(0, BLOCK_SIZE_D)
|
243 |
+
o = tl.load(
|
244 |
+
o_ptr
|
245 |
+
+ off_n[:, None] * stride_on
|
246 |
+
+ pid_h * stride_oh
|
247 |
+
+ off_d[None, :] * stride_od,
|
248 |
+
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
|
249 |
+
other=0,
|
250 |
+
).to(tl.float32)
|
251 |
+
do = tl.load(
|
252 |
+
do_ptr
|
253 |
+
+ off_n[:, None] * stride_don
|
254 |
+
+ pid_h * stride_doh
|
255 |
+
+ off_d[None, :] * stride_dod,
|
256 |
+
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
|
257 |
+
other=0,
|
258 |
+
).to(tl.float32)
|
259 |
+
delta = tl.sum(o * do, axis=1)
|
260 |
+
tl.store(
|
261 |
+
delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len
|
262 |
+
)
|
263 |
+
|
264 |
+
|
265 |
+
@triton.jit
|
266 |
+
def backward_dkdv(
|
267 |
+
q_ptr, # Q: n x qh x d
|
268 |
+
k_ptr, # K: n x kh x d
|
269 |
+
v_ptr, # V: n x kh x d
|
270 |
+
lse_ptr, # LSE: qh x n
|
271 |
+
d_ptr, # Delta: qh x n
|
272 |
+
do_ptr,
|
273 |
+
dk_ptr, # DK: sh x n x kh x d
|
274 |
+
dv_ptr, # DV: sh x n x kh x d
|
275 |
+
kernel_size,
|
276 |
+
kernel_stride,
|
277 |
+
# seqlens
|
278 |
+
cu_seqlens_q,
|
279 |
+
cu_seqlens_k,
|
280 |
+
# shape
|
281 |
+
NUM_KV_HEADS,
|
282 |
+
NUM_SHARE_Q_HEADS,
|
283 |
+
HEAD_DIM,
|
284 |
+
# sm_scale
|
285 |
+
sm_scale,
|
286 |
+
# stride
|
287 |
+
stride_qn,
|
288 |
+
stride_qh,
|
289 |
+
stride_qd,
|
290 |
+
stride_kn,
|
291 |
+
stride_kh,
|
292 |
+
stride_kd,
|
293 |
+
stride_vn,
|
294 |
+
stride_vh,
|
295 |
+
stride_vd,
|
296 |
+
stride_lh,
|
297 |
+
stride_ln,
|
298 |
+
stride_dh,
|
299 |
+
stride_dn,
|
300 |
+
stride_don,
|
301 |
+
stride_doh,
|
302 |
+
stride_dod,
|
303 |
+
stride_dks,
|
304 |
+
stride_dkn,
|
305 |
+
stride_dkh,
|
306 |
+
stride_dkd,
|
307 |
+
stride_dvs,
|
308 |
+
stride_dvn,
|
309 |
+
stride_dvh,
|
310 |
+
stride_dvd,
|
311 |
+
# META parameters
|
312 |
+
BLOCK_SIZE_Q: tl.constexpr, # q block size
|
313 |
+
BLOCK_SIZE_K: tl.constexpr, # k block size
|
314 |
+
BLOCK_SIZE_D: tl.constexpr,
|
315 |
+
):
|
316 |
+
qk_scale = sm_scale * 1.44269504
|
317 |
+
# get batch id and head id
|
318 |
+
pid_b = tl.program_id(0)
|
319 |
+
pid_h = tl.program_id(1)
|
320 |
+
pid_kh = pid_h // NUM_SHARE_Q_HEADS
|
321 |
+
pid_sh = pid_h % NUM_SHARE_Q_HEADS
|
322 |
+
pid_k = tl.program_id(2)
|
323 |
+
# get q k start and len after rmpad
|
324 |
+
q_start = tl.load(cu_seqlens_q + pid_b)
|
325 |
+
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
|
326 |
+
k_start = tl.load(cu_seqlens_k + pid_b)
|
327 |
+
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
|
328 |
+
if BLOCK_SIZE_K * pid_k >= k_len:
|
329 |
+
return
|
330 |
+
# init pointers
|
331 |
+
k_ptrs = tl.make_block_ptr(
|
332 |
+
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
|
333 |
+
shape=(k_len, HEAD_DIM),
|
334 |
+
strides=(stride_kn, stride_kd),
|
335 |
+
offsets=(pid_k * BLOCK_SIZE_K, 0),
|
336 |
+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
|
337 |
+
order=(1, 0),
|
338 |
+
)
|
339 |
+
dk_ptrs = tl.make_block_ptr(
|
340 |
+
base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
|
341 |
+
shape=(k_len, HEAD_DIM),
|
342 |
+
strides=(stride_dkn, stride_dkd),
|
343 |
+
offsets=(pid_k * BLOCK_SIZE_K, 0),
|
344 |
+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
|
345 |
+
order=(1, 0),
|
346 |
+
)
|
347 |
+
v_ptrs = tl.make_block_ptr(
|
348 |
+
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
|
349 |
+
shape=(k_len, HEAD_DIM),
|
350 |
+
strides=(stride_vn, stride_vd),
|
351 |
+
offsets=(pid_k * BLOCK_SIZE_K, 0),
|
352 |
+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
|
353 |
+
order=(1, 0),
|
354 |
+
)
|
355 |
+
dv_ptrs = tl.make_block_ptr(
|
356 |
+
base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
|
357 |
+
shape=(k_len, HEAD_DIM),
|
358 |
+
strides=(stride_dvn, stride_dvd),
|
359 |
+
offsets=(pid_k * BLOCK_SIZE_K, 0),
|
360 |
+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
|
361 |
+
order=(1, 0),
|
362 |
+
)
|
363 |
+
# offsets
|
364 |
+
off_q = tl.arange(0, BLOCK_SIZE_Q)
|
365 |
+
off_k = (
|
366 |
+
pid_k * BLOCK_SIZE_K * kernel_stride
|
367 |
+
+ tl.arange(0, BLOCK_SIZE_K) * kernel_stride
|
368 |
+
+ kernel_size
|
369 |
+
- 1
|
370 |
+
)
|
371 |
+
# load k v and keep in SRAM
|
372 |
+
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
|
373 |
+
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
|
374 |
+
# init dk dv
|
375 |
+
dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
|
376 |
+
dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
|
377 |
+
q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1
|
378 |
+
q_ptrs = tl.make_block_ptr(
|
379 |
+
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
|
380 |
+
shape=(HEAD_DIM, q_len),
|
381 |
+
strides=(stride_qd, stride_qn),
|
382 |
+
offsets=(0, q_lo),
|
383 |
+
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
|
384 |
+
order=(0, 1),
|
385 |
+
)
|
386 |
+
do_ptrs = tl.make_block_ptr(
|
387 |
+
base=do_ptr + q_start * stride_don + pid_h * stride_doh,
|
388 |
+
shape=(HEAD_DIM, q_len),
|
389 |
+
strides=(stride_dod, stride_don),
|
390 |
+
offsets=(0, q_lo),
|
391 |
+
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
|
392 |
+
order=(0, 1),
|
393 |
+
)
|
394 |
+
d_ptrs = tl.make_block_ptr(
|
395 |
+
base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
|
396 |
+
shape=(1, q_len),
|
397 |
+
strides=(0, stride_dn),
|
398 |
+
offsets=(0, q_lo),
|
399 |
+
block_shape=(1, BLOCK_SIZE_Q),
|
400 |
+
order=(1, 0),
|
401 |
+
)
|
402 |
+
lse_ptrs = tl.make_block_ptr(
|
403 |
+
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
|
404 |
+
shape=(1, q_len),
|
405 |
+
strides=(0, stride_ln),
|
406 |
+
offsets=(0, q_lo),
|
407 |
+
block_shape=(1, BLOCK_SIZE_Q),
|
408 |
+
order=(0, 1),
|
409 |
+
)
|
410 |
+
# loop for q blocks
|
411 |
+
for i in range(q_lo, q_len, BLOCK_SIZE_Q):
|
412 |
+
# load
|
413 |
+
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
|
414 |
+
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
|
415 |
+
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
|
416 |
+
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
|
417 |
+
# compute qk
|
418 |
+
# [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
|
419 |
+
qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf"))
|
420 |
+
qk += tl.dot(k, q) * qk_scale
|
421 |
+
# compute p, ds
|
422 |
+
# [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
|
423 |
+
p = tl.exp2(qk - lse)
|
424 |
+
# [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
|
425 |
+
dp = tl.dot(v, do)
|
426 |
+
ds = sm_scale * p * (dp - d)
|
427 |
+
# cast dtype
|
428 |
+
p = p.to(do.dtype)
|
429 |
+
ds = ds.to(q.dtype)
|
430 |
+
# update dk and dv
|
431 |
+
# [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM]
|
432 |
+
dk += tl.dot(ds, tl.trans(q))
|
433 |
+
dv += tl.dot(p, tl.trans(do))
|
434 |
+
# increment pointers
|
435 |
+
q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q))
|
436 |
+
do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q))
|
437 |
+
lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q))
|
438 |
+
d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q))
|
439 |
+
# save dk dv
|
440 |
+
tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
|
441 |
+
tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
|
442 |
+
|
443 |
+
|
444 |
+
@triton.jit
|
445 |
+
def backward_dq(
|
446 |
+
q_ptr, # Q: n x qh x d
|
447 |
+
k_ptr, # K: n x kh x d
|
448 |
+
v_ptr, # V: n x kh x d
|
449 |
+
lse_ptr, # LSE: qh x n
|
450 |
+
d_ptr, # Delta: qh x n
|
451 |
+
do_ptr,
|
452 |
+
dq_ptr,
|
453 |
+
kernel_size,
|
454 |
+
kernel_stride,
|
455 |
+
# seqlens
|
456 |
+
cu_seqlens_q,
|
457 |
+
cu_seqlens_k,
|
458 |
+
# shape
|
459 |
+
NUM_KV_HEADS,
|
460 |
+
NUM_SHARE_Q_HEADS,
|
461 |
+
HEAD_DIM,
|
462 |
+
# sm_scale
|
463 |
+
sm_scale,
|
464 |
+
# stride
|
465 |
+
stride_qn,
|
466 |
+
stride_qh,
|
467 |
+
stride_qd,
|
468 |
+
stride_kn,
|
469 |
+
stride_kh,
|
470 |
+
stride_kd,
|
471 |
+
stride_vn,
|
472 |
+
stride_vh,
|
473 |
+
stride_vd,
|
474 |
+
stride_lh,
|
475 |
+
stride_ln,
|
476 |
+
stride_dh,
|
477 |
+
stride_dn,
|
478 |
+
stride_don,
|
479 |
+
stride_doh,
|
480 |
+
stride_dod,
|
481 |
+
stride_dqn,
|
482 |
+
stride_dqh,
|
483 |
+
stride_dqd,
|
484 |
+
# META parameters
|
485 |
+
BLOCK_SIZE_Q: tl.constexpr, # q block size
|
486 |
+
BLOCK_SIZE_K: tl.constexpr, # k block size
|
487 |
+
BLOCK_SIZE_D: tl.constexpr,
|
488 |
+
):
|
489 |
+
qk_scale = sm_scale * 1.44269504
|
490 |
+
# get batch id and head id
|
491 |
+
pid_b = tl.program_id(0)
|
492 |
+
pid_h = tl.program_id(1)
|
493 |
+
pid_q = tl.program_id(2)
|
494 |
+
pid_kh = pid_h // NUM_SHARE_Q_HEADS
|
495 |
+
# get q k start and len after rmpad
|
496 |
+
q_start = tl.load(cu_seqlens_q + pid_b)
|
497 |
+
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
|
498 |
+
k_start = tl.load(cu_seqlens_k + pid_b)
|
499 |
+
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
|
500 |
+
# skip first kernel_size query block, because they do no attend to any keys
|
501 |
+
q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
|
502 |
+
if q_start_in_seq >= q_len:
|
503 |
+
return
|
504 |
+
# init pointers
|
505 |
+
q_ptrs = tl.make_block_ptr(
|
506 |
+
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
|
507 |
+
shape=(q_len, HEAD_DIM),
|
508 |
+
strides=(stride_qn, stride_qd),
|
509 |
+
offsets=(q_start_in_seq, 0),
|
510 |
+
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
|
511 |
+
order=(1, 0),
|
512 |
+
)
|
513 |
+
dq_ptrs = tl.make_block_ptr(
|
514 |
+
base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh,
|
515 |
+
shape=(q_len, HEAD_DIM),
|
516 |
+
strides=(stride_dqn, stride_dqd),
|
517 |
+
offsets=(q_start_in_seq, 0),
|
518 |
+
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
|
519 |
+
order=(1, 0),
|
520 |
+
)
|
521 |
+
k_ptrs = tl.make_block_ptr(
|
522 |
+
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
|
523 |
+
shape=(k_len, HEAD_DIM),
|
524 |
+
strides=(stride_kn, stride_kd),
|
525 |
+
offsets=(0, 0),
|
526 |
+
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
|
527 |
+
order=(1, 0),
|
528 |
+
)
|
529 |
+
v_ptrs = tl.make_block_ptr(
|
530 |
+
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
|
531 |
+
shape=(HEAD_DIM, k_len),
|
532 |
+
strides=(stride_vd, stride_vn),
|
533 |
+
offsets=(0, 0),
|
534 |
+
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
|
535 |
+
order=(0, 1),
|
536 |
+
)
|
537 |
+
do_ptrs = tl.make_block_ptr(
|
538 |
+
base=do_ptr + q_start * stride_don + pid_h * stride_doh,
|
539 |
+
shape=(q_len, HEAD_DIM),
|
540 |
+
strides=(stride_don, stride_dod),
|
541 |
+
offsets=(q_start_in_seq, 0),
|
542 |
+
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
|
543 |
+
order=(1, 0),
|
544 |
+
)
|
545 |
+
d_ptrs = tl.make_block_ptr(
|
546 |
+
base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
|
547 |
+
shape=(q_len, 1),
|
548 |
+
strides=(stride_dn, stride_dh),
|
549 |
+
offsets=(q_start_in_seq, 0),
|
550 |
+
block_shape=(BLOCK_SIZE_Q, 1),
|
551 |
+
order=(0, 1),
|
552 |
+
)
|
553 |
+
lse_ptrs = tl.make_block_ptr(
|
554 |
+
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
|
555 |
+
shape=(q_len, 1),
|
556 |
+
strides=(stride_ln, stride_lh),
|
557 |
+
offsets=(q_start_in_seq, 0),
|
558 |
+
block_shape=(BLOCK_SIZE_Q, 1),
|
559 |
+
order=(0, 1),
|
560 |
+
)
|
561 |
+
# offsets
|
562 |
+
off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
|
563 |
+
off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
|
564 |
+
# load q, do, lse, delta, and keep in SRAM
|
565 |
+
q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
|
566 |
+
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
|
567 |
+
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
|
568 |
+
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
|
569 |
+
# init dq
|
570 |
+
dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32)
|
571 |
+
lo = 0
|
572 |
+
hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
|
573 |
+
for i in range(lo, hi, BLOCK_SIZE_K):
|
574 |
+
# load
|
575 |
+
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
|
576 |
+
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
|
577 |
+
# compute qk
|
578 |
+
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
|
579 |
+
qk += tl.where(
|
580 |
+
off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")
|
581 |
+
)
|
582 |
+
qk += tl.dot(q, tl.trans(k)) * qk_scale
|
583 |
+
# compute p, ds
|
584 |
+
p = tl.exp2(qk - lse)
|
585 |
+
dp = tl.dot(do, v)
|
586 |
+
ds = sm_scale * p * (dp - d)
|
587 |
+
# cast dtype
|
588 |
+
ds = ds.to(q.dtype)
|
589 |
+
# update dq
|
590 |
+
dq += tl.dot(ds, k)
|
591 |
+
# increment pointers
|
592 |
+
k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0))
|
593 |
+
v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K))
|
594 |
+
# save dq
|
595 |
+
tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))
|
596 |
+
|
597 |
+
|
598 |
+
def _compressed_attention_fwd(
|
599 |
+
q: torch.Tensor,
|
600 |
+
k: torch.Tensor,
|
601 |
+
v: torch.Tensor,
|
602 |
+
kernel_size: int,
|
603 |
+
kernel_stride: int,
|
604 |
+
cu_seqlens_q: torch.Tensor,
|
605 |
+
cu_seqlens_k: torch.Tensor,
|
606 |
+
max_seqlen_q: torch.Tensor,
|
607 |
+
max_seqlen_k: torch.Tensor,
|
608 |
+
sm_scale: float,
|
609 |
+
):
|
610 |
+
# dtype check
|
611 |
+
assert k.dtype == q.dtype and v.dtype == q.dtype
|
612 |
+
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
|
613 |
+
# shape
|
614 |
+
q_len, num_q_heads, head_dim = q.shape
|
615 |
+
k_len, num_k_heads, head_dim = k.shape
|
616 |
+
v_len, num_v_heads, head_dim = v.shape
|
617 |
+
batch_size = cu_seqlens_q.shape[0] - 1
|
618 |
+
assert k_len == v_len and q_len > k_len
|
619 |
+
# gqa
|
620 |
+
assert num_k_heads == num_v_heads
|
621 |
+
assert num_q_heads % num_k_heads == 0
|
622 |
+
num_share_q_heads = num_q_heads // num_k_heads
|
623 |
+
# output tensor
|
624 |
+
o = torch.zeros_like(q)
|
625 |
+
lse = torch.full(
|
626 |
+
(num_q_heads, q_len),
|
627 |
+
fill_value=-torch.inf,
|
628 |
+
dtype=torch.float32,
|
629 |
+
device=q.device,
|
630 |
+
)
|
631 |
+
# launch kernel
|
632 |
+
grid = lambda META: (
|
633 |
+
batch_size,
|
634 |
+
num_q_heads,
|
635 |
+
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
|
636 |
+
)
|
637 |
+
BLOCK_SIZE_Q = 128
|
638 |
+
BLOCK_SIZE_K = 128
|
639 |
+
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
|
640 |
+
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
|
641 |
+
forward_kernel[grid](
|
642 |
+
q,
|
643 |
+
k,
|
644 |
+
v,
|
645 |
+
o,
|
646 |
+
lse,
|
647 |
+
kernel_size,
|
648 |
+
kernel_stride,
|
649 |
+
cu_seqlens_q,
|
650 |
+
cu_seqlens_k,
|
651 |
+
num_k_heads,
|
652 |
+
num_share_q_heads,
|
653 |
+
head_dim,
|
654 |
+
sm_scale,
|
655 |
+
q.stride(0),
|
656 |
+
q.stride(1),
|
657 |
+
q.stride(2),
|
658 |
+
k.stride(0),
|
659 |
+
k.stride(1),
|
660 |
+
k.stride(2),
|
661 |
+
v.stride(0),
|
662 |
+
v.stride(1),
|
663 |
+
v.stride(2),
|
664 |
+
o.stride(0),
|
665 |
+
o.stride(1),
|
666 |
+
o.stride(2),
|
667 |
+
lse.stride(0),
|
668 |
+
lse.stride(1),
|
669 |
+
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
670 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
671 |
+
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
672 |
+
num_warps=num_warps,
|
673 |
+
num_stages=num_stages,
|
674 |
+
)
|
675 |
+
return o, lse
|
676 |
+
|
677 |
+
|
678 |
+
def _compressed_attention_bwd(
|
679 |
+
o: torch.Tensor,
|
680 |
+
do: torch.Tensor,
|
681 |
+
lse: torch.Tensor,
|
682 |
+
q: torch.Tensor,
|
683 |
+
k: torch.Tensor,
|
684 |
+
v: torch.Tensor,
|
685 |
+
kernel_size: int,
|
686 |
+
kernel_stride: int,
|
687 |
+
cu_seqlens_q: torch.Tensor,
|
688 |
+
cu_seqlens_k: torch.Tensor,
|
689 |
+
max_seqlen_q: torch.Tensor,
|
690 |
+
max_seqlen_k: torch.Tensor,
|
691 |
+
sm_scale: float,
|
692 |
+
):
|
693 |
+
q_len, num_q_heads, head_dim = q.shape
|
694 |
+
k_len, num_k_heads, head_dim = k.shape
|
695 |
+
v_len, num_v_heads, head_dim = v.shape
|
696 |
+
o_len, num_o_heads, head_dim = o.shape
|
697 |
+
num_share_q_heads = num_q_heads // num_k_heads
|
698 |
+
# compute D
|
699 |
+
delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)
|
700 |
+
grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads)
|
701 |
+
BLOCK_SIZE_O = 256
|
702 |
+
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
|
703 |
+
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)
|
704 |
+
backward_sum_o_do[grid](
|
705 |
+
o,
|
706 |
+
do,
|
707 |
+
delta,
|
708 |
+
o_len,
|
709 |
+
head_dim,
|
710 |
+
o.stride(0),
|
711 |
+
o.stride(1),
|
712 |
+
o.stride(2),
|
713 |
+
do.stride(0),
|
714 |
+
do.stride(1),
|
715 |
+
do.stride(2),
|
716 |
+
delta.stride(0),
|
717 |
+
delta.stride(1),
|
718 |
+
BLOCK_SIZE_O=BLOCK_SIZE_O,
|
719 |
+
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
720 |
+
num_warps=num_warps,
|
721 |
+
num_stages=num_stages,
|
722 |
+
)
|
723 |
+
# compute dk dv
|
724 |
+
dk = torch.zeros(
|
725 |
+
num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
|
726 |
+
)
|
727 |
+
dv = torch.zeros(
|
728 |
+
num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
|
729 |
+
)
|
730 |
+
batch_size = cu_seqlens_q.shape[0] - 1
|
731 |
+
grid = lambda META: (
|
732 |
+
batch_size,
|
733 |
+
num_q_heads,
|
734 |
+
triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
|
735 |
+
)
|
736 |
+
BLOCK_SIZE_Q = 64
|
737 |
+
BLOCK_SIZE_K = 128
|
738 |
+
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
|
739 |
+
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)
|
740 |
+
backward_dkdv[grid](
|
741 |
+
q,
|
742 |
+
k,
|
743 |
+
v,
|
744 |
+
lse,
|
745 |
+
delta,
|
746 |
+
do,
|
747 |
+
dk,
|
748 |
+
dv,
|
749 |
+
kernel_size,
|
750 |
+
kernel_stride,
|
751 |
+
cu_seqlens_q,
|
752 |
+
cu_seqlens_k,
|
753 |
+
num_k_heads,
|
754 |
+
num_share_q_heads,
|
755 |
+
head_dim,
|
756 |
+
sm_scale,
|
757 |
+
q.stride(0),
|
758 |
+
q.stride(1),
|
759 |
+
q.stride(2),
|
760 |
+
k.stride(0),
|
761 |
+
k.stride(1),
|
762 |
+
k.stride(2),
|
763 |
+
v.stride(0),
|
764 |
+
v.stride(1),
|
765 |
+
v.stride(2),
|
766 |
+
lse.stride(0),
|
767 |
+
lse.stride(1),
|
768 |
+
delta.stride(0),
|
769 |
+
delta.stride(1),
|
770 |
+
do.stride(0),
|
771 |
+
do.stride(1),
|
772 |
+
do.stride(2),
|
773 |
+
dk.stride(0),
|
774 |
+
dk.stride(1),
|
775 |
+
dk.stride(2),
|
776 |
+
dk.stride(3),
|
777 |
+
dv.stride(0),
|
778 |
+
dv.stride(1),
|
779 |
+
dv.stride(2),
|
780 |
+
dv.stride(3),
|
781 |
+
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
782 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
783 |
+
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
784 |
+
num_warps=num_warps,
|
785 |
+
num_stages=num_stages,
|
786 |
+
)
|
787 |
+
dk = dk.sum(0)
|
788 |
+
dv = dv.sum(0)
|
789 |
+
# compute dq
|
790 |
+
dq = torch.zeros_like(q)
|
791 |
+
grid = lambda META: (
|
792 |
+
batch_size,
|
793 |
+
num_q_heads,
|
794 |
+
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
|
795 |
+
)
|
796 |
+
BLOCK_SIZE_Q = 128
|
797 |
+
BLOCK_SIZE_K = 64
|
798 |
+
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
|
799 |
+
backward_dq[grid](
|
800 |
+
q,
|
801 |
+
k,
|
802 |
+
v,
|
803 |
+
lse,
|
804 |
+
delta,
|
805 |
+
do,
|
806 |
+
dq,
|
807 |
+
kernel_size,
|
808 |
+
kernel_stride,
|
809 |
+
cu_seqlens_q,
|
810 |
+
cu_seqlens_k,
|
811 |
+
num_k_heads,
|
812 |
+
num_share_q_heads,
|
813 |
+
head_dim,
|
814 |
+
sm_scale,
|
815 |
+
q.stride(0),
|
816 |
+
q.stride(1),
|
817 |
+
q.stride(2),
|
818 |
+
k.stride(0),
|
819 |
+
k.stride(1),
|
820 |
+
k.stride(2),
|
821 |
+
v.stride(0),
|
822 |
+
v.stride(1),
|
823 |
+
v.stride(2),
|
824 |
+
lse.stride(0),
|
825 |
+
lse.stride(1),
|
826 |
+
delta.stride(0),
|
827 |
+
delta.stride(1),
|
828 |
+
do.stride(0),
|
829 |
+
do.stride(1),
|
830 |
+
do.stride(2),
|
831 |
+
dq.stride(0),
|
832 |
+
dq.stride(1),
|
833 |
+
dq.stride(2),
|
834 |
+
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
835 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
836 |
+
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
837 |
+
num_warps=num_warps,
|
838 |
+
num_stages=num_stages,
|
839 |
+
)
|
840 |
+
return dq, dk, dv
|
841 |
+
|
842 |
+
|
843 |
+
class CompressedAttention(torch.autograd.Function):
|
844 |
+
@staticmethod
|
845 |
+
def forward(
|
846 |
+
ctx,
|
847 |
+
q: torch.Tensor,
|
848 |
+
k: torch.Tensor,
|
849 |
+
v: torch.Tensor,
|
850 |
+
kernel_size: int,
|
851 |
+
kernel_stride: int,
|
852 |
+
cu_seqlens_q: torch.Tensor,
|
853 |
+
cu_seqlens_k: torch.Tensor,
|
854 |
+
max_seqlen_q: torch.Tensor,
|
855 |
+
max_seqlen_k: torch.Tensor,
|
856 |
+
sm_scale=None,
|
857 |
+
):
|
858 |
+
# dtype check
|
859 |
+
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
|
860 |
+
assert q.dtype == k.dtype and k.dtype == v.dtype
|
861 |
+
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
|
862 |
+
# softmax scale
|
863 |
+
if sm_scale is None:
|
864 |
+
sm_scale = 1 / math.sqrt(q.shape[-1])
|
865 |
+
o, lse = _compressed_attention_fwd(
|
866 |
+
q,
|
867 |
+
k,
|
868 |
+
v,
|
869 |
+
kernel_size,
|
870 |
+
kernel_stride,
|
871 |
+
cu_seqlens_q,
|
872 |
+
cu_seqlens_k,
|
873 |
+
max_seqlen_q,
|
874 |
+
max_seqlen_k,
|
875 |
+
sm_scale,
|
876 |
+
)
|
877 |
+
ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k)
|
878 |
+
ctx.sm_scale = sm_scale
|
879 |
+
ctx.max_seqlen_q = max_seqlen_q
|
880 |
+
ctx.max_seqlen_k = max_seqlen_k
|
881 |
+
ctx.kernel_size = kernel_size
|
882 |
+
ctx.kernel_stride = kernel_stride
|
883 |
+
return o, lse
|
884 |
+
|
885 |
+
@staticmethod
|
886 |
+
def backward(ctx, do: torch.Tensor, *args) -> Any:
|
887 |
+
q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
|
888 |
+
max_seqlen_q = ctx.max_seqlen_q
|
889 |
+
max_seqlen_k = ctx.max_seqlen_k
|
890 |
+
sm_scale = ctx.sm_scale
|
891 |
+
kernel_size = ctx.kernel_size
|
892 |
+
kernel_stride = ctx.kernel_stride
|
893 |
+
dq, dk, dv = _compressed_attention_bwd(
|
894 |
+
o,
|
895 |
+
do,
|
896 |
+
lse,
|
897 |
+
q,
|
898 |
+
k,
|
899 |
+
v,
|
900 |
+
kernel_size,
|
901 |
+
kernel_stride,
|
902 |
+
cu_seqlens_q,
|
903 |
+
cu_seqlens_k,
|
904 |
+
max_seqlen_q,
|
905 |
+
max_seqlen_k,
|
906 |
+
sm_scale,
|
907 |
+
)
|
908 |
+
return dq, dk, dv, None, None, None, None, None, None, None
|
909 |
+
|
910 |
+
|
911 |
+
@triton.jit
|
912 |
+
def score_kernel(
|
913 |
+
q_ptr,
|
914 |
+
k_ptr,
|
915 |
+
lse_ptr,
|
916 |
+
s_ptr,
|
917 |
+
kernel_size,
|
918 |
+
kernel_stride,
|
919 |
+
# seqlens
|
920 |
+
cu_seqlens_q,
|
921 |
+
cu_seqlens_k,
|
922 |
+
# shape
|
923 |
+
NUM_KV_HEADS,
|
924 |
+
NUM_SHARE_Q_HEADS,
|
925 |
+
HEAD_DIM,
|
926 |
+
# sm_scale
|
927 |
+
sm_scale,
|
928 |
+
# stride
|
929 |
+
stride_qn,
|
930 |
+
stride_qh,
|
931 |
+
stride_qd,
|
932 |
+
stride_kn,
|
933 |
+
stride_kh,
|
934 |
+
stride_kd,
|
935 |
+
stride_lh,
|
936 |
+
stride_ln,
|
937 |
+
stride_sh,
|
938 |
+
stride_sq,
|
939 |
+
stride_sk,
|
940 |
+
# META parameters
|
941 |
+
BLOCK_SIZE_Q: tl.constexpr, # q block size
|
942 |
+
BLOCK_SIZE_K: tl.constexpr, # k block size
|
943 |
+
BLOCK_SIZE_D: tl.constexpr,
|
944 |
+
):
|
945 |
+
qk_scale = sm_scale * 1.44269504
|
946 |
+
# get batch id and head id
|
947 |
+
pid_bkh = tl.program_id(0)
|
948 |
+
pid_b = pid_bkh // NUM_KV_HEADS
|
949 |
+
pid_kh = pid_bkh % NUM_KV_HEADS
|
950 |
+
pid_q = tl.program_id(1)
|
951 |
+
pid_k = tl.program_id(2)
|
952 |
+
# get q k start and len after rmpad
|
953 |
+
q_start = tl.load(cu_seqlens_q + pid_b)
|
954 |
+
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
|
955 |
+
k_start = tl.load(cu_seqlens_k + pid_b)
|
956 |
+
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
|
957 |
+
if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len:
|
958 |
+
return
|
959 |
+
# init k pointer and load k
|
960 |
+
k_ptrs = tl.make_block_ptr(
|
961 |
+
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
|
962 |
+
shape=(HEAD_DIM, k_len),
|
963 |
+
strides=(stride_kd, stride_kn),
|
964 |
+
offsets=(0, pid_k * BLOCK_SIZE_K),
|
965 |
+
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
|
966 |
+
order=(0, 1),
|
967 |
+
)
|
968 |
+
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
|
969 |
+
# offsets
|
970 |
+
off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q
|
971 |
+
off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
|
972 |
+
causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :]
|
973 |
+
# init score
|
974 |
+
s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
|
975 |
+
# loop over gqa heads
|
976 |
+
for h in range(NUM_SHARE_Q_HEADS):
|
977 |
+
pid_h = pid_kh * NUM_SHARE_Q_HEADS + h
|
978 |
+
q_ptrs = tl.make_block_ptr(
|
979 |
+
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
|
980 |
+
shape=(q_len, HEAD_DIM),
|
981 |
+
strides=(stride_qn, stride_qd),
|
982 |
+
offsets=(pid_q * BLOCK_SIZE_Q, 0),
|
983 |
+
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
|
984 |
+
order=(1, 0),
|
985 |
+
)
|
986 |
+
lse_ptrs = tl.make_block_ptr(
|
987 |
+
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
|
988 |
+
shape=(q_len, 1),
|
989 |
+
strides=(stride_ln, stride_lh),
|
990 |
+
offsets=(pid_q * BLOCK_SIZE_Q, 0),
|
991 |
+
block_shape=(BLOCK_SIZE_Q, 1),
|
992 |
+
order=(0, 1),
|
993 |
+
)
|
994 |
+
# load q and lse
|
995 |
+
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
|
996 |
+
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
|
997 |
+
# compute qk
|
998 |
+
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
|
999 |
+
qk += tl.dot(q, k) * qk_scale
|
1000 |
+
# compute score
|
1001 |
+
s += tl.where(causal_mask, tl.exp2(qk - lse), 0)
|
1002 |
+
# save output
|
1003 |
+
s_ptrs = tl.make_block_ptr(
|
1004 |
+
base=s_ptr + pid_kh * stride_sh + q_start * stride_sq,
|
1005 |
+
shape=(q_len, k_len),
|
1006 |
+
strides=(stride_sq, stride_sk),
|
1007 |
+
offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K),
|
1008 |
+
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K),
|
1009 |
+
order=(1, 0),
|
1010 |
+
)
|
1011 |
+
tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1))
|
1012 |
+
|
1013 |
+
|
1014 |
+
def _get_attention_score(
|
1015 |
+
q: torch.Tensor, # [total_query_len, num_q_heads, head_dim]
|
1016 |
+
k: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
|
1017 |
+
lse: torch.Tensor, # [num_q_heads, total_query_len]
|
1018 |
+
kernel_size: int,
|
1019 |
+
kernel_stride: int,
|
1020 |
+
cu_seqlens_q: torch.Tensor,
|
1021 |
+
cu_seqlens_k: torch.Tensor,
|
1022 |
+
max_seqlen_q: int,
|
1023 |
+
max_seqlen_k: int,
|
1024 |
+
sm_scale: float,
|
1025 |
+
) -> torch.Tensor:
|
1026 |
+
# dtype check
|
1027 |
+
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
|
1028 |
+
assert q.dtype == k.dtype
|
1029 |
+
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
|
1030 |
+
assert (
|
1031 |
+
lse.dtype == torch.float32
|
1032 |
+
) # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale)))
|
1033 |
+
# shape
|
1034 |
+
q_len, num_q_heads, head_dim = q.shape
|
1035 |
+
k_len, num_k_heads, head_dim = k.shape
|
1036 |
+
batch_size = cu_seqlens_q.shape[0] - 1
|
1037 |
+
assert q_len > k_len
|
1038 |
+
if sm_scale is None:
|
1039 |
+
sm_scale = 1 / math.sqrt(head_dim)
|
1040 |
+
# gqa
|
1041 |
+
assert num_q_heads % num_k_heads == 0
|
1042 |
+
num_share_q_heads = num_q_heads // num_k_heads
|
1043 |
+
# init score
|
1044 |
+
score = torch.zeros(
|
1045 |
+
num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device
|
1046 |
+
)
|
1047 |
+
# launch kernel
|
1048 |
+
grid = lambda META: (
|
1049 |
+
batch_size * num_k_heads,
|
1050 |
+
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
|
1051 |
+
triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
|
1052 |
+
)
|
1053 |
+
BLOCK_SIZE_Q = 128
|
1054 |
+
BLOCK_SIZE_K = 128
|
1055 |
+
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
|
1056 |
+
score_kernel[grid](
|
1057 |
+
q,
|
1058 |
+
k,
|
1059 |
+
lse,
|
1060 |
+
score,
|
1061 |
+
kernel_size,
|
1062 |
+
kernel_stride,
|
1063 |
+
cu_seqlens_q,
|
1064 |
+
cu_seqlens_k,
|
1065 |
+
num_k_heads,
|
1066 |
+
num_share_q_heads,
|
1067 |
+
head_dim,
|
1068 |
+
sm_scale,
|
1069 |
+
q.stride(0),
|
1070 |
+
q.stride(1),
|
1071 |
+
q.stride(2),
|
1072 |
+
k.stride(0),
|
1073 |
+
k.stride(1),
|
1074 |
+
k.stride(2),
|
1075 |
+
lse.stride(0),
|
1076 |
+
lse.stride(1),
|
1077 |
+
score.stride(0),
|
1078 |
+
score.stride(1),
|
1079 |
+
score.stride(2),
|
1080 |
+
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
1081 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
1082 |
+
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
1083 |
+
num_warps=8,
|
1084 |
+
num_stages=3,
|
1085 |
+
)
|
1086 |
+
return score
|
1087 |
+
|
1088 |
+
|
1089 |
+
@triton.jit
|
1090 |
+
def _transform_score_kernel(
|
1091 |
+
s_ptr, # score, shape: [num_heads, q_len, k_len]
|
1092 |
+
bs_ptr, # block wise score: [num_heads, q_len, num_k_block]
|
1093 |
+
offs,
|
1094 |
+
cu_seqlens_q,
|
1095 |
+
# shape
|
1096 |
+
num_heads,
|
1097 |
+
num_offs,
|
1098 |
+
max_k_len,
|
1099 |
+
max_blocks,
|
1100 |
+
pad_len,
|
1101 |
+
# kernel & block size
|
1102 |
+
block_size,
|
1103 |
+
block_stride, # block_size // kernel_stride
|
1104 |
+
init_blocks,
|
1105 |
+
local_blocks,
|
1106 |
+
# stride
|
1107 |
+
stride_sh,
|
1108 |
+
stride_sq,
|
1109 |
+
stride_sk,
|
1110 |
+
stride_bsh,
|
1111 |
+
stride_bsq,
|
1112 |
+
stride_bsk,
|
1113 |
+
BLOCK_SIZE_Q: tl.constexpr,
|
1114 |
+
BLOCK_SIZE_K: tl.constexpr,
|
1115 |
+
BLOCK_SIZE_O: tl.constexpr,
|
1116 |
+
):
|
1117 |
+
pid_bh = tl.program_id(0)
|
1118 |
+
pid_b = pid_bh // num_heads
|
1119 |
+
pid_h = pid_bh % num_heads
|
1120 |
+
pid_q = tl.program_id(1)
|
1121 |
+
pid_k = tl.program_id(2)
|
1122 |
+
q_start = tl.load(cu_seqlens_q + pid_b)
|
1123 |
+
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
|
1124 |
+
k_start = pid_k * BLOCK_SIZE_K
|
1125 |
+
if pid_q * BLOCK_SIZE_Q >= q_len:
|
1126 |
+
return
|
1127 |
+
# load weight
|
1128 |
+
off_o = tl.arange(0, BLOCK_SIZE_O)
|
1129 |
+
w = tl.load(offs + off_o, mask=off_o < num_offs, other=0)
|
1130 |
+
# load score
|
1131 |
+
off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)
|
1132 |
+
off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len
|
1133 |
+
off_k = off_k[None, :] + off_o[:, None]
|
1134 |
+
s_ptrs = (
|
1135 |
+
s_ptr
|
1136 |
+
+ q_start * stride_sq
|
1137 |
+
+ pid_h * stride_sh
|
1138 |
+
+ off_q[:, None, None] * stride_sq
|
1139 |
+
+ off_k[None, :, :] * stride_sk
|
1140 |
+
)
|
1141 |
+
# weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK]
|
1142 |
+
s = tl.load(
|
1143 |
+
s_ptrs,
|
1144 |
+
mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len),
|
1145 |
+
other=0,
|
1146 |
+
)
|
1147 |
+
s = s * w[None, :, None]
|
1148 |
+
s = tl.max(s, axis=1)
|
1149 |
+
# init mask and local mask
|
1150 |
+
off_bq = off_q // block_size
|
1151 |
+
off_bk = tl.arange(0, BLOCK_SIZE_K)
|
1152 |
+
|
1153 |
+
s = tl.where(
|
1154 |
+
# For local blocks: set to negative infinity (exclude from topk)
|
1155 |
+
(off_bq[:, None] >= (off_bk + k_start)[None, :]) & (off_bq[:, None] < (off_bk + k_start)[None, :] + local_blocks),
|
1156 |
+
float("-inf"),
|
1157 |
+
s,
|
1158 |
+
)
|
1159 |
+
|
1160 |
+
# Keep the original conditions for init_blocks and query location as infinity
|
1161 |
+
s = tl.where(
|
1162 |
+
(off_bk[None, :] < init_blocks - k_start)
|
1163 |
+
# Force blocks where the query is located to have infinite score (always include in topk)
|
1164 |
+
| (off_bq[:, None] == (off_bk + k_start)[None, :]),
|
1165 |
+
float("inf"),
|
1166 |
+
s,
|
1167 |
+
)
|
1168 |
+
# store block wise score
|
1169 |
+
bs_ptrs = (
|
1170 |
+
bs_ptr
|
1171 |
+
+ q_start * stride_bsq
|
1172 |
+
+ k_start * stride_bsk
|
1173 |
+
+ pid_h * stride_bsh
|
1174 |
+
+ off_q[:, None] * stride_bsq
|
1175 |
+
+ off_bk[None, :] * stride_bsk
|
1176 |
+
)
|
1177 |
+
tl.store(
|
1178 |
+
bs_ptrs,
|
1179 |
+
s,
|
1180 |
+
mask=(off_q < q_len)[:, None] & (off_bk < max_blocks - k_start)[None, :],
|
1181 |
+
)
|
1182 |
+
|
1183 |
+
|
1184 |
+
def transform_score(
|
1185 |
+
score: torch.Tensor,
|
1186 |
+
kernel_size: int,
|
1187 |
+
kernel_stride: int,
|
1188 |
+
block_size: int,
|
1189 |
+
cu_seqlens_q: torch.Tensor,
|
1190 |
+
cu_seqlens_k: torch.Tensor,
|
1191 |
+
max_seqlen_q: int,
|
1192 |
+
max_seqlen_k: int,
|
1193 |
+
init_blocks: int = 1,
|
1194 |
+
local_blocks: int = 2,
|
1195 |
+
) -> torch.Tensor:
|
1196 |
+
num_k_heads, total_query_len, max_key_len = score.shape
|
1197 |
+
batch_size = cu_seqlens_q.shape[0] - 1
|
1198 |
+
pad_len = kernel_size // kernel_stride - 1
|
1199 |
+
max_blocks = math.ceil(max_seqlen_q / block_size)
|
1200 |
+
block_score = torch.zeros(
|
1201 |
+
num_k_heads,
|
1202 |
+
total_query_len,
|
1203 |
+
max_blocks,
|
1204 |
+
dtype=torch.float32,
|
1205 |
+
device=score.device,
|
1206 |
+
)
|
1207 |
+
offs = (
|
1208 |
+
torch.arange(kernel_size // kernel_stride, device=score.device)[:, None]
|
1209 |
+
+ torch.arange(block_size // kernel_stride, device=score.device)[None, :]
|
1210 |
+
).view(-1)
|
1211 |
+
offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max())
|
1212 |
+
num_offs = int(offs.shape[0])
|
1213 |
+
BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks))
|
1214 |
+
BLOCK_SIZE_O = triton.next_power_of_2(num_offs)
|
1215 |
+
BLOCK_SIZE_Q = 8
|
1216 |
+
grid = (
|
1217 |
+
num_k_heads * batch_size,
|
1218 |
+
triton.cdiv(total_query_len, BLOCK_SIZE_Q),
|
1219 |
+
triton.cdiv(max_blocks, BLOCK_SIZE_K),
|
1220 |
+
)
|
1221 |
+
_transform_score_kernel[grid](
|
1222 |
+
score,
|
1223 |
+
block_score,
|
1224 |
+
torch.ones_like(offs, dtype=offs.dtype,device=offs.device), #! 为了max 就不用wieght了
|
1225 |
+
cu_seqlens_q,
|
1226 |
+
num_k_heads,
|
1227 |
+
offs.shape[0],
|
1228 |
+
max_key_len,
|
1229 |
+
max_blocks,
|
1230 |
+
pad_len,
|
1231 |
+
block_size,
|
1232 |
+
block_size // kernel_stride,
|
1233 |
+
init_blocks,
|
1234 |
+
local_blocks,
|
1235 |
+
score.stride(0),
|
1236 |
+
score.stride(1),
|
1237 |
+
score.stride(2),
|
1238 |
+
block_score.stride(0),
|
1239 |
+
block_score.stride(1),
|
1240 |
+
block_score.stride(2),
|
1241 |
+
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
|
1242 |
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
1243 |
+
BLOCK_SIZE_O=BLOCK_SIZE_O,
|
1244 |
+
num_warps=8,
|
1245 |
+
num_stages=3,
|
1246 |
+
)
|
1247 |
+
return block_score
|
1248 |
+
|
1249 |
+
|
1250 |
+
def compressed_attention(
|
1251 |
+
q: torch.Tensor,
|
1252 |
+
k: torch.Tensor,
|
1253 |
+
v: torch.Tensor,
|
1254 |
+
kernel_size: int,
|
1255 |
+
kernel_stride: int,
|
1256 |
+
block_size: int,
|
1257 |
+
topk: int,
|
1258 |
+
cu_seqlens_q: torch.Tensor,
|
1259 |
+
cu_seqlens_k: torch.Tensor,
|
1260 |
+
max_seqlen_q: int,
|
1261 |
+
max_seqlen_k: int,
|
1262 |
+
sm_scale: float = None,
|
1263 |
+
init_blocks: int = 1,
|
1264 |
+
local_blocks: int = 2,
|
1265 |
+
parallel_topk_compute: Union[str, bool] = "auto",
|
1266 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1267 |
+
"""Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
|
1268 |
+
|
1269 |
+
Args:
|
1270 |
+
q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
|
1271 |
+
k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
|
1272 |
+
v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
|
1273 |
+
kernel_size (int): kernel size in compress_key_value
|
1274 |
+
kernel_stride (int): stride of compress_key_value
|
1275 |
+
block_size (int): key value block size for topk sparse attention.
|
1276 |
+
topk (int): number of blocks for each query.
|
1277 |
+
cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
|
1278 |
+
cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
|
1279 |
+
max_seqlen_q (int): max q len of the batch.
|
1280 |
+
max_seqlen_k (int): max k len of the batch.
|
1281 |
+
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
|
1282 |
+
init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
|
1283 |
+
local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
|
1284 |
+
parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug.
|
1285 |
+
We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise.
|
1286 |
+
|
1287 |
+
Returns:
|
1288 |
+
Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
|
1289 |
+
"""
|
1290 |
+
if max_seqlen_q is None:
|
1291 |
+
max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item()
|
1292 |
+
if max_seqlen_k is None:
|
1293 |
+
max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()
|
1294 |
+
attn_output, lse = CompressedAttention.apply(
|
1295 |
+
q,
|
1296 |
+
k,
|
1297 |
+
v,
|
1298 |
+
kernel_size,
|
1299 |
+
kernel_stride,
|
1300 |
+
cu_seqlens_q,
|
1301 |
+
cu_seqlens_k,
|
1302 |
+
max_seqlen_q,
|
1303 |
+
max_seqlen_k,
|
1304 |
+
sm_scale,
|
1305 |
+
)
|
1306 |
+
|
1307 |
+
# do not select topk index
|
1308 |
+
if topk <= 0:
|
1309 |
+
warnings.warn("topk <= 0, returned topk_idx will be None")
|
1310 |
+
return attn_output, None
|
1311 |
+
|
1312 |
+
assert topk >= init_blocks #+ local_blocks
|
1313 |
+
with torch.no_grad():
|
1314 |
+
num_k_heads, num_q_heads = k.shape[1], q.shape[1]
|
1315 |
+
num_shared_q_heads = num_q_heads // num_k_heads
|
1316 |
+
batch_size = cu_seqlens_q.shape[0] - 1
|
1317 |
+
q_idx = torch.cat(
|
1318 |
+
[
|
1319 |
+
torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device)
|
1320 |
+
for i in range(batch_size)
|
1321 |
+
],
|
1322 |
+
dim=0,
|
1323 |
+
)
|
1324 |
+
q_idx = q_idx // block_size
|
1325 |
+
# whether to use parallel version
|
1326 |
+
if parallel_topk_compute == "auto":
|
1327 |
+
parallel_topk_compute = cu_seqlens_q[-1] <= 32768
|
1328 |
+
# parallel version
|
1329 |
+
if parallel_topk_compute:
|
1330 |
+
# recompute score
|
1331 |
+
score = _get_attention_score(
|
1332 |
+
q,
|
1333 |
+
k,
|
1334 |
+
lse,
|
1335 |
+
kernel_size,
|
1336 |
+
kernel_stride,
|
1337 |
+
cu_seqlens_q,
|
1338 |
+
cu_seqlens_k,
|
1339 |
+
max_seqlen_q,
|
1340 |
+
max_seqlen_k,
|
1341 |
+
sm_scale,
|
1342 |
+
)
|
1343 |
+
# transform score to block-wise score
|
1344 |
+
score = transform_score(
|
1345 |
+
score,
|
1346 |
+
kernel_size,
|
1347 |
+
kernel_stride,
|
1348 |
+
block_size,
|
1349 |
+
cu_seqlens_q,
|
1350 |
+
cu_seqlens_k,
|
1351 |
+
max_seqlen_q,
|
1352 |
+
max_seqlen_k,
|
1353 |
+
init_blocks,
|
1354 |
+
local_blocks,
|
1355 |
+
)
|
1356 |
+
# get topk
|
1357 |
+
topk = min(topk, score.shape[-1])
|
1358 |
+
topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
|
1359 |
+
# print(cu_seqlens_q)
|
1360 |
+
# breakpoint()
|
1361 |
+
topk_idx[topk_idx >= q_idx[None, :, None]] = -1
|
1362 |
+
topk_idx = topk_idx.to(torch.int32)
|
1363 |
+
# non parallel version, avoid some current bugs when sequence length is too long
|
1364 |
+
# FIXME: need to fix later
|
1365 |
+
else:
|
1366 |
+
topk_idx_list = []
|
1367 |
+
for h in range(num_k_heads):
|
1368 |
+
# recompute score
|
1369 |
+
score = _get_attention_score(
|
1370 |
+
q[:, h * num_shared_q_heads : (h + 1) * num_shared_q_heads],
|
1371 |
+
k[:, h : h + 1],
|
1372 |
+
lse[h * num_shared_q_heads : (h + 1) * num_shared_q_heads],
|
1373 |
+
kernel_size,
|
1374 |
+
kernel_stride,
|
1375 |
+
cu_seqlens_q,
|
1376 |
+
cu_seqlens_k,
|
1377 |
+
max_seqlen_q,
|
1378 |
+
max_seqlen_k,
|
1379 |
+
sm_scale,
|
1380 |
+
)
|
1381 |
+
# transform score to block-wise score
|
1382 |
+
score = transform_score(
|
1383 |
+
score,
|
1384 |
+
kernel_size,
|
1385 |
+
kernel_stride,
|
1386 |
+
block_size,
|
1387 |
+
cu_seqlens_q,
|
1388 |
+
cu_seqlens_k,
|
1389 |
+
max_seqlen_q,
|
1390 |
+
max_seqlen_k,
|
1391 |
+
init_blocks,
|
1392 |
+
local_blocks,
|
1393 |
+
)
|
1394 |
+
# get topk
|
1395 |
+
topk = min(topk, score.shape[-1])
|
1396 |
+
topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
|
1397 |
+
topk_idx[topk_idx >= q_idx[None, :, None]] = -1
|
1398 |
+
topk_idx = topk_idx.to(torch.int32)
|
1399 |
+
topk_idx_list.append(topk_idx)
|
1400 |
+
topk_idx = torch.cat(topk_idx_list, dim=0)
|
1401 |
+
return attn_output, topk_idx
|
config.json
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/share_data/data7/fanshengda/mcp-agent/minicpm4_sft/mcp_summary/checkpoint-25000",
|
3 |
+
"architectures": [
|
4 |
+
"MiniCPMForCausalLM"
|
5 |
+
],
|
6 |
+
"attention_bias": false,
|
7 |
+
"attention_dropout": 0.0,
|
8 |
+
"auto_map": {
|
9 |
+
"AutoConfig": "configuration_minicpm.MiniCPMConfig",
|
10 |
+
"AutoModel": "modeling_minicpm.MiniCPMForCausalLM",
|
11 |
+
"AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM",
|
12 |
+
"AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM",
|
13 |
+
"AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification"
|
14 |
+
},
|
15 |
+
"bos_token_id": 1,
|
16 |
+
"dim_model_base": 256,
|
17 |
+
"eos_token_id": [
|
18 |
+
2,
|
19 |
+
73440
|
20 |
+
],
|
21 |
+
"hidden_act": "silu",
|
22 |
+
"hidden_size": 4096,
|
23 |
+
"initializer_range": 0.1,
|
24 |
+
"intermediate_size": 16384,
|
25 |
+
"max_position_embeddings": 32768,
|
26 |
+
"model_type": "minicpm",
|
27 |
+
"num_attention_heads": 32,
|
28 |
+
"num_hidden_layers": 32,
|
29 |
+
"num_key_value_heads": 2,
|
30 |
+
"pad_token_id": 2,
|
31 |
+
"pretraining_tp": 1,
|
32 |
+
"rms_norm_eps": 1e-06,
|
33 |
+
"rope_scaling": {
|
34 |
+
"long_factor": [
|
35 |
+
0.9977997200264581,
|
36 |
+
1.014658295992452,
|
37 |
+
1.0349680404997148,
|
38 |
+
1.059429246056193,
|
39 |
+
1.0888815016813513,
|
40 |
+
1.1243301355211495,
|
41 |
+
1.166977103606075,
|
42 |
+
1.2182568066927284,
|
43 |
+
1.2798772354275727,
|
44 |
+
1.3538666751582975,
|
45 |
+
1.4426259039919596,
|
46 |
+
1.5489853358570191,
|
47 |
+
1.6762658237220625,
|
48 |
+
1.8283407612492941,
|
49 |
+
2.0096956085876183,
|
50 |
+
2.225478927469756,
|
51 |
+
2.481536379650452,
|
52 |
+
2.784415934557119,
|
53 |
+
3.1413289096347365,
|
54 |
+
3.560047844772632,
|
55 |
+
4.048719380066383,
|
56 |
+
4.615569542115128,
|
57 |
+
5.2684819496549835,
|
58 |
+
6.014438591970396,
|
59 |
+
6.858830049237097,
|
60 |
+
7.804668263503327,
|
61 |
+
8.851768731513417,
|
62 |
+
9.99600492938444,
|
63 |
+
11.228766118181639,
|
64 |
+
12.536757560834843,
|
65 |
+
13.902257701387796,
|
66 |
+
15.303885189125953,
|
67 |
+
16.717837610115794,
|
68 |
+
18.119465097853947,
|
69 |
+
19.484965238406907,
|
70 |
+
20.792956681060105,
|
71 |
+
22.02571786985731,
|
72 |
+
23.16995406772833,
|
73 |
+
24.217054535738416,
|
74 |
+
25.16289275000465,
|
75 |
+
26.007284207271347,
|
76 |
+
26.753240849586767,
|
77 |
+
27.40615325712662,
|
78 |
+
27.973003419175363,
|
79 |
+
28.461674954469114,
|
80 |
+
28.880393889607006,
|
81 |
+
29.237306864684626,
|
82 |
+
29.540186419591297,
|
83 |
+
29.79624387177199,
|
84 |
+
30.01202719065413,
|
85 |
+
30.193382037992453,
|
86 |
+
30.34545697551969,
|
87 |
+
30.47273746338473,
|
88 |
+
30.579096895249787,
|
89 |
+
30.66785612408345,
|
90 |
+
30.741845563814174,
|
91 |
+
30.80346599254902,
|
92 |
+
30.85474569563567,
|
93 |
+
30.897392663720595,
|
94 |
+
30.932841297560394,
|
95 |
+
30.962293553185553,
|
96 |
+
30.986754758742034,
|
97 |
+
31.007064503249293,
|
98 |
+
31.02392307921529
|
99 |
+
],
|
100 |
+
"original_max_position_embeddings": 32768,
|
101 |
+
"rope_type": "longrope",
|
102 |
+
"short_factor": [
|
103 |
+
0.9977997200264581,
|
104 |
+
1.014658295992452,
|
105 |
+
1.0349680404997148,
|
106 |
+
1.059429246056193,
|
107 |
+
1.0888815016813513,
|
108 |
+
1.1243301355211495,
|
109 |
+
1.166977103606075,
|
110 |
+
1.2182568066927284,
|
111 |
+
1.2798772354275727,
|
112 |
+
1.3538666751582975,
|
113 |
+
1.4426259039919596,
|
114 |
+
1.5489853358570191,
|
115 |
+
1.6762658237220625,
|
116 |
+
1.8283407612492941,
|
117 |
+
2.0096956085876183,
|
118 |
+
2.225478927469756,
|
119 |
+
2.481536379650452,
|
120 |
+
2.784415934557119,
|
121 |
+
3.1413289096347365,
|
122 |
+
3.560047844772632,
|
123 |
+
4.048719380066383,
|
124 |
+
4.615569542115128,
|
125 |
+
5.2684819496549835,
|
126 |
+
6.014438591970396,
|
127 |
+
6.858830049237097,
|
128 |
+
7.804668263503327,
|
129 |
+
8.851768731513417,
|
130 |
+
9.99600492938444,
|
131 |
+
11.228766118181639,
|
132 |
+
12.536757560834843,
|
133 |
+
13.902257701387796,
|
134 |
+
15.303885189125953,
|
135 |
+
16.717837610115794,
|
136 |
+
18.119465097853947,
|
137 |
+
19.484965238406907,
|
138 |
+
20.792956681060105,
|
139 |
+
22.02571786985731,
|
140 |
+
23.16995406772833,
|
141 |
+
24.217054535738416,
|
142 |
+
25.16289275000465,
|
143 |
+
26.007284207271347,
|
144 |
+
26.753240849586767,
|
145 |
+
27.40615325712662,
|
146 |
+
27.973003419175363,
|
147 |
+
28.461674954469114,
|
148 |
+
28.880393889607006,
|
149 |
+
29.237306864684626,
|
150 |
+
29.540186419591297,
|
151 |
+
29.79624387177199,
|
152 |
+
30.01202719065413,
|
153 |
+
30.193382037992453,
|
154 |
+
30.34545697551969,
|
155 |
+
30.47273746338473,
|
156 |
+
30.579096895249787,
|
157 |
+
30.66785612408345,
|
158 |
+
30.741845563814174,
|
159 |
+
30.80346599254902,
|
160 |
+
30.85474569563567,
|
161 |
+
30.897392663720595,
|
162 |
+
30.932841297560394,
|
163 |
+
30.962293553185553,
|
164 |
+
30.986754758742034,
|
165 |
+
31.007064503249293,
|
166 |
+
31.02392307921529
|
167 |
+
]
|
168 |
+
},
|
169 |
+
"rope_theta": 10000.0,
|
170 |
+
"scale_depth": 1.4,
|
171 |
+
"scale_emb": 12,
|
172 |
+
"tie_word_embeddings": false,
|
173 |
+
"torch_dtype": "bfloat16",
|
174 |
+
"transformers_version": "4.49.0",
|
175 |
+
"use_cache": true,
|
176 |
+
"vocab_size": 73448
|
177 |
+
}
|
configuration_minicpm.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
""" MiniCPM model configuration"""
|
21 |
+
|
22 |
+
from transformers.configuration_utils import PretrainedConfig
|
23 |
+
from transformers.utils import logging
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.get_logger(__name__)
|
27 |
+
|
28 |
+
MINICPM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
29 |
+
|
30 |
+
|
31 |
+
class MiniCPMConfig(PretrainedConfig):
|
32 |
+
r"""
|
33 |
+
This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM
|
34 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
35 |
+
defaults will yield a similar configuration to that of the MiniCPM-7B.
|
36 |
+
|
37 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
38 |
+
documentation from [`PretrainedConfig`] for more information.
|
39 |
+
|
40 |
+
|
41 |
+
Args:
|
42 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
43 |
+
Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the
|
44 |
+
`inputs_ids` passed when calling [`MiniCPMModel`]
|
45 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
46 |
+
Dimension of the hidden representations.
|
47 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
48 |
+
Dimension of the MLP representations.
|
49 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
50 |
+
Number of hidden layers in the Transformer decoder.
|
51 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
52 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
53 |
+
num_key_value_heads (`int`, *optional*):
|
54 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
55 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
56 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
57 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
58 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
59 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
60 |
+
`num_attention_heads`.
|
61 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
62 |
+
The non-linear activation function (function or string) in the decoder.
|
63 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
64 |
+
The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens,
|
65 |
+
MiniCPM 2 up to 4096, CodeMiniCPM up to 16384.
|
66 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
67 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
68 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
69 |
+
The epsilon used by the rms normalization layers.
|
70 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
71 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
72 |
+
relevant if `config.is_decoder=True`.
|
73 |
+
pad_token_id (`int`, *optional*):
|
74 |
+
Padding token id.
|
75 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
76 |
+
Beginning of stream token id.
|
77 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
78 |
+
End of stream token id.
|
79 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
80 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
81 |
+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
82 |
+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
83 |
+
issue](https://github.com/pytorch/pytorch/issues/76232).
|
84 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
85 |
+
Whether to tie weight embeddings
|
86 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
87 |
+
The base period of the RoPE embeddings.
|
88 |
+
rope_scaling (`Dict`, *optional*):
|
89 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
90 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
91 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
92 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
93 |
+
these scaling strategies behave:
|
94 |
+
https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
95 |
+
experimental feature, subject to breaking API changes in future versions.
|
96 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
97 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
98 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
99 |
+
The dropout ratio for the attention probabilities.
|
100 |
+
|
101 |
+
```python
|
102 |
+
>>> from transformers import MiniCPMModel, MiniCPMConfig
|
103 |
+
|
104 |
+
>>> # Initializing a MiniCPM minicpm-7b style configuration
|
105 |
+
>>> configuration = MiniCPMConfig()
|
106 |
+
|
107 |
+
>>> # Initializing a model from the minicpm-7b style configuration
|
108 |
+
>>> model = MiniCPMModel(configuration)
|
109 |
+
|
110 |
+
>>> # Accessing the model configuration
|
111 |
+
>>> configuration = model.config
|
112 |
+
```"""
|
113 |
+
|
114 |
+
model_type = "minicpm"
|
115 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
116 |
+
|
117 |
+
def __init__(
|
118 |
+
self,
|
119 |
+
vocab_size=32000,
|
120 |
+
hidden_size=4096,
|
121 |
+
intermediate_size=11008,
|
122 |
+
num_hidden_layers=32,
|
123 |
+
num_attention_heads=32,
|
124 |
+
num_key_value_heads=None,
|
125 |
+
hidden_act="silu",
|
126 |
+
max_position_embeddings=2048,
|
127 |
+
initializer_range=0.02,
|
128 |
+
rms_norm_eps=1e-6,
|
129 |
+
use_cache=True,
|
130 |
+
pad_token_id=None,
|
131 |
+
bos_token_id=1,
|
132 |
+
eos_token_id=2,
|
133 |
+
pretraining_tp=1,
|
134 |
+
tie_word_embeddings=True,
|
135 |
+
rope_theta=10000.0,
|
136 |
+
rope_scaling=None,
|
137 |
+
attention_bias=False,
|
138 |
+
attention_dropout=0.0,
|
139 |
+
scale_emb=1,
|
140 |
+
dim_model_base=1,
|
141 |
+
scale_depth=1,
|
142 |
+
**kwargs,
|
143 |
+
):
|
144 |
+
self.vocab_size = vocab_size
|
145 |
+
self.max_position_embeddings = max_position_embeddings
|
146 |
+
self.hidden_size = hidden_size
|
147 |
+
self.intermediate_size = intermediate_size
|
148 |
+
self.num_hidden_layers = num_hidden_layers
|
149 |
+
self.num_attention_heads = num_attention_heads
|
150 |
+
|
151 |
+
# for backward compatibility
|
152 |
+
if num_key_value_heads is None:
|
153 |
+
num_key_value_heads = num_attention_heads
|
154 |
+
|
155 |
+
self.num_key_value_heads = num_key_value_heads
|
156 |
+
self.hidden_act = hidden_act
|
157 |
+
self.initializer_range = initializer_range
|
158 |
+
self.rms_norm_eps = rms_norm_eps
|
159 |
+
self.pretraining_tp = pretraining_tp
|
160 |
+
self.use_cache = use_cache
|
161 |
+
self.rope_theta = rope_theta
|
162 |
+
self.rope_scaling = rope_scaling
|
163 |
+
# self._rope_scaling_validation()
|
164 |
+
self.attention_bias = attention_bias
|
165 |
+
self.attention_dropout = attention_dropout
|
166 |
+
self.scale_emb = scale_emb
|
167 |
+
self.dim_model_base = dim_model_base
|
168 |
+
self.scale_depth = scale_depth
|
169 |
+
|
170 |
+
super().__init__(
|
171 |
+
pad_token_id=pad_token_id,
|
172 |
+
bos_token_id=bos_token_id,
|
173 |
+
eos_token_id=eos_token_id,
|
174 |
+
tie_word_embeddings=tie_word_embeddings,
|
175 |
+
**kwargs,
|
176 |
+
)
|
177 |
+
try:
|
178 |
+
import flash_attn
|
179 |
+
self._attn_implementation = "flash_attention_2"
|
180 |
+
except:
|
181 |
+
pass
|
182 |
+
|
183 |
+
def _rope_scaling_validation(self):
|
184 |
+
"""
|
185 |
+
Validate the `rope_scaling` configuration.
|
186 |
+
"""
|
187 |
+
if self.rope_scaling is None:
|
188 |
+
return
|
189 |
+
|
190 |
+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
191 |
+
raise ValueError(
|
192 |
+
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
193 |
+
f"got {self.rope_scaling}"
|
194 |
+
)
|
195 |
+
rope_scaling_type = self.rope_scaling.get("type", None)
|
196 |
+
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
197 |
+
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
198 |
+
raise ValueError(
|
199 |
+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
200 |
+
)
|
201 |
+
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
202 |
+
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
generation_config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token_id": 1,
|
3 |
+
"do_sample": true,
|
4 |
+
"eos_token_id": [
|
5 |
+
2,
|
6 |
+
73440
|
7 |
+
],
|
8 |
+
"pad_token_id": 2,
|
9 |
+
"temperature": 0.8,
|
10 |
+
"top_p": 0.8,
|
11 |
+
"transformers_version": "4.49.0"
|
12 |
+
}
|
model-00001-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8889121190f1249d223ba24a5cbf50c959a6c4cc2456d5cf3448e5463ddce882
|
3 |
+
size 3990806216
|
model-00002-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e513ad5fb6f8180cceb278d5ac2ad3d677694cbf84b725187e047774677c50a1
|
3 |
+
size 3926008088
|
model-00003-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:059c3fa22026b886ad5169ed954f7e336320b605ab8973a6ea16308f7a61a597
|
3 |
+
size 3926008120
|
model-00004-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:548dc62dd17b6ae7b1d328a35a8d7ecb89a20e5dcbfc83d20c060d1cd45ddeba
|
3 |
+
size 3926033016
|
model-00005-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cbc464a995b4d5ad7188c965862202621af0b5c25e67d7078c220dc4a1e46d06
|
3 |
+
size 601686144
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 16370507776
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"lm_head.weight": "model-00005-of-00005.safetensors",
|
7 |
+
"model.embed_tokens.weight": "model-00001-of-00005.safetensors",
|
8 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00005.safetensors",
|
9 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
|
10 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
|
11 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
|
12 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
|
13 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
|
14 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
|
15 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
|
16 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
|
17 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00005.safetensors",
|
18 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
|
19 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
|
20 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
|
21 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
|
22 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
|
23 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
|
24 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
|
25 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
|
26 |
+
"model.layers.10.input_layernorm.weight": "model-00002-of-00005.safetensors",
|
27 |
+
"model.layers.10.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
|
28 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
|
29 |
+
"model.layers.10.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
|
30 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
|
31 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
|
32 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
|
33 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
|
34 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
|
35 |
+
"model.layers.11.input_layernorm.weight": "model-00002-of-00005.safetensors",
|
36 |
+
"model.layers.11.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
|
37 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
|
38 |
+
"model.layers.11.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
|
39 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
|
40 |
+
"model.layers.11.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
|
41 |
+
"model.layers.11.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
|
42 |
+
"model.layers.11.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
|
43 |
+
"model.layers.11.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
|
44 |
+
"model.layers.12.input_layernorm.weight": "model-00002-of-00005.safetensors",
|
45 |
+
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
|
46 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
|
47 |
+
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
|
48 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
|
49 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
|
50 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
|
51 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
|
52 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
|
53 |
+
"model.layers.13.input_layernorm.weight": "model-00002-of-00005.safetensors",
|
54 |
+
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
|
55 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
|
56 |
+
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
|
57 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
|
58 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
|
59 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
|
60 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
|
61 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
|
62 |
+
"model.layers.14.input_layernorm.weight": "model-00002-of-00005.safetensors",
|
63 |
+
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
|
64 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
|
65 |
+
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
|
66 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
|
67 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
|
68 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
|
69 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
|
70 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
|
71 |
+
"model.layers.15.input_layernorm.weight": "model-00003-of-00005.safetensors",
|
72 |
+
"model.layers.15.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
|
73 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
|
74 |
+
"model.layers.15.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
|
75 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
|
76 |
+
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
|
77 |
+
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
|
78 |
+
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
|
79 |
+
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
|
80 |
+
"model.layers.16.input_layernorm.weight": "model-00003-of-00005.safetensors",
|
81 |
+
"model.layers.16.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
|
82 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
|
83 |
+
"model.layers.16.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
|
84 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
|
85 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
|
86 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
|
87 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
|
88 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
|
89 |
+
"model.layers.17.input_layernorm.weight": "model-00003-of-00005.safetensors",
|
90 |
+
"model.layers.17.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
|
91 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
|
92 |
+
"model.layers.17.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
|
93 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
|
94 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
|
95 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
|
96 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
|
97 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
|
98 |
+
"model.layers.18.input_layernorm.weight": "model-00003-of-00005.safetensors",
|
99 |
+
"model.layers.18.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
|
100 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
|
101 |
+
"model.layers.18.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
|
102 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
|
103 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
|
104 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
|
105 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
|
106 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
|
107 |
+
"model.layers.19.input_layernorm.weight": "model-00003-of-00005.safetensors",
|
108 |
+
"model.layers.19.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
|
109 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
|
110 |
+
"model.layers.19.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
|
111 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
|
112 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
|
113 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
|
114 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
|
115 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
|
116 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00005.safetensors",
|
117 |
+
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
|
118 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
|
119 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
|
120 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
|
121 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
|
122 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
|
123 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
|
124 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
|
125 |
+
"model.layers.20.input_layernorm.weight": "model-00003-of-00005.safetensors",
|
126 |
+
"model.layers.20.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
|
127 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
|
128 |
+
"model.layers.20.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
|
129 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
|
130 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
|
131 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
|
132 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
|
133 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
|
134 |
+
"model.layers.21.input_layernorm.weight": "model-00003-of-00005.safetensors",
|
135 |
+
"model.layers.21.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
|
136 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
|
137 |
+
"model.layers.21.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
|
138 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
|
139 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
|
140 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
|
141 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
|
142 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
|
143 |
+
"model.layers.22.input_layernorm.weight": "model-00003-of-00005.safetensors",
|
144 |
+
"model.layers.22.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
|
145 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
|
146 |
+
"model.layers.22.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
|
147 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
|
148 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
|
149 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
|
150 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
|
151 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
|
152 |
+
"model.layers.23.input_layernorm.weight": "model-00004-of-00005.safetensors",
|
153 |
+
"model.layers.23.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
|
154 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
|
155 |
+
"model.layers.23.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
|
156 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
|
157 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00003-of-00005.safetensors",
|
158 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00003-of-00005.safetensors",
|
159 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00003-of-00005.safetensors",
|
160 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00003-of-00005.safetensors",
|
161 |
+
"model.layers.24.input_layernorm.weight": "model-00004-of-00005.safetensors",
|
162 |
+
"model.layers.24.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
|
163 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
|
164 |
+
"model.layers.24.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
|
165 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
|
166 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
|
167 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
|
168 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
|
169 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
|
170 |
+
"model.layers.25.input_layernorm.weight": "model-00004-of-00005.safetensors",
|
171 |
+
"model.layers.25.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
|
172 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
|
173 |
+
"model.layers.25.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
|
174 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
|
175 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
|
176 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
|
177 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
|
178 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
|
179 |
+
"model.layers.26.input_layernorm.weight": "model-00004-of-00005.safetensors",
|
180 |
+
"model.layers.26.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
|
181 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
|
182 |
+
"model.layers.26.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
|
183 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
|
184 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
|
185 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
|
186 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
|
187 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
|
188 |
+
"model.layers.27.input_layernorm.weight": "model-00004-of-00005.safetensors",
|
189 |
+
"model.layers.27.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
|
190 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
|
191 |
+
"model.layers.27.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
|
192 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
|
193 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
|
194 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
|
195 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
|
196 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
|
197 |
+
"model.layers.28.input_layernorm.weight": "model-00004-of-00005.safetensors",
|
198 |
+
"model.layers.28.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
|
199 |
+
"model.layers.28.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
|
200 |
+
"model.layers.28.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
|
201 |
+
"model.layers.28.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
|
202 |
+
"model.layers.28.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
|
203 |
+
"model.layers.28.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
|
204 |
+
"model.layers.28.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
|
205 |
+
"model.layers.28.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
|
206 |
+
"model.layers.29.input_layernorm.weight": "model-00004-of-00005.safetensors",
|
207 |
+
"model.layers.29.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
|
208 |
+
"model.layers.29.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
|
209 |
+
"model.layers.29.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
|
210 |
+
"model.layers.29.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
|
211 |
+
"model.layers.29.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
|
212 |
+
"model.layers.29.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
|
213 |
+
"model.layers.29.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
|
214 |
+
"model.layers.29.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
|
215 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00005.safetensors",
|
216 |
+
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
|
217 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
|
218 |
+
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
|
219 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
|
220 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
|
221 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
|
222 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
|
223 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
|
224 |
+
"model.layers.30.input_layernorm.weight": "model-00004-of-00005.safetensors",
|
225 |
+
"model.layers.30.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
|
226 |
+
"model.layers.30.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
|
227 |
+
"model.layers.30.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
|
228 |
+
"model.layers.30.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
|
229 |
+
"model.layers.30.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
|
230 |
+
"model.layers.30.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
|
231 |
+
"model.layers.30.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
|
232 |
+
"model.layers.30.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
|
233 |
+
"model.layers.31.input_layernorm.weight": "model-00004-of-00005.safetensors",
|
234 |
+
"model.layers.31.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
|
235 |
+
"model.layers.31.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
|
236 |
+
"model.layers.31.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
|
237 |
+
"model.layers.31.post_attention_layernorm.weight": "model-00004-of-00005.safetensors",
|
238 |
+
"model.layers.31.self_attn.k_proj.weight": "model-00004-of-00005.safetensors",
|
239 |
+
"model.layers.31.self_attn.o_proj.weight": "model-00004-of-00005.safetensors",
|
240 |
+
"model.layers.31.self_attn.q_proj.weight": "model-00004-of-00005.safetensors",
|
241 |
+
"model.layers.31.self_attn.v_proj.weight": "model-00004-of-00005.safetensors",
|
242 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00005.safetensors",
|
243 |
+
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
|
244 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
|
245 |
+
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
|
246 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
|
247 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
|
248 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
|
249 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
|
250 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
|
251 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00005.safetensors",
|
252 |
+
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
|
253 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
|
254 |
+
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
|
255 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
|
256 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
|
257 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
|
258 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
|
259 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
|
260 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00005.safetensors",
|
261 |
+
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00005.safetensors",
|
262 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
|
263 |
+
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
|
264 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
|
265 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
|
266 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
|
267 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
|
268 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
|
269 |
+
"model.layers.7.input_layernorm.weight": "model-00002-of-00005.safetensors",
|
270 |
+
"model.layers.7.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
|
271 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
|
272 |
+
"model.layers.7.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
|
273 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
|
274 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00005.safetensors",
|
275 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00005.safetensors",
|
276 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00005.safetensors",
|
277 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00005.safetensors",
|
278 |
+
"model.layers.8.input_layernorm.weight": "model-00002-of-00005.safetensors",
|
279 |
+
"model.layers.8.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
|
280 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
|
281 |
+
"model.layers.8.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
|
282 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
|
283 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
|
284 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
|
285 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
|
286 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
|
287 |
+
"model.layers.9.input_layernorm.weight": "model-00002-of-00005.safetensors",
|
288 |
+
"model.layers.9.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
|
289 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
|
290 |
+
"model.layers.9.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
|
291 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
|
292 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00002-of-00005.safetensors",
|
293 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00002-of-00005.safetensors",
|
294 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00002-of-00005.safetensors",
|
295 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00002-of-00005.safetensors",
|
296 |
+
"model.norm.weight": "model-00004-of-00005.safetensors"
|
297 |
+
}
|
298 |
+
}
|
modeling_minicpm.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
special_tokens_map.json
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|im_end|>",
|
4 |
+
"<|im_start|>",
|
5 |
+
"<|tool_call|>",
|
6 |
+
"<|execute_start|>",
|
7 |
+
"<|execute_end|>",
|
8 |
+
"<|fim_prefix|>",
|
9 |
+
"<|fim_middle|>",
|
10 |
+
"<|fim_suffix|>"
|
11 |
+
],
|
12 |
+
"bos_token": {
|
13 |
+
"content": "<s>",
|
14 |
+
"lstrip": false,
|
15 |
+
"normalized": false,
|
16 |
+
"rstrip": false,
|
17 |
+
"single_word": false
|
18 |
+
},
|
19 |
+
"eos_token": {
|
20 |
+
"content": "<|im_end|>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false
|
25 |
+
},
|
26 |
+
"pad_token": {
|
27 |
+
"content": "<|im_end|>",
|
28 |
+
"lstrip": false,
|
29 |
+
"normalized": false,
|
30 |
+
"rstrip": false,
|
31 |
+
"single_word": false
|
32 |
+
},
|
33 |
+
"unk_token": {
|
34 |
+
"content": "<unk>",
|
35 |
+
"lstrip": false,
|
36 |
+
"normalized": false,
|
37 |
+
"rstrip": false,
|
38 |
+
"single_word": false
|
39 |
+
}
|
40 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb74d51116831c3bf65db812c553f94ab0c88dcf97a5bbb37e3504f6d359c530
|
3 |
+
size 1181204
|
tokenizer_config.json
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"add_prefix_space": null,
|
5 |
+
"added_tokens_decoder": {
|
6 |
+
"0": {
|
7 |
+
"content": "<unk>",
|
8 |
+
"lstrip": false,
|
9 |
+
"normalized": false,
|
10 |
+
"rstrip": false,
|
11 |
+
"single_word": false,
|
12 |
+
"special": true
|
13 |
+
},
|
14 |
+
"1": {
|
15 |
+
"content": "<s>",
|
16 |
+
"lstrip": false,
|
17 |
+
"normalized": false,
|
18 |
+
"rstrip": false,
|
19 |
+
"single_word": false,
|
20 |
+
"special": true
|
21 |
+
},
|
22 |
+
"2": {
|
23 |
+
"content": "</s>",
|
24 |
+
"lstrip": false,
|
25 |
+
"normalized": false,
|
26 |
+
"rstrip": false,
|
27 |
+
"single_word": false,
|
28 |
+
"special": true
|
29 |
+
},
|
30 |
+
"73440": {
|
31 |
+
"content": "<|im_end|>",
|
32 |
+
"lstrip": false,
|
33 |
+
"normalized": false,
|
34 |
+
"rstrip": false,
|
35 |
+
"single_word": false,
|
36 |
+
"special": true
|
37 |
+
},
|
38 |
+
"73441": {
|
39 |
+
"content": "<|im_start|>",
|
40 |
+
"lstrip": false,
|
41 |
+
"normalized": false,
|
42 |
+
"rstrip": false,
|
43 |
+
"single_word": false,
|
44 |
+
"special": true
|
45 |
+
},
|
46 |
+
"73442": {
|
47 |
+
"content": "<|tool_call|>",
|
48 |
+
"lstrip": false,
|
49 |
+
"normalized": false,
|
50 |
+
"rstrip": false,
|
51 |
+
"single_word": false,
|
52 |
+
"special": true
|
53 |
+
},
|
54 |
+
"73443": {
|
55 |
+
"content": "<|execute_start|>",
|
56 |
+
"lstrip": false,
|
57 |
+
"normalized": false,
|
58 |
+
"rstrip": false,
|
59 |
+
"single_word": false,
|
60 |
+
"special": true
|
61 |
+
},
|
62 |
+
"73444": {
|
63 |
+
"content": "<|execute_end|>",
|
64 |
+
"lstrip": false,
|
65 |
+
"normalized": false,
|
66 |
+
"rstrip": false,
|
67 |
+
"single_word": false,
|
68 |
+
"special": true
|
69 |
+
},
|
70 |
+
"73445": {
|
71 |
+
"content": "<|fim_prefix|>",
|
72 |
+
"lstrip": false,
|
73 |
+
"normalized": false,
|
74 |
+
"rstrip": false,
|
75 |
+
"single_word": false,
|
76 |
+
"special": true
|
77 |
+
},
|
78 |
+
"73446": {
|
79 |
+
"content": "<|fim_middle|>",
|
80 |
+
"lstrip": false,
|
81 |
+
"normalized": false,
|
82 |
+
"rstrip": false,
|
83 |
+
"single_word": false,
|
84 |
+
"special": true
|
85 |
+
},
|
86 |
+
"73447": {
|
87 |
+
"content": "<|fim_suffix|>",
|
88 |
+
"lstrip": false,
|
89 |
+
"normalized": false,
|
90 |
+
"rstrip": false,
|
91 |
+
"single_word": false,
|
92 |
+
"special": true
|
93 |
+
}
|
94 |
+
},
|
95 |
+
"additional_special_tokens": [
|
96 |
+
"<|im_end|>",
|
97 |
+
"<|im_start|>",
|
98 |
+
"<|tool_call|>",
|
99 |
+
"<|execute_start|>",
|
100 |
+
"<|execute_end|>",
|
101 |
+
"<|fim_prefix|>",
|
102 |
+
"<|fim_middle|>",
|
103 |
+
"<|fim_suffix|>"
|
104 |
+
],
|
105 |
+
"bos_token": "<s>",
|
106 |
+
"chat_template": "{%- macro json_to_python_type(param_name, json_spec) %}\n{%- set basic_type_map = {\n 'string': 'str',\n 'number': 'float',\n 'integer': 'int',\n 'boolean': 'bool',\n 'null': 'None'\n} %}\n\n{%- if json_spec.enum %}\n {{- param_name|title }}\n{%- elif basic_type_map[json_spec.type] is defined %}\n {{- basic_type_map[json_spec.type] }}\n{%- elif json_spec.type == 'array' %}\n {{- 'List[' + json_to_python_type(param_name, json_spec['items']) + ']' }}\n{%- elif json_spec.type == 'object' %}\n {{- 'Dict[str, ' + json_to_python_type(param_name, json_spec.additionalProperties if json_spec.additionalProperties else 'Any') + ']' if not json_spec.properties else param_name|title }}\n{%- elif json_spec.type is iterable %}\n {{- 'Union[' }}\n {%- for t in json_spec.type %}\n {{- json_to_python_type(param_name, {'type': t}) }}\n {{- ', ' if not loop.last }}\n {%- endfor %}\n {{- ']' }}\n{%- else %}\n {{- 'Any' }}\n{%- endif %}\n{%- endmacro %}\n\n{%- macro object_to_fields(json_spec, field_indent) %}\n {%- set o_ns = namespace(f = caller()) %}\n {%- for param_name, param_fields in json_spec.properties|items %}\n {%- if param_fields.enum %}\n {{- '\\n\\nclass ' + param_name|title + '(Enum):\\n' }}\n {%- for enum_option in param_fields.enum %}\n {{- ' enum_' + loop.index0|string + ' = ' + enum_option|tojson + '\\n' }}\n {%- endfor %}\n {%- elif param_fields.type == 'object' and param_fields.properties %}\n {%- call object_to_fields(param_fields, ' ') %}\n {{- '\\n\\nclass ' + param_name|title + '(BaseModel):\\n' }}\n {%- endcall %}\n {%- elif param_fields.type == 'array' and param_fields['items'] and param_fields['items'].type == 'object' and param_fields['items'].properties %}\n {%- call object_to_fields(param_fields['items'], ' ') %}\n {{- '\\n\\nclass ' + param_name|title + '(BaseModel):\\n' }}\n {%- endcall %}\n {%- endif %}\n {%- set param_default = param_fields.default|tojson if param_fields.default is string else param_fields.default|string if param_fields.default is defined else 'None' %}\n {%- set o_ns.f = o_ns.f + field_indent + param_name + ': ' %}\n {%- set o_ns.f = o_ns.f + ('Optional[' + json_to_python_type(param_name, param_fields) + ']' if param_name not in json_spec.required else json_to_python_type(param_name, param_fields)) %}\n {%- if not param_fields.title and not param_fields.description and not param_fields.pattern %}\n {%- set o_ns.f = o_ns.f + (' = ' + param_default if param_name not in json_spec.required else '') %}\n {%- else %}\n {%- set o_ns.f = o_ns.f + (' = Field(...' if param_name in json_spec.required else ' = Field(' + param_default) %}\n {%- set o_ns.f = o_ns.f + (', description=' + param_fields.description|tojson if param_fields.description else '') %}\n {%- set o_ns.f = o_ns.f + (', regex=' + param_fields.pattern|tojson if param_fields.pattern else '') %}\n {%- set o_ns.f = o_ns.f + (', title=' + param_fields.title|tojson if param_fields.title else '') %}\n {%- set o_ns.f = o_ns.f + ')' %}\n {%- endif %}\n {%- set o_ns.f = o_ns.f + '\\n' %}\n {%- endfor %}\n {{- o_ns.f }}\n{%- endmacro %}\n\n{%- macro tool_parser(tools) %}\n{%- for tool in tools %}\n {%- if tool.type is not defined or tool.type == 'function' %}\n {%- if tool.function is defined %}\n {%- set tool = tool.function %}\n {%- endif %}\n {%- set tool_params = tool.parameters if tool.parameters is defined else none %}\n {%- call object_to_fields(tool_params, ' ') %}\n {{- '\\n\\ndef ' + tool.name + '(' }}\n {%- if tool_params %}\n {%- for param_name, param_fields in tool_params.properties|items %}\n {%- set param_default = param_fields.default|tojson if param_fields.default is string else param_fields.default|string if param_fields.default is defined else 'None' %}\n {{- ', ' if loop.index0 != 0 }}\n {{- param_name }}\n {{- '=' + param_default if param_name not in tool_params.required }}\n {%- endfor %}\n {%- endif %}\n {{- '):\\n \"\"\"' }}\n {{- tool.description }}\n {{- '\\n\\n Args:\\n' if tool_params else '\\n' }}\n {%- endcall %}\n {{- ' \"\"\"\\n' }}\n {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n\n{%- if messages[0]['role'] == 'system' %}\n {%- set loop_messages = messages[1:] %}\n {%- set system_message = messages[0]['content'] %}\n{%- else %}\n {%- set loop_messages = messages %}\n {%- set system_message = '' %}\n{%- endif %}\n{{- '<|im_start|>system\\n' + system_message if system_message or tools }}\n{%- if tools %}\n {{- '\\n# Functions\\nHere is a list of functions that you can invoke:\\n```python\\nfrom enum import Enum\\nfrom typing import List, Dict, Optional\\nfrom pydantic import BaseModel, Field\\n\\n' }}\n {{- tool_parser(tools) }}\n {{- \"\\n```\\n\\n# Function Call Rule and Output Format\\n- If the user's question can be answered without calling any function, please answer the user's question directly. In this situation, you should return your thought and answer the user's question directly.\\n- If the user cannot be answered without calling any function, and the user does not provide enough information to call functions, please ask the user for more information. In this situation, you should return your thought and ask the user for more information.\\n- If the user's question cannot be answered without calling any function, and the user has provided enough information to call functions to solve it, you should call the functions. In this situation, the assistant should return your thought and call the functions.\\n- Use default parameters unless the user has specified otherwise.\\n- You should answer in the following format:\\n\\n<|thought_start|>\\n{explain why the user's question can be answered without calling a function or why you should ask the user for more information or why you should call one or more functions and your plan to solve the user's question.}\\n<|thought_end|>\\n<|tool_call_start|>\\n```python\\nfunc1(params_name=params_value, params_name2=params_value2...)\\nfunc2(params)\\n```\\n<|tool_call_end|>\\n{answer the user's question directly or ask the user for more information}\" }}\n{%- endif %}\n{{- '<|im_end|>\\n' if system_message or tools }}\n{%- for message in loop_messages %}\n {%- set content = message.content %}\n {%- if message.role == 'assistant' and message.tool_calls %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {{- '<|thought_start|>\\n' + message.thought + '\\n<|thought_end|>\\n' if message.thought }}\n {{- '<|tool_call_start|>\\n```python\\n' }}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- tool_call.name + '(' }}\n {%- if tool_call.arguments is defined and tool_call.arguments|length > 0 %}\n {%- for param_name, param_value in tool_call.arguments|items %}\n {{- param_name + '=' + param_value|tojson }}\n {{- ',' if not loop.last }}\n {%- endfor %}\n {%- endif %}\n {{- ')\\n' }}\n {%- endfor %}\n {{- '```\\n<|tool_call_end|>\\n' }}\n {{- content if content and not content.startswith('<|tool_call_start|>') }}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == 'assistant' and message.thought %}\n {{- '<|im_start|>' + message.role + '\\n' + '<|thought_start|>\\n' + message.thought + '\\n<|thought_end|>\\n' + content + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endfor %}\n\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}",
|
107 |
+
"clean_up_tokenization_spaces": false,
|
108 |
+
"eos_token": "<|im_end|>",
|
109 |
+
"extra_special_tokens": {},
|
110 |
+
"legacy": true,
|
111 |
+
"model_max_length": 1000000000000000019884624838656,
|
112 |
+
"pad_token": "<|im_end|>",
|
113 |
+
"padding_side": "left",
|
114 |
+
"sp_model_kwargs": {},
|
115 |
+
"spaces_between_special_tokens": false,
|
116 |
+
"split_special_tokens": false,
|
117 |
+
"tokenizer_class": "LlamaTokenizer",
|
118 |
+
"unk_token": "<unk>",
|
119 |
+
"use_default_system_prompt": false
|
120 |
+
}
|