File size: 21,737 Bytes
87592d0
8acb22f
87592d0
 
 
 
 
 
 
610c95b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baf3165
610c95b
 
 
 
 
 
 
 
 
 
 
baf3165
610c95b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
094fa15
 
 
610c95b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6e991d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
---
license: llama3
language:
- en
base_model:
- meta-llama/Meta-Llama-3-8B-Instruct
tags:
- custom_generate
---
# SepCache - Native Sparse Attention Cache

## Table of Contents

- [1. Abstract](#1-abstract)
- [2. Usage](#2-usage)
  - [2.1 Sample Base Model](#21-sample-base-model)
  - [2.2 Quick Start](#22-quick-start)
    - [2.2.1 Environment Setup](#221-environment-setup)
    - [2.2.2 A Simple Example](#222-a-simple-example)
    - [2.2.3 Frequently-Used Parameters](#223-frequently-used-parameters)
    - [2.2.4 Update Function](#224-update-function)
    - [2.2.5 Monkey Patch Demo](#225-monkey-patch-demo)
    - [2.2.6 Downstream Task Evaluation](#226-downstream-task-evaluation)
    - [2.2.7 The Detailed Signature of `generate` Function](#227-the-detailed-signature-of-generate-function)
- [3. Adaptation for Other Models](#3-adaptation-for-other-models)
  - [3.1 Method 1 - Monkey Patching](#31-method-1---monkey-patching)
  - [3.2 Method 2 - Direct Code Modification](#32-method-2---direct-code-modification)
  - [3.3 Important Note](#33-important-note)
- [4. Other Advanced Usage](#4-other-advanced-usage)

---

## 1. Abstract
`SepCache` is a simple yet effective, native sparse attention `Cache` class proposed in the [`SepLLM paper - ICML 2025`](https://icml.cc/virtual/2025/poster/45536), which most closely aligns with the semantic distribution of natural language. In the training phase, `SepLLM` condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the corresponding `SepCache` only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation.

Notably, `SepCache` also delivers strong performance across many tasks in training-free scenarios. Moreover, `SepLLM` (or simply `SepCache`) is the **most suitable baseline method for sparse attention mechanisms and KV compression/management**, as it is the natively sparse attention mechanism that best aligns with the natural semantic distribution of language.

See more details and advanced usage in https://github.com/HKUDS/SepLLM

![image](https://hackmd.io/_uploads/r1POJoR4yg.png)

## 2. Usage

### 2.1 Sample Base Model

We recommend using models from the **Llama 3 series**. Our example model is based on `meta-llama/Meta-Llama-3-8B-Instruct`, for which we have already prepared a targeted `monkey patch`.

For other models, using `SepCache` requires minor modifications to the corresponding `modeling_xxx.py` file or writing a **custom monkey patch**. These changes are **very simple** -- you only need to pass arguments like `input_ids` to the `update` function of `SepCache` when calling it.

We will provide a detailed guide later on how to modify your `modeling_xxx.py` file or `monkey patch` file to adapt `SepCache` to any model.

### 2.2 Quick Start

#### 2.2.1 Environment Setup
You need to install `transformers>=4.53.0,<4.54.0`, and we recommend using `lm_eval>=0.4.9` for running evaluations. We suggest managing your Python environment with `conda` for better dependency control.

```bash
conda create -n sepcache python=3.10
conda activate sepcache
pip install transformers==4.53
pip install lm_eval==0.4.9
```
#### 2.2.2 A Simple Example
You can use `SepCache` by specifying `custom_generate="transformers-community/sep_cache"` or `custom_generate="Gausson/sep_cache"` when calling the `generate` function. In our demo, we have already prepared sample monkey patching for the `Llama 3 series` models and provided some common parameters for initializing `SepCache`.

```python
# requires `transformers>=4.53.0,<4.54.0`
from transformers import AutoModelForCausalLM, AutoTokenizer

# Preparing model, tokenizer, and model inputs
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto")


messages = [{"role": "user", "content": "Tell me a story about a cat."}]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)


# Using SepCache for generation
gen_out = model.generate(
    # usual `generate` arguments
    **model_inputs,
    do_sample=False,
    max_new_tokens=100,
    return_dict_in_generate=True,
    monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.

    # Using SepCache
    custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
    trust_remote_code=True,

    # SepCache arguments
    init_cache_size = 4,
    sep_cache_size = 128,
    local_size = 256, 
    cache_size = 512,    
    USE_MAX_SEP_CACHE = True,
    model_type = 'llama'
)

print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
assert "sepcache" in str(type(gen_out.past_key_values)).lower()
```

It is worth noting that you must specify the `separator_token_ids: List[int]` and `PADDING_ID: int` parameters for initializing `SepCache`. In the example above, we did not do this because, for convenience, in the demo above, we specified `model_type = "llama"`, in which case `separator_token_ids` and `PADDING_ID` will be automatically filled.

However, when you use a tokenizer for a non-Llama 3 series model, you need to specify the specific values of `separator_token_ids` and `PADDING_ID` based on the tokenizer you are using. For example, the following example is based on the values obtained from a Llama 3 series tokenizer.
```python
# Using SepCache for generation
gen_out = model.generate(
    # usual `generate` arguments
    **model_inputs,
    do_sample=False,
    max_new_tokens=100,
    return_dict_in_generate=True,
    monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.

    # Using SepCache
    custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
    trust_remote_code=True,

    # SepCache arguments
    init_cache_size = 4,
    sep_cache_size = 128,
    local_size = 256, 
    cache_size = 512,    
    USE_MAX_SEP_CACHE = True,    
    separator_token_ids = [128000, 13, 11, 30, 0, 26, 25, 198, 220, 662, 1174, 949, 758, 2652, 551, 720, 256,262],
    PADDING_ID = 128009
)
```


#### 2.2.3 Frequently-Used Parameters

Below, we provide explanations and examples for the most commonly used parameters when initializing `SepCache`. These parameters can be passed through the `generate` function.

```
`SepCache` stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.

Frequently-Used Parameters:

    `init_cache_size: Union[int, List]`:
        The maximum number of KVs to be stored for initial tokens.
        In the paper, the hyperparameter `a` is an abbreviated alias for `init_cache_size`.                
            
    `sep_cache_size: Union[int, List]`:
        The maximum number of KVs to be stored for separator tokens.
        In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.

    `local_size: Union[int, List]`: 
        The maximum number of KVs to be stored for local tokens (i.e., sliding window).
        In the paper, the hyperparameter `w` is an abbreviated alias for `local_size`.

    `cache_size: Union[int, List]`:    
        The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache.  
        In the paper, the hyperparameter `c` is an abbreviated alias for `cache_size`.

    Concerning these four parameters above:
        When a list is passed (its length must be `layer_num`), it represents different values for each layer. 
        When an integer is passed, it means the setting is the same for all layers.
    
    
    `USE_MAX_SEP_CACHE: bool`: 
        If True, it means we only keep at most `sep_cache_size` separators' KVs.  
        If the number exceeds this limit, older separators' KVs will be discarded, keeping only the most recent `sep_cache_size` KVs. 
        In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.
      
    `separator_token_ids: List[int]`:
        The token ids of the separator tokens for the current model's tokenizer.            
        We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you 
            to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them).

    `PADDING_ID: int`:
        The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model.  
```
Important Note: 
- When `cache_size` and `local_size` are set to infinity (i.e., sufficiently large positive integers), and `USE_MAX_SEP_CACHE` is `False`, `SepCache` degenerates into a regular Cache. 
- You must always ensure that `init_cache_size` + `sep_cache_size` + `local_size` + `left_padding_offset` < `cache_size`. Here, `left_padding_offset` denotes the number of padding tokens in the record with the largest left paddings within a runtime batch. `left_padding_offset` can only be determined at runtime.        
- To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality:
        `init_cache_size` + `sep_cache_size` + `local_size`  < `cache_size`, i.e., `a`+`s`+`w`<`c` in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094) to leave room for `left_padding_offset`.  

**More Important Note: In practice, no need to do positional encoding (PE) shifting like [StreamingLLM](https://github.com/mit-han-lab/streaming-llm/) if the actual length does not exceed the pretrained max PE length (which applies to most downstream tasks.) . So, for most basic usages, just set `APPLY_PE_SHIFT=False` (`False` is also the default setting) and `APPLY_PES_INSIDE=False` for initialization.**


#### 2.2.4 Update Function
After initialization, another key point to note is that when using the `update` function of `SepCache` to update the **keys/values** and the **past token IDs** (which is necessary in SepCache), the current `input_ids` must also be provided.
```python
key_states, value_states = past_key_values.update(                
          key_states = key_states,
          value_states = value_states,    
          input_ids = input_ids,  ## required
          layer_idx = layer_idx,     
          PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states`
          )
```


#### 2.2.5 Monkey Patch Demo
To adapt the `update` function of `SepCache` mentioned in [`2.2.4 Update Function`](#224-update-function), i.e., passing the current `input_ids` as a parameter to the `update` function. It is worth noting that during the prefilling stage, the shape of the input_ids tensor is `[batch_size, seq_len]`, while during the decoding stage of auto-regressive models, the shape of the `input_ids` tensor should be `[batch_size, 1]`.


In our `custom_generate/generate.py` file, we provide the `monkey_patching` function, which works by replacing the `forward` function in all the related instances of the `XXXAttention` class (for example, in the Llama 3 series model, it would be `LlamaAttention`) with our customized forward function (specified by the `model_atten_forward` parameter of the `monkey_patching` function).
```python
def monkey_patching(model_obj, 
                    model_atten_forward , ## The `forward` function used to patch.
                    possible_inner_model_names: List[str] = ["model", "transformer", "gpt_neox"] , # In `XXXForCausalLM` class, the possible name of internal attribute for model. e.g.,  "model", "transformer", "gpt_neox", etc.
                    possible_layers_names: List[str] = ["layers", "h" ],  # In `XXXModel` class,  the possible name of internal attribute for decoder layers, e.g.,  "layers", "h", etc.
                    atten_attr_name_pattern_list: List[str] = ["attention", "self_attn"],  # In `XXXDecoderLayer` class, the possible name of internal attribute for self-attention, e.g.,  "attention", "self_attn", etc.
                    atten_attr_name_pattern_exclude: List[str] = ["norm", "layer"], # In `XXXDecoderLayer` class, the impossible name patterns (i.e., the patterns to be excluded) of internal attribute for self-attention module class, e.g., "norm" , etc. Sometimes, there will be some attributes like "post_attention_norm" and we do not want modify the `forward` function of it - we want to modify the `forward` function of `XXXAttention`. So, we need to exclude attribute name patterns like "norm" to accurately find the correct "forward" function to replace.
                    verbose = True):
    
    """
    This `monkey_patching` function is to
        - find the `forward` function of the `XXXAttention` class.
        - replace all the related `forward` functions of the instances of `XXXAttention` class with `model_atten_forward`.
    """
    
    ## To avoid the argument check failure, i.e., let "sepllm_kwargs" pass the check.
    transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs    

    ## Get inner model obj
    inner_model_type = PreTrainedModel
    inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)
    
    ## Get the decoder layers (`nn.ModuleList`) obj
    layers_type = nn.ModuleList
    model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
    
    ## Replace all the related `forward` functions of XXXAttention class's instances.
    for i, decoder_layer in enumerate(model_layers):
        self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
        result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
        if verbose:
            decoder_class_name = get_importable_class_path(decoder_layer)
            print(f"For Layer {i}'s `{decoder_class_name}`: {result}")

    return model_layers
```

The `monkey_patching` function primarily does three things:
- Precisely locate the `forward` function of all instances of the `XXXAttention` class.
- Replace the `forward` function with the `model_atten_forward` function you provide.
- Return the corresponding properties of the decoder layers found during the process, typically of type `nn.ModuleList`. This return value (`model_layers`) is only used to determine the number of layers in the current model later on (obtained by `len(model_layers)`).

In addition, the `monkey_patching` function replaces `transformers.generation.GenerationMixin._validate_model_kwargs` with our `_validate_model_kwargs` to bypass some parameter checks, as we will provide an additional `sepllm_kwargs` parameter to wrap the `input_ids` for eventual transmission to the `SepCache` `update` function. 


**Please ensure that the `monkey_patching` function accurately locates and replaces the `forward` function of the `XXXAttention` class. The current `monkey_patching` is designed for the `Llama 3 series` models. For other models, you need to appropriately modify `monkey_patching` to ensure its correctness of targeting and replacement !** You can monitor the monkey patching process by setting `verbose=True` in the `monkey_patching` function (or, `monkey_patch_verbose = True` for the `generate` function.)


```python
def truncate_input_ids_4_autoregression(input_ids, key_states):
    if input_ids.shape[-1] != key_states.shape[-2]:
        assert input_ids.shape[-1] >= key_states.shape[-2]
        truncated_input_ids = input_ids[..., -key_states.shape[-2]: ]
        return truncated_input_ids
    else:
        return input_ids
```
The `truncate_input_ids_4_autoregression` function in the `custom_generate/generate.py` file is used to shape the `input_ids` tensor to `[batch_size, 1]` during decoding.

#### 2.2.6 Downstream Task Evaluation
We recommend using `lm_eval==0.4.9` for downstream task evaluation. You can pass model-related parameters via `--model_args` and generation-related parameters (including those required for initializing `SepCache`) via `--gen_kwargs`. Notably, you typically need to pass a `list` to `separator_token_ids` using a string format like `"id1;id2;id3"` (as shown in the example below).
```bash
lm_eval --model hf \
	--model_args pretrained=meta-llama/Meta-Llama-3-8B-Instruct,attn_implementation=flash_attention_2 \
	--tasks    gsm8k_cot  \
	--gen_kwargs custom_generate=transformers-community/sep_cache,trust_remote_code=True,monkey_patch_verbose=True,init_cache_size=4,sep_cache_size=128,local_size=256,cache_size=512,separator_token_ids="128000;13;11;30;0;26;25;198;220;662;1174;949;758;2652;551;720;256;262",PADDING_ID=128009 \
    --device cuda:0\
    --batch_size 80 2>&1 | tee log.txt
```
Note: `SepCache` is typically used in combination with `Flash Attention` to maximize generation efficiency.

<img width="1022" height="248" alt="1752618213617" src="https://github.com/user-attachments/assets/87e2e745-9677-4101-895e-dd6fc7b6039d" />

#### 2.2.7 The Detailed Signature of `generate` Function
Here is the detailed signature of our customized `generate` function for `SepCache` in `custom_generate/generate.py` file:

```python
def generate(model,                          
            ## For SepCache                              
            init_cache_size: Union[int, List] = 4,        
            sep_cache_size: Union[int, List] = 128,
            local_size: Union[int, List]=256, 
            cache_size: Union[int, List]=512,    
            SEP_ACCUMULATION: bool = True,
            USE_MAX_SEP_CACHE: bool = False,
            SEP_PADDING_IN_BATCH: bool = False,
            separator_token_ids: List[int] = None, ## required for initialization if `model_type` is not provided.
            PADDING_ID: int = None, ## required for initialization if `model_type` is not provided.

            ## For inheritance & initialization states
            past_tok_ids: List[torch.Tensor] = None,  ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.                
            key_cache: List[torch.Tensor] = None,          
            value_cache: List[torch.Tensor] = None,

            ## For debugging
            PRINT_KV_RATIO_INSIDE: bool = False,
            print_KV_inside_per_steps: int = 1000,   
            _seen_tokens: int = 0, 
            _kept_kv_ratio: List[Tuple[int]] = None,
            
            ### For positional encoding shifting
            APPLY_PE_SHIFT: bool = False,
            APPLY_PES_INSIDE: bool = False,
            _shifted_position_ids:  List[torch.Tensor] = None,
            _rope_unsqueeze_dim: int = 1, ## The unsqueeze_dim when applying RoPE.
            _rope_seq_dim: int=1, ## The seq_len dimension for the `cos` or `sin` tensors.
            pe_scaling_factor:float = 1.0,
            pe_dim:int=128, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this.
            max_position_embeddings: int = 8192, 
            base: int=10000,  ## The base for RoPE.               
            
            ## For basic transformer architecture
            k_seq_dim: int=2, ## The dimension for seq_len in key tensors
            v_seq_dim: int=2, ## The dimension for seq_len in value tensors
            layer_num: int = None, ## required for initialization

            model_type: str = 'llama',  ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
            device = None,   

            ## For verbosity of monkey patching
            monkey_patch_verbose: bool = False,

            **kwargs
             ):
             ...
```

## 3. Adaptation for Other Models

Adapting `SepCache` to various models is simple - two approaches:


### 3.1 Method 1 - Monkey Patching
- Modify the `monkey_patching` function to correctly locate and target the `forward` function of your model's `XXXAttention` class (e.g., `LlamaAttention` for Llama 3).
- Write your custom `model_atten_forward` function and use `monkey_patching` to replace the `forward` function of all `XXXAttention` class instances. The key modification is passing `input_ids` to `SepCache`'s `update` function.

### 3.2 Method 2 - Direct Code Modification (Recommended for Simplicity)
Simply edit your `modeling_xxx.py` file to implement:

- Initialize `past_key_values` as a `SepCache` instance at the appropriate location (e.g., in `XXXForCausalLM` or `XXXModel` class' `forward` function).
- Modify the `forward` function of the `XXXAttention` class to pass `input_ids` to `SepCache`'s `update` function.

 ### 3.3 Important Note
 The shape of `input_ids` is `[batch_size, seq_len]` during prefilling, and `[batch_size, 1]` during generation.

 ## 4. Other Advanced Usage

 Please refer to https://github.com/HKUDS/SepLLM, in which there are detailed explanations and examples.

 ## 5. Citation
 If you find our work helpful, please consider giving us a like ❤️ and citing our paper. We greatly appreciate your support 😄
```
@inproceedings{chen2025sepllm,
  title={{SepLLM: Accelerate Large Language Models by Compressing One Segment into One Separator}},
  author={Chen, Guoxuan and Shi, Han and Li, Jiawei and Gao, Yihang and Ren, Xiaozhe and Chen, Yimeng and Jiang, Xin and Li, Zhenguo and Liu, Weiyang and Huang, Chao},
  booktitle={International Conference on Machine Learning},
  year={2025},
  note={Also available at arXiv:2412.12094}
}
```