roychao19477 commited on
Commit
7001051
·
1 Parent(s): 5b3ab69

Initial clean commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +70 -3
  2. ckpts/SEMamba_advanced.pth +3 -0
  3. mamba_install/.DS_Store +0 -0
  4. mamba_install/AUTHORS +2 -0
  5. mamba_install/LICENSE +201 -0
  6. mamba_install/README.md +182 -0
  7. mamba_install/benchmarks/benchmark_generation_mamba_simple.py +92 -0
  8. mamba_install/csrc/selective_scan/reverse_scan.cuh +401 -0
  9. mamba_install/csrc/selective_scan/selective_scan.cpp +497 -0
  10. mamba_install/csrc/selective_scan/selective_scan.h +101 -0
  11. mamba_install/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu +9 -0
  12. mamba_install/csrc/selective_scan/selective_scan_bwd_bf16_real.cu +9 -0
  13. mamba_install/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu +9 -0
  14. mamba_install/csrc/selective_scan/selective_scan_bwd_fp16_real.cu +9 -0
  15. mamba_install/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu +9 -0
  16. mamba_install/csrc/selective_scan/selective_scan_bwd_fp32_real.cu +9 -0
  17. mamba_install/csrc/selective_scan/selective_scan_bwd_kernel.cuh +531 -0
  18. mamba_install/csrc/selective_scan/selective_scan_common.h +221 -0
  19. mamba_install/csrc/selective_scan/selective_scan_fwd_bf16.cu +10 -0
  20. mamba_install/csrc/selective_scan/selective_scan_fwd_fp16.cu +10 -0
  21. mamba_install/csrc/selective_scan/selective_scan_fwd_fp32.cu +10 -0
  22. mamba_install/csrc/selective_scan/selective_scan_fwd_kernel.cuh +345 -0
  23. mamba_install/csrc/selective_scan/static_switch.h +25 -0
  24. mamba_install/csrc/selective_scan/uninitialized_copy.cuh +69 -0
  25. mamba_install/evals/lm_harness_eval.py +39 -0
  26. mamba_install/mamba_ssm/.DS_Store +0 -0
  27. mamba_install/mamba_ssm/__init__.py +5 -0
  28. mamba_install/mamba_ssm/models/__init__.py +0 -0
  29. mamba_install/mamba_ssm/models/config_mamba.py +15 -0
  30. mamba_install/mamba_ssm/models/mixer_seq_simple.py +264 -0
  31. mamba_install/mamba_ssm/modules/__init__.py +0 -0
  32. mamba_install/mamba_ssm/modules/mamba_simple.py +353 -0
  33. mamba_install/mamba_ssm/ops/__init__.py +0 -0
  34. mamba_install/mamba_ssm/ops/selective_scan_interface.py +357 -0
  35. mamba_install/mamba_ssm/ops/triton/__init__.py +0 -0
  36. mamba_install/mamba_ssm/ops/triton/layernorm.py +635 -0
  37. mamba_install/mamba_ssm/ops/triton/selective_state_update.py +263 -0
  38. mamba_install/mamba_ssm/utils/__init__.py +0 -0
  39. mamba_install/mamba_ssm/utils/generation.py +387 -0
  40. mamba_install/mamba_ssm/utils/hf.py +23 -0
  41. mamba_install/setup.py +284 -0
  42. mamba_install/tests/ops/test_selective_scan.py +247 -0
  43. mamba_install/tests/ops/triton/test_selective_state_update.py +49 -0
  44. mamba_ssm/.DS_Store +0 -0
  45. mamba_ssm/__init__.py +5 -0
  46. mamba_ssm/models/__init__.py +0 -0
  47. mamba_ssm/models/config_mamba.py +15 -0
  48. mamba_ssm/models/mixer_seq_simple.py +264 -0
  49. mamba_ssm/modules/__init__.py +0 -0
  50. mamba_ssm/modules/mamba_simple.py +353 -0
app.py CHANGED
@@ -1,7 +1,74 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import yaml
4
+ import librosa
5
+ from huggingface_hub import hf_hub_download
6
+ from models.stfts import mag_phase_stft, mag_phase_istft
7
+ from models.generator import SEMamba
8
+ from models.pcs400 import cal_pcs
9
+
10
+ # download model files from your HF repo
11
+ ckpt = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
12
+ "ckpts/SEMamba_advanced.pth")
13
+ cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
14
+ "recipes/SEMamba_advanced.yaml")
15
+
16
+ # load config
17
+ with open(cfg_f) as f:
18
+ cfg = yaml.safe_load(f)
19
+
20
+ stft_cfg = cfg["stft_cfg"]
21
+ model_cfg = cfg["model_cfg"]
22
+ sr = stft_cfg["sampling_rate"]
23
+ n_fft = stft_cfg["n_fft"]
24
+ hop_size = stft_cfg["hop_size"]
25
+ win_size = stft_cfg["win_size"]
26
+ compress_ff = model_cfg["compress_factor"]
27
+
28
+ # init model
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ model = SEMamba(cfg).to(device)
31
+ sdict = torch.load(ckpt, map_location=device)
32
+ model.load_state_dict(sdict["generator"])
33
+ model.eval()
34
+
35
+ def enhance(audio, do_pcs):
36
+ orig_sr, wav_np = audio
37
+ # 1) resample to 16 kHz if needed
38
+ if orig_sr != sr:
39
+ wav_np = librosa.resample(wav_np, orig_sr, sr)
40
+ wav = torch.from_numpy(wav_np).float().to(device)
41
+
42
+ # normalize
43
+ norm = torch.sqrt(len(wav) / torch.sum(wav**2))
44
+ wav = (wav * norm).unsqueeze(0)
45
+
46
+ # STFT → model → ISTFT
47
+ amp, pha, _ = mag_phase_stft(wav, n_fft, hop_size, win_size, compress_ff)
48
+ amp_g, pha_g = model(amp, pha)
49
+ out = mag_phase_istft(amp_g, pha_g, n_fft, hop_size, win_size, compress_ff)
50
+ out = (out / norm).squeeze().cpu().numpy()
51
+
52
+ # optional PCS filter
53
+ if do_pcs:
54
+ out = cal_pcs(out)
55
+
56
+ # 2) resample back to original rate
57
+ if orig_sr != sr:
58
+ out = librosa.resample(out, sr, orig_sr)
59
+
60
+ return orig_sr, out
61
+
62
+ demo = gr.Interface(
63
+ fn=enhance,
64
+ inputs=[
65
+ gr.Audio(source="upload", type="numpy", label="Noisy wav"),
66
+ gr.Checkbox(label="Apply PCS post-processing", value=False),
67
+ ],
68
+ outputs=gr.Audio(type="numpy", label="Enhanced wav"),
69
+ title="SEMamba Speech Enhancement",
70
+ description="Upload a noisy WAV; tick **Apply PCS** for the pcs400 filter.",
71
+ )
72
 
 
 
73
 
 
74
  demo.launch()
ckpts/SEMamba_advanced.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f68a1aaa2b5cdf6a4f8ef87e1534edd83c523135ba0ecaddeadce6f35c8c4142
3
+ size 9127253
mamba_install/.DS_Store ADDED
Binary file (8.2 kB). View file
 
mamba_install/AUTHORS ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Tri Dao, tri@tridao.me
2
+ Albert Gu, agu@andrew.cmu.edu
mamba_install/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2023 Tri Dao, Albert Gu
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
mamba_install/README.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This repository serves as a backup, cloned from the official Mamba Repository:
2
+ https://github.com/state-spaces/mamba/tree/a07faffa36a7b89e754b5de972418475bcdd77b6
3
+
4
+ ===
5
+ # Mamba
6
+
7
+ ![Mamba](assets/selection.png "Selective State Space")
8
+ > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
9
+ > Albert Gu*, Tri Dao*\
10
+ > Paper: https://arxiv.org/abs/2312.00752
11
+
12
+ ## About
13
+
14
+ Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
15
+ It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
16
+ with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
17
+
18
+ ## Installation
19
+
20
+ - [Option] `pip install causal-conv1d>=1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
21
+ - `pip install mamba-ssm`: the core Mamba package.
22
+
23
+ It can also be built from source with `pip install .` from this repository.
24
+
25
+ If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
26
+
27
+ Other requirements:
28
+ - Linux
29
+ - NVIDIA GPU
30
+ - PyTorch 1.12+
31
+ - CUDA 11.6+
32
+
33
+ ## Usage
34
+
35
+ We expose several levels of interface with the Mamba model.
36
+
37
+ ### Selective SSM
38
+
39
+ Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
40
+
41
+ Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
42
+
43
+ ### Mamba Block
44
+
45
+ The main module of this repository is the Mamba architecture block wrapping the selective SSM.
46
+
47
+ Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
48
+
49
+ Usage:
50
+ ```
51
+ import torch
52
+ from mamba_ssm import Mamba
53
+
54
+ batch, length, dim = 2, 64, 16
55
+ x = torch.randn(batch, length, dim).to("cuda")
56
+ model = Mamba(
57
+ # This module uses roughly 3 * expand * d_model^2 parameters
58
+ d_model=dim, # Model dimension d_model
59
+ d_state=16, # SSM state expansion factor
60
+ d_conv=4, # Local convolution width
61
+ expand=2, # Block expansion factor
62
+ ).to("cuda")
63
+ y = model(x)
64
+ assert y.shape == x.shape
65
+ ```
66
+
67
+ ### Mamba Language Model
68
+
69
+ Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
70
+
71
+ Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
72
+
73
+ This is an example of how to integrate Mamba into an end-to-end neural network.
74
+ This example is used in the generation scripts below.
75
+
76
+
77
+
78
+ ## Pretrained Models
79
+
80
+ Pretrained models are uploaded to
81
+ [Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
82
+ `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj`
83
+ (trained on 600B tokens on the SlimPajama dataset).
84
+
85
+
86
+ The models will be autodownloaded by the generation script below.
87
+
88
+ These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
89
+
90
+ | Parameters | Layers | Model dim. |
91
+ |------------|--------|------------|
92
+ | 130M | 24 | 768 |
93
+ | 370M | 48 | 1024 |
94
+ | 790M | 48 | 1536 |
95
+ | 1.4B | 48 | 2048 |
96
+ | 2.8B | 64 | 2560 |
97
+
98
+ (The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
99
+
100
+ Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
101
+ Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
102
+
103
+
104
+ ## Evaluations
105
+
106
+ To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
107
+ we use the
108
+ [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor)
109
+ library.
110
+
111
+ 1. Pull the `lm-evaluation-harness` repo by `git submodule update --init
112
+ --recursive`. We use the `big-refactor` branch.
113
+ 2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness`.
114
+ On Python 3.10 you might need to manually install the latest version of `promptsource`: `pip install git+https://github.com/bigscience-workshop/promptsource.git`.
115
+ 3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
116
+ ```
117
+ python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
118
+ python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
119
+ ```
120
+
121
+ To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts:
122
+ ```
123
+ python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 64
124
+ python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 64
125
+ ```
126
+
127
+ Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
128
+
129
+ ## Inference
130
+
131
+ The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
132
+ 1. autoloads a model from the Hugging Face Hub,
133
+ 2. generates completions of a user-specified prompt,
134
+ 3. benchmarks the inference speed of this generation.
135
+
136
+ Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
137
+
138
+ ### Examples
139
+
140
+ To test generation latency (e.g. batch size = 1) with different sampling strategies:
141
+
142
+ ```
143
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
144
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
145
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
146
+ ```
147
+
148
+ To test generation throughput with random prompts (e.g. large batch size):
149
+ ```
150
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
151
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128
152
+ ```
153
+
154
+
155
+ ## Troubleshooting
156
+
157
+ ### Precision
158
+ Our models were trained using PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary.
159
+ On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation).
160
+
161
+ We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities,
162
+ as a first step please try a framework storing parameters in fp32 (such as AMP).
163
+
164
+ ### Initialization
165
+ Some parts of the model have initializations inherited from prior work on S4 models.
166
+ For [example](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102), the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection.
167
+ However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in `nn.Linear` modules to zero).
168
+ If this is the case, you may have to add custom logic (e.g. this [line](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) turns off re-initializing in our trainer, but would be a no-op in any other framework)
169
+ that is specific to the training framework.
170
+
171
+
172
+ ## Citation
173
+
174
+ If you use this codebase, or otherwise found our work valuable, please cite Mamba:
175
+ ```
176
+ @article{mamba,
177
+ title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
178
+ author={Gu, Albert and Dao, Tri},
179
+ journal={arXiv preprint arXiv:2312.00752},
180
+ year={2023}
181
+ }
182
+ ```
mamba_install/benchmarks/benchmark_generation_mamba_simple.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import argparse
4
+ import time
5
+ import json
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange
11
+
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+
14
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
15
+
16
+
17
+ parser = argparse.ArgumentParser(description="Generation benchmarking")
18
+ parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
19
+ parser.add_argument("--prompt", type=str, default=None)
20
+ parser.add_argument("--promptlen", type=int, default=100)
21
+ parser.add_argument("--genlen", type=int, default=100)
22
+ parser.add_argument("--temperature", type=float, default=1.0)
23
+ parser.add_argument("--topk", type=int, default=1)
24
+ parser.add_argument("--topp", type=float, default=1.0)
25
+ parser.add_argument("--minp", type=float, default=0.0)
26
+ parser.add_argument("--repetition-penalty", type=float, default=1.0)
27
+ parser.add_argument("--batch", type=int, default=1)
28
+ args = parser.parse_args()
29
+
30
+ repeats = 3
31
+ device = "cuda"
32
+ dtype = torch.float16
33
+
34
+ print(f"Loading model {args.model_name}")
35
+ is_mamba = args.model_name.startswith("state-spaces/mamba-")
36
+ if is_mamba:
37
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
38
+ model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
39
+ else:
40
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
41
+ model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
42
+ model.eval()
43
+ print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
44
+
45
+ torch.random.manual_seed(0)
46
+ if args.prompt is None:
47
+ input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
48
+ attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
49
+ else:
50
+ tokens = tokenizer(args.prompt, return_tensors="pt")
51
+ input_ids = tokens.input_ids.to(device=device)
52
+ attn_mask = tokens.attention_mask.to(device=device)
53
+ max_length = input_ids.shape[1] + args.genlen
54
+
55
+ if is_mamba:
56
+ fn = lambda: model.generate(
57
+ input_ids=input_ids,
58
+ max_length=max_length,
59
+ cg=True,
60
+ return_dict_in_generate=True,
61
+ output_scores=True,
62
+ enable_timing=False,
63
+ temperature=args.temperature,
64
+ top_k=args.topk,
65
+ top_p=args.topp,
66
+ min_p=args.minp,
67
+ repetition_penalty=args.repetition_penalty,
68
+ )
69
+ else:
70
+ fn = lambda: model.generate(
71
+ input_ids=input_ids,
72
+ attention_mask=attn_mask,
73
+ max_length=max_length,
74
+ return_dict_in_generate=True,
75
+ pad_token_id=tokenizer.eos_token_id,
76
+ do_sample=True,
77
+ temperature=args.temperature,
78
+ top_k=args.topk,
79
+ top_p=args.topp,
80
+ repetition_penalty=args.repetition_penalty,
81
+ )
82
+ out = fn()
83
+ if args.prompt is not None:
84
+ print(tokenizer.batch_decode(out.sequences.tolist()))
85
+
86
+ torch.cuda.synchronize()
87
+ start = time.time()
88
+ for _ in range(repeats):
89
+ fn()
90
+ torch.cuda.synchronize()
91
+ print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
92
+ print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
mamba_install/csrc/selective_scan/reverse_scan.cuh ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <cub/config.cuh>
8
+
9
+ #include <cub/util_ptx.cuh>
10
+ #include <cub/util_type.cuh>
11
+ #include <cub/block/block_raking_layout.cuh>
12
+ // #include <cub/detail/uninitialized_copy.cuh>
13
+ #include "uninitialized_copy.cuh"
14
+
15
+ /**
16
+ * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned.
17
+ */
18
+ template <
19
+ int LENGTH,
20
+ typename T,
21
+ typename ReductionOp>
22
+ __device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {
23
+ static_assert(LENGTH > 0);
24
+ T retval = input[LENGTH - 1];
25
+ #pragma unroll
26
+ for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }
27
+ return retval;
28
+ }
29
+
30
+ /**
31
+ * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
32
+ */
33
+ template <
34
+ int LENGTH,
35
+ typename T,
36
+ typename ScanOp>
37
+ __device__ __forceinline__ T ThreadReverseScanInclusive(
38
+ const T (&input)[LENGTH],
39
+ T (&output)[LENGTH],
40
+ ScanOp scan_op,
41
+ const T postfix)
42
+ {
43
+ T inclusive = postfix;
44
+ #pragma unroll
45
+ for (int i = LENGTH - 1; i >= 0; --i) {
46
+ inclusive = scan_op(inclusive, input[i]);
47
+ output[i] = inclusive;
48
+ }
49
+ }
50
+
51
+ /**
52
+ * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
53
+ */
54
+ template <
55
+ int LENGTH,
56
+ typename T,
57
+ typename ScanOp>
58
+ __device__ __forceinline__ T ThreadReverseScanExclusive(
59
+ const T (&input)[LENGTH],
60
+ T (&output)[LENGTH],
61
+ ScanOp scan_op,
62
+ const T postfix)
63
+ {
64
+ // Careful, output maybe be aliased to input
65
+ T exclusive = postfix;
66
+ T inclusive;
67
+ #pragma unroll
68
+ for (int i = LENGTH - 1; i >= 0; --i) {
69
+ inclusive = scan_op(exclusive, input[i]);
70
+ output[i] = exclusive;
71
+ exclusive = inclusive;
72
+ }
73
+ return inclusive;
74
+ }
75
+
76
+
77
+ /**
78
+ * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.
79
+ *
80
+ * LOGICAL_WARP_THREADS must be a power-of-two
81
+ */
82
+ template <
83
+ typename T, ///< Data type being scanned
84
+ int LOGICAL_WARP_THREADS ///< Number of threads per logical warp
85
+ >
86
+ struct WarpReverseScan {
87
+ //---------------------------------------------------------------------
88
+ // Constants and type definitions
89
+ //---------------------------------------------------------------------
90
+
91
+ /// Whether the logical warp size and the PTX warp size coincide
92
+ static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0));
93
+ /// The number of warp scan steps
94
+ static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
95
+ static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);
96
+
97
+
98
+ //---------------------------------------------------------------------
99
+ // Thread fields
100
+ //---------------------------------------------------------------------
101
+
102
+ /// Lane index in logical warp
103
+ unsigned int lane_id;
104
+
105
+ /// Logical warp index in 32-thread physical warp
106
+ unsigned int warp_id;
107
+
108
+ /// 32-thread physical warp member mask of logical warp
109
+ unsigned int member_mask;
110
+
111
+ //---------------------------------------------------------------------
112
+ // Construction
113
+ //---------------------------------------------------------------------
114
+
115
+ /// Constructor
116
+ explicit __device__ __forceinline__
117
+ WarpReverseScan()
118
+ : lane_id(cub::LaneId())
119
+ , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))
120
+ , member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))
121
+ {
122
+ if (!IS_ARCH_WARP) {
123
+ lane_id = lane_id % LOGICAL_WARP_THREADS;
124
+ }
125
+ }
126
+
127
+
128
+ /// Broadcast
129
+ __device__ __forceinline__ T Broadcast(
130
+ T input, ///< [in] The value to broadcast
131
+ int src_lane) ///< [in] Which warp lane is to do the broadcasting
132
+ {
133
+ return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);
134
+ }
135
+
136
+
137
+ /// Inclusive scan
138
+ template <typename ScanOpT>
139
+ __device__ __forceinline__ void InclusiveReverseScan(
140
+ T input, ///< [in] Calling thread's input item.
141
+ T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
142
+ ScanOpT scan_op) ///< [in] Binary scan operator
143
+ {
144
+ inclusive_output = input;
145
+ #pragma unroll
146
+ for (int STEP = 0; STEP < STEPS; STEP++) {
147
+ int offset = 1 << STEP;
148
+ T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
149
+ inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask
150
+ );
151
+ // Perform scan op if from a valid peer
152
+ inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset
153
+ ? inclusive_output : scan_op(temp, inclusive_output);
154
+ }
155
+ }
156
+
157
+ /// Exclusive scan
158
+ // Get exclusive from inclusive
159
+ template <typename ScanOpT>
160
+ __device__ __forceinline__ void ExclusiveReverseScan(
161
+ T input, ///< [in] Calling thread's input item.
162
+ T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
163
+ ScanOpT scan_op, ///< [in] Binary scan operator
164
+ T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
165
+ {
166
+ T inclusive_output;
167
+ InclusiveReverseScan(input, inclusive_output, scan_op);
168
+ warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);
169
+ // initial value unknown
170
+ exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
171
+ inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
172
+ );
173
+ }
174
+
175
+ /**
176
+ * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined.
177
+ */
178
+ template <typename ScanOpT>
179
+ __device__ __forceinline__ void ReverseScan(
180
+ T input, ///< [in] Calling thread's input item.
181
+ T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item.
182
+ T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item.
183
+ ScanOpT scan_op) ///< [in] Binary scan operator
184
+ {
185
+ InclusiveReverseScan(input, inclusive_output, scan_op);
186
+ // initial value unknown
187
+ exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
188
+ inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
189
+ );
190
+ }
191
+
192
+ };
193
+
194
+ /**
195
+ * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.
196
+ */
197
+ template <
198
+ typename T, ///< Data type being scanned
199
+ int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension
200
+ bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure
201
+ >
202
+ struct BlockReverseScan {
203
+ //---------------------------------------------------------------------
204
+ // Types and constants
205
+ //---------------------------------------------------------------------
206
+
207
+ /// Constants
208
+ /// The thread block size in threads
209
+ static constexpr int BLOCK_THREADS = BLOCK_DIM_X;
210
+
211
+ /// Layout type for padded thread block raking grid
212
+ using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;
213
+ // The number of reduction elements is not a multiple of the number of raking threads for now
214
+ static_assert(BlockRakingLayout::UNGUARDED);
215
+
216
+ /// Number of raking threads
217
+ static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;
218
+ /// Number of raking elements per warp synchronous raking thread
219
+ static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;
220
+ /// Cooperative work can be entirely warp synchronous
221
+ static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));
222
+
223
+ /// WarpReverseScan utility type
224
+ using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;
225
+
226
+ /// Shared memory storage layout type
227
+ struct _TempStorage {
228
+ typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid
229
+ };
230
+
231
+
232
+ /// Alias wrapper allowing storage to be unioned
233
+ struct TempStorage : cub::Uninitialized<_TempStorage> {};
234
+
235
+
236
+ //---------------------------------------------------------------------
237
+ // Per-thread fields
238
+ //---------------------------------------------------------------------
239
+
240
+ // Thread fields
241
+ _TempStorage &temp_storage;
242
+ unsigned int linear_tid;
243
+ T cached_segment[SEGMENT_LENGTH];
244
+
245
+
246
+ //---------------------------------------------------------------------
247
+ // Utility methods
248
+ //---------------------------------------------------------------------
249
+
250
+ /// Performs upsweep raking reduction, returning the aggregate
251
+ template <typename ScanOp>
252
+ __device__ __forceinline__ T Upsweep(ScanOp scan_op) {
253
+ T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
254
+ // Read data into registers
255
+ #pragma unroll
256
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
257
+ T raking_partial = cached_segment[SEGMENT_LENGTH - 1];
258
+ #pragma unroll
259
+ for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {
260
+ raking_partial = scan_op(raking_partial, cached_segment[i]);
261
+ }
262
+ return raking_partial;
263
+ }
264
+
265
+
266
+ /// Performs exclusive downsweep raking scan
267
+ template <typename ScanOp>
268
+ __device__ __forceinline__ void ExclusiveDownsweep(
269
+ ScanOp scan_op,
270
+ T raking_partial)
271
+ {
272
+ T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
273
+ // Read data back into registers
274
+ if (!MEMOIZE) {
275
+ #pragma unroll
276
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
277
+ }
278
+ ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);
279
+ // Write data back to smem
280
+ #pragma unroll
281
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }
282
+ }
283
+
284
+
285
+ //---------------------------------------------------------------------
286
+ // Constructors
287
+ //---------------------------------------------------------------------
288
+
289
+ /// Constructor
290
+ __device__ __forceinline__ BlockReverseScan(
291
+ TempStorage &temp_storage)
292
+ :
293
+ temp_storage(temp_storage.Alias()),
294
+ linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))
295
+ {}
296
+
297
+
298
+ /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
299
+ template <
300
+ typename ScanOp,
301
+ typename BlockPostfixCallbackOp>
302
+ __device__ __forceinline__ void ExclusiveReverseScan(
303
+ T input, ///< [in] Calling thread's input item
304
+ T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input)
305
+ ScanOp scan_op, ///< [in] Binary scan operator
306
+ BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.
307
+ {
308
+ if (WARP_SYNCHRONOUS) {
309
+ // Short-circuit directly to warp-synchronous scan
310
+ T block_aggregate;
311
+ WarpReverseScan warp_scan;
312
+ warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);
313
+ // Obtain warp-wide postfix in lane0, then broadcast to other lanes
314
+ T block_postfix = block_postfix_callback_op(block_aggregate);
315
+ block_postfix = warp_scan.Broadcast(block_postfix, 0);
316
+ exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);
317
+ } else {
318
+ // Place thread partial into shared memory raking grid
319
+ T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);
320
+ detail::uninitialized_copy(placement_ptr, input);
321
+ cub::CTA_SYNC();
322
+ // Reduce parallelism down to just raking threads
323
+ if (linear_tid < RAKING_THREADS) {
324
+ WarpReverseScan warp_scan;
325
+ // Raking upsweep reduction across shared partials
326
+ T upsweep_partial = Upsweep(scan_op);
327
+ // Warp-synchronous scan
328
+ T exclusive_partial, block_aggregate;
329
+ warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);
330
+ // Obtain block-wide postfix in lane0, then broadcast to other lanes
331
+ T block_postfix = block_postfix_callback_op(block_aggregate);
332
+ block_postfix = warp_scan.Broadcast(block_postfix, 0);
333
+ // Update postfix with warpscan exclusive partial
334
+ T downsweep_postfix = linear_tid == RAKING_THREADS - 1
335
+ ? block_postfix : scan_op(block_postfix, exclusive_partial);
336
+ // Exclusive raking downsweep scan
337
+ ExclusiveDownsweep(scan_op, downsweep_postfix);
338
+ }
339
+ cub::CTA_SYNC();
340
+ // Grab thread postfix from shared memory
341
+ exclusive_output = *placement_ptr;
342
+
343
+ // // Compute warp scan in each warp.
344
+ // // The exclusive output from the last lane in each warp is invalid.
345
+ // T inclusive_output;
346
+ // WarpReverseScan warp_scan;
347
+ // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);
348
+
349
+ // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid.
350
+ // T block_aggregate;
351
+ // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);
352
+
353
+ // // Apply warp postfix to our lane's partial
354
+ // if (warp_id != 0) {
355
+ // exclusive_output = scan_op(warp_postfix, exclusive_output);
356
+ // if (lane_id == 0) { exclusive_output = warp_postfix; }
357
+ // }
358
+
359
+ // // Use the first warp to determine the thread block postfix, returning the result in lane0
360
+ // if (warp_id == 0) {
361
+ // T block_postfix = block_postfix_callback_op(block_aggregate);
362
+ // if (lane_id == 0) {
363
+ // // Share the postfix with all threads
364
+ // detail::uninitialized_copy(&temp_storage.block_postfix,
365
+ // block_postfix);
366
+
367
+ // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0
368
+ // }
369
+ // }
370
+
371
+ // cub::CTA_SYNC();
372
+
373
+ // // Incorporate thread block postfix into outputs
374
+ // T block_postfix = temp_storage.block_postfix;
375
+ // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }
376
+ }
377
+ }
378
+
379
+
380
+ /**
381
+ * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
382
+ */
383
+ template <
384
+ int ITEMS_PER_THREAD,
385
+ typename ScanOp,
386
+ typename BlockPostfixCallbackOp>
387
+ __device__ __forceinline__ void InclusiveReverseScan(
388
+ T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
389
+ T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input)
390
+ ScanOp scan_op, ///< [in] Binary scan functor
391
+ BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.
392
+ {
393
+ // Reduce consecutive thread items in registers
394
+ T thread_postfix = ThreadReverseReduce(input, scan_op);
395
+ // Exclusive thread block-scan
396
+ ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);
397
+ // Inclusive scan in registers with postfix as seed
398
+ ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);
399
+ }
400
+
401
+ };
mamba_install/csrc/selective_scan/selective_scan.cpp ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <torch/extension.h>
8
+ #include <vector>
9
+
10
+ #include "selective_scan.h"
11
+
12
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13
+
14
+ #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
15
+ if (ITYPE == at::ScalarType::Half) { \
16
+ using input_t = at::Half; \
17
+ __VA_ARGS__(); \
18
+ } else if (ITYPE == at::ScalarType::BFloat16) { \
19
+ using input_t = at::BFloat16; \
20
+ __VA_ARGS__(); \
21
+ } else if (ITYPE == at::ScalarType::Float) { \
22
+ using input_t = float; \
23
+ __VA_ARGS__(); \
24
+ } else { \
25
+ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
26
+ }
27
+
28
+ #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
29
+ if (WTYPE == at::ScalarType::Half) { \
30
+ using weight_t = at::Half; \
31
+ __VA_ARGS__(); \
32
+ } else if (WTYPE == at::ScalarType::BFloat16) { \
33
+ using weight_t = at::BFloat16; \
34
+ __VA_ARGS__(); \
35
+ } else if (WTYPE == at::ScalarType::Float) { \
36
+ using weight_t = float; \
37
+ __VA_ARGS__(); \
38
+ } else { \
39
+ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
40
+ }
41
+
42
+ #define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \
43
+ if (WTYPE == at::ScalarType::Float) { \
44
+ using weight_t = float; \
45
+ __VA_ARGS__(); \
46
+ } else if (WTYPE == at::ScalarType::ComplexFloat) { \
47
+ using weight_t = c10::complex<float>; \
48
+ __VA_ARGS__(); \
49
+ } else { \
50
+ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
51
+ }
52
+
53
+ template<typename input_t, typename weight_t>
54
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
55
+
56
+ template <typename input_t, typename weight_t>
57
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream);
58
+
59
+ void set_ssm_params_fwd(SSMParamsBase &params,
60
+ // sizes
61
+ const size_t batch,
62
+ const size_t dim,
63
+ const size_t seqlen,
64
+ const size_t dstate,
65
+ const size_t n_groups,
66
+ const size_t n_chunks,
67
+ const bool is_variable_B,
68
+ const bool is_variable_C,
69
+ // device pointers
70
+ const at::Tensor u,
71
+ const at::Tensor delta,
72
+ const at::Tensor A,
73
+ const at::Tensor B,
74
+ const at::Tensor C,
75
+ const at::Tensor out,
76
+ const at::Tensor z,
77
+ const at::Tensor out_z,
78
+ void* D_ptr,
79
+ void* delta_bias_ptr,
80
+ void* x_ptr,
81
+ bool has_z,
82
+ bool delta_softplus) {
83
+
84
+ // Reset the parameters
85
+ memset(&params, 0, sizeof(params));
86
+
87
+ params.batch = batch;
88
+ params.dim = dim;
89
+ params.seqlen = seqlen;
90
+ params.dstate = dstate;
91
+ params.n_groups = n_groups;
92
+ params.n_chunks = n_chunks;
93
+ params.dim_ngroups_ratio = dim / n_groups;
94
+
95
+ params.delta_softplus = delta_softplus;
96
+
97
+ params.is_variable_B = is_variable_B;
98
+ params.is_variable_C = is_variable_C;
99
+
100
+ // Set the pointers and strides.
101
+ params.u_ptr = u.data_ptr();
102
+ params.delta_ptr = delta.data_ptr();
103
+ params.A_ptr = A.data_ptr();
104
+ params.B_ptr = B.data_ptr();
105
+ params.C_ptr = C.data_ptr();
106
+ params.D_ptr = D_ptr;
107
+ params.delta_bias_ptr = delta_bias_ptr;
108
+ params.out_ptr = out.data_ptr();
109
+ params.x_ptr = x_ptr;
110
+ params.z_ptr = has_z ? z.data_ptr() : nullptr;
111
+ params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
112
+ // All stride are in elements, not bytes.
113
+ params.A_d_stride = A.stride(0);
114
+ params.A_dstate_stride = A.stride(1);
115
+ if (!is_variable_B) {
116
+ params.B_d_stride = B.stride(0);
117
+ } else {
118
+ params.B_batch_stride = B.stride(0);
119
+ params.B_group_stride = B.stride(1);
120
+ }
121
+ params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
122
+ if (!is_variable_C) {
123
+ params.C_d_stride = C.stride(0);
124
+ } else {
125
+ params.C_batch_stride = C.stride(0);
126
+ params.C_group_stride = C.stride(1);
127
+ }
128
+ params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
129
+ params.u_batch_stride = u.stride(0);
130
+ params.u_d_stride = u.stride(1);
131
+ params.delta_batch_stride = delta.stride(0);
132
+ params.delta_d_stride = delta.stride(1);
133
+ if (has_z) {
134
+ params.z_batch_stride = z.stride(0);
135
+ params.z_d_stride = z.stride(1);
136
+ params.out_z_batch_stride = out_z.stride(0);
137
+ params.out_z_d_stride = out_z.stride(1);
138
+ }
139
+ params.out_batch_stride = out.stride(0);
140
+ params.out_d_stride = out.stride(1);
141
+ }
142
+
143
+ void set_ssm_params_bwd(SSMParamsBwd &params,
144
+ // sizes
145
+ const size_t batch,
146
+ const size_t dim,
147
+ const size_t seqlen,
148
+ const size_t dstate,
149
+ const size_t n_groups,
150
+ const size_t n_chunks,
151
+ const bool is_variable_B,
152
+ const bool is_variable_C,
153
+ // device pointers
154
+ const at::Tensor u,
155
+ const at::Tensor delta,
156
+ const at::Tensor A,
157
+ const at::Tensor B,
158
+ const at::Tensor C,
159
+ const at::Tensor z,
160
+ const at::Tensor out,
161
+ const at::Tensor out_z,
162
+ void* D_ptr,
163
+ void* delta_bias_ptr,
164
+ void* x_ptr,
165
+ const at::Tensor dout,
166
+ const at::Tensor du,
167
+ const at::Tensor ddelta,
168
+ const at::Tensor dA,
169
+ const at::Tensor dB,
170
+ const at::Tensor dC,
171
+ const at::Tensor dz,
172
+ void* dD_ptr,
173
+ void* ddelta_bias_ptr,
174
+ bool has_z,
175
+ bool delta_softplus,
176
+ bool recompute_out_z) {
177
+ // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
178
+ set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
179
+ u, delta, A, B, C, has_z ? out : dout,
180
+ has_z ? z : dout,
181
+ // If not recompute_out_z, pass dout instead of out_z.
182
+ // This won't be used by the bwd kernel
183
+ recompute_out_z ? out_z : dout,
184
+ D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
185
+ if (!recompute_out_z) { params.out_z_ptr = nullptr; }
186
+
187
+ // Set the pointers and strides.
188
+ params.dout_ptr = dout.data_ptr();
189
+ params.du_ptr = du.data_ptr();
190
+ params.dA_ptr = dA.data_ptr();
191
+ params.dB_ptr = dB.data_ptr();
192
+ params.dC_ptr = dC.data_ptr();
193
+ params.dD_ptr = dD_ptr;
194
+ params.ddelta_ptr = ddelta.data_ptr();
195
+ params.ddelta_bias_ptr = ddelta_bias_ptr;
196
+ params.dz_ptr = has_z ? dz.data_ptr() : nullptr;
197
+ // All stride are in elements, not bytes.
198
+ params.dout_batch_stride = dout.stride(0);
199
+ params.dout_d_stride = dout.stride(1);
200
+ params.dA_d_stride = dA.stride(0);
201
+ params.dA_dstate_stride = dA.stride(1);
202
+ if (!is_variable_B) {
203
+ params.dB_d_stride = dB.stride(0);
204
+ } else {
205
+ params.dB_batch_stride = dB.stride(0);
206
+ params.dB_group_stride = dB.stride(1);
207
+ }
208
+ params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2);
209
+ if (!is_variable_C) {
210
+ params.dC_d_stride = dC.stride(0);
211
+ } else {
212
+ params.dC_batch_stride = dC.stride(0);
213
+ params.dC_group_stride = dC.stride(1);
214
+ }
215
+ params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2);
216
+ params.du_batch_stride = du.stride(0);
217
+ params.du_d_stride = du.stride(1);
218
+ params.ddelta_batch_stride = ddelta.stride(0);
219
+ params.ddelta_d_stride = ddelta.stride(1);
220
+ if (has_z) {
221
+ params.dz_batch_stride = dz.stride(0);
222
+ params.dz_d_stride = dz.stride(1);
223
+ }
224
+ }
225
+
226
+ std::vector<at::Tensor>
227
+ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
228
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
229
+ const c10::optional<at::Tensor> &D_,
230
+ const c10::optional<at::Tensor> &z_,
231
+ const c10::optional<at::Tensor> &delta_bias_,
232
+ bool delta_softplus) {
233
+ auto input_type = u.scalar_type();
234
+ auto weight_type = A.scalar_type();
235
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
236
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
237
+
238
+ const bool is_variable_B = B.dim() >= 3;
239
+ const bool is_variable_C = C.dim() >= 3;
240
+ const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
241
+
242
+ TORCH_CHECK(delta.scalar_type() == input_type);
243
+ TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
244
+ TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
245
+
246
+ TORCH_CHECK(u.is_cuda());
247
+ TORCH_CHECK(delta.is_cuda());
248
+ TORCH_CHECK(A.is_cuda());
249
+ TORCH_CHECK(B.is_cuda());
250
+ TORCH_CHECK(C.is_cuda());
251
+
252
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
253
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
254
+
255
+ const auto sizes = u.sizes();
256
+ const int batch_size = sizes[0];
257
+ const int dim = sizes[1];
258
+ const int seqlen = sizes[2];
259
+ const int dstate = A.size(1);
260
+ const int n_groups = is_variable_B ? B.size(1) : 1;
261
+
262
+ TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
263
+
264
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
265
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
266
+ CHECK_SHAPE(A, dim, dstate);
267
+ if (!is_variable_B) {
268
+ CHECK_SHAPE(B, dim, dstate);
269
+ } else {
270
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
271
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
272
+ }
273
+ if (!is_variable_C) {
274
+ CHECK_SHAPE(C, dim, dstate);
275
+ } else {
276
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
277
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
278
+ }
279
+
280
+ if (D_.has_value()) {
281
+ auto D = D_.value();
282
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
283
+ TORCH_CHECK(D.is_cuda());
284
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
285
+ CHECK_SHAPE(D, dim);
286
+ }
287
+
288
+ if (delta_bias_.has_value()) {
289
+ auto delta_bias = delta_bias_.value();
290
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
291
+ TORCH_CHECK(delta_bias.is_cuda());
292
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
293
+ CHECK_SHAPE(delta_bias, dim);
294
+ }
295
+
296
+ at::Tensor z, out_z;
297
+ const bool has_z = z_.has_value();
298
+ if (has_z) {
299
+ z = z_.value();
300
+ TORCH_CHECK(z.scalar_type() == input_type);
301
+ TORCH_CHECK(z.is_cuda());
302
+ TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
303
+ CHECK_SHAPE(z, batch_size, dim, seqlen);
304
+ out_z = torch::empty_like(z);
305
+ }
306
+
307
+ const int n_chunks = (seqlen + 2048 - 1) / 2048;
308
+ // const int n_chunks = (seqlen + 1024 - 1) / 1024;
309
+ // at::Tensor out = torch::empty_like(u);
310
+ // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
311
+ at::Tensor out = torch::empty_like(delta);
312
+ at::Tensor x;
313
+ x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
314
+
315
+ SSMParamsBase params;
316
+ set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
317
+ u, delta, A, B, C, out, z, out_z,
318
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
319
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
320
+ x.data_ptr(),
321
+ has_z,
322
+ delta_softplus);
323
+
324
+ // Otherwise the kernel will be launched from cuda:0 device
325
+ // Cast to char to avoid compiler warning about narrowing
326
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
327
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
328
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
329
+ DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] {
330
+ selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
331
+ });
332
+ });
333
+ std::vector<at::Tensor> result = {out, x};
334
+ if (has_z) { result.push_back(out_z); }
335
+ return result;
336
+ }
337
+
338
+ std::vector<at::Tensor>
339
+ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
340
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
341
+ const c10::optional<at::Tensor> &D_,
342
+ const c10::optional<at::Tensor> &z_,
343
+ const c10::optional<at::Tensor> &delta_bias_,
344
+ const at::Tensor &dout,
345
+ const c10::optional<at::Tensor> &x_,
346
+ const c10::optional<at::Tensor> &out_,
347
+ c10::optional<at::Tensor> &dz_,
348
+ bool delta_softplus,
349
+ bool recompute_out_z) {
350
+ auto input_type = u.scalar_type();
351
+ auto weight_type = A.scalar_type();
352
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
353
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
354
+
355
+ const bool is_variable_B = B.dim() >= 3;
356
+ const bool is_variable_C = C.dim() >= 3;
357
+ const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
358
+
359
+ TORCH_CHECK(delta.scalar_type() == input_type);
360
+ TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
361
+ TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
362
+ TORCH_CHECK(dout.scalar_type() == input_type);
363
+
364
+ TORCH_CHECK(u.is_cuda());
365
+ TORCH_CHECK(delta.is_cuda());
366
+ TORCH_CHECK(A.is_cuda());
367
+ TORCH_CHECK(B.is_cuda());
368
+ TORCH_CHECK(C.is_cuda());
369
+ TORCH_CHECK(dout.is_cuda());
370
+
371
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
372
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
373
+ TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
374
+
375
+ const auto sizes = u.sizes();
376
+ const int batch_size = sizes[0];
377
+ const int dim = sizes[1];
378
+ const int seqlen = sizes[2];
379
+ const int dstate = A.size(1);
380
+ const int n_groups = is_variable_B ? B.size(1) : 1;
381
+
382
+ TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
383
+
384
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
385
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
386
+ CHECK_SHAPE(A, dim, dstate);
387
+ if (!is_variable_B) {
388
+ CHECK_SHAPE(B, dim, dstate);
389
+ } else {
390
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
391
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
392
+ }
393
+ if (!is_variable_C) {
394
+ CHECK_SHAPE(C, dim, dstate);
395
+ } else {
396
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
397
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
398
+ }
399
+ CHECK_SHAPE(dout, batch_size, dim, seqlen);
400
+
401
+ if (D_.has_value()) {
402
+ auto D = D_.value();
403
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
404
+ TORCH_CHECK(D.is_cuda());
405
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
406
+ CHECK_SHAPE(D, dim);
407
+ }
408
+
409
+ if (delta_bias_.has_value()) {
410
+ auto delta_bias = delta_bias_.value();
411
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
412
+ TORCH_CHECK(delta_bias.is_cuda());
413
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
414
+ CHECK_SHAPE(delta_bias, dim);
415
+ }
416
+
417
+ at::Tensor z, out, dz, out_z;
418
+ const bool has_z = z_.has_value();
419
+ if (has_z) {
420
+ z = z_.value();
421
+ TORCH_CHECK(z.scalar_type() == input_type);
422
+ TORCH_CHECK(z.is_cuda());
423
+ TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
424
+ CHECK_SHAPE(z, batch_size, dim, seqlen);
425
+
426
+ TORCH_CHECK(out_.has_value());
427
+ out = out_.value();
428
+ TORCH_CHECK(out.scalar_type() == input_type);
429
+ TORCH_CHECK(out.is_cuda());
430
+ TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1);
431
+ CHECK_SHAPE(out, batch_size, dim, seqlen);
432
+
433
+ if (dz_.has_value()) {
434
+ dz = dz_.value();
435
+ TORCH_CHECK(dz.scalar_type() == input_type);
436
+ TORCH_CHECK(dz.is_cuda());
437
+ TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1);
438
+ CHECK_SHAPE(dz, batch_size, dim, seqlen);
439
+ } else {
440
+ dz = torch::empty_like(z);
441
+ }
442
+ if (recompute_out_z) {
443
+ out_z = torch::empty_like(out);
444
+ }
445
+ }
446
+
447
+ const int n_chunks = (seqlen + 2048 - 1) / 2048;
448
+ // const int n_chunks = (seqlen + 1024 - 1) / 1024;
449
+ if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
450
+ if (x_.has_value()) {
451
+ auto x = x_.value();
452
+ TORCH_CHECK(x.scalar_type() == weight_type);
453
+ TORCH_CHECK(x.is_cuda());
454
+ TORCH_CHECK(x.is_contiguous());
455
+ CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
456
+ }
457
+
458
+ at::Tensor du = torch::empty_like(u);
459
+ at::Tensor ddelta = torch::empty_like(delta);
460
+ at::Tensor dA = torch::zeros_like(A);
461
+ at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32));
462
+ at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32));
463
+ at::Tensor dD;
464
+ if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
465
+ at::Tensor ddelta_bias;
466
+ if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
467
+
468
+ SSMParamsBwd params;
469
+ set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
470
+ u, delta, A, B, C, z, out, out_z,
471
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
472
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
473
+ x_.has_value() ? x_.value().data_ptr() : nullptr,
474
+ dout, du, ddelta, dA, dB, dC, dz,
475
+ D_.has_value() ? dD.data_ptr() : nullptr,
476
+ delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
477
+ has_z, delta_softplus, recompute_out_z);
478
+
479
+ // Otherwise the kernel will be launched from cuda:0 device
480
+ // Cast to char to avoid compiler warning about narrowing
481
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
482
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
483
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
484
+ DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] {
485
+ selective_scan_bwd_cuda<input_t, weight_t>(params, stream);
486
+ });
487
+ });
488
+ std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
489
+ if (has_z) { result.push_back(dz); }
490
+ if (recompute_out_z) { result.push_back(out_z); }
491
+ return result;
492
+ }
493
+
494
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
495
+ m.def("fwd", &selective_scan_fwd, "Selective scan forward");
496
+ m.def("bwd", &selective_scan_bwd, "Selective scan backward");
497
+ }
mamba_install/csrc/selective_scan/selective_scan.h ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
8
+
9
+ struct SSMScanParamsBase {
10
+ using index_t = uint32_t;
11
+
12
+ int batch, seqlen, n_chunks;
13
+ index_t a_batch_stride;
14
+ index_t b_batch_stride;
15
+ index_t out_batch_stride;
16
+
17
+ // Common data pointers.
18
+ void *__restrict__ a_ptr;
19
+ void *__restrict__ b_ptr;
20
+ void *__restrict__ out_ptr;
21
+ void *__restrict__ x_ptr;
22
+ };
23
+
24
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
25
+
26
+ struct SSMParamsBase {
27
+ using index_t = uint32_t;
28
+
29
+ int batch, dim, seqlen, dstate, n_groups, n_chunks;
30
+ int dim_ngroups_ratio;
31
+ bool is_variable_B;
32
+ bool is_variable_C;
33
+
34
+ bool delta_softplus;
35
+
36
+ index_t A_d_stride;
37
+ index_t A_dstate_stride;
38
+ index_t B_batch_stride;
39
+ index_t B_d_stride;
40
+ index_t B_dstate_stride;
41
+ index_t B_group_stride;
42
+ index_t C_batch_stride;
43
+ index_t C_d_stride;
44
+ index_t C_dstate_stride;
45
+ index_t C_group_stride;
46
+ index_t u_batch_stride;
47
+ index_t u_d_stride;
48
+ index_t delta_batch_stride;
49
+ index_t delta_d_stride;
50
+ index_t z_batch_stride;
51
+ index_t z_d_stride;
52
+ index_t out_batch_stride;
53
+ index_t out_d_stride;
54
+ index_t out_z_batch_stride;
55
+ index_t out_z_d_stride;
56
+
57
+ // Common data pointers.
58
+ void *__restrict__ A_ptr;
59
+ void *__restrict__ B_ptr;
60
+ void *__restrict__ C_ptr;
61
+ void *__restrict__ D_ptr;
62
+ void *__restrict__ u_ptr;
63
+ void *__restrict__ delta_ptr;
64
+ void *__restrict__ delta_bias_ptr;
65
+ void *__restrict__ out_ptr;
66
+ void *__restrict__ x_ptr;
67
+ void *__restrict__ z_ptr;
68
+ void *__restrict__ out_z_ptr;
69
+ };
70
+
71
+ struct SSMParamsBwd: public SSMParamsBase {
72
+ index_t dout_batch_stride;
73
+ index_t dout_d_stride;
74
+ index_t dA_d_stride;
75
+ index_t dA_dstate_stride;
76
+ index_t dB_batch_stride;
77
+ index_t dB_group_stride;
78
+ index_t dB_d_stride;
79
+ index_t dB_dstate_stride;
80
+ index_t dC_batch_stride;
81
+ index_t dC_group_stride;
82
+ index_t dC_d_stride;
83
+ index_t dC_dstate_stride;
84
+ index_t du_batch_stride;
85
+ index_t du_d_stride;
86
+ index_t dz_batch_stride;
87
+ index_t dz_d_stride;
88
+ index_t ddelta_batch_stride;
89
+ index_t ddelta_d_stride;
90
+
91
+ // Common data pointers.
92
+ void *__restrict__ dout_ptr;
93
+ void *__restrict__ dA_ptr;
94
+ void *__restrict__ dB_ptr;
95
+ void *__restrict__ dC_ptr;
96
+ void *__restrict__ dD_ptr;
97
+ void *__restrict__ du_ptr;
98
+ void *__restrict__ dz_ptr;
99
+ void *__restrict__ ddelta_ptr;
100
+ void *__restrict__ ddelta_bias_ptr;
101
+ };
mamba_install/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::BFloat16, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
mamba_install/csrc/selective_scan/selective_scan_bwd_bf16_real.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::BFloat16, float>(SSMParamsBwd &params, cudaStream_t stream);
mamba_install/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::Half, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
mamba_install/csrc/selective_scan/selective_scan_bwd_fp16_real.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);
mamba_install/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<float, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
mamba_install/csrc/selective_scan/selective_scan_bwd_fp32_real.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<float, float>(SSMParamsBwd &params, cudaStream_t stream);
mamba_install/csrc/selective_scan/selective_scan_bwd_kernel.cuh ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+ #include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
11
+
12
+ #include <cub/block/block_load.cuh>
13
+ #include <cub/block/block_store.cuh>
14
+ #include <cub/block/block_scan.cuh>
15
+ #include <cub/block/block_reduce.cuh>
16
+
17
+ #include "selective_scan.h"
18
+ #include "selective_scan_common.h"
19
+ #include "reverse_scan.cuh"
20
+ #include "static_switch.h"
21
+
22
+ template<typename scalar_t> __device__ __forceinline__ scalar_t conj(scalar_t x);
23
+ template<> __device__ __forceinline__ float conj<float>(float x) { return x; }
24
+ template<> __device__ __forceinline__ complex_t conj<complex_t>(complex_t x) { return std::conj(x); }
25
+
26
+ template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kIsVariableB_, bool kIsVariableC_,
27
+ bool kDeltaSoftplus_, bool kHasZ_, typename input_t_, typename weight_t_>
28
+ struct Selective_Scan_bwd_kernel_traits {
29
+ static_assert(kNItems_ % 4 == 0);
30
+ using input_t = input_t_;
31
+ using weight_t = weight_t_;
32
+ static constexpr int kNThreads = kNThreads_;
33
+ static constexpr int kNItems = kNItems_;
34
+ static constexpr int kNBytes = sizeof(input_t);
35
+ static_assert(kNBytes == 2 || kNBytes == 4);
36
+ static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
37
+ static_assert(kNItems % kNElts == 0);
38
+ static constexpr int kNLoads = kNItems / kNElts;
39
+ static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
40
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
41
+ static constexpr bool kIsVariableB = kIsVariableB_;
42
+ static constexpr bool kIsVariableC = kIsVariableC_;
43
+ static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
44
+ static constexpr bool kHasZ = kHasZ_;
45
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
46
+ // For complex this would lead to massive register spilling, so we keep it at 2.
47
+ static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2;
48
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
49
+ using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
50
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
51
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
52
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
53
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
54
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
55
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
56
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
57
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
58
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
59
+ using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
60
+ using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
61
+ using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
62
+ using BlockReduceComplexT = cub::BlockReduce<complex_t, kNThreads>;
63
+ using BlockExchangeT = cub::BlockExchange<float, kNThreads, !kIsComplex ? kNItems : kNItems * 2>;
64
+ static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
65
+ sizeof(typename BlockLoadVecT::TempStorage),
66
+ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
67
+ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
68
+ sizeof(typename BlockStoreT::TempStorage),
69
+ sizeof(typename BlockStoreVecT::TempStorage)});
70
+ static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage);
71
+ static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
72
+ static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
73
+ };
74
+
75
+ template<typename Ktraits>
76
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
77
+ void selective_scan_bwd_kernel(SSMParamsBwd params) {
78
+ constexpr bool kIsComplex = Ktraits::kIsComplex;
79
+ constexpr bool kIsVariableB = Ktraits::kIsVariableB;
80
+ constexpr bool kIsVariableC = Ktraits::kIsVariableC;
81
+ constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
82
+ constexpr bool kHasZ = Ktraits::kHasZ;
83
+ constexpr int kNThreads = Ktraits::kNThreads;
84
+ constexpr int kNItems = Ktraits::kNItems;
85
+ using input_t = typename Ktraits::input_t;
86
+ using weight_t = typename Ktraits::weight_t;
87
+ using scan_t = typename Ktraits::scan_t;
88
+
89
+ // Shared memory.
90
+ extern __shared__ char smem_[];
91
+ // cast to lvalue reference of expected type
92
+ // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
93
+ // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
94
+ // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
95
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
96
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
97
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
98
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
99
+ auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
100
+ auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
101
+ auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
102
+ auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
103
+ auto& smem_reduce_complex = *reinterpret_cast<typename Ktraits::BlockReduceComplexT::TempStorage*>(&smem_reduce);
104
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
105
+ auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
106
+ weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
107
+ scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * MAX_DSTATE + kNThreads);
108
+ weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + MAX_DSTATE);
109
+ weight_t *smem_dbc = reinterpret_cast<weight_t *>(smem_da + MAX_DSTATE);
110
+
111
+ const int batch_id = blockIdx.x;
112
+ const int dim_id = blockIdx.y;
113
+ const int group_id = dim_id / (params.dim_ngroups_ratio);
114
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
115
+ + dim_id * params.u_d_stride;
116
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
117
+ + dim_id * params.delta_d_stride;
118
+ input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
119
+ + dim_id * params.dout_d_stride;
120
+ weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
121
+ weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * params.B_d_stride;
122
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
123
+ weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * params.C_d_stride;
124
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
125
+ weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride;
126
+ weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
127
+ + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride);
128
+ weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
129
+ + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride);
130
+ float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
131
+ float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
132
+ float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
133
+ float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
134
+ scan_t *x = params.x_ptr == nullptr
135
+ ? nullptr
136
+ : reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
137
+ float dD_val = 0;
138
+ float ddelta_bias_val = 0;
139
+
140
+ constexpr int kChunkSize = kNThreads * kNItems;
141
+ u += (params.n_chunks - 1) * kChunkSize;
142
+ delta += (params.n_chunks - 1) * kChunkSize;
143
+ dout += (params.n_chunks - 1) * kChunkSize;
144
+ Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
145
+ Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
146
+ for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
147
+ input_t u_vals[kNItems];
148
+ input_t delta_vals_load[kNItems];
149
+ input_t dout_vals_load[kNItems];
150
+ __syncthreads();
151
+ load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
152
+ u -= kChunkSize;
153
+ __syncthreads();
154
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
155
+ // Will reload delta at the same location if kDeltaSoftplus
156
+ if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
157
+ __syncthreads();
158
+ load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
159
+ dout -= kChunkSize;
160
+
161
+ float dout_vals[kNItems], delta_vals[kNItems];
162
+ #pragma unroll
163
+ for (int i = 0; i < kNItems; ++i) {
164
+ dout_vals[i] = float(dout_vals_load[i]);
165
+ delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
166
+ if constexpr (kDeltaSoftplus) {
167
+ delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
168
+ }
169
+ }
170
+
171
+ if constexpr (kHasZ) {
172
+ input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
173
+ + dim_id * params.z_d_stride + chunk * kChunkSize;
174
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
175
+ + dim_id * params.out_d_stride + chunk * kChunkSize;
176
+ input_t *dz = reinterpret_cast<input_t *>(params.dz_ptr) + batch_id * params.dz_batch_stride
177
+ + dim_id * params.dz_d_stride + chunk * kChunkSize;
178
+ input_t z_vals[kNItems], out_vals[kNItems];
179
+ __syncthreads();
180
+ load_input<Ktraits>(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
181
+ __syncthreads();
182
+ load_input<Ktraits>(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize);
183
+ float dz_vals[kNItems], z_silu_vals[kNItems];
184
+ #pragma unroll
185
+ for (int i = 0; i < kNItems; ++i) {
186
+ float z_val = z_vals[i];
187
+ float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val));
188
+ z_silu_vals[i] = z_val * z_sigmoid_val;
189
+ dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val
190
+ * (1.0f + z_val * (1.0f - z_sigmoid_val));
191
+ dout_vals[i] *= z_silu_vals[i];
192
+ }
193
+ __syncthreads();
194
+ store_output<Ktraits>(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize);
195
+ if (params.out_z_ptr != nullptr) { // Recompute and store out_z
196
+ float out_z_vals[kNItems];
197
+ #pragma unroll
198
+ for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; }
199
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
200
+ // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]);
201
+ // }
202
+ input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
203
+ + dim_id * params.out_z_d_stride + chunk * kChunkSize;
204
+ __syncthreads();
205
+ store_output<Ktraits>(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize);
206
+ }
207
+ }
208
+
209
+ float du_vals[kNItems];
210
+ #pragma unroll
211
+ for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
212
+ #pragma unroll
213
+ for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
214
+
215
+ float ddelta_vals[kNItems] = {0};
216
+ __syncthreads();
217
+ for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
218
+ const weight_t A_val = A[state_idx * params.A_dstate_stride];
219
+ // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
220
+ weight_t A_scaled;
221
+ constexpr float kLog2e = M_LOG2E;
222
+ if constexpr (!kIsComplex) {
223
+ A_scaled = A_val * kLog2e;
224
+ } else {
225
+ A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_);
226
+ }
227
+ weight_t B_val, C_val;
228
+ weight_t B_vals[kNItems], C_vals[kNItems];
229
+ if constexpr (!kIsVariableB) {
230
+ B_val = B[state_idx * params.B_dstate_stride];
231
+ } else {
232
+ load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
233
+ smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
234
+ }
235
+ if constexpr (!kIsVariableC) {
236
+ C_val = C[state_idx * params.C_dstate_stride];
237
+ } else {
238
+ auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
239
+ load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
240
+ smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
241
+ }
242
+ // const weight_t A_val = smem_a[state_idx];
243
+ scan_t thread_data[kNItems], thread_reverse_data[kNItems];
244
+ if constexpr (!kIsComplex) {
245
+ #pragma unroll
246
+ for (int i = 0; i < kNItems; ++i) {
247
+ const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
248
+ thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
249
+ if (i == 0) {
250
+ smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
251
+ } else {
252
+ thread_reverse_data[i - 1].x = delta_a_exp;
253
+ }
254
+ thread_reverse_data[i].y = dout_vals[i] *
255
+ (!kIsVariableC
256
+ ? (!kIsVariableB ? B_val * C_val : C_val)
257
+ : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
258
+ }
259
+ __syncthreads();
260
+ thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
261
+ ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
262
+ : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
263
+ // Initialize running total
264
+ scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
265
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
266
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
267
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
268
+ );
269
+ scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
270
+ SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
271
+ Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
272
+ thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
273
+ );
274
+ if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
275
+ weight_t dA_val = 0, dBC_val = 0;
276
+ weight_t dB_vals[kNItems], dC_vals[kNItems];
277
+ #pragma unroll
278
+ for (int i = 0; i < kNItems; ++i) {
279
+ const float dx = thread_reverse_data[i].y;
280
+ const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i];
281
+ du_vals[i] += ddelta_u * delta_vals[i];
282
+ const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
283
+ ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
284
+ dA_val += dx * delta_vals[i] * a;
285
+ if constexpr (!kIsVariableB || !kIsVariableC) {
286
+ if constexpr (!kIsVariableB) { // dBC_val is dB_val
287
+ dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]);
288
+ } else { // dBC_val is dC_val
289
+ dBC_val += dout_vals[i] * thread_data[i].y;
290
+ }
291
+ }
292
+ if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
293
+ if constexpr (kIsVariableC) {
294
+ dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y);
295
+ }
296
+ }
297
+ // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
298
+ if constexpr (kIsVariableB || kIsVariableC) {
299
+ if constexpr (kIsVariableB) {
300
+ Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
301
+ }
302
+ if constexpr (kIsVariableC) {
303
+ auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
304
+ Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
305
+ }
306
+ const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
307
+ weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
308
+ weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
309
+ #pragma unroll
310
+ for (int i = 0; i < kNItems; ++i) {
311
+ if (i * kNThreads < seqlen_remaining) {
312
+ if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
313
+ if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
314
+ }
315
+ }
316
+ }
317
+ if constexpr (!kIsVariableB || !kIsVariableC) {
318
+ float2 dA_dBC_val = make_float2(dA_val, dBC_val);
319
+ dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
320
+ dA_val = dA_dBC_val.x;
321
+ if (threadIdx.x == 0) {
322
+ smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx];
323
+ }
324
+ } else {
325
+ dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
326
+ }
327
+ if (threadIdx.x == 0) {
328
+ smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
329
+ }
330
+ } else {
331
+ #pragma unroll
332
+ for (int i = 0; i < kNItems; ++i) {
333
+ // Pytorch's implementation of complex exp (which calls thrust) is very slow
334
+ complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);
335
+ weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);
336
+ thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
337
+ if (i == 0) {
338
+ smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
339
+ } else {
340
+ thread_reverse_data[i - 1].x = delta_a_exp.real_;
341
+ thread_reverse_data[i - 1].y = -delta_a_exp.imag_;
342
+ }
343
+ complex_t dout_BC = 2 * dout_vals[i]
344
+ * conj(!kIsVariableC
345
+ ? (!kIsVariableB ? B_val * C_val : C_val)
346
+ : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
347
+ thread_reverse_data[i].z = dout_BC.real_;
348
+ thread_reverse_data[i].w = dout_BC.imag_;
349
+ }
350
+ __syncthreads();
351
+ complex_t delta_a_exp = threadIdx.x == kNThreads - 1
352
+ ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
353
+ : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
354
+ thread_reverse_data[kNItems - 1].x = delta_a_exp.real_;
355
+ thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_;
356
+ // Initialize running total
357
+ scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
358
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
359
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
360
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
361
+ );
362
+ scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
363
+ SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
364
+ Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
365
+ thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
366
+ );
367
+ if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
368
+ weight_t dA_val = 0, dBC_val = 0;
369
+ weight_t dB_vals[kNItems], dC_vals[kNItems];
370
+ #pragma unroll
371
+ for (int i = 0; i < kNItems; ++i) {
372
+ complex_t x = complex_t(thread_data[i].z, thread_data[i].w);
373
+ complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w);
374
+ float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_;
375
+ if constexpr (!kIsVariableB || !kIsVariableC) {
376
+ if constexpr (!kIsVariableB) { // dBC_val is dB_val
377
+ dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]);
378
+ } else { // dBC_val is dC_val
379
+ dBC_val += (2 * dout_vals[i]) * conj(x);
380
+ }
381
+ }
382
+ const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]));
383
+ du_vals[i] += ddelta_u * delta_vals[i];
384
+ ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_;
385
+ dA_val += delta_vals[i] * dx * a_conj;
386
+ if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
387
+ if constexpr (kIsVariableC) {
388
+ dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x);
389
+ }
390
+ }
391
+ // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
392
+ if constexpr (kIsVariableB || kIsVariableC) {
393
+ float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2];
394
+ if constexpr (kIsVariableB) {
395
+ #pragma unroll
396
+ for (int i = 0; i < kNItems; ++i) {
397
+ dB_vals_f[i * 2] = dB_vals[i].real_;
398
+ dB_vals_f[i * 2 + 1] = dB_vals[i].imag_;
399
+ }
400
+ Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f);
401
+ }
402
+ if constexpr (kIsVariableC) {
403
+ #pragma unroll
404
+ for (int i = 0; i < kNItems; ++i) {
405
+ dC_vals_f[i * 2] = dC_vals[i].real_;
406
+ dC_vals_f[i * 2 + 1] = dC_vals[i].imag_;
407
+ }
408
+ auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
409
+ Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f);
410
+ }
411
+ const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x;
412
+ float *dB_cur = reinterpret_cast<float *>(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
413
+ float *dC_cur = reinterpret_cast<float *>(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
414
+ #pragma unroll
415
+ for (int i = 0; i < kNItems * 2; ++i) {
416
+ if (i * kNThreads < seqlen_remaining) {
417
+ if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); }
418
+ if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); }
419
+ }
420
+ }
421
+ }
422
+ if constexpr (!kIsVariableB || !kIsVariableC) {
423
+ float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_);
424
+ dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
425
+ dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y);
426
+ dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w);
427
+ if (threadIdx.x == 0) {
428
+ smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx];
429
+ }
430
+ } else {
431
+ dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);
432
+ }
433
+ if (threadIdx.x == 0) {
434
+ smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
435
+ }
436
+ }
437
+ }
438
+
439
+ if constexpr (kDeltaSoftplus) {
440
+ __syncthreads();
441
+ input_t delta_vals_load[kNItems];
442
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
443
+ delta -= kChunkSize;
444
+ #pragma unroll
445
+ for (int i = 0; i < kNItems; ++i) {
446
+ float delta_val = float(delta_vals_load[i]) + delta_bias;
447
+ float delta_val_neg_exp = expf(-delta_val);
448
+ ddelta_vals[i] = delta_val <= 20.f
449
+ ? ddelta_vals[i] / (1.f + delta_val_neg_exp)
450
+ : ddelta_vals[i];
451
+ }
452
+ }
453
+ for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
454
+
455
+ input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
456
+ + dim_id * params.du_d_stride + chunk * kChunkSize;
457
+ input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
458
+ + dim_id * params.ddelta_d_stride + chunk * kChunkSize;
459
+ __syncthreads();
460
+ store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
461
+ __syncthreads();
462
+ store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
463
+
464
+ Bvar -= kChunkSize * (!kIsComplex ? 1 : 2);
465
+ Cvar -= kChunkSize * (!kIsComplex ? 1 : 2);
466
+ }
467
+ if (params.dD_ptr != nullptr) {
468
+ dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
469
+ if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
470
+ }
471
+ if (params.ddelta_bias_ptr != nullptr) {
472
+ __syncthreads();
473
+ ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
474
+ if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
475
+ }
476
+ for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
477
+ gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]);
478
+ weight_t dBC_val;
479
+ if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; }
480
+ if constexpr (!kIsVariableB) {
481
+ gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]),
482
+ !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val);
483
+ }
484
+ if constexpr (!kIsVariableC) {
485
+ gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]),
486
+ !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val);
487
+ }
488
+ }
489
+ }
490
+
491
+ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
492
+ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
493
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
494
+ BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
495
+ BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
496
+ BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
497
+ BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
498
+ using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
499
+ // using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
500
+ // TODO: check this
501
+ constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);
502
+ // printf("smem_size = %d\n", kSmemSize);
503
+ dim3 grid(params.batch, params.dim);
504
+ auto kernel = &selective_scan_bwd_kernel<Ktraits>;
505
+ if (kSmemSize >= 48 * 1024) {
506
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
507
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
508
+ }
509
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
510
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
511
+ });
512
+ });
513
+ });
514
+ });
515
+ });
516
+ }
517
+
518
+ template<typename input_t, typename weight_t>
519
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
520
+ if (params.seqlen <= 128) {
521
+ selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
522
+ } else if (params.seqlen <= 256) {
523
+ selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
524
+ } else if (params.seqlen <= 512) {
525
+ selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
526
+ } else if (params.seqlen <= 1024) {
527
+ selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
528
+ } else {
529
+ selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
530
+ }
531
+ }
mamba_install/csrc/selective_scan/selective_scan_common.h ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <cuda_bf16.h>
8
+ #include <cuda_fp16.h>
9
+ #include <c10/util/complex.h> // For scalar_value_type
10
+
11
+ #define MAX_DSTATE 256
12
+
13
+ using complex_t = c10::complex<float>;
14
+
15
+ inline __device__ float2 operator+(const float2 & a, const float2 & b){
16
+ return {a.x + b.x, a.y + b.y};
17
+ }
18
+
19
+ inline __device__ float3 operator+(const float3 &a, const float3 &b) {
20
+ return {a.x + b.x, a.y + b.y, a.z + b.z};
21
+ }
22
+
23
+ inline __device__ float4 operator+(const float4 & a, const float4 & b){
24
+ return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
25
+ }
26
+
27
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
28
+
29
+ template<int BYTES> struct BytesToType {};
30
+
31
+ template<> struct BytesToType<16> {
32
+ using Type = uint4;
33
+ static_assert(sizeof(Type) == 16);
34
+ };
35
+
36
+ template<> struct BytesToType<8> {
37
+ using Type = uint64_t;
38
+ static_assert(sizeof(Type) == 8);
39
+ };
40
+
41
+ template<> struct BytesToType<4> {
42
+ using Type = uint32_t;
43
+ static_assert(sizeof(Type) == 4);
44
+ };
45
+
46
+ template<> struct BytesToType<2> {
47
+ using Type = uint16_t;
48
+ static_assert(sizeof(Type) == 2);
49
+ };
50
+
51
+ template<> struct BytesToType<1> {
52
+ using Type = uint8_t;
53
+ static_assert(sizeof(Type) == 1);
54
+ };
55
+
56
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
57
+
58
+ template<typename scalar_t, int N>
59
+ struct Converter{
60
+ static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
61
+ #pragma unroll
62
+ for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
63
+ }
64
+ };
65
+
66
+ template<int N>
67
+ struct Converter<at::Half, N>{
68
+ static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
69
+ static_assert(N % 2 == 0);
70
+ auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
71
+ auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
72
+ #pragma unroll
73
+ for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
74
+ }
75
+ };
76
+
77
+ #if __CUDA_ARCH__ >= 800
78
+ template<int N>
79
+ struct Converter<at::BFloat16, N>{
80
+ static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
81
+ static_assert(N % 2 == 0);
82
+ auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
83
+ auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
84
+ #pragma unroll
85
+ for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
86
+ }
87
+ };
88
+ #endif
89
+
90
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
91
+
92
+ // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp
93
+ // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696
94
+ __device__ __forceinline__ complex_t cexp2f(complex_t z) {
95
+ float t = exp2f(z.real_);
96
+ float c, s;
97
+ sincosf(z.imag_, &s, &c);
98
+ return complex_t(c * t, s * t);
99
+ }
100
+
101
+ __device__ __forceinline__ complex_t cexpf(complex_t z) {
102
+ float t = expf(z.real_);
103
+ float c, s;
104
+ sincosf(z.imag_, &s, &c);
105
+ return complex_t(c * t, s * t);
106
+ }
107
+
108
+ template<typename scalar_t> struct SSMScanOp;
109
+
110
+ template<>
111
+ struct SSMScanOp<float> {
112
+ __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
113
+ return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
114
+ }
115
+ };
116
+
117
+ template<>
118
+ struct SSMScanOp<complex_t> {
119
+ __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {
120
+ complex_t a0 = complex_t(ab0.x, ab0.y);
121
+ complex_t b0 = complex_t(ab0.z, ab0.w);
122
+ complex_t a1 = complex_t(ab1.x, ab1.y);
123
+ complex_t b1 = complex_t(ab1.z, ab1.w);
124
+ complex_t out_a = a1 * a0;
125
+ complex_t out_b = a1 * b0 + b1;
126
+ return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);
127
+ }
128
+ };
129
+
130
+ // A stateful callback functor that maintains a running prefix to be applied
131
+ // during consecutive scan operations.
132
+ template <typename scalar_t> struct SSMScanPrefixCallbackOp {
133
+ using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
134
+ scan_t running_prefix;
135
+ // Constructor
136
+ __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
137
+ // Callback operator to be entered by the first warp of threads in the block.
138
+ // Thread-0 is responsible for returning a value for seeding the block-wide scan.
139
+ __device__ scan_t operator()(scan_t block_aggregate) {
140
+ scan_t old_prefix = running_prefix;
141
+ running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
142
+ return old_prefix;
143
+ }
144
+ };
145
+
146
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
147
+
148
+ template<typename Ktraits>
149
+ inline __device__ void load_input(typename Ktraits::input_t *u,
150
+ typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
151
+ typename Ktraits::BlockLoadT::TempStorage &smem_load,
152
+ int seqlen) {
153
+ if constexpr (Ktraits::kIsEvenLen) {
154
+ auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
155
+ using vec_t = typename Ktraits::vec_t;
156
+ Ktraits::BlockLoadVecT(smem_load_vec).Load(
157
+ reinterpret_cast<vec_t*>(u),
158
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
159
+ );
160
+ } else {
161
+ Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
162
+ }
163
+ }
164
+
165
+ template<typename Ktraits>
166
+ inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
167
+ typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
168
+ typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
169
+ int seqlen) {
170
+ constexpr int kNItems = Ktraits::kNItems;
171
+ if constexpr (!Ktraits::kIsComplex) {
172
+ typename Ktraits::input_t B_vals_load[kNItems];
173
+ if constexpr (Ktraits::kIsEvenLen) {
174
+ auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
175
+ using vec_t = typename Ktraits::vec_t;
176
+ Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
177
+ reinterpret_cast<vec_t*>(Bvar),
178
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
179
+ );
180
+ } else {
181
+ Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
182
+ }
183
+ // #pragma unroll
184
+ // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
185
+ Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
186
+ } else {
187
+ typename Ktraits::input_t B_vals_load[kNItems * 2];
188
+ if constexpr (Ktraits::kIsEvenLen) {
189
+ auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
190
+ using vec_t = typename Ktraits::vec_t;
191
+ Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
192
+ reinterpret_cast<vec_t*>(Bvar),
193
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads * 2]>(B_vals_load)
194
+ );
195
+ } else {
196
+ Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
197
+ }
198
+ #pragma unroll
199
+ for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }
200
+ }
201
+ }
202
+
203
+ template<typename Ktraits>
204
+ inline __device__ void store_output(typename Ktraits::input_t *out,
205
+ const float (&out_vals)[Ktraits::kNItems],
206
+ typename Ktraits::BlockStoreT::TempStorage &smem_store,
207
+ int seqlen) {
208
+ typename Ktraits::input_t write_vals[Ktraits::kNItems];
209
+ #pragma unroll
210
+ for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
211
+ if constexpr (Ktraits::kIsEvenLen) {
212
+ auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
213
+ using vec_t = typename Ktraits::vec_t;
214
+ Ktraits::BlockStoreVecT(smem_store_vec).Store(
215
+ reinterpret_cast<vec_t*>(out),
216
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
217
+ );
218
+ } else {
219
+ Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
220
+ }
221
+ }
mamba_install/csrc/selective_scan/selective_scan_fwd_bf16.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_fwd_kernel.cuh"
8
+
9
+ template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
10
+ template void selective_scan_fwd_cuda<at::BFloat16, complex_t>(SSMParamsBase &params, cudaStream_t stream);
mamba_install/csrc/selective_scan/selective_scan_fwd_fp16.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_fwd_kernel.cuh"
8
+
9
+ template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
10
+ template void selective_scan_fwd_cuda<at::Half, complex_t>(SSMParamsBase &params, cudaStream_t stream);
mamba_install/csrc/selective_scan/selective_scan_fwd_fp32.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_fwd_kernel.cuh"
8
+
9
+ template void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaStream_t stream);
10
+ template void selective_scan_fwd_cuda<float, complex_t>(SSMParamsBase &params, cudaStream_t stream);
mamba_install/csrc/selective_scan/selective_scan_fwd_kernel.cuh ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+
11
+ #include <cub/block/block_load.cuh>
12
+ #include <cub/block/block_store.cuh>
13
+ #include <cub/block/block_scan.cuh>
14
+
15
+ #include "selective_scan.h"
16
+ #include "selective_scan_common.h"
17
+ #include "static_switch.h"
18
+
19
+ template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
20
+ bool kIsVariableB_, bool kIsVariableC_,
21
+ bool kHasZ_, typename input_t_, typename weight_t_>
22
+ struct Selective_Scan_fwd_kernel_traits {
23
+ static_assert(kNItems_ % 4 == 0);
24
+ using input_t = input_t_;
25
+ using weight_t = weight_t_;
26
+ static constexpr int kNThreads = kNThreads_;
27
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
28
+ static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
29
+ static constexpr int kNItems = kNItems_;
30
+ static constexpr int kNRows = kNRows_;
31
+ static constexpr int kNBytes = sizeof(input_t);
32
+ static_assert(kNBytes == 2 || kNBytes == 4);
33
+ static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
34
+ static_assert(kNItems % kNElts == 0);
35
+ static constexpr int kNLoads = kNItems / kNElts;
36
+ static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
37
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
38
+ static constexpr bool kIsVariableB = kIsVariableB_;
39
+ static constexpr bool kIsVariableC = kIsVariableC_;
40
+ static constexpr bool kHasZ = kHasZ_;
41
+
42
+ static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
43
+
44
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
45
+ using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
46
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
47
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
48
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
49
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
50
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2,
51
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
52
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
53
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
54
+ !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
55
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
56
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
57
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
58
+ static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
59
+ sizeof(typename BlockLoadVecT::TempStorage),
60
+ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
61
+ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
62
+ sizeof(typename BlockStoreT::TempStorage),
63
+ sizeof(typename BlockStoreVecT::TempStorage)});
64
+ static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
65
+ };
66
+
67
+ template<typename Ktraits>
68
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
69
+ void selective_scan_fwd_kernel(SSMParamsBase params) {
70
+ constexpr bool kIsComplex = Ktraits::kIsComplex;
71
+ constexpr bool kIsVariableB = Ktraits::kIsVariableB;
72
+ constexpr bool kIsVariableC = Ktraits::kIsVariableC;
73
+ constexpr bool kHasZ = Ktraits::kHasZ;
74
+ constexpr int kNThreads = Ktraits::kNThreads;
75
+ constexpr int kNItems = Ktraits::kNItems;
76
+ constexpr int kNRows = Ktraits::kNRows;
77
+ constexpr bool kDirectIO = Ktraits::kDirectIO;
78
+ using input_t = typename Ktraits::input_t;
79
+ using weight_t = typename Ktraits::weight_t;
80
+ using scan_t = typename Ktraits::scan_t;
81
+
82
+ // Shared memory.
83
+ extern __shared__ char smem_[];
84
+ // cast to lvalue reference of expected type
85
+ // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
86
+ // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
87
+ // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
88
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
89
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
90
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
91
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
92
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
93
+ // weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
94
+ // weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
95
+ scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
96
+
97
+ const int batch_id = blockIdx.x;
98
+ const int dim_id = blockIdx.y;
99
+ const int group_id = dim_id / (params.dim_ngroups_ratio);
100
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
101
+ + dim_id * kNRows * params.u_d_stride;
102
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
103
+ + dim_id * kNRows * params.delta_d_stride;
104
+ weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
105
+ weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
106
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
107
+ weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
108
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
109
+ scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
110
+
111
+ float D_val[kNRows] = {0};
112
+ if (params.D_ptr != nullptr) {
113
+ #pragma unroll
114
+ for (int r = 0; r < kNRows; ++r) {
115
+ D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
116
+ }
117
+ }
118
+ float delta_bias[kNRows] = {0};
119
+ if (params.delta_bias_ptr != nullptr) {
120
+ #pragma unroll
121
+ for (int r = 0; r < kNRows; ++r) {
122
+ delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
123
+ }
124
+ }
125
+
126
+ // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
127
+ // smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
128
+ // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
129
+ // }
130
+
131
+ constexpr int kChunkSize = kNThreads * kNItems;
132
+ for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
133
+ input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
134
+ __syncthreads();
135
+ #pragma unroll
136
+ for (int r = 0; r < kNRows; ++r) {
137
+ if constexpr (!kDirectIO) {
138
+ if (r > 0) { __syncthreads(); }
139
+ }
140
+ load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
141
+ if constexpr (!kDirectIO) { __syncthreads(); }
142
+ load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
143
+ }
144
+ u += kChunkSize;
145
+ delta += kChunkSize;
146
+
147
+ float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
148
+ #pragma unroll
149
+ for (int r = 0; r < kNRows; ++r) {
150
+ #pragma unroll
151
+ for (int i = 0; i < kNItems; ++i) {
152
+ float u_val = float(u_vals[r][i]);
153
+ delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
154
+ if (params.delta_softplus) {
155
+ delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
156
+ }
157
+ delta_u_vals[r][i] = delta_vals[r][i] * u_val;
158
+ out_vals[r][i] = D_val[r] * u_val;
159
+ }
160
+ }
161
+
162
+ __syncthreads();
163
+ for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
164
+ weight_t A_val[kNRows];
165
+ #pragma unroll
166
+ for (int r = 0; r < kNRows; ++r) {
167
+ A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
168
+ // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
169
+ constexpr float kLog2e = M_LOG2E;
170
+ if constexpr (!kIsComplex) {
171
+ A_val[r] *= kLog2e;
172
+ } else {
173
+ A_val[r].real_ *= kLog2e;
174
+ }
175
+ }
176
+ // This variable holds B * C if both B and C are constant across seqlen. If only B varies
177
+ // across seqlen, this holds C. If only C varies across seqlen, this holds B.
178
+ // If both B and C vary, this is unused.
179
+ weight_t BC_val[kNRows];
180
+ weight_t B_vals[kNItems], C_vals[kNItems];
181
+ if constexpr (kIsVariableB) {
182
+ load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
183
+ smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
184
+ if constexpr (!kIsVariableC) {
185
+ #pragma unroll
186
+ for (int r = 0; r < kNRows; ++r) {
187
+ BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
188
+ }
189
+ }
190
+ }
191
+ if constexpr (kIsVariableC) {
192
+ auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
193
+ load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
194
+ smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
195
+ if constexpr (!kIsVariableB) {
196
+ #pragma unroll
197
+ for (int r = 0; r < kNRows; ++r) {
198
+ BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
199
+ }
200
+ }
201
+ }
202
+ if constexpr (!kIsVariableB && !kIsVariableC) {
203
+ #pragma unroll
204
+ for (int r = 0; r < kNRows; ++r) {
205
+ BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
206
+ }
207
+ }
208
+
209
+ #pragma unroll
210
+ for (int r = 0; r < kNRows; ++r) {
211
+ if (r > 0) { __syncthreads(); } // Scan could be using the same smem
212
+ scan_t thread_data[kNItems];
213
+ #pragma unroll
214
+ for (int i = 0; i < kNItems; ++i) {
215
+ if constexpr (!kIsComplex) {
216
+ thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
217
+ !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
218
+ if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
219
+ if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
220
+ thread_data[i] = make_float2(1.f, 0.f);
221
+ }
222
+ }
223
+ } else {
224
+ // Pytorch's implementation of complex exp (which calls thrust) is very slow
225
+ complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]);
226
+ weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];
227
+ thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
228
+ if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
229
+ if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
230
+ thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f);
231
+ }
232
+ }
233
+ }
234
+ }
235
+ // Initialize running total
236
+ scan_t running_prefix;
237
+ if constexpr (!kIsComplex) {
238
+ // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
239
+ running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f);
240
+ // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
241
+ } else {
242
+ running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f);
243
+ // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
244
+ }
245
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
246
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
247
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
248
+ );
249
+ // There's a syncthreads in the scan op, so we don't need to sync here.
250
+ // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
251
+ if (threadIdx.x == 0) {
252
+ smem_running_prefix[state_idx] = prefix_op.running_prefix;
253
+ x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
254
+ }
255
+ #pragma unroll
256
+ for (int i = 0; i < kNItems; ++i) {
257
+ const weight_t C_val = !kIsVariableC
258
+ ? BC_val[r]
259
+ : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
260
+ if constexpr (!kIsComplex) {
261
+ out_vals[r][i] += thread_data[i].y * C_val;
262
+ } else {
263
+ out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2;
264
+ }
265
+ }
266
+ }
267
+ }
268
+
269
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
270
+ + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
271
+ __syncthreads();
272
+ #pragma unroll
273
+ for (int r = 0; r < kNRows; ++r) {
274
+ if constexpr (!kDirectIO) {
275
+ if (r > 0) { __syncthreads(); }
276
+ }
277
+ store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
278
+ }
279
+
280
+ if constexpr (kHasZ) {
281
+ input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
282
+ + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
283
+ input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
284
+ + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
285
+ #pragma unroll
286
+ for (int r = 0; r < kNRows; ++r) {
287
+ input_t z_vals[kNItems];
288
+ __syncthreads();
289
+ load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
290
+ #pragma unroll
291
+ for (int i = 0; i < kNItems; ++i) {
292
+ float z_val = z_vals[i];
293
+ out_vals[r][i] *= z_val / (1 + expf(-z_val));
294
+ }
295
+ __syncthreads();
296
+ store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
297
+ }
298
+ }
299
+
300
+ Bvar += kChunkSize * (!kIsComplex ? 1 : 2);
301
+ Cvar += kChunkSize * (!kIsComplex ? 1 : 2);
302
+ }
303
+ }
304
+
305
+ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
306
+ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
307
+ // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
308
+ // processing 1 row.
309
+ constexpr int kNRows = 1;
310
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
311
+ BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
312
+ BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
313
+ BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
314
+ using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
315
+ // constexpr int kSmemSize = Ktraits::kSmemSize;
316
+ constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
317
+ // printf("smem_size = %d\n", kSmemSize);
318
+ dim3 grid(params.batch, params.dim / kNRows);
319
+ auto kernel = &selective_scan_fwd_kernel<Ktraits>;
320
+ if (kSmemSize >= 48 * 1024) {
321
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
322
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
323
+ }
324
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
325
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
326
+ });
327
+ });
328
+ });
329
+ });
330
+ }
331
+
332
+ template<typename input_t, typename weight_t>
333
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
334
+ if (params.seqlen <= 128) {
335
+ selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
336
+ } else if (params.seqlen <= 256) {
337
+ selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
338
+ } else if (params.seqlen <= 512) {
339
+ selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
340
+ } else if (params.seqlen <= 1024) {
341
+ selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
342
+ } else {
343
+ selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
344
+ }
345
+ }
mamba_install/csrc/selective_scan/static_switch.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2
+ // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3
+
4
+ #pragma once
5
+
6
+ /// @param COND - a boolean expression to switch by
7
+ /// @param CONST_NAME - a name given for the constexpr bool variable.
8
+ /// @param ... - code to execute for true and false
9
+ ///
10
+ /// Usage:
11
+ /// ```
12
+ /// BOOL_SWITCH(flag, BoolConst, [&] {
13
+ /// some_function<BoolConst>(...);
14
+ /// });
15
+ /// ```
16
+ #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17
+ [&] { \
18
+ if (COND) { \
19
+ constexpr bool CONST_NAME = true; \
20
+ return __VA_ARGS__(); \
21
+ } else { \
22
+ constexpr bool CONST_NAME = false; \
23
+ return __VA_ARGS__(); \
24
+ } \
25
+ }()
mamba_install/csrc/selective_scan/uninitialized_copy.cuh ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Redistribution and use in source and binary forms, with or without
5
+ * modification, are permitted provided that the following conditions are met:
6
+ * * Redistributions of source code must retain the above copyright
7
+ * notice, this list of conditions and the following disclaimer.
8
+ * * Redistributions in binary form must reproduce the above copyright
9
+ * notice, this list of conditions and the following disclaimer in the
10
+ * documentation and/or other materials provided with the distribution.
11
+ * * Neither the name of the NVIDIA CORPORATION nor the
12
+ * names of its contributors may be used to endorse or promote products
13
+ * derived from this software without specific prior written permission.
14
+ *
15
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18
+ * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ *
26
+ ******************************************************************************/
27
+
28
+ #pragma once
29
+
30
+ #include <cub/config.cuh>
31
+
32
+ #include <cuda/std/type_traits>
33
+
34
+
35
+ namespace detail
36
+ {
37
+
38
+ #if defined(_NVHPC_CUDA)
39
+ template <typename T, typename U>
40
+ __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
41
+ {
42
+ // NVBug 3384810
43
+ new (ptr) T(::cuda::std::forward<U>(val));
44
+ }
45
+ #else
46
+ template <typename T,
47
+ typename U,
48
+ typename ::cuda::std::enable_if<
49
+ ::cuda::std::is_trivially_copyable<T>::value,
50
+ int
51
+ >::type = 0>
52
+ __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
53
+ {
54
+ *ptr = ::cuda::std::forward<U>(val);
55
+ }
56
+
57
+ template <typename T,
58
+ typename U,
59
+ typename ::cuda::std::enable_if<
60
+ !::cuda::std::is_trivially_copyable<T>::value,
61
+ int
62
+ >::type = 0>
63
+ __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
64
+ {
65
+ new (ptr) T(::cuda::std::forward<U>(val));
66
+ }
67
+ #endif
68
+
69
+ } // namespace detail
mamba_install/evals/lm_harness_eval.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import transformers
4
+ from transformers import AutoTokenizer
5
+
6
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
7
+
8
+ from lm_eval.api.model import LM
9
+ from lm_eval.models.huggingface import HFLM
10
+ from lm_eval.api.registry import register_model
11
+ from lm_eval.__main__ import cli_evaluate
12
+
13
+
14
+ @register_model("mamba")
15
+ class MambaEvalWrapper(HFLM):
16
+
17
+ AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
18
+
19
+ def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda",
20
+ dtype=torch.float16):
21
+ LM.__init__(self)
22
+ self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype)
23
+ self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
24
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
25
+ self.vocab_size = self.tokenizer.vocab_size
26
+ self._batch_size = int(batch_size) if batch_size is not None else 64
27
+ self._max_length = max_length
28
+ self._device = torch.device(device)
29
+
30
+ @property
31
+ def batch_size(self):
32
+ return self._batch_size
33
+
34
+ def _model_generate(self, context, max_length, stop, **generation_kwargs):
35
+ raise NotImplementedError()
36
+
37
+
38
+ if __name__ == "__main__":
39
+ cli_evaluate()
mamba_install/mamba_ssm/.DS_Store ADDED
Binary file (6.15 kB). View file
 
mamba_install/mamba_ssm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __version__ = "1.2.2"
2
+
3
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
+ from mamba_ssm.modules.mamba_simple import Mamba
5
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
mamba_install/mamba_ssm/models/__init__.py ADDED
File without changes
mamba_install/mamba_ssm/models/config_mamba.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class MambaConfig:
6
+
7
+ d_model: int = 2560
8
+ n_layer: int = 64
9
+ vocab_size: int = 50277
10
+ ssm_cfg: dict = field(default_factory=dict)
11
+ rms_norm: bool = True
12
+ residual_in_fp32: bool = True
13
+ fused_add_norm: bool = True
14
+ pad_vocab_size_multiple: int = 8
15
+ tie_embeddings: bool = True
mamba_install/mamba_ssm/models/mixer_seq_simple.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+ import json
6
+ import os
7
+
8
+ from collections import namedtuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from mamba_ssm.models.config_mamba import MambaConfig
14
+ from mamba_ssm.modules.mamba_simple import Mamba, Block
15
+ from mamba_ssm.utils.generation import GenerationMixin
16
+ from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
17
+
18
+ try:
19
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
20
+ except ImportError:
21
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
22
+
23
+
24
+ def create_block(
25
+ d_model,
26
+ ssm_cfg=None,
27
+ norm_epsilon=1e-5,
28
+ rms_norm=False,
29
+ residual_in_fp32=False,
30
+ fused_add_norm=False,
31
+ layer_idx=None,
32
+ device=None,
33
+ dtype=None,
34
+ ):
35
+ if ssm_cfg is None:
36
+ ssm_cfg = {}
37
+ factory_kwargs = {"device": device, "dtype": dtype}
38
+ mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
39
+ norm_cls = partial(
40
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
41
+ )
42
+ block = Block(
43
+ d_model,
44
+ mixer_cls,
45
+ norm_cls=norm_cls,
46
+ fused_add_norm=fused_add_norm,
47
+ residual_in_fp32=residual_in_fp32,
48
+ )
49
+ block.layer_idx = layer_idx
50
+ return block
51
+
52
+
53
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
54
+ def _init_weights(
55
+ module,
56
+ n_layer,
57
+ initializer_range=0.02, # Now only used for embedding layer.
58
+ rescale_prenorm_residual=True,
59
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
60
+ ):
61
+ if isinstance(module, nn.Linear):
62
+ if module.bias is not None:
63
+ if not getattr(module.bias, "_no_reinit", False):
64
+ nn.init.zeros_(module.bias)
65
+ elif isinstance(module, nn.Embedding):
66
+ nn.init.normal_(module.weight, std=initializer_range)
67
+
68
+ if rescale_prenorm_residual:
69
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
70
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
71
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
72
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
73
+ #
74
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
75
+ for name, p in module.named_parameters():
76
+ if name in ["out_proj.weight", "fc2.weight"]:
77
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
78
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
79
+ # We need to reinit p since this code could be called multiple times
80
+ # Having just p *= scale would repeatedly scale it down
81
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
82
+ with torch.no_grad():
83
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
84
+
85
+
86
+ class MixerModel(nn.Module):
87
+ def __init__(
88
+ self,
89
+ d_model: int,
90
+ n_layer: int,
91
+ vocab_size: int,
92
+ ssm_cfg=None,
93
+ norm_epsilon: float = 1e-5,
94
+ rms_norm: bool = False,
95
+ initializer_cfg=None,
96
+ fused_add_norm=False,
97
+ residual_in_fp32=False,
98
+ device=None,
99
+ dtype=None,
100
+ ) -> None:
101
+ factory_kwargs = {"device": device, "dtype": dtype}
102
+ super().__init__()
103
+ self.residual_in_fp32 = residual_in_fp32
104
+
105
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
106
+
107
+ # We change the order of residual and layer norm:
108
+ # Instead of LN -> Attn / MLP -> Add, we do:
109
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
110
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
111
+ # This is for performance reason: we can fuse add + layer_norm.
112
+ self.fused_add_norm = fused_add_norm
113
+ if self.fused_add_norm:
114
+ if layer_norm_fn is None or rms_norm_fn is None:
115
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
116
+
117
+ self.layers = nn.ModuleList(
118
+ [
119
+ create_block(
120
+ d_model,
121
+ ssm_cfg=ssm_cfg,
122
+ norm_epsilon=norm_epsilon,
123
+ rms_norm=rms_norm,
124
+ residual_in_fp32=residual_in_fp32,
125
+ fused_add_norm=fused_add_norm,
126
+ layer_idx=i,
127
+ **factory_kwargs,
128
+ )
129
+ for i in range(n_layer)
130
+ ]
131
+ )
132
+
133
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
134
+ d_model, eps=norm_epsilon, **factory_kwargs
135
+ )
136
+
137
+ self.apply(
138
+ partial(
139
+ _init_weights,
140
+ n_layer=n_layer,
141
+ **(initializer_cfg if initializer_cfg is not None else {}),
142
+ )
143
+ )
144
+
145
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
146
+ return {
147
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
148
+ for i, layer in enumerate(self.layers)
149
+ }
150
+
151
+ def forward(self, input_ids, inference_params=None):
152
+ hidden_states = self.embedding(input_ids)
153
+ residual = None
154
+ for layer in self.layers:
155
+ hidden_states, residual = layer(
156
+ hidden_states, residual, inference_params=inference_params
157
+ )
158
+ if not self.fused_add_norm:
159
+ residual = (hidden_states + residual) if residual is not None else hidden_states
160
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
161
+ else:
162
+ # Set prenorm=False here since we don't need the residual
163
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
164
+ hidden_states = fused_add_norm_fn(
165
+ hidden_states,
166
+ self.norm_f.weight,
167
+ self.norm_f.bias,
168
+ eps=self.norm_f.eps,
169
+ residual=residual,
170
+ prenorm=False,
171
+ residual_in_fp32=self.residual_in_fp32,
172
+ )
173
+ return hidden_states
174
+
175
+
176
+ class MambaLMHeadModel(nn.Module, GenerationMixin):
177
+
178
+ def __init__(
179
+ self,
180
+ config: MambaConfig,
181
+ initializer_cfg=None,
182
+ device=None,
183
+ dtype=None,
184
+ ) -> None:
185
+ self.config = config
186
+ d_model = config.d_model
187
+ n_layer = config.n_layer
188
+ vocab_size = config.vocab_size
189
+ ssm_cfg = config.ssm_cfg
190
+ rms_norm = config.rms_norm
191
+ residual_in_fp32 = config.residual_in_fp32
192
+ fused_add_norm = config.fused_add_norm
193
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
194
+ factory_kwargs = {"device": device, "dtype": dtype}
195
+
196
+ super().__init__()
197
+ if vocab_size % pad_vocab_size_multiple != 0:
198
+ vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
199
+ self.backbone = MixerModel(
200
+ d_model=d_model,
201
+ n_layer=n_layer,
202
+ vocab_size=vocab_size,
203
+ ssm_cfg=ssm_cfg,
204
+ rms_norm=rms_norm,
205
+ initializer_cfg=initializer_cfg,
206
+ fused_add_norm=fused_add_norm,
207
+ residual_in_fp32=residual_in_fp32,
208
+ **factory_kwargs,
209
+ )
210
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
211
+
212
+ # Initialize weights and apply final processing
213
+ self.apply(
214
+ partial(
215
+ _init_weights,
216
+ n_layer=n_layer,
217
+ **(initializer_cfg if initializer_cfg is not None else {}),
218
+ )
219
+ )
220
+ self.tie_weights()
221
+
222
+ def tie_weights(self):
223
+ if self.config.tie_embeddings:
224
+ self.lm_head.weight = self.backbone.embedding.weight
225
+
226
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
227
+ return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
228
+
229
+ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
230
+ """
231
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
232
+ num_last_tokens: if > 0, only return the logits for the last n tokens
233
+ """
234
+ hidden_states = self.backbone(input_ids, inference_params=inference_params)
235
+ if num_last_tokens > 0:
236
+ hidden_states = hidden_states[:, -num_last_tokens:]
237
+ lm_logits = self.lm_head(hidden_states)
238
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
239
+ return CausalLMOutput(logits=lm_logits)
240
+
241
+ @classmethod
242
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
243
+ config_data = load_config_hf(pretrained_model_name)
244
+ config = MambaConfig(**config_data)
245
+ model = cls(config, device=device, dtype=dtype, **kwargs)
246
+ model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
247
+ return model
248
+
249
+ def save_pretrained(self, save_directory):
250
+ """
251
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
252
+ Save the model and its configuration file to a directory.
253
+ """
254
+ # Ensure save_directory exists
255
+ os.makedirs(save_directory, exist_ok=True)
256
+
257
+ # Save the model's state_dict
258
+ model_path = os.path.join(save_directory, 'pytorch_model.bin')
259
+ torch.save(self.state_dict(), model_path)
260
+
261
+ # Save the configuration of the model
262
+ config_path = os.path.join(save_directory, 'config.json')
263
+ with open(config_path, 'w') as f:
264
+ json.dump(self.config.__dict__, f)
mamba_install/mamba_ssm/modules/__init__.py ADDED
File without changes
mamba_install/mamba_ssm/modules/mamba_simple.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from einops import rearrange, repeat
12
+
13
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
14
+
15
+ try:
16
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
17
+ except ImportError:
18
+ causal_conv1d_fn, causal_conv1d_update = None, None
19
+
20
+ try:
21
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
22
+ except ImportError:
23
+ selective_state_update = None
24
+
25
+ try:
26
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
27
+ except ImportError:
28
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
+
30
+
31
+ class Mamba(nn.Module):
32
+ def __init__(
33
+ self,
34
+ d_model,
35
+ d_state=16,
36
+ d_conv=4,
37
+ expand=2,
38
+ dt_rank="auto",
39
+ dt_min=0.001,
40
+ dt_max=0.1,
41
+ dt_init="random",
42
+ dt_scale=1.0,
43
+ dt_init_floor=1e-4,
44
+ conv_bias=True,
45
+ bias=False,
46
+ use_fast_path=True, # Fused kernel options
47
+ layer_idx=None,
48
+ device=None,
49
+ dtype=None,
50
+ ):
51
+ factory_kwargs = {"device": device, "dtype": dtype}
52
+ super().__init__()
53
+ self.d_model = d_model
54
+ self.d_state = d_state
55
+ self.d_conv = d_conv
56
+ self.expand = expand
57
+ self.d_inner = int(self.expand * self.d_model)
58
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
59
+ self.use_fast_path = use_fast_path
60
+ self.layer_idx = layer_idx
61
+
62
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
63
+
64
+ self.conv1d = nn.Conv1d(
65
+ in_channels=self.d_inner,
66
+ out_channels=self.d_inner,
67
+ bias=conv_bias,
68
+ kernel_size=d_conv,
69
+ groups=self.d_inner,
70
+ padding=d_conv - 1,
71
+ **factory_kwargs,
72
+ )
73
+
74
+ self.activation = "silu"
75
+ self.act = nn.SiLU()
76
+
77
+ self.x_proj = nn.Linear(
78
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
79
+ )
80
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
81
+
82
+ # Initialize special dt projection to preserve variance at initialization
83
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
84
+ if dt_init == "constant":
85
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
86
+ elif dt_init == "random":
87
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
88
+ else:
89
+ raise NotImplementedError
90
+
91
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
92
+ dt = torch.exp(
93
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
94
+ + math.log(dt_min)
95
+ ).clamp(min=dt_init_floor)
96
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
97
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
98
+ with torch.no_grad():
99
+ self.dt_proj.bias.copy_(inv_dt)
100
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
101
+ self.dt_proj.bias._no_reinit = True
102
+
103
+ # S4D real initialization
104
+ A = repeat(
105
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
106
+ "n -> d n",
107
+ d=self.d_inner,
108
+ ).contiguous()
109
+ A_log = torch.log(A) # Keep A_log in fp32
110
+ self.A_log = nn.Parameter(A_log)
111
+ self.A_log._no_weight_decay = True
112
+
113
+ # D "skip" parameter
114
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
115
+ self.D._no_weight_decay = True
116
+
117
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
118
+
119
+ def forward(self, hidden_states, inference_params=None):
120
+ """
121
+ hidden_states: (B, L, D)
122
+ Returns: same shape as hidden_states
123
+ """
124
+ batch, seqlen, dim = hidden_states.shape
125
+
126
+ conv_state, ssm_state = None, None
127
+ if inference_params is not None:
128
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
129
+ if inference_params.seqlen_offset > 0:
130
+ # The states are updated inplace
131
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
132
+ return out
133
+
134
+ # We do matmul and transpose BLH -> HBL at the same time
135
+ xz = rearrange(
136
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
137
+ "d (b l) -> b d l",
138
+ l=seqlen,
139
+ )
140
+ if self.in_proj.bias is not None:
141
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
142
+
143
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
144
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
145
+ if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
146
+ out = mamba_inner_fn(
147
+ xz,
148
+ self.conv1d.weight,
149
+ self.conv1d.bias,
150
+ self.x_proj.weight,
151
+ self.dt_proj.weight,
152
+ self.out_proj.weight,
153
+ self.out_proj.bias,
154
+ A,
155
+ None, # input-dependent B
156
+ None, # input-dependent C
157
+ self.D.float(),
158
+ delta_bias=self.dt_proj.bias.float(),
159
+ delta_softplus=True,
160
+ )
161
+ else:
162
+ x, z = xz.chunk(2, dim=1)
163
+ # Compute short convolution
164
+ if conv_state is not None:
165
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
166
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
167
+ conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
168
+ if causal_conv1d_fn is None:
169
+ x = self.act(self.conv1d(x)[..., :seqlen])
170
+ else:
171
+ assert self.activation in ["silu", "swish"]
172
+ x = causal_conv1d_fn(
173
+ x=x,
174
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
175
+ bias=self.conv1d.bias,
176
+ activation=self.activation,
177
+ )
178
+
179
+ # We're careful here about the layout, to avoid extra transposes.
180
+ # We want dt to have d as the slowest moving dimension
181
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
182
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
183
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
184
+ dt = self.dt_proj.weight @ dt.t()
185
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
186
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
187
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
188
+ assert self.activation in ["silu", "swish"]
189
+ y = selective_scan_fn(
190
+ x,
191
+ dt,
192
+ A,
193
+ B,
194
+ C,
195
+ self.D.float(),
196
+ z=z,
197
+ delta_bias=self.dt_proj.bias.float(),
198
+ delta_softplus=True,
199
+ return_last_state=ssm_state is not None,
200
+ )
201
+ if ssm_state is not None:
202
+ y, last_state = y
203
+ ssm_state.copy_(last_state)
204
+ y = rearrange(y, "b d l -> b l d")
205
+ out = self.out_proj(y)
206
+ return out
207
+
208
+ def step(self, hidden_states, conv_state, ssm_state):
209
+ dtype = hidden_states.dtype
210
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
211
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
212
+ x, z = xz.chunk(2, dim=-1) # (B D)
213
+
214
+ # Conv step
215
+ if causal_conv1d_update is None:
216
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
217
+ conv_state[:, :, -1] = x
218
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
219
+ if self.conv1d.bias is not None:
220
+ x = x + self.conv1d.bias
221
+ x = self.act(x).to(dtype=dtype)
222
+ else:
223
+ x = causal_conv1d_update(
224
+ x,
225
+ conv_state,
226
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
227
+ self.conv1d.bias,
228
+ self.activation,
229
+ )
230
+
231
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
232
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
233
+ # Don't add dt_bias here
234
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
235
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
236
+
237
+ # SSM step
238
+ if selective_state_update is None:
239
+ # Discretize A and B
240
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
241
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
242
+ dB = torch.einsum("bd,bn->bdn", dt, B)
243
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
244
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
245
+ y = y + self.D.to(dtype) * x
246
+ y = y * self.act(z) # (B D)
247
+ else:
248
+ y = selective_state_update(
249
+ ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
250
+ )
251
+
252
+ out = self.out_proj(y)
253
+ return out.unsqueeze(1), conv_state, ssm_state
254
+
255
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
256
+ device = self.out_proj.weight.device
257
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
258
+ conv_state = torch.zeros(
259
+ batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
260
+ )
261
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
262
+ # ssm_dtype = torch.float32
263
+ ssm_state = torch.zeros(
264
+ batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
265
+ )
266
+ return conv_state, ssm_state
267
+
268
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
269
+ assert self.layer_idx is not None
270
+ if self.layer_idx not in inference_params.key_value_memory_dict:
271
+ batch_shape = (batch_size,)
272
+ conv_state = torch.zeros(
273
+ batch_size,
274
+ self.d_model * self.expand,
275
+ self.d_conv,
276
+ device=self.conv1d.weight.device,
277
+ dtype=self.conv1d.weight.dtype,
278
+ )
279
+ ssm_state = torch.zeros(
280
+ batch_size,
281
+ self.d_model * self.expand,
282
+ self.d_state,
283
+ device=self.dt_proj.weight.device,
284
+ dtype=self.dt_proj.weight.dtype,
285
+ # dtype=torch.float32,
286
+ )
287
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
288
+ else:
289
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
290
+ # TODO: What if batch size changes between generation, and we reuse the same states?
291
+ if initialize_states:
292
+ conv_state.zero_()
293
+ ssm_state.zero_()
294
+ return conv_state, ssm_state
295
+
296
+
297
+ class Block(nn.Module):
298
+ def __init__(
299
+ self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
300
+ ):
301
+ """
302
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
303
+
304
+ This Block has a slightly different structure compared to a regular
305
+ prenorm Transformer block.
306
+ The standard block is: LN -> MHA/MLP -> Add.
307
+ [Ref: https://arxiv.org/abs/2002.04745]
308
+ Here we have: Add -> LN -> Mixer, returning both
309
+ the hidden_states (output of the mixer) and the residual.
310
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
311
+ The residual needs to be provided (except for the very first block).
312
+ """
313
+ super().__init__()
314
+ self.residual_in_fp32 = residual_in_fp32
315
+ self.fused_add_norm = fused_add_norm
316
+ self.mixer = mixer_cls(dim)
317
+ self.norm = norm_cls(dim)
318
+ if self.fused_add_norm:
319
+ assert RMSNorm is not None, "RMSNorm import fails"
320
+ assert isinstance(
321
+ self.norm, (nn.LayerNorm, RMSNorm)
322
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
323
+
324
+ def forward(
325
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
326
+ ):
327
+ r"""Pass the input through the encoder layer.
328
+
329
+ Args:
330
+ hidden_states: the sequence to the encoder layer (required).
331
+ residual: hidden_states = Mixer(LN(residual))
332
+ """
333
+ if not self.fused_add_norm:
334
+ residual = (hidden_states + residual) if residual is not None else hidden_states
335
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
336
+ if self.residual_in_fp32:
337
+ residual = residual.to(torch.float32)
338
+ else:
339
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
340
+ hidden_states, residual = fused_add_norm_fn(
341
+ hidden_states,
342
+ self.norm.weight,
343
+ self.norm.bias,
344
+ residual=residual,
345
+ prenorm=True,
346
+ residual_in_fp32=self.residual_in_fp32,
347
+ eps=self.norm.eps,
348
+ )
349
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
350
+ return hidden_states, residual
351
+
352
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
353
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
mamba_install/mamba_ssm/ops/__init__.py ADDED
File without changes
mamba_install/mamba_ssm/ops/selective_scan_interface.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.cuda.amp import custom_bwd, custom_fwd
6
+
7
+ from einops import rearrange, repeat
8
+
9
+ try:
10
+ from causal_conv1d import causal_conv1d_fn
11
+ import causal_conv1d_cuda
12
+ except ImportError:
13
+ causal_conv1d_fn = None
14
+ causal_conv1d_cuda = None
15
+
16
+ import selective_scan_cuda
17
+
18
+
19
+ class SelectiveScanFn(torch.autograd.Function):
20
+
21
+ @staticmethod
22
+ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
23
+ return_last_state=False):
24
+ if u.stride(-1) != 1:
25
+ u = u.contiguous()
26
+ if delta.stride(-1) != 1:
27
+ delta = delta.contiguous()
28
+ if D is not None:
29
+ D = D.contiguous()
30
+ if B.stride(-1) != 1:
31
+ B = B.contiguous()
32
+ if C.stride(-1) != 1:
33
+ C = C.contiguous()
34
+ if z is not None and z.stride(-1) != 1:
35
+ z = z.contiguous()
36
+ if B.dim() == 3:
37
+ B = rearrange(B, "b dstate l -> b 1 dstate l")
38
+ ctx.squeeze_B = True
39
+ if C.dim() == 3:
40
+ C = rearrange(C, "b dstate l -> b 1 dstate l")
41
+ ctx.squeeze_C = True
42
+ out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
43
+ ctx.delta_softplus = delta_softplus
44
+ ctx.has_z = z is not None
45
+ last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
46
+ if not ctx.has_z:
47
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
48
+ return out if not return_last_state else (out, last_state)
49
+ else:
50
+ ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
51
+ out_z = rest[0]
52
+ return out_z if not return_last_state else (out_z, last_state)
53
+
54
+ @staticmethod
55
+ def backward(ctx, dout, *args):
56
+ if not ctx.has_z:
57
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
58
+ z = None
59
+ out = None
60
+ else:
61
+ u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
62
+ if dout.stride(-1) != 1:
63
+ dout = dout.contiguous()
64
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
65
+ # backward of selective_scan_cuda with the backward of chunk).
66
+ # Here we just pass in None and dz will be allocated in the C++ code.
67
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
68
+ u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
69
+ False # option to recompute out_z, not used here
70
+ )
71
+ dz = rest[0] if ctx.has_z else None
72
+ dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
73
+ dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
74
+ return (du, ddelta, dA, dB, dC,
75
+ dD if D is not None else None,
76
+ dz,
77
+ ddelta_bias if delta_bias is not None else None,
78
+ None,
79
+ None)
80
+
81
+
82
+ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
83
+ return_last_state=False):
84
+ """if return_last_state is True, returns (out, last_state)
85
+ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
86
+ not considered in the backward pass.
87
+ """
88
+ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
89
+
90
+
91
+ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
92
+ return_last_state=False):
93
+ """
94
+ u: r(B D L)
95
+ delta: r(B D L)
96
+ A: c(D N) or r(D N)
97
+ B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
98
+ C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
99
+ D: r(D)
100
+ z: r(B D L)
101
+ delta_bias: r(D), fp32
102
+
103
+ out: r(B D L)
104
+ last_state (optional): r(B D dstate) or c(B D dstate)
105
+ """
106
+ dtype_in = u.dtype
107
+ u = u.float()
108
+ delta = delta.float()
109
+ if delta_bias is not None:
110
+ delta = delta + delta_bias[..., None].float()
111
+ if delta_softplus:
112
+ delta = F.softplus(delta)
113
+ batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
114
+ is_variable_B = B.dim() >= 3
115
+ is_variable_C = C.dim() >= 3
116
+ if A.is_complex():
117
+ if is_variable_B:
118
+ B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
119
+ if is_variable_C:
120
+ C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
121
+ else:
122
+ B = B.float()
123
+ C = C.float()
124
+ x = A.new_zeros((batch, dim, dstate))
125
+ ys = []
126
+ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
127
+ if not is_variable_B:
128
+ deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
129
+ else:
130
+ if B.dim() == 3:
131
+ deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
132
+ else:
133
+ B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
134
+ deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
135
+ if is_variable_C and C.dim() == 4:
136
+ C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
137
+ last_state = None
138
+ for i in range(u.shape[2]):
139
+ x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
140
+ if not is_variable_C:
141
+ y = torch.einsum('bdn,dn->bd', x, C)
142
+ else:
143
+ if C.dim() == 3:
144
+ y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
145
+ else:
146
+ y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
147
+ if i == u.shape[2] - 1:
148
+ last_state = x
149
+ if y.is_complex():
150
+ y = y.real * 2
151
+ ys.append(y)
152
+ y = torch.stack(ys, dim=2) # (batch dim L)
153
+ out = y if D is None else y + u * rearrange(D, "d -> d 1")
154
+ if z is not None:
155
+ out = out * F.silu(z)
156
+ out = out.to(dtype=dtype_in)
157
+ return out if not return_last_state else (out, last_state)
158
+
159
+
160
+ class MambaInnerFn(torch.autograd.Function):
161
+
162
+ @staticmethod
163
+ @custom_fwd
164
+ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
165
+ out_proj_weight, out_proj_bias,
166
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
167
+ C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
168
+ """
169
+ xz: (batch, dim, seqlen)
170
+ """
171
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
172
+ assert checkpoint_lvl in [0, 1]
173
+ L = xz.shape[-1]
174
+ delta_rank = delta_proj_weight.shape[1]
175
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
176
+ if torch.is_autocast_enabled():
177
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
178
+ delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
179
+ out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
180
+ out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
181
+ if out_proj_bias is not None else None)
182
+ if xz.stride(-1) != 1:
183
+ xz = xz.contiguous()
184
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
185
+ x, z = xz.chunk(2, dim=1)
186
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
187
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
188
+ x, conv1d_weight, conv1d_bias, None, None, None, True
189
+ )
190
+ # We're being very careful here about the layout, to avoid extra transposes.
191
+ # We want delta to have d as the slowest moving dimension
192
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
193
+ x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
194
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
195
+ ctx.is_variable_B = B is None
196
+ ctx.is_variable_C = C is None
197
+ ctx.B_proj_bias_is_None = B_proj_bias is None
198
+ ctx.C_proj_bias_is_None = C_proj_bias is None
199
+ if B is None: # variable B
200
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
201
+ if B_proj_bias is not None:
202
+ B = B + B_proj_bias.to(dtype=B.dtype)
203
+ if not A.is_complex():
204
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
205
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
206
+ else:
207
+ B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
208
+ else:
209
+ if B.stride(-1) != 1:
210
+ B = B.contiguous()
211
+ if C is None: # variable C
212
+ C = x_dbl[:, -d_state:] # (bl dstate)
213
+ if C_proj_bias is not None:
214
+ C = C + C_proj_bias.to(dtype=C.dtype)
215
+ if not A.is_complex():
216
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
217
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
218
+ else:
219
+ C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
220
+ else:
221
+ if C.stride(-1) != 1:
222
+ C = C.contiguous()
223
+ if D is not None:
224
+ D = D.contiguous()
225
+ out, scan_intermediates, out_z = selective_scan_cuda.fwd(
226
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
227
+ )
228
+ ctx.delta_softplus = delta_softplus
229
+ ctx.out_proj_bias_is_None = out_proj_bias is None
230
+ ctx.checkpoint_lvl = checkpoint_lvl
231
+ if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
232
+ conv1d_out, delta = None, None
233
+ ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
234
+ delta_proj_weight, out_proj_weight, conv1d_out, delta,
235
+ A, B, C, D, delta_bias, scan_intermediates, out)
236
+ return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
237
+
238
+ @staticmethod
239
+ @custom_bwd
240
+ def backward(ctx, dout):
241
+ # dout: (batch, seqlen, dim)
242
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
243
+ (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
244
+ conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
245
+ L = xz.shape[-1]
246
+ delta_rank = delta_proj_weight.shape[1]
247
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
248
+ x, z = xz.chunk(2, dim=1)
249
+ if dout.stride(-1) != 1:
250
+ dout = dout.contiguous()
251
+ if ctx.checkpoint_lvl == 1:
252
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
253
+ x, conv1d_weight, conv1d_bias, None, None, None, True
254
+ )
255
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
256
+ "d (b l) -> b d l", l = L)
257
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
258
+ # backward of selective_scan_cuda with the backward of chunk).
259
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
260
+ dx, dz = dxz.chunk(2, dim=1)
261
+ dout = rearrange(dout, "b l e -> e (b l)")
262
+ dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
263
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
264
+ conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
265
+ ctx.delta_softplus,
266
+ True # option to recompute out_z
267
+ )
268
+ dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
269
+ dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
270
+ dD = dD if D is not None else None
271
+ dx_dbl = torch.empty_like(x_dbl)
272
+ dB_proj_bias = None
273
+ if ctx.is_variable_B:
274
+ if not A.is_complex():
275
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
276
+ else:
277
+ dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
278
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
279
+ dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
280
+ dB = None
281
+ dC_proj_bias = None
282
+ if ctx.is_variable_C:
283
+ if not A.is_complex():
284
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
285
+ else:
286
+ dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
287
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
288
+ dx_dbl[:, -d_state:] = dC # (bl d)
289
+ dC = None
290
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
291
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
292
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
293
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
294
+ dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
295
+ dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
296
+ dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
297
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
298
+ # backward of conv1d with the backward of chunk).
299
+ dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
300
+ x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
301
+ )
302
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
303
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
304
+ return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
305
+ dout_proj_weight, dout_proj_bias,
306
+ dA, dB, dC, dD,
307
+ ddelta_bias if delta_bias is not None else None,
308
+ dB_proj_bias, dC_proj_bias, None)
309
+
310
+
311
+ def mamba_inner_fn(
312
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
313
+ out_proj_weight, out_proj_bias,
314
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
315
+ C_proj_bias=None, delta_softplus=True
316
+ ):
317
+ return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
318
+ out_proj_weight, out_proj_bias,
319
+ A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
320
+
321
+
322
+ def mamba_inner_ref(
323
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
324
+ out_proj_weight, out_proj_bias,
325
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
326
+ C_proj_bias=None, delta_softplus=True
327
+ ):
328
+ assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
329
+ L = xz.shape[-1]
330
+ delta_rank = delta_proj_weight.shape[1]
331
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
332
+ x, z = xz.chunk(2, dim=1)
333
+ x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
334
+ # We're being very careful here about the layout, to avoid extra transposes.
335
+ # We want delta to have d as the slowest moving dimension
336
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
337
+ x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
338
+ delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
339
+ delta = rearrange(delta, "d (b l) -> b d l", l=L)
340
+ if B is None: # variable B
341
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
342
+ if B_proj_bias is not None:
343
+ B = B + B_proj_bias.to(dtype=B.dtype)
344
+ if not A.is_complex():
345
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
346
+ else:
347
+ B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
348
+ if C is None: # variable B
349
+ C = x_dbl[:, -d_state:] # (bl d)
350
+ if C_proj_bias is not None:
351
+ C = C + C_proj_bias.to(dtype=C.dtype)
352
+ if not A.is_complex():
353
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
354
+ else:
355
+ C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
356
+ y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
357
+ return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
mamba_install/mamba_ssm/ops/triton/__init__.py ADDED
File without changes
mamba_install/mamba_ssm/ops/triton/layernorm.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ # Implement residual + layer_norm / rms_norm.
3
+
4
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
+
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch.cuda.amp import custom_fwd, custom_bwd
14
+
15
+ import triton
16
+ import triton.language as tl
17
+
18
+
19
+ def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
20
+ dtype = x.dtype
21
+ if upcast:
22
+ weight = weight.float()
23
+ bias = bias.float() if bias is not None else None
24
+ if upcast:
25
+ x = x.float()
26
+ residual = residual.float() if residual is not None else residual
27
+ if residual is not None:
28
+ x = (x + residual).to(x.dtype)
29
+ out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
30
+ dtype
31
+ )
32
+ return out if not prenorm else (out, x)
33
+
34
+
35
+ def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
36
+ dtype = x.dtype
37
+ if upcast:
38
+ weight = weight.float()
39
+ bias = bias.float() if bias is not None else None
40
+ if upcast:
41
+ x = x.float()
42
+ residual = residual.float() if residual is not None else residual
43
+ if residual is not None:
44
+ x = (x + residual).to(x.dtype)
45
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
46
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
47
+ out = out.to(dtype)
48
+ return out if not prenorm else (out, x)
49
+
50
+
51
+ @triton.autotune(
52
+ configs=[
53
+ triton.Config({}, num_warps=1),
54
+ triton.Config({}, num_warps=2),
55
+ triton.Config({}, num_warps=4),
56
+ triton.Config({}, num_warps=8),
57
+ triton.Config({}, num_warps=16),
58
+ triton.Config({}, num_warps=32),
59
+ ],
60
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
61
+ )
62
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
63
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
64
+ @triton.jit
65
+ def _layer_norm_fwd_1pass_kernel(
66
+ X, # pointer to the input
67
+ Y, # pointer to the output
68
+ W, # pointer to the weights
69
+ B, # pointer to the biases
70
+ RESIDUAL, # pointer to the residual
71
+ RESIDUAL_OUT, # pointer to the residual
72
+ Mean, # pointer to the mean
73
+ Rstd, # pointer to the 1/std
74
+ stride_x_row, # how much to increase the pointer when moving by 1 row
75
+ stride_y_row,
76
+ stride_res_row,
77
+ stride_res_out_row,
78
+ N, # number of columns in X
79
+ eps, # epsilon to avoid division by zero
80
+ IS_RMS_NORM: tl.constexpr,
81
+ BLOCK_N: tl.constexpr,
82
+ HAS_RESIDUAL: tl.constexpr,
83
+ STORE_RESIDUAL_OUT: tl.constexpr,
84
+ HAS_BIAS: tl.constexpr,
85
+ ):
86
+ # Map the program id to the row of X and Y it should compute.
87
+ row = tl.program_id(0)
88
+ X += row * stride_x_row
89
+ Y += row * stride_y_row
90
+ if HAS_RESIDUAL:
91
+ RESIDUAL += row * stride_res_row
92
+ if STORE_RESIDUAL_OUT:
93
+ RESIDUAL_OUT += row * stride_res_out_row
94
+ # Compute mean and variance
95
+ cols = tl.arange(0, BLOCK_N)
96
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
97
+ if HAS_RESIDUAL:
98
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
99
+ x += residual
100
+ if STORE_RESIDUAL_OUT:
101
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
102
+ if not IS_RMS_NORM:
103
+ mean = tl.sum(x, axis=0) / N
104
+ tl.store(Mean + row, mean)
105
+ xbar = tl.where(cols < N, x - mean, 0.0)
106
+ var = tl.sum(xbar * xbar, axis=0) / N
107
+ else:
108
+ xbar = tl.where(cols < N, x, 0.0)
109
+ var = tl.sum(xbar * xbar, axis=0) / N
110
+ rstd = 1 / tl.sqrt(var + eps)
111
+ tl.store(Rstd + row, rstd)
112
+ # Normalize and apply linear transformation
113
+ mask = cols < N
114
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
115
+ if HAS_BIAS:
116
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
117
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
118
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
119
+ # Write output
120
+ tl.store(Y + cols, y, mask=mask)
121
+
122
+
123
+ def _layer_norm_fwd(
124
+ x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
125
+ ):
126
+ if residual is not None:
127
+ residual_dtype = residual.dtype
128
+ M, N = x.shape
129
+ assert x.stride(-1) == 1
130
+ if residual is not None:
131
+ assert residual.stride(-1) == 1
132
+ assert residual.shape == (M, N)
133
+ assert weight.shape == (N,)
134
+ assert weight.stride(-1) == 1
135
+ if bias is not None:
136
+ assert bias.stride(-1) == 1
137
+ assert bias.shape == (N,)
138
+ # allocate output
139
+ y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
140
+ assert y.stride(-1) == 1
141
+ if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
142
+ residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
143
+ assert residual_out.stride(-1) == 1
144
+ else:
145
+ residual_out = None
146
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
147
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
148
+ # Less than 64KB per feature: enqueue fused kernel
149
+ MAX_FUSED_SIZE = 65536 // x.element_size()
150
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
151
+ if N > BLOCK_N:
152
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
153
+ # heuristics for number of warps
154
+ with torch.cuda.device(x.device.index):
155
+ _layer_norm_fwd_1pass_kernel[(M,)](
156
+ x,
157
+ y,
158
+ weight,
159
+ bias,
160
+ residual,
161
+ residual_out,
162
+ mean,
163
+ rstd,
164
+ x.stride(0),
165
+ y.stride(0),
166
+ residual.stride(0) if residual is not None else 0,
167
+ residual_out.stride(0) if residual_out is not None else 0,
168
+ N,
169
+ eps,
170
+ is_rms_norm,
171
+ BLOCK_N,
172
+ residual is not None,
173
+ residual_out is not None,
174
+ bias is not None,
175
+ )
176
+ # residual_out is None if residual is None and residual_dtype == input_dtype
177
+ return y, mean, rstd, residual_out if residual_out is not None else x
178
+
179
+
180
+ @triton.autotune(
181
+ configs=[
182
+ triton.Config({}, num_warps=1),
183
+ triton.Config({}, num_warps=2),
184
+ triton.Config({}, num_warps=4),
185
+ triton.Config({}, num_warps=8),
186
+ triton.Config({}, num_warps=16),
187
+ triton.Config({}, num_warps=32),
188
+ ],
189
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
190
+ )
191
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
192
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
193
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
194
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
195
+ @triton.jit
196
+ def _layer_norm_bwd_kernel(
197
+ X, # pointer to the input
198
+ W, # pointer to the weights
199
+ B, # pointer to the biases
200
+ Y, # pointer to the output to be recomputed
201
+ DY, # pointer to the output gradient
202
+ DX, # pointer to the input gradient
203
+ DW, # pointer to the partial sum of weights gradient
204
+ DB, # pointer to the partial sum of biases gradient
205
+ DRESIDUAL,
206
+ DRESIDUAL_IN,
207
+ Mean, # pointer to the mean
208
+ Rstd, # pointer to the 1/std
209
+ stride_x_row, # how much to increase the pointer when moving by 1 row
210
+ stride_y_row,
211
+ stride_dy_row,
212
+ stride_dx_row,
213
+ stride_dres_row,
214
+ stride_dres_in_row,
215
+ M, # number of rows in X
216
+ N, # number of columns in X
217
+ eps, # epsilon to avoid division by zero
218
+ rows_per_program,
219
+ IS_RMS_NORM: tl.constexpr,
220
+ BLOCK_N: tl.constexpr,
221
+ HAS_DRESIDUAL: tl.constexpr,
222
+ STORE_DRESIDUAL: tl.constexpr,
223
+ HAS_BIAS: tl.constexpr,
224
+ RECOMPUTE_OUTPUT: tl.constexpr,
225
+ ):
226
+ # Map the program id to the elements of X, DX, and DY it should compute.
227
+ row_block_id = tl.program_id(0)
228
+ row_start = row_block_id * rows_per_program
229
+ cols = tl.arange(0, BLOCK_N)
230
+ mask = cols < N
231
+ X += row_start * stride_x_row
232
+ if HAS_DRESIDUAL:
233
+ DRESIDUAL += row_start * stride_dres_row
234
+ if STORE_DRESIDUAL:
235
+ DRESIDUAL_IN += row_start * stride_dres_in_row
236
+ DY += row_start * stride_dy_row
237
+ DX += row_start * stride_dx_row
238
+ if RECOMPUTE_OUTPUT:
239
+ Y += row_start * stride_y_row
240
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
241
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
242
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
243
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
244
+ if HAS_BIAS:
245
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
246
+ row_end = min((row_block_id + 1) * rows_per_program, M)
247
+ for row in range(row_start, row_end):
248
+ # Load data to SRAM
249
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
250
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
251
+ if not IS_RMS_NORM:
252
+ mean = tl.load(Mean + row)
253
+ rstd = tl.load(Rstd + row)
254
+ # Compute dx
255
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
256
+ xhat = tl.where(mask, xhat, 0.0)
257
+ if RECOMPUTE_OUTPUT:
258
+ y = xhat * w + b if HAS_BIAS else xhat * w
259
+ tl.store(Y + cols, y, mask=mask)
260
+ wdy = w * dy
261
+ dw += dy * xhat
262
+ if HAS_BIAS:
263
+ db += dy
264
+ if not IS_RMS_NORM:
265
+ c1 = tl.sum(xhat * wdy, axis=0) / N
266
+ c2 = tl.sum(wdy, axis=0) / N
267
+ dx = (wdy - (xhat * c1 + c2)) * rstd
268
+ else:
269
+ c1 = tl.sum(xhat * wdy, axis=0) / N
270
+ dx = (wdy - xhat * c1) * rstd
271
+ if HAS_DRESIDUAL:
272
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
273
+ dx += dres
274
+ # Write dx
275
+ if STORE_DRESIDUAL:
276
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
277
+ tl.store(DX + cols, dx, mask=mask)
278
+
279
+ X += stride_x_row
280
+ if HAS_DRESIDUAL:
281
+ DRESIDUAL += stride_dres_row
282
+ if STORE_DRESIDUAL:
283
+ DRESIDUAL_IN += stride_dres_in_row
284
+ if RECOMPUTE_OUTPUT:
285
+ Y += stride_y_row
286
+ DY += stride_dy_row
287
+ DX += stride_dx_row
288
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
289
+ if HAS_BIAS:
290
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
291
+
292
+
293
+ def _layer_norm_bwd(
294
+ dy,
295
+ x,
296
+ weight,
297
+ bias,
298
+ eps,
299
+ mean,
300
+ rstd,
301
+ dresidual=None,
302
+ has_residual=False,
303
+ is_rms_norm=False,
304
+ x_dtype=None,
305
+ recompute_output=False,
306
+ ):
307
+ M, N = x.shape
308
+ assert x.stride(-1) == 1
309
+ assert dy.stride(-1) == 1
310
+ assert dy.shape == (M, N)
311
+ if dresidual is not None:
312
+ assert dresidual.stride(-1) == 1
313
+ assert dresidual.shape == (M, N)
314
+ assert weight.shape == (N,)
315
+ assert weight.stride(-1) == 1
316
+ if bias is not None:
317
+ assert bias.stride(-1) == 1
318
+ assert bias.shape == (N,)
319
+ # allocate output
320
+ dx = (
321
+ torch.empty_like(x)
322
+ if x_dtype is None
323
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
324
+ )
325
+ dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
326
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
327
+
328
+ # Less than 64KB per feature: enqueue fused kernel
329
+ MAX_FUSED_SIZE = 65536 // x.element_size()
330
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
331
+ if N > BLOCK_N:
332
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
333
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
334
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
335
+ _db = (
336
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
337
+ if bias is not None
338
+ else None
339
+ )
340
+ rows_per_program = math.ceil(M / sm_count)
341
+ grid = (sm_count,)
342
+ with torch.cuda.device(x.device.index):
343
+ _layer_norm_bwd_kernel[grid](
344
+ x,
345
+ weight,
346
+ bias,
347
+ y,
348
+ dy,
349
+ dx,
350
+ _dw,
351
+ _db,
352
+ dresidual,
353
+ dresidual_in,
354
+ mean,
355
+ rstd,
356
+ x.stride(0),
357
+ 0 if not recompute_output else y.stride(0),
358
+ dy.stride(0),
359
+ dx.stride(0),
360
+ dresidual.stride(0) if dresidual is not None else 0,
361
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
362
+ M,
363
+ N,
364
+ eps,
365
+ rows_per_program,
366
+ is_rms_norm,
367
+ BLOCK_N,
368
+ dresidual is not None,
369
+ dresidual_in is not None,
370
+ bias is not None,
371
+ )
372
+ dw = _dw.sum(0).to(weight.dtype)
373
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
374
+ # Don't need to compute dresidual_in separately in this case
375
+ if has_residual and dx.dtype == x.dtype:
376
+ dresidual_in = dx
377
+ return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
378
+
379
+
380
+ class LayerNormFn(torch.autograd.Function):
381
+ @staticmethod
382
+ def forward(
383
+ ctx,
384
+ x,
385
+ weight,
386
+ bias,
387
+ residual=None,
388
+ eps=1e-6,
389
+ prenorm=False,
390
+ residual_in_fp32=False,
391
+ is_rms_norm=False,
392
+ ):
393
+ x_shape_og = x.shape
394
+ # reshape input data into 2D tensor
395
+ x = x.reshape(-1, x.shape[-1])
396
+ if x.stride(-1) != 1:
397
+ x = x.contiguous()
398
+ if residual is not None:
399
+ assert residual.shape == x_shape_og
400
+ residual = residual.reshape(-1, residual.shape[-1])
401
+ if residual.stride(-1) != 1:
402
+ residual = residual.contiguous()
403
+ weight = weight.contiguous()
404
+ if bias is not None:
405
+ bias = bias.contiguous()
406
+ residual_dtype = (
407
+ residual.dtype
408
+ if residual is not None
409
+ else (torch.float32 if residual_in_fp32 else None)
410
+ )
411
+ y, mean, rstd, residual_out = _layer_norm_fwd(
412
+ x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
413
+ )
414
+ ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
415
+ ctx.x_shape_og = x_shape_og
416
+ ctx.eps = eps
417
+ ctx.is_rms_norm = is_rms_norm
418
+ ctx.has_residual = residual is not None
419
+ ctx.prenorm = prenorm
420
+ ctx.x_dtype = x.dtype
421
+ y = y.reshape(x_shape_og)
422
+ return y if not prenorm else (y, residual_out.reshape(x_shape_og))
423
+
424
+ @staticmethod
425
+ def backward(ctx, dy, *args):
426
+ x, weight, bias, mean, rstd = ctx.saved_tensors
427
+ dy = dy.reshape(-1, dy.shape[-1])
428
+ if dy.stride(-1) != 1:
429
+ dy = dy.contiguous()
430
+ assert dy.shape == x.shape
431
+ if ctx.prenorm:
432
+ dresidual = args[0]
433
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
434
+ if dresidual.stride(-1) != 1:
435
+ dresidual = dresidual.contiguous()
436
+ assert dresidual.shape == x.shape
437
+ else:
438
+ dresidual = None
439
+ dx, dw, db, dresidual_in = _layer_norm_bwd(
440
+ dy,
441
+ x,
442
+ weight,
443
+ bias,
444
+ ctx.eps,
445
+ mean,
446
+ rstd,
447
+ dresidual,
448
+ ctx.has_residual,
449
+ ctx.is_rms_norm,
450
+ x_dtype=ctx.x_dtype,
451
+ )
452
+ return (
453
+ dx.reshape(ctx.x_shape_og),
454
+ dw,
455
+ db,
456
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
457
+ None,
458
+ None,
459
+ None,
460
+ None,
461
+ )
462
+
463
+
464
+ def layer_norm_fn(
465
+ x,
466
+ weight,
467
+ bias,
468
+ residual=None,
469
+ eps=1e-6,
470
+ prenorm=False,
471
+ residual_in_fp32=False,
472
+ is_rms_norm=False,
473
+ ):
474
+ return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
475
+
476
+
477
+ def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
478
+ return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
479
+
480
+
481
+ class RMSNorm(torch.nn.Module):
482
+ def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
483
+ factory_kwargs = {"device": device, "dtype": dtype}
484
+ super().__init__()
485
+ self.eps = eps
486
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
487
+ self.register_parameter("bias", None)
488
+ self.reset_parameters()
489
+
490
+ def reset_parameters(self):
491
+ torch.nn.init.ones_(self.weight)
492
+
493
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
494
+ return rms_norm_fn(
495
+ x,
496
+ self.weight,
497
+ self.bias,
498
+ residual=residual,
499
+ eps=self.eps,
500
+ prenorm=prenorm,
501
+ residual_in_fp32=residual_in_fp32,
502
+ )
503
+
504
+
505
+ class LayerNormLinearFn(torch.autograd.Function):
506
+ @staticmethod
507
+ @custom_fwd
508
+ def forward(
509
+ ctx,
510
+ x,
511
+ norm_weight,
512
+ norm_bias,
513
+ linear_weight,
514
+ linear_bias,
515
+ residual=None,
516
+ eps=1e-6,
517
+ prenorm=False,
518
+ residual_in_fp32=False,
519
+ is_rms_norm=False,
520
+ ):
521
+ x_shape_og = x.shape
522
+ # reshape input data into 2D tensor
523
+ x = x.reshape(-1, x.shape[-1])
524
+ if x.stride(-1) != 1:
525
+ x = x.contiguous()
526
+ if residual is not None:
527
+ assert residual.shape == x_shape_og
528
+ residual = residual.reshape(-1, residual.shape[-1])
529
+ if residual.stride(-1) != 1:
530
+ residual = residual.contiguous()
531
+ norm_weight = norm_weight.contiguous()
532
+ if norm_bias is not None:
533
+ norm_bias = norm_bias.contiguous()
534
+ residual_dtype = (
535
+ residual.dtype
536
+ if residual is not None
537
+ else (torch.float32 if residual_in_fp32 else None)
538
+ )
539
+ y, mean, rstd, residual_out = _layer_norm_fwd(
540
+ x,
541
+ norm_weight,
542
+ norm_bias,
543
+ eps,
544
+ residual,
545
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
546
+ residual_dtype=residual_dtype,
547
+ is_rms_norm=is_rms_norm,
548
+ )
549
+ y = y.reshape(x_shape_og)
550
+ dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
551
+ linear_weight = linear_weight.to(dtype)
552
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
553
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
554
+ # We don't store y, will be recomputed in the backward pass to save memory
555
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
556
+ ctx.x_shape_og = x_shape_og
557
+ ctx.eps = eps
558
+ ctx.is_rms_norm = is_rms_norm
559
+ ctx.has_residual = residual is not None
560
+ ctx.prenorm = prenorm
561
+ ctx.x_dtype = x.dtype
562
+ ctx.linear_bias_is_none = linear_bias is None
563
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
564
+
565
+ @staticmethod
566
+ @custom_bwd
567
+ def backward(ctx, dout, *args):
568
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
569
+ dout = dout.reshape(-1, dout.shape[-1])
570
+ dy = F.linear(dout, linear_weight.t())
571
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
572
+ if dy.stride(-1) != 1:
573
+ dy = dy.contiguous()
574
+ assert dy.shape == x.shape
575
+ if ctx.prenorm:
576
+ dresidual = args[0]
577
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
578
+ if dresidual.stride(-1) != 1:
579
+ dresidual = dresidual.contiguous()
580
+ assert dresidual.shape == x.shape
581
+ else:
582
+ dresidual = None
583
+ dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
584
+ dy,
585
+ x,
586
+ norm_weight,
587
+ norm_bias,
588
+ ctx.eps,
589
+ mean,
590
+ rstd,
591
+ dresidual,
592
+ ctx.has_residual,
593
+ ctx.is_rms_norm,
594
+ x_dtype=ctx.x_dtype,
595
+ recompute_output=True,
596
+ )
597
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
598
+ return (
599
+ dx.reshape(ctx.x_shape_og),
600
+ dnorm_weight,
601
+ dnorm_bias,
602
+ dlinear_weight,
603
+ dlinear_bias,
604
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
605
+ None,
606
+ None,
607
+ None,
608
+ None,
609
+ )
610
+
611
+
612
+ def layer_norm_linear_fn(
613
+ x,
614
+ norm_weight,
615
+ norm_bias,
616
+ linear_weight,
617
+ linear_bias,
618
+ residual=None,
619
+ eps=1e-6,
620
+ prenorm=False,
621
+ residual_in_fp32=False,
622
+ is_rms_norm=False,
623
+ ):
624
+ return LayerNormLinearFn.apply(
625
+ x,
626
+ norm_weight,
627
+ norm_bias,
628
+ linear_weight,
629
+ linear_bias,
630
+ residual,
631
+ eps,
632
+ prenorm,
633
+ residual_in_fp32,
634
+ is_rms_norm,
635
+ )
mamba_install/mamba_ssm/ops/triton/selective_state_update.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+
16
+ @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
17
+ @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
18
+ @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
19
+ @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
20
+ @triton.jit
21
+ def _selective_scan_update_kernel(
22
+ # Pointers to matrices
23
+ state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
24
+ # Matrix dimensions
25
+ batch, nheads, dim, dstate, nheads_ngroups_ratio,
26
+ # Strides
27
+ stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,
28
+ stride_x_batch, stride_x_head, stride_x_dim,
29
+ stride_dt_batch, stride_dt_head, stride_dt_dim,
30
+ stride_dt_bias_head, stride_dt_bias_dim,
31
+ stride_A_head, stride_A_dim, stride_A_dstate,
32
+ stride_B_batch, stride_B_group, stride_B_dstate,
33
+ stride_C_batch, stride_C_group, stride_C_dstate,
34
+ stride_D_head, stride_D_dim,
35
+ stride_z_batch, stride_z_head, stride_z_dim,
36
+ stride_out_batch, stride_out_head, stride_out_dim,
37
+ # Meta-parameters
38
+ DT_SOFTPLUS: tl.constexpr,
39
+ TIE_HDIM: tl.constexpr,
40
+ BLOCK_SIZE_M: tl.constexpr,
41
+ HAS_DT_BIAS: tl.constexpr,
42
+ HAS_D: tl.constexpr,
43
+ HAS_Z: tl.constexpr,
44
+ BLOCK_SIZE_DSTATE: tl.constexpr,
45
+ ):
46
+ pid_m = tl.program_id(axis=0)
47
+ pid_b = tl.program_id(axis=1)
48
+ pid_h = tl.program_id(axis=2)
49
+ state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
50
+ x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
51
+ dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
52
+ if HAS_DT_BIAS:
53
+ dt_bias_ptr += pid_h * stride_dt_bias_head
54
+ A_ptr += pid_h * stride_A_head
55
+ B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
56
+ C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
57
+ if HAS_Z:
58
+ z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
59
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
60
+
61
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
62
+ offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
63
+ state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
64
+ x_ptrs = x_ptr + offs_m * stride_x_dim
65
+ dt_ptrs = dt_ptr + offs_m * stride_dt_dim
66
+ if HAS_DT_BIAS:
67
+ dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
68
+ if HAS_D:
69
+ D_ptr += pid_h * stride_D_head
70
+ A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
71
+ B_ptrs = B_ptr + offs_n * stride_B_dstate
72
+ C_ptrs = C_ptr + offs_n * stride_C_dstate
73
+ if HAS_D:
74
+ D_ptrs = D_ptr + offs_m * stride_D_dim
75
+ if HAS_Z:
76
+ z_ptrs = z_ptr + offs_m * stride_z_dim
77
+ out_ptrs = out_ptr + offs_m * stride_out_dim
78
+
79
+ state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
80
+ x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
81
+ if not TIE_HDIM:
82
+ dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
83
+ if HAS_DT_BIAS:
84
+ dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
85
+ if DT_SOFTPLUS:
86
+ dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
87
+ A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
88
+ dA = tl.exp(A * dt[:, None])
89
+ else:
90
+ dt = tl.load(dt_ptr).to(tl.float32)
91
+ if HAS_DT_BIAS:
92
+ dt += tl.load(dt_bias_ptr).to(tl.float32)
93
+ if DT_SOFTPLUS:
94
+ dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
95
+ A = tl.load(A_ptr).to(tl.float32)
96
+ dA = tl.exp(A * dt) # scalar, not a matrix
97
+
98
+ B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
99
+ C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
100
+ if HAS_D:
101
+ D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
102
+ if HAS_Z:
103
+ z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
104
+
105
+ if not TIE_HDIM:
106
+ dB = B[None, :] * dt[:, None]
107
+ else:
108
+ dB = B * dt # vector of size (dstate,)
109
+ state = state * dA + dB * x[:, None]
110
+ tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
111
+ out = tl.sum(state * C[None, :], axis=1)
112
+ if HAS_D:
113
+ out += x * D
114
+ if HAS_Z:
115
+ out *= z * tl.sigmoid(z)
116
+ tl.store(out_ptrs, out, mask=offs_m < dim)
117
+
118
+
119
+ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
120
+ """
121
+ Argument:
122
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
123
+ x: (batch, dim) or (batch, nheads, dim)
124
+ dt: (batch, dim) or (batch, nheads, dim)
125
+ A: (dim, dstate) or (nheads, dim, dstate)
126
+ B: (batch, dstate) or (batch, ngroups, dstate)
127
+ C: (batch, dstate) or (batch, ngroups, dstate)
128
+ D: (dim,) or (nheads, dim)
129
+ z: (batch, dim) or (batch, nheads, dim)
130
+ dt_bias: (dim,) or (nheads, dim)
131
+ Return:
132
+ out: (batch, dim) or (batch, nheads, dim)
133
+ """
134
+ has_heads = state.dim() > 3
135
+ if state.dim() == 3:
136
+ state = state.unsqueeze(1)
137
+ if x.dim() == 2:
138
+ x = x.unsqueeze(1)
139
+ if dt.dim() == 2:
140
+ dt = dt.unsqueeze(1)
141
+ if A.dim() == 2:
142
+ A = A.unsqueeze(0)
143
+ if B.dim() == 2:
144
+ B = B.unsqueeze(1)
145
+ if C.dim() == 2:
146
+ C = C.unsqueeze(1)
147
+ if D is not None and D.dim() == 1:
148
+ D = D.unsqueeze(0)
149
+ if z is not None and z.dim() == 2:
150
+ z = z.unsqueeze(1)
151
+ if dt_bias is not None and dt_bias.dim() == 1:
152
+ dt_bias = dt_bias.unsqueeze(0)
153
+ batch, nheads, dim, dstate = state.shape
154
+ assert x.shape == (batch, nheads, dim)
155
+ assert dt.shape == x.shape
156
+ assert A.shape == (nheads, dim, dstate)
157
+ ngroups = B.shape[1]
158
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
159
+ assert B.shape == (batch, ngroups, dstate)
160
+ assert C.shape == B.shape
161
+ if D is not None:
162
+ assert D.shape == (nheads, dim)
163
+ if z is not None:
164
+ assert z.shape == x.shape
165
+ if dt_bias is not None:
166
+ assert dt_bias.shape == (nheads, dim)
167
+ out = torch.empty_like(x)
168
+ grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
169
+ z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
170
+ # We don't want autotune since it will overwrite the state
171
+ # We instead tune by hand.
172
+ BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
173
+ else ((16, 4) if dstate <= 32 else
174
+ ((8, 4) if dstate <= 64 else
175
+ ((4, 4) if dstate <= 128 else
176
+ ((4, 8))))))
177
+ tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
178
+ with torch.cuda.device(x.device.index):
179
+ _selective_scan_update_kernel[grid](
180
+ state, x, dt, dt_bias, A, B, C, D, z, out,
181
+ batch, nheads, dim, dstate, nheads // ngroups,
182
+ state.stride(0), state.stride(1), state.stride(2), state.stride(3),
183
+ x.stride(0), x.stride(1), x.stride(2),
184
+ dt.stride(0), dt.stride(1), dt.stride(2),
185
+ *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
186
+ A.stride(0), A.stride(1), A.stride(2),
187
+ B.stride(0), B.stride(1), B.stride(2),
188
+ C.stride(0), C.stride(1), C.stride(2),
189
+ *(D.stride(0), D.stride(1)) if D is not None else 0,
190
+ z_strides[0], z_strides[1], z_strides[2],
191
+ out.stride(0), out.stride(1), out.stride(2),
192
+ dt_softplus,
193
+ tie_hdim,
194
+ BLOCK_SIZE_M,
195
+ num_warps=num_warps,
196
+ )
197
+ if not has_heads:
198
+ out = out.squeeze(1)
199
+ return out
200
+
201
+
202
+ def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
203
+ """
204
+ Argument:
205
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
206
+ x: (batch, dim) or (batch, nheads, dim)
207
+ dt: (batch, dim) or (batch, nheads, dim)
208
+ A: (dim, dstate) or (nheads, dim, dstate)
209
+ B: (batch, dstate) or (batch, ngroups, dstate)
210
+ C: (batch, dstate) or (batch, ngroups, dstate)
211
+ D: (dim,) or (nheads, dim)
212
+ z: (batch, dim) or (batch, nheads, dim)
213
+ dt_bias: (dim,) or (nheads, dim)
214
+ Return:
215
+ out: (batch, dim) or (batch, nheads, dim)
216
+ """
217
+ has_heads = state.dim() > 3
218
+ if state.dim() == 3:
219
+ state = state.unsqueeze(1)
220
+ if x.dim() == 2:
221
+ x = x.unsqueeze(1)
222
+ if dt.dim() == 2:
223
+ dt = dt.unsqueeze(1)
224
+ if A.dim() == 2:
225
+ A = A.unsqueeze(0)
226
+ if B.dim() == 2:
227
+ B = B.unsqueeze(1)
228
+ if C.dim() == 2:
229
+ C = C.unsqueeze(1)
230
+ if D is not None and D.dim() == 1:
231
+ D = D.unsqueeze(0)
232
+ if z is not None and z.dim() == 2:
233
+ z = z.unsqueeze(1)
234
+ if dt_bias is not None and dt_bias.dim() == 1:
235
+ dt_bias = dt_bias.unsqueeze(0)
236
+ batch, nheads, dim, dstate = state.shape
237
+ assert x.shape == (batch, nheads, dim)
238
+ assert dt.shape == x.shape
239
+ assert A.shape == (nheads, dim, dstate)
240
+ ngroups = B.shape[1]
241
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
242
+ assert B.shape == (batch, ngroups, dstate)
243
+ assert C.shape == B.shape
244
+ if D is not None:
245
+ assert D.shape == (nheads, dim)
246
+ if z is not None:
247
+ assert z.shape == x.shape
248
+ if dt_bias is not None:
249
+ assert dt_bias.shape == (nheads, dim)
250
+ dt = dt + dt_bias
251
+ dt = F.softplus(dt) if dt_softplus else dt
252
+ dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate)
253
+ B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
254
+ C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
255
+ dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
256
+ state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
257
+ out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
258
+ if D is not None:
259
+ out += (x * D).to(out.dtype)
260
+ out = (out if z is None else out * F.silu(z)).to(x.dtype)
261
+ if not has_heads:
262
+ out = out.squeeze(1)
263
+ return out
mamba_install/mamba_ssm/utils/__init__.py ADDED
File without changes
mamba_install/mamba_ssm/utils/generation.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+ import gc
3
+ import time
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Callable, Optional, Sequence, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+ from torch import Tensor
13
+ from torch.profiler import ProfilerActivity, profile, record_function
14
+ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
15
+
16
+
17
+ @dataclass
18
+ class InferenceParams:
19
+ """Inference parameters that are passed to the main model in order
20
+ to efficienly calculate and store the context during inference."""
21
+
22
+ max_seqlen: int
23
+ max_batch_size: int
24
+ seqlen_offset: int = 0
25
+ batch_size_offset: int = 0
26
+ key_value_memory_dict: dict = field(default_factory=dict)
27
+ lengths_per_sample: Optional[Tensor] = None
28
+
29
+ def reset(self, max_seqlen, max_batch_size):
30
+ self.max_seqlen = max_seqlen
31
+ self.max_batch_size = max_batch_size
32
+ self.seqlen_offset = 0
33
+ if self.lengths_per_sample is not None:
34
+ self.lengths_per_sample.zero_()
35
+
36
+
37
+ def modify_logits_for_min_p_filtering(logits, min_p):
38
+ """Set the logits for none min_p values to -inf. Done in-place."""
39
+ if min_p <= 0.0 or min_p >= 1.0:
40
+ return
41
+ indices_to_remove = logits < min_p
42
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
43
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
44
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
45
+ def modify_logits_for_top_k_filtering(logits, top_k):
46
+ """Set the logits for none top-k values to -inf. Done in-place."""
47
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
48
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
49
+
50
+
51
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
52
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
53
+ def modify_logits_for_top_p_filtering(logits, top_p):
54
+ """Set the logits for none top-p values to -inf. Done in-place."""
55
+ if top_p <= 0.0 or top_p >= 1.0:
56
+ return
57
+ # First sort and calculate cumulative sum of probabilities.
58
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
59
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
60
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
61
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
62
+ # scatter sorted tensors to original indexing
63
+ indices_to_remove = sorted_indices_to_remove.scatter(
64
+ 1, sorted_indices, sorted_indices_to_remove
65
+ )
66
+ logits.masked_fill_(indices_to_remove, float("-inf"))
67
+
68
+
69
+ def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
70
+ """Apply repetition penalty. See https://arxiv.org/abs/1909.05858
71
+ logits: (batch_size, vocab_size)
72
+ prev_output_tokens: (batch_size, seq_len)
73
+ """
74
+ if repetition_penalty == 1.0:
75
+ return logits
76
+ score = torch.gather(logits, 1, prev_output_tokens)
77
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
78
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
79
+ logits.scatter_(1, prev_output_tokens, score)
80
+ return logits
81
+
82
+
83
+ def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
84
+ """Sample from top-k logits.
85
+ Arguments:
86
+ logits: Tensor of shape (batch_size, vocab_size)
87
+ """
88
+ if top_k == 1: # Short-circuit for greedy decoding
89
+ return logits.argmax(dim=-1)
90
+ else:
91
+ if top_p > 0.0:
92
+ assert top_p <= 1.0, "top-p should be in (0, 1]."
93
+ if top_k > 0:
94
+ top_k = min(top_k, logits.size(-1)) # Safety check
95
+ logits_top, indices = torch.topk(logits, top_k, dim=-1)
96
+ if temperature != 1.0:
97
+ logits_top /= temperature
98
+ modify_logits_for_top_p_filtering(logits_top, top_p)
99
+ return indices[
100
+ torch.arange(indices.shape[0], device=indices.device),
101
+ torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
102
+ ]
103
+ else:
104
+ if min_p > 0.0:
105
+ logits_top = logits.clone()
106
+ max_prob = logits_top[..., 0].item()
107
+ min_prob = max_prob * min_p
108
+ modify_logits_for_min_p_filtering(logits_top, min_p)
109
+ if temperature != 1.0:
110
+ logits_top /= temperature
111
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
112
+ # Clone so that when we modify for top_p we don't change the original logits
113
+ logits_top = logits / temperature if temperature != 1.0 else logits.clone()
114
+ modify_logits_for_top_p_filtering(logits_top, top_p)
115
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
116
+ dim=-1
117
+ )
118
+
119
+
120
+ @torch.inference_mode()
121
+ def decode(
122
+ input_ids,
123
+ model,
124
+ max_length,
125
+ top_k=1,
126
+ top_p=0.0,
127
+ min_p=0.0,
128
+ temperature=1.0,
129
+ repetition_penalty=1.0,
130
+ eos_token_id=None,
131
+ teacher_outputs=None,
132
+ vocab_size=None,
133
+ cg=False,
134
+ enable_timing=False,
135
+ streamer: Optional[TextStreamer] = None
136
+ ):
137
+ """Decoding, either greedy or with top-k or top-p sampling.
138
+ If top-k = 0, don't limit the number of candidates (pure sampling).
139
+ Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
140
+ then top-p.
141
+ We assume that all sequences in the same batch have the same length.
142
+
143
+ Arguments:
144
+ input_ids: (batch, seq_len)
145
+ max_length: int
146
+ teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
147
+ logits, the next token is taken from the teacher_outputs. Useful for testing.
148
+ Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
149
+ sequences: (batch, max_length)
150
+ scores: tuples of (batch, vocab_size)
151
+ """
152
+ if streamer is not None:
153
+ streamer.put(input_ids.cpu())
154
+
155
+ batch_size, seqlen_og = input_ids.shape
156
+ teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
157
+ if cg:
158
+ if not hasattr(model, "_decoding_cache"):
159
+ model._decoding_cache = None
160
+ model._decoding_cache = update_graph_cache(
161
+ model,
162
+ model._decoding_cache,
163
+ batch_size,
164
+ seqlen_og,
165
+ max_length,
166
+ )
167
+ inference_params = model._decoding_cache.inference_params
168
+ inference_params.reset(max_length, batch_size)
169
+ else:
170
+ inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
171
+
172
+ def get_logits(input_ids, inference_params):
173
+ decoding = inference_params.seqlen_offset > 0
174
+ if decoding:
175
+ position_ids = torch.full(
176
+ (batch_size, 1),
177
+ inference_params.seqlen_offset,
178
+ dtype=torch.long,
179
+ device=input_ids.device,
180
+ )
181
+ else:
182
+ position_ids = None
183
+ if not cg or not decoding:
184
+ logits = model(
185
+ input_ids,
186
+ position_ids=position_ids,
187
+ inference_params=inference_params,
188
+ num_last_tokens=1,
189
+ ).logits.squeeze(dim=1)
190
+ else:
191
+ logits = model._decoding_cache.run(
192
+ input_ids, position_ids, inference_params.seqlen_offset
193
+ ).squeeze(dim=1)
194
+ return logits[..., :vocab_size] if vocab_size is not None else logits
195
+
196
+ def sample_tokens(logits, inference_params):
197
+ if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
198
+ token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
199
+ else:
200
+ token = teacher_outputs[:, inference_params.seqlen_offset]
201
+ # return rearrange(token, "b -> b 1")
202
+ return token.unsqueeze(1)
203
+
204
+ def should_stop(current_token, inference_params):
205
+ if inference_params.seqlen_offset == 0:
206
+ return False
207
+ if eos_token_id is not None and (current_token == eos_token_id).all():
208
+ return True
209
+ if inference_params.seqlen_offset >= max_length - 1:
210
+ return True
211
+ return False
212
+
213
+ start = torch.cuda.Event(enable_timing=enable_timing)
214
+ end = torch.cuda.Event(enable_timing=enable_timing)
215
+
216
+ if enable_timing:
217
+ start.record()
218
+ scores, sequences = [], [input_ids]
219
+ sequences_cat = input_ids
220
+ while not should_stop(sequences[-1], inference_params):
221
+ scores.append(get_logits(sequences[-1], inference_params))
222
+ inference_params.seqlen_offset += sequences[-1].shape[1]
223
+ if repetition_penalty == 1.0:
224
+ sampled_tokens = sample_tokens(scores[-1], inference_params)
225
+ else:
226
+ logits = modify_logit_for_repetition_penalty(
227
+ scores[-1].clone(), sequences_cat, repetition_penalty
228
+ )
229
+ sampled_tokens = sample_tokens(logits, inference_params)
230
+ sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
231
+ sequences.append(sampled_tokens)
232
+ if streamer is not None:
233
+ streamer.put(sampled_tokens.cpu())
234
+ if streamer is not None:
235
+ streamer.end()
236
+ if enable_timing:
237
+ end.record()
238
+ torch.cuda.synchronize()
239
+ print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
240
+ output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
241
+ return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
242
+
243
+
244
+ class GenerationMixin:
245
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
246
+ raise NotImplementedError
247
+
248
+ def generate(
249
+ self,
250
+ input_ids,
251
+ max_length,
252
+ top_k=1,
253
+ top_p=0.0,
254
+ min_p=0.0,
255
+ temperature=1.0,
256
+ return_dict_in_generate=False,
257
+ output_scores=False,
258
+ **kwargs,
259
+ ):
260
+ output = decode(
261
+ input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
262
+ )
263
+ if not output_scores:
264
+ output.scores = None
265
+ return output if return_dict_in_generate else output.sequences
266
+
267
+
268
+ @dataclass
269
+ class DecodingCGCache:
270
+ max_batch_size: int = 0
271
+ max_seqlen: int = 0
272
+ device = None
273
+ dtype = None
274
+ callables: dict = field(default_factory=dict)
275
+ mempool = None
276
+ inference_params: Optional[InferenceParams] = None
277
+ run: Optional[Callable] = None
278
+
279
+
280
+ @torch.inference_mode()
281
+ def update_graph_cache(
282
+ model,
283
+ cache,
284
+ batch_size,
285
+ seqlen_og,
286
+ max_seqlen,
287
+ decoding_seqlens=(1,),
288
+ dtype=None,
289
+ n_warmups=2,
290
+ ):
291
+ if cache is None:
292
+ cache = DecodingCGCache()
293
+ param_example = next(iter(model.parameters()))
294
+ device = param_example.device
295
+ if dtype is None:
296
+ dtype = param_example.dtype
297
+ if (
298
+ (device, dtype) != (cache.device, cache.dtype)
299
+ or batch_size > cache.max_batch_size
300
+ or max_seqlen > cache.max_seqlen
301
+ ): # Invalidate the cache
302
+ cache.callables = {}
303
+ cache.mempool = None
304
+ cache.inference_params = None
305
+ gc.collect()
306
+ cache.device, cache.dtype = device, dtype
307
+ cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
308
+ assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
309
+ inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
310
+ lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
311
+ cache.inference_params = InferenceParams(
312
+ max_seqlen=max_seqlen,
313
+ max_batch_size=batch_size,
314
+ seqlen_offset=seqlen_og,
315
+ key_value_memory_dict=inf_cache,
316
+ lengths_per_sample=lengths_per_sample,
317
+ )
318
+ cache.mempool = torch.cuda.graphs.graph_pool_handle()
319
+ for decoding_seqlen in decoding_seqlens:
320
+ if (batch_size, decoding_seqlen) not in cache.callables:
321
+ cache.callables[batch_size, decoding_seqlen] = capture_graph(
322
+ model,
323
+ cache.inference_params,
324
+ batch_size,
325
+ max_seqlen,
326
+ decoding_seqlen=decoding_seqlen,
327
+ mempool=cache.mempool,
328
+ n_warmups=n_warmups,
329
+ )
330
+
331
+ def dispatch(input_ids, position_ids, seqlen):
332
+ batch_size, decoding_seqlen = input_ids.shape[:2]
333
+ return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
334
+
335
+ cache.run = dispatch
336
+ cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
337
+ return cache
338
+
339
+
340
+ def capture_graph(
341
+ model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
342
+ ):
343
+ device = next(iter(model.parameters())).device
344
+ input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
345
+ position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
346
+ seqlen_offset_og = inference_params.seqlen_offset
347
+ inference_params.seqlen_offset = max_seqlen - decoding_seqlen
348
+ inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
349
+
350
+ # Warmup before capture
351
+ s = torch.cuda.Stream()
352
+ s.wait_stream(torch.cuda.current_stream())
353
+ with torch.cuda.stream(s):
354
+ for _ in range(n_warmups):
355
+ logits = model(
356
+ input_ids,
357
+ position_ids=position_ids,
358
+ inference_params=inference_params,
359
+ num_last_tokens=decoding_seqlen,
360
+ ).logits
361
+ s.synchronize()
362
+ # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
363
+ # which requires that graph launch and non-captured launch to not overlap (I think,
364
+ # that's how I interpret the documentation). I'm not sure if this is required.
365
+ if torch.distributed.is_initialized():
366
+ torch.distributed.barrier()
367
+ torch.cuda.current_stream().wait_stream(s)
368
+ # Captures the graph
369
+ # To allow capture, automatically sets a side stream as the current stream in the context
370
+ graph = torch.cuda.CUDAGraph()
371
+ with torch.cuda.graph(graph, pool=mempool):
372
+ logits = model(
373
+ input_ids,
374
+ position_ids=position_ids,
375
+ inference_params=inference_params,
376
+ num_last_tokens=decoding_seqlen,
377
+ ).logits
378
+
379
+ def run(new_input_ids, new_position_ids, seqlen):
380
+ inference_params.lengths_per_sample[:] = seqlen
381
+ input_ids.copy_(new_input_ids)
382
+ position_ids.copy_(new_position_ids)
383
+ graph.replay()
384
+ return logits.clone()
385
+
386
+ inference_params.seqlen_offset = seqlen_offset_og
387
+ return run
mamba_install/mamba_ssm/utils/hf.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import torch
4
+
5
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
6
+ from transformers.utils.hub import cached_file
7
+
8
+
9
+ def load_config_hf(model_name):
10
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
11
+ return json.load(open(resolved_archive_file))
12
+
13
+
14
+ def load_state_dict_hf(model_name, device=None, dtype=None):
15
+ # If not fp32, then we don't want to load directly to the GPU
16
+ mapped_device = "cpu" if dtype not in [torch.float32, None] else device
17
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
18
+ return torch.load(resolved_archive_file, map_location=mapped_device)
19
+ # Convert dtype before moving to GPU to save memory
20
+ if dtype is not None:
21
+ state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
22
+ state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
23
+ return state_dict
mamba_install/setup.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+ import sys
3
+ import warnings
4
+ import os
5
+ import re
6
+ import ast
7
+ from pathlib import Path
8
+ from packaging.version import parse, Version
9
+ import platform
10
+ import shutil
11
+
12
+ from setuptools import setup, find_packages
13
+ import subprocess
14
+
15
+ import urllib.request
16
+ import urllib.error
17
+ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
18
+
19
+ import torch
20
+ from torch.utils.cpp_extension import (
21
+ BuildExtension,
22
+ CppExtension,
23
+ CUDAExtension,
24
+ CUDA_HOME,
25
+ )
26
+
27
+
28
+ with open("README.md", "r", encoding="utf-8") as fh:
29
+ long_description = fh.read()
30
+
31
+
32
+ # ninja build does not work unless include_dirs are abs path
33
+ this_dir = os.path.dirname(os.path.abspath(__file__))
34
+
35
+ PACKAGE_NAME = "mamba_ssm"
36
+
37
+ BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}"
38
+
39
+ # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
40
+ # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
41
+ FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
42
+ SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
43
+ # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
44
+ FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
45
+
46
+
47
+ def get_platform():
48
+ """
49
+ Returns the platform name as used in wheel filenames.
50
+ """
51
+ if sys.platform.startswith("linux"):
52
+ return "linux_x86_64"
53
+ elif sys.platform == "darwin":
54
+ mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
55
+ return f"macosx_{mac_version}_x86_64"
56
+ elif sys.platform == "win32":
57
+ return "win_amd64"
58
+ else:
59
+ raise ValueError("Unsupported platform: {}".format(sys.platform))
60
+
61
+
62
+ def get_cuda_bare_metal_version(cuda_dir):
63
+ raw_output = subprocess.check_output(
64
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
65
+ )
66
+ output = raw_output.split()
67
+ release_idx = output.index("release") + 1
68
+ bare_metal_version = parse(output[release_idx].split(",")[0])
69
+
70
+ return raw_output, bare_metal_version
71
+
72
+
73
+ def check_if_cuda_home_none(global_option: str) -> None:
74
+ if CUDA_HOME is not None:
75
+ return
76
+ # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
77
+ # in that case.
78
+ warnings.warn(
79
+ f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
80
+ "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
81
+ "only images whose names contain 'devel' will provide nvcc."
82
+ )
83
+
84
+
85
+ def append_nvcc_threads(nvcc_extra_args):
86
+ return nvcc_extra_args + ["--threads", "4"]
87
+
88
+
89
+ cmdclass = {}
90
+ ext_modules = []
91
+
92
+ if not SKIP_CUDA_BUILD:
93
+ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
94
+ TORCH_MAJOR = int(torch.__version__.split(".")[0])
95
+ TORCH_MINOR = int(torch.__version__.split(".")[1])
96
+
97
+ check_if_cuda_home_none(PACKAGE_NAME)
98
+ # Check, if CUDA11 is installed for compute capability 8.0
99
+ cc_flag = []
100
+ if CUDA_HOME is not None:
101
+ _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
102
+ if bare_metal_version < Version("11.6"):
103
+ raise RuntimeError(
104
+ f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
105
+ "Note: make sure nvcc has a supported version by running nvcc -V."
106
+ )
107
+
108
+ cc_flag.append("-gencode")
109
+ cc_flag.append("arch=compute_53,code=sm_53")
110
+ cc_flag.append("-gencode")
111
+ cc_flag.append("arch=compute_62,code=sm_62")
112
+ cc_flag.append("-gencode")
113
+ cc_flag.append("arch=compute_70,code=sm_70")
114
+ cc_flag.append("-gencode")
115
+ cc_flag.append("arch=compute_72,code=sm_72")
116
+ cc_flag.append("-gencode")
117
+ cc_flag.append("arch=compute_80,code=sm_80")
118
+ cc_flag.append("-gencode")
119
+ cc_flag.append("arch=compute_87,code=sm_87")
120
+ if bare_metal_version >= Version("11.8"):
121
+ cc_flag.append("-gencode")
122
+ cc_flag.append("arch=compute_90,code=sm_90")
123
+
124
+ # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
125
+ # torch._C._GLIBCXX_USE_CXX11_ABI
126
+ # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
127
+ if FORCE_CXX11_ABI:
128
+ torch._C._GLIBCXX_USE_CXX11_ABI = True
129
+
130
+ ext_modules.append(
131
+ CUDAExtension(
132
+ name="selective_scan_cuda",
133
+ sources=[
134
+ "csrc/selective_scan/selective_scan.cpp",
135
+ "csrc/selective_scan/selective_scan_fwd_fp32.cu",
136
+ "csrc/selective_scan/selective_scan_fwd_fp16.cu",
137
+ "csrc/selective_scan/selective_scan_fwd_bf16.cu",
138
+ "csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
139
+ "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
140
+ "csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
141
+ "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
142
+ "csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
143
+ "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
144
+ ],
145
+ extra_compile_args={
146
+ "cxx": ["-O3", "-std=c++17"],
147
+ "nvcc": append_nvcc_threads(
148
+ [
149
+ "-O3",
150
+ "-std=c++17",
151
+ "-U__CUDA_NO_HALF_OPERATORS__",
152
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
153
+ "-U__CUDA_NO_BFLOAT16_OPERATORS__",
154
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
155
+ "-U__CUDA_NO_BFLOAT162_OPERATORS__",
156
+ "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
157
+ "--expt-relaxed-constexpr",
158
+ "--expt-extended-lambda",
159
+ "--use_fast_math",
160
+ "--ptxas-options=-v",
161
+ "-lineinfo",
162
+ ]
163
+ + cc_flag
164
+ ),
165
+ },
166
+ include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
167
+ )
168
+ )
169
+
170
+
171
+ def get_package_version():
172
+ with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f:
173
+ version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
174
+ public_version = ast.literal_eval(version_match.group(1))
175
+ local_version = os.environ.get("MAMBA_LOCAL_VERSION")
176
+ if local_version:
177
+ return f"{public_version}+{local_version}"
178
+ else:
179
+ return str(public_version)
180
+
181
+
182
+ def get_wheel_url():
183
+ # Determine the version numbers that will be used to determine the correct wheel
184
+ # We're using the CUDA version used to build torch, not the one currently installed
185
+ # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
186
+ torch_cuda_version = parse(torch.version.cuda)
187
+ torch_version_raw = parse(torch.__version__)
188
+ # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
189
+ # to save CI time. Minor versions should be compatible.
190
+ torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
191
+ python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
192
+ platform_name = get_platform()
193
+ mamba_ssm_version = get_package_version()
194
+ # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
195
+ cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
196
+ torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
197
+ cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
198
+
199
+ # Determine wheel URL based on CUDA version, torch version, python version and OS
200
+ wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
201
+ wheel_url = BASE_WHEEL_URL.format(
202
+ tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename
203
+ )
204
+ return wheel_url, wheel_filename
205
+
206
+
207
+ class CachedWheelsCommand(_bdist_wheel):
208
+ """
209
+ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
210
+ find an existing wheel (which is currently the case for all installs). We use
211
+ the environment parameters to detect whether there is already a pre-built version of a compatible
212
+ wheel available and short-circuits the standard full build pipeline.
213
+ """
214
+
215
+ def run(self):
216
+ if FORCE_BUILD:
217
+ return super().run()
218
+
219
+ wheel_url, wheel_filename = get_wheel_url()
220
+ print("Guessing wheel URL: ", wheel_url)
221
+ try:
222
+ urllib.request.urlretrieve(wheel_url, wheel_filename)
223
+
224
+ # Make the archive
225
+ # Lifted from the root wheel processing command
226
+ # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
227
+ if not os.path.exists(self.dist_dir):
228
+ os.makedirs(self.dist_dir)
229
+
230
+ impl_tag, abi_tag, plat_tag = self.get_tag()
231
+ archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
232
+
233
+ wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
234
+ print("Raw wheel path", wheel_path)
235
+ shutil.move(wheel_filename, wheel_path)
236
+ except urllib.error.HTTPError:
237
+ print("Precompiled wheel not found. Building from source...")
238
+ # If the wheel could not be downloaded, build from source
239
+ super().run()
240
+
241
+
242
+ setup(
243
+ name=PACKAGE_NAME,
244
+ version=get_package_version(),
245
+ packages=find_packages(
246
+ exclude=(
247
+ "build",
248
+ "csrc",
249
+ "include",
250
+ "tests",
251
+ "dist",
252
+ "docs",
253
+ "benchmarks",
254
+ "mamba_ssm.egg-info",
255
+ )
256
+ ),
257
+ author="Tri Dao, Albert Gu",
258
+ author_email="tri@tridao.me, agu@cs.cmu.edu",
259
+ description="Mamba state-space model",
260
+ long_description=long_description,
261
+ long_description_content_type="text/markdown",
262
+ url="https://github.com/state-spaces/mamba",
263
+ classifiers=[
264
+ "Programming Language :: Python :: 3",
265
+ "License :: OSI Approved :: BSD License",
266
+ "Operating System :: Unix",
267
+ ],
268
+ ext_modules=ext_modules,
269
+ cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
270
+ if ext_modules
271
+ else {
272
+ "bdist_wheel": CachedWheelsCommand,
273
+ },
274
+ python_requires=">=3.7",
275
+ install_requires=[
276
+ "torch",
277
+ "packaging",
278
+ "ninja",
279
+ "einops",
280
+ "triton",
281
+ "transformers",
282
+ # "causal_conv1d>=1.2.0",
283
+ ],
284
+ )
mamba_install/tests/ops/test_selective_scan.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2023, Tri Dao.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import pytest
8
+
9
+ from einops import rearrange
10
+
11
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
12
+ from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref
13
+
14
+
15
+ # @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
16
+ @pytest.mark.parametrize('wtype', [torch.float32])
17
+ # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
18
+ @pytest.mark.parametrize('itype', [torch.float32])
19
+ # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
20
+ @pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
21
+ # @pytest.mark.parametrize('seqlen', [128])
22
+ # @pytest.mark.parametrize("return_last_state", [False, True])
23
+ @pytest.mark.parametrize("return_last_state", [True])
24
+ # @pytest.mark.parametrize('has_delta_bias', [False, True])
25
+ @pytest.mark.parametrize('has_delta_bias', [True])
26
+ # @pytest.mark.parametrize('delta_softplus', [False, True])
27
+ @pytest.mark.parametrize('delta_softplus', [True])
28
+ # @pytest.mark.parametrize('has_z', [False, True])
29
+ @pytest.mark.parametrize('has_z', [True])
30
+ # @pytest.mark.parametrize('has_D', [False, True])
31
+ @pytest.mark.parametrize('has_D', [True])
32
+ @pytest.mark.parametrize("varBC_groups", [1, 2])
33
+ # @pytest.mark.parametrize("varBC_groups", [1])
34
+ # @pytest.mark.parametrize("is_variable_C", [False, True])
35
+ @pytest.mark.parametrize("is_variable_C", [True])
36
+ # @pytest.mark.parametrize("is_variable_B", [False, True])
37
+ @pytest.mark.parametrize("is_variable_B", [True])
38
+ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias,
39
+ delta_softplus, return_last_state, seqlen, itype, wtype):
40
+ if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
41
+ pytest.skip() # This config is not applicable
42
+ device = 'cuda'
43
+ rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
44
+ if itype == torch.bfloat16:
45
+ rtol, atol = 3e-2, 5e-2
46
+ rtolw, atolw = (1e-3, 1e-3)
47
+ if has_z: # If we have z, the errors on the weights seem higher
48
+ rtolw = max(rtolw, rtol)
49
+ atolw = max(atolw, atol)
50
+ # set seed
51
+ torch.random.manual_seed(0)
52
+ batch_size = 2
53
+ dim = 4
54
+ dstate = 8
55
+ is_complex = wtype == torch.complex64
56
+ A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
57
+ if not is_variable_B:
58
+ B_shape = (dim, dstate)
59
+ elif varBC_groups == 1:
60
+ B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
61
+ else:
62
+ B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
63
+ B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype,
64
+ requires_grad=True)
65
+ if not is_variable_C:
66
+ C_shape = (dim, dstate)
67
+ elif varBC_groups == 1:
68
+ C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
69
+ else:
70
+ C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
71
+ C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype,
72
+ requires_grad=True)
73
+ if has_D:
74
+ D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
75
+ else:
76
+ D = None
77
+ if has_z:
78
+ z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
79
+ else:
80
+ z = None
81
+ if has_delta_bias:
82
+ delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
83
+ else:
84
+ delta_bias = None
85
+ u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
86
+ delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_()
87
+ A_ref = A.detach().clone().requires_grad_()
88
+ B_ref = B.detach().clone().requires_grad_()
89
+ C_ref = C.detach().clone().requires_grad_()
90
+ D_ref = D.detach().clone().requires_grad_() if D is not None else None
91
+ z_ref = z.detach().clone().requires_grad_() if z is not None else None
92
+ u_ref = u.detach().clone().requires_grad_()
93
+ delta_ref = delta.detach().clone().requires_grad_()
94
+ delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
95
+ out, *rest = selective_scan_fn(
96
+ u, delta, A, B, C, D, z=z,
97
+ delta_bias=delta_bias, delta_softplus=delta_softplus,
98
+ return_last_state=return_last_state
99
+ )
100
+ if return_last_state:
101
+ state = rest[0]
102
+ out_ref, *rest = selective_scan_ref(
103
+ u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref,
104
+ delta_bias=delta_bias_ref, delta_softplus=delta_softplus,
105
+ return_last_state=return_last_state
106
+ )
107
+ if return_last_state:
108
+ state_ref = rest[0]
109
+ # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
110
+ # dt_u = delta * u
111
+
112
+ print(f'Output max diff: {(out - out_ref).abs().max().item()}')
113
+ print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
114
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
115
+ if return_last_state:
116
+ print(f'State max diff: {(state - state_ref).abs().max().item()}')
117
+ assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
118
+
119
+ g = torch.randn_like(out)
120
+ out_ref.backward(g)
121
+ out.backward(g)
122
+
123
+ print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}')
124
+ print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}')
125
+ print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
126
+ print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
127
+ print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
128
+ if has_D:
129
+ print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
130
+ if has_z:
131
+ print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}')
132
+ if has_delta_bias:
133
+ print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
134
+
135
+ assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
136
+ assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
137
+ assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
138
+ assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
139
+ atol=atolw if not is_variable_B else atol)
140
+ assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
141
+ atol=atolw if not is_variable_C else atol)
142
+ if has_D:
143
+ assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
144
+ if has_z:
145
+ assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw)
146
+ if has_delta_bias:
147
+ assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
148
+
149
+
150
+ @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
151
+ # @pytest.mark.parametrize('wtype', [torch.complex64])
152
+ # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
153
+ @pytest.mark.parametrize('itype', [torch.float32])
154
+ # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
155
+ @pytest.mark.parametrize('seqlen', [128])
156
+ @pytest.mark.parametrize("is_variable_C", [False, True])
157
+ # @pytest.mark.parametrize("is_variable_C", [False])
158
+ @pytest.mark.parametrize("is_variable_B", [False, True])
159
+ # @pytest.mark.parametrize("is_variable_B", [True])
160
+ def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):
161
+ device = 'cuda'
162
+ rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
163
+ if itype == torch.bfloat16:
164
+ rtol, atol = 3e-2, 5e-2
165
+ rtolw, atolw = (1e-3, 1e-3)
166
+ # If we have z, the errors on the weights seem higher
167
+ rtolw = max(rtolw, rtol)
168
+ atolw = max(atolw, atol)
169
+ # set seed
170
+ torch.random.manual_seed(0)
171
+ batch_size = 2
172
+ dim = 768
173
+ dstate = 8
174
+ dt_rank = 48
175
+ is_complex = wtype == torch.complex64
176
+ xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)
177
+ conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)
178
+ conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
179
+ x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate
180
+ * (1 if not is_complex else 2),
181
+ dim, device=device, dtype=itype, requires_grad=True)
182
+ delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)
183
+ out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)
184
+ out_proj_bias = None
185
+ A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
186
+ B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
187
+ if not is_variable_B else None)
188
+ C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
189
+ if not is_variable_C else None)
190
+ D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
191
+ delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
192
+ B_proj_bias = None
193
+ C_proj_bias = None
194
+ xz_ref = xz.detach().clone().requires_grad_()
195
+ conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()
196
+ conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()
197
+ x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()
198
+ delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()
199
+ out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()
200
+ out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()
201
+ if out_proj_bias is not None else None)
202
+ A_ref = A.detach().clone().requires_grad_()
203
+ B_ref = B.detach().clone().requires_grad_() if B is not None else None
204
+ C_ref = C.detach().clone().requires_grad_() if C is not None else None
205
+ D_ref = D.detach().clone().requires_grad_()
206
+ delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
207
+ out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
208
+ out_proj_weight, out_proj_bias,
209
+ A, B, C, D, delta_bias=delta_bias, delta_softplus=True)
210
+ out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref,
211
+ delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref,
212
+ A_ref, B_ref, C_ref, D_ref,
213
+ delta_bias=delta_bias_ref, delta_softplus=True)
214
+ # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
215
+ # dt_u = delta * u
216
+
217
+ print(f'Output max diff: {(out - out_ref).abs().max().item()}')
218
+ print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
219
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
220
+
221
+ g = torch.randn_like(out)
222
+ out_ref.backward(g)
223
+ out.backward(g)
224
+
225
+ print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}')
226
+ print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
227
+ if not is_variable_B:
228
+ print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
229
+ if not is_variable_C:
230
+ print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
231
+ print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
232
+ print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
233
+ print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}')
234
+ print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}')
235
+ print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}')
236
+ print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}')
237
+ print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}')
238
+
239
+ # assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
240
+ # assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
241
+ # assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
242
+ # assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
243
+ # atol=atolw if not is_variable_B else atol)
244
+ # assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
245
+ # atol=atolw if not is_variable_C else atol)
246
+ # assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
247
+ # assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
mamba_install/tests/ops/triton/test_selective_state_update.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2023, Tri Dao.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import pytest
8
+
9
+ from einops import rearrange
10
+
11
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref
12
+
13
+
14
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
15
+ # @pytest.mark.parametrize('itype', [torch.float16])
16
+ @pytest.mark.parametrize("has_z", [False, True])
17
+ # @pytest.mark.parametrize('has_z', [True])
18
+ @pytest.mark.parametrize("dstate", [16, 32, 64])
19
+ # @pytest.mark.parametrize("dstate", [16])
20
+ @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
21
+ # @pytest.mark.parametrize("dim", [2048])
22
+ def test_selective_state_update(dim, dstate, has_z, itype):
23
+ device = "cuda"
24
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
25
+ if itype == torch.bfloat16:
26
+ rtol, atol = 1e-2, 5e-2
27
+ # set seed
28
+ torch.random.manual_seed(0)
29
+ batch_size = 2
30
+ state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
31
+ x = torch.randn(batch_size, dim, device=device, dtype=itype)
32
+ dt = torch.randn(batch_size, dim, device=device, dtype=itype)
33
+ dt_bias = torch.rand(dim, device=device) - 4.0
34
+ A = -torch.rand(dim, dstate, device=device) - 1.0
35
+ B = torch.randn(batch_size, dstate, device=device)
36
+ C = torch.randn(batch_size, dstate, device=device)
37
+ D = torch.randn(dim, device=device)
38
+ if has_z:
39
+ z = torch.randn_like(x)
40
+ else:
41
+ z = None
42
+ state_ref = state.detach().clone()
43
+ out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
44
+ out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
45
+
46
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
47
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
48
+ assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
49
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
mamba_ssm/.DS_Store ADDED
Binary file (6.15 kB). View file
 
mamba_ssm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __version__ = "1.2.2"
2
+
3
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
+ from mamba_ssm.modules.mamba_simple import Mamba
5
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
mamba_ssm/models/__init__.py ADDED
File without changes
mamba_ssm/models/config_mamba.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class MambaConfig:
6
+
7
+ d_model: int = 2560
8
+ n_layer: int = 64
9
+ vocab_size: int = 50277
10
+ ssm_cfg: dict = field(default_factory=dict)
11
+ rms_norm: bool = True
12
+ residual_in_fp32: bool = True
13
+ fused_add_norm: bool = True
14
+ pad_vocab_size_multiple: int = 8
15
+ tie_embeddings: bool = True
mamba_ssm/models/mixer_seq_simple.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+ import json
6
+ import os
7
+
8
+ from collections import namedtuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from mamba_ssm.models.config_mamba import MambaConfig
14
+ from mamba_ssm.modules.mamba_simple import Mamba, Block
15
+ from mamba_ssm.utils.generation import GenerationMixin
16
+ from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
17
+
18
+ try:
19
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
20
+ except ImportError:
21
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
22
+
23
+
24
+ def create_block(
25
+ d_model,
26
+ ssm_cfg=None,
27
+ norm_epsilon=1e-5,
28
+ rms_norm=False,
29
+ residual_in_fp32=False,
30
+ fused_add_norm=False,
31
+ layer_idx=None,
32
+ device=None,
33
+ dtype=None,
34
+ ):
35
+ if ssm_cfg is None:
36
+ ssm_cfg = {}
37
+ factory_kwargs = {"device": device, "dtype": dtype}
38
+ mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
39
+ norm_cls = partial(
40
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
41
+ )
42
+ block = Block(
43
+ d_model,
44
+ mixer_cls,
45
+ norm_cls=norm_cls,
46
+ fused_add_norm=fused_add_norm,
47
+ residual_in_fp32=residual_in_fp32,
48
+ )
49
+ block.layer_idx = layer_idx
50
+ return block
51
+
52
+
53
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
54
+ def _init_weights(
55
+ module,
56
+ n_layer,
57
+ initializer_range=0.02, # Now only used for embedding layer.
58
+ rescale_prenorm_residual=True,
59
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
60
+ ):
61
+ if isinstance(module, nn.Linear):
62
+ if module.bias is not None:
63
+ if not getattr(module.bias, "_no_reinit", False):
64
+ nn.init.zeros_(module.bias)
65
+ elif isinstance(module, nn.Embedding):
66
+ nn.init.normal_(module.weight, std=initializer_range)
67
+
68
+ if rescale_prenorm_residual:
69
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
70
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
71
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
72
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
73
+ #
74
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
75
+ for name, p in module.named_parameters():
76
+ if name in ["out_proj.weight", "fc2.weight"]:
77
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
78
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
79
+ # We need to reinit p since this code could be called multiple times
80
+ # Having just p *= scale would repeatedly scale it down
81
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
82
+ with torch.no_grad():
83
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
84
+
85
+
86
+ class MixerModel(nn.Module):
87
+ def __init__(
88
+ self,
89
+ d_model: int,
90
+ n_layer: int,
91
+ vocab_size: int,
92
+ ssm_cfg=None,
93
+ norm_epsilon: float = 1e-5,
94
+ rms_norm: bool = False,
95
+ initializer_cfg=None,
96
+ fused_add_norm=False,
97
+ residual_in_fp32=False,
98
+ device=None,
99
+ dtype=None,
100
+ ) -> None:
101
+ factory_kwargs = {"device": device, "dtype": dtype}
102
+ super().__init__()
103
+ self.residual_in_fp32 = residual_in_fp32
104
+
105
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
106
+
107
+ # We change the order of residual and layer norm:
108
+ # Instead of LN -> Attn / MLP -> Add, we do:
109
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
110
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
111
+ # This is for performance reason: we can fuse add + layer_norm.
112
+ self.fused_add_norm = fused_add_norm
113
+ if self.fused_add_norm:
114
+ if layer_norm_fn is None or rms_norm_fn is None:
115
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
116
+
117
+ self.layers = nn.ModuleList(
118
+ [
119
+ create_block(
120
+ d_model,
121
+ ssm_cfg=ssm_cfg,
122
+ norm_epsilon=norm_epsilon,
123
+ rms_norm=rms_norm,
124
+ residual_in_fp32=residual_in_fp32,
125
+ fused_add_norm=fused_add_norm,
126
+ layer_idx=i,
127
+ **factory_kwargs,
128
+ )
129
+ for i in range(n_layer)
130
+ ]
131
+ )
132
+
133
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
134
+ d_model, eps=norm_epsilon, **factory_kwargs
135
+ )
136
+
137
+ self.apply(
138
+ partial(
139
+ _init_weights,
140
+ n_layer=n_layer,
141
+ **(initializer_cfg if initializer_cfg is not None else {}),
142
+ )
143
+ )
144
+
145
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
146
+ return {
147
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
148
+ for i, layer in enumerate(self.layers)
149
+ }
150
+
151
+ def forward(self, input_ids, inference_params=None):
152
+ hidden_states = self.embedding(input_ids)
153
+ residual = None
154
+ for layer in self.layers:
155
+ hidden_states, residual = layer(
156
+ hidden_states, residual, inference_params=inference_params
157
+ )
158
+ if not self.fused_add_norm:
159
+ residual = (hidden_states + residual) if residual is not None else hidden_states
160
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
161
+ else:
162
+ # Set prenorm=False here since we don't need the residual
163
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
164
+ hidden_states = fused_add_norm_fn(
165
+ hidden_states,
166
+ self.norm_f.weight,
167
+ self.norm_f.bias,
168
+ eps=self.norm_f.eps,
169
+ residual=residual,
170
+ prenorm=False,
171
+ residual_in_fp32=self.residual_in_fp32,
172
+ )
173
+ return hidden_states
174
+
175
+
176
+ class MambaLMHeadModel(nn.Module, GenerationMixin):
177
+
178
+ def __init__(
179
+ self,
180
+ config: MambaConfig,
181
+ initializer_cfg=None,
182
+ device=None,
183
+ dtype=None,
184
+ ) -> None:
185
+ self.config = config
186
+ d_model = config.d_model
187
+ n_layer = config.n_layer
188
+ vocab_size = config.vocab_size
189
+ ssm_cfg = config.ssm_cfg
190
+ rms_norm = config.rms_norm
191
+ residual_in_fp32 = config.residual_in_fp32
192
+ fused_add_norm = config.fused_add_norm
193
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
194
+ factory_kwargs = {"device": device, "dtype": dtype}
195
+
196
+ super().__init__()
197
+ if vocab_size % pad_vocab_size_multiple != 0:
198
+ vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
199
+ self.backbone = MixerModel(
200
+ d_model=d_model,
201
+ n_layer=n_layer,
202
+ vocab_size=vocab_size,
203
+ ssm_cfg=ssm_cfg,
204
+ rms_norm=rms_norm,
205
+ initializer_cfg=initializer_cfg,
206
+ fused_add_norm=fused_add_norm,
207
+ residual_in_fp32=residual_in_fp32,
208
+ **factory_kwargs,
209
+ )
210
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
211
+
212
+ # Initialize weights and apply final processing
213
+ self.apply(
214
+ partial(
215
+ _init_weights,
216
+ n_layer=n_layer,
217
+ **(initializer_cfg if initializer_cfg is not None else {}),
218
+ )
219
+ )
220
+ self.tie_weights()
221
+
222
+ def tie_weights(self):
223
+ if self.config.tie_embeddings:
224
+ self.lm_head.weight = self.backbone.embedding.weight
225
+
226
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
227
+ return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
228
+
229
+ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
230
+ """
231
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
232
+ num_last_tokens: if > 0, only return the logits for the last n tokens
233
+ """
234
+ hidden_states = self.backbone(input_ids, inference_params=inference_params)
235
+ if num_last_tokens > 0:
236
+ hidden_states = hidden_states[:, -num_last_tokens:]
237
+ lm_logits = self.lm_head(hidden_states)
238
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
239
+ return CausalLMOutput(logits=lm_logits)
240
+
241
+ @classmethod
242
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
243
+ config_data = load_config_hf(pretrained_model_name)
244
+ config = MambaConfig(**config_data)
245
+ model = cls(config, device=device, dtype=dtype, **kwargs)
246
+ model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
247
+ return model
248
+
249
+ def save_pretrained(self, save_directory):
250
+ """
251
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
252
+ Save the model and its configuration file to a directory.
253
+ """
254
+ # Ensure save_directory exists
255
+ os.makedirs(save_directory, exist_ok=True)
256
+
257
+ # Save the model's state_dict
258
+ model_path = os.path.join(save_directory, 'pytorch_model.bin')
259
+ torch.save(self.state_dict(), model_path)
260
+
261
+ # Save the configuration of the model
262
+ config_path = os.path.join(save_directory, 'config.json')
263
+ with open(config_path, 'w') as f:
264
+ json.dump(self.config.__dict__, f)
mamba_ssm/modules/__init__.py ADDED
File without changes
mamba_ssm/modules/mamba_simple.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from einops import rearrange, repeat
12
+
13
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
14
+
15
+ try:
16
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
17
+ except ImportError:
18
+ causal_conv1d_fn, causal_conv1d_update = None, None
19
+
20
+ try:
21
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
22
+ except ImportError:
23
+ selective_state_update = None
24
+
25
+ try:
26
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
27
+ except ImportError:
28
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
+
30
+
31
+ class Mamba(nn.Module):
32
+ def __init__(
33
+ self,
34
+ d_model,
35
+ d_state=16,
36
+ d_conv=4,
37
+ expand=2,
38
+ dt_rank="auto",
39
+ dt_min=0.001,
40
+ dt_max=0.1,
41
+ dt_init="random",
42
+ dt_scale=1.0,
43
+ dt_init_floor=1e-4,
44
+ conv_bias=True,
45
+ bias=False,
46
+ use_fast_path=True, # Fused kernel options
47
+ layer_idx=None,
48
+ device=None,
49
+ dtype=None,
50
+ ):
51
+ factory_kwargs = {"device": device, "dtype": dtype}
52
+ super().__init__()
53
+ self.d_model = d_model
54
+ self.d_state = d_state
55
+ self.d_conv = d_conv
56
+ self.expand = expand
57
+ self.d_inner = int(self.expand * self.d_model)
58
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
59
+ self.use_fast_path = use_fast_path
60
+ self.layer_idx = layer_idx
61
+
62
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
63
+
64
+ self.conv1d = nn.Conv1d(
65
+ in_channels=self.d_inner,
66
+ out_channels=self.d_inner,
67
+ bias=conv_bias,
68
+ kernel_size=d_conv,
69
+ groups=self.d_inner,
70
+ padding=d_conv - 1,
71
+ **factory_kwargs,
72
+ )
73
+
74
+ self.activation = "silu"
75
+ self.act = nn.SiLU()
76
+
77
+ self.x_proj = nn.Linear(
78
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
79
+ )
80
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
81
+
82
+ # Initialize special dt projection to preserve variance at initialization
83
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
84
+ if dt_init == "constant":
85
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
86
+ elif dt_init == "random":
87
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
88
+ else:
89
+ raise NotImplementedError
90
+
91
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
92
+ dt = torch.exp(
93
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
94
+ + math.log(dt_min)
95
+ ).clamp(min=dt_init_floor)
96
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
97
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
98
+ with torch.no_grad():
99
+ self.dt_proj.bias.copy_(inv_dt)
100
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
101
+ self.dt_proj.bias._no_reinit = True
102
+
103
+ # S4D real initialization
104
+ A = repeat(
105
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
106
+ "n -> d n",
107
+ d=self.d_inner,
108
+ ).contiguous()
109
+ A_log = torch.log(A) # Keep A_log in fp32
110
+ self.A_log = nn.Parameter(A_log)
111
+ self.A_log._no_weight_decay = True
112
+
113
+ # D "skip" parameter
114
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
115
+ self.D._no_weight_decay = True
116
+
117
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
118
+
119
+ def forward(self, hidden_states, inference_params=None):
120
+ """
121
+ hidden_states: (B, L, D)
122
+ Returns: same shape as hidden_states
123
+ """
124
+ batch, seqlen, dim = hidden_states.shape
125
+
126
+ conv_state, ssm_state = None, None
127
+ if inference_params is not None:
128
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
129
+ if inference_params.seqlen_offset > 0:
130
+ # The states are updated inplace
131
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
132
+ return out
133
+
134
+ # We do matmul and transpose BLH -> HBL at the same time
135
+ xz = rearrange(
136
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
137
+ "d (b l) -> b d l",
138
+ l=seqlen,
139
+ )
140
+ if self.in_proj.bias is not None:
141
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
142
+
143
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
144
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
145
+ if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
146
+ out = mamba_inner_fn(
147
+ xz,
148
+ self.conv1d.weight,
149
+ self.conv1d.bias,
150
+ self.x_proj.weight,
151
+ self.dt_proj.weight,
152
+ self.out_proj.weight,
153
+ self.out_proj.bias,
154
+ A,
155
+ None, # input-dependent B
156
+ None, # input-dependent C
157
+ self.D.float(),
158
+ delta_bias=self.dt_proj.bias.float(),
159
+ delta_softplus=True,
160
+ )
161
+ else:
162
+ x, z = xz.chunk(2, dim=1)
163
+ # Compute short convolution
164
+ if conv_state is not None:
165
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
166
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
167
+ conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
168
+ if causal_conv1d_fn is None:
169
+ x = self.act(self.conv1d(x)[..., :seqlen])
170
+ else:
171
+ assert self.activation in ["silu", "swish"]
172
+ x = causal_conv1d_fn(
173
+ x=x,
174
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
175
+ bias=self.conv1d.bias,
176
+ activation=self.activation,
177
+ )
178
+
179
+ # We're careful here about the layout, to avoid extra transposes.
180
+ # We want dt to have d as the slowest moving dimension
181
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
182
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
183
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
184
+ dt = self.dt_proj.weight @ dt.t()
185
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
186
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
187
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
188
+ assert self.activation in ["silu", "swish"]
189
+ y = selective_scan_fn(
190
+ x,
191
+ dt,
192
+ A,
193
+ B,
194
+ C,
195
+ self.D.float(),
196
+ z=z,
197
+ delta_bias=self.dt_proj.bias.float(),
198
+ delta_softplus=True,
199
+ return_last_state=ssm_state is not None,
200
+ )
201
+ if ssm_state is not None:
202
+ y, last_state = y
203
+ ssm_state.copy_(last_state)
204
+ y = rearrange(y, "b d l -> b l d")
205
+ out = self.out_proj(y)
206
+ return out
207
+
208
+ def step(self, hidden_states, conv_state, ssm_state):
209
+ dtype = hidden_states.dtype
210
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
211
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
212
+ x, z = xz.chunk(2, dim=-1) # (B D)
213
+
214
+ # Conv step
215
+ if causal_conv1d_update is None:
216
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
217
+ conv_state[:, :, -1] = x
218
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
219
+ if self.conv1d.bias is not None:
220
+ x = x + self.conv1d.bias
221
+ x = self.act(x).to(dtype=dtype)
222
+ else:
223
+ x = causal_conv1d_update(
224
+ x,
225
+ conv_state,
226
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
227
+ self.conv1d.bias,
228
+ self.activation,
229
+ )
230
+
231
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
232
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
233
+ # Don't add dt_bias here
234
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
235
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
236
+
237
+ # SSM step
238
+ if selective_state_update is None:
239
+ # Discretize A and B
240
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
241
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
242
+ dB = torch.einsum("bd,bn->bdn", dt, B)
243
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
244
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
245
+ y = y + self.D.to(dtype) * x
246
+ y = y * self.act(z) # (B D)
247
+ else:
248
+ y = selective_state_update(
249
+ ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
250
+ )
251
+
252
+ out = self.out_proj(y)
253
+ return out.unsqueeze(1), conv_state, ssm_state
254
+
255
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
256
+ device = self.out_proj.weight.device
257
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
258
+ conv_state = torch.zeros(
259
+ batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
260
+ )
261
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
262
+ # ssm_dtype = torch.float32
263
+ ssm_state = torch.zeros(
264
+ batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
265
+ )
266
+ return conv_state, ssm_state
267
+
268
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
269
+ assert self.layer_idx is not None
270
+ if self.layer_idx not in inference_params.key_value_memory_dict:
271
+ batch_shape = (batch_size,)
272
+ conv_state = torch.zeros(
273
+ batch_size,
274
+ self.d_model * self.expand,
275
+ self.d_conv,
276
+ device=self.conv1d.weight.device,
277
+ dtype=self.conv1d.weight.dtype,
278
+ )
279
+ ssm_state = torch.zeros(
280
+ batch_size,
281
+ self.d_model * self.expand,
282
+ self.d_state,
283
+ device=self.dt_proj.weight.device,
284
+ dtype=self.dt_proj.weight.dtype,
285
+ # dtype=torch.float32,
286
+ )
287
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
288
+ else:
289
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
290
+ # TODO: What if batch size changes between generation, and we reuse the same states?
291
+ if initialize_states:
292
+ conv_state.zero_()
293
+ ssm_state.zero_()
294
+ return conv_state, ssm_state
295
+
296
+
297
+ class Block(nn.Module):
298
+ def __init__(
299
+ self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
300
+ ):
301
+ """
302
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
303
+
304
+ This Block has a slightly different structure compared to a regular
305
+ prenorm Transformer block.
306
+ The standard block is: LN -> MHA/MLP -> Add.
307
+ [Ref: https://arxiv.org/abs/2002.04745]
308
+ Here we have: Add -> LN -> Mixer, returning both
309
+ the hidden_states (output of the mixer) and the residual.
310
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
311
+ The residual needs to be provided (except for the very first block).
312
+ """
313
+ super().__init__()
314
+ self.residual_in_fp32 = residual_in_fp32
315
+ self.fused_add_norm = fused_add_norm
316
+ self.mixer = mixer_cls(dim)
317
+ self.norm = norm_cls(dim)
318
+ if self.fused_add_norm:
319
+ assert RMSNorm is not None, "RMSNorm import fails"
320
+ assert isinstance(
321
+ self.norm, (nn.LayerNorm, RMSNorm)
322
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
323
+
324
+ def forward(
325
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
326
+ ):
327
+ r"""Pass the input through the encoder layer.
328
+
329
+ Args:
330
+ hidden_states: the sequence to the encoder layer (required).
331
+ residual: hidden_states = Mixer(LN(residual))
332
+ """
333
+ if not self.fused_add_norm:
334
+ residual = (hidden_states + residual) if residual is not None else hidden_states
335
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
336
+ if self.residual_in_fp32:
337
+ residual = residual.to(torch.float32)
338
+ else:
339
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
340
+ hidden_states, residual = fused_add_norm_fn(
341
+ hidden_states,
342
+ self.norm.weight,
343
+ self.norm.bias,
344
+ residual=residual,
345
+ prenorm=True,
346
+ residual_in_fp32=self.residual_in_fp32,
347
+ eps=self.norm.eps,
348
+ )
349
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
350
+ return hidden_states, residual
351
+
352
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
353
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)