File size: 3,669 Bytes
4d1b54e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <torch/library.h>

#include "pytorch_shim.h"
#include "registration.h"
#include "torch_binding.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  ops.def("fwd(Tensor!  q,"
          "    Tensor   k,"
          "    Tensor   v,"
          "    Tensor?  k_new,"
          "    Tensor?  v_new,"
          "    Tensor?  q_v,"
          "    Tensor!? out,"
          "    Tensor?  cu_seqlens_q,"
          "    Tensor?  cu_seqlens_k,"
          "    Tensor?  cu_seqlens_k_new,"
          "    Tensor?  seqused_q,"
          "    Tensor?  seqused_k,"
          "    int?     max_seqlen_q,"
          "    int?     max_seqlen_k,"
          "    Tensor?  page_table,"
          "    Tensor?  kv_batch_idx,"
          "    Tensor?  leftpad_k,"
          "    Tensor?  rotary_cos,"
          "    Tensor?  rotary_sin,"
          "    Tensor?  seqlens_rotary,"
          "    Tensor?  q_descale,"
          "    Tensor?  k_descale,"
          "    Tensor?  v_descale,"
          "    float    softmax_scale,"
          "    bool     is_causal,"
          "    int      window_size_left,"
          "    int      window_size_right,"
          "    float    softcap,"
          "    bool     is_rotary_interleaved,"
          "    Tensor?  scheduler_metadata,"
          "    int      num_splits,"
          "    bool?    pack_gqa,"
          "    int      sm_margin,"
          "    Tensor?  s_aux_) -> Tensor[]");
  ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));

  ops.def("bwd(Tensor   dout,"
          "    Tensor   q,"
          "    Tensor   k,"
          "    Tensor   v,"
          "    Tensor   out,"
          "    Tensor   softmax_lse,"
          "    Tensor!? dq,"
          "    Tensor!? dk,"
          "    Tensor!? dv,"
          "    Tensor?  cu_seqlens_q,"
          "    Tensor?  cu_seqlens_k,"
          "    Tensor?  seqused_q,"
          "    Tensor?  seqused_k,"
          "    int?     max_seqlen_q,"
          "    int?     max_seqlen_k,"
          "    float    softmax_scale,"
          "    bool     is_causal,"
          "    int      window_size_left,"
          "    int      window_size_right,"
          "    float    softcap,"
          "    bool     deterministic,"
          "    int      sm_margin) -> Tensor[]");
  ops.impl("bwd", torch::kCUDA, make_pytorch_shim(&mha_bwd));

  ops.def("fwd_combine(Tensor       out_partial,"
          "            Tensor       lse_partial,"
          "            Tensor!?     out,"
          "            ScalarType?  out_dtype) -> Tensor[]");
  ops.impl("fwd_combine", torch::kCUDA, make_pytorch_shim(&mha_combine));

  ops.def("get_scheduler_metadata("
        "    int      batch_size,"
        "    int      max_seqlen_q,"
        "    int      max_seqlen_k,"
        "    int      num_heads,"
        "    int      num_heads_k,"
        "    int      headdim,"
        "    int      headdim_v,"
        "    ScalarType qkv_dtype,"
        "    Tensor   seqused_k,"
        "    Tensor?  cu_seqlens_q,"
        "    Tensor?  cu_seqlens_k,"
        "    Tensor?  cu_seqlens_k_new,"
        "    Tensor?  seqused_q,"
        "    Tensor?  leftpad_k,"
        "    int?     page_size,"
        "    int      max_seqlen_k_new," // 0 means we're not appending new KV
        "    bool     is_causal,"
        "    int      window_size_left,"
        "    int      window_size_right,"
        "    bool     has_softcap,"
        "    int      num_splits,"
        "    bool?    pack_gqa,"
        "    int      sm_margin) -> Tensor");
  ops.impl("get_scheduler_metadata", torch::kCUDA, make_pytorch_shim(&mha_fwd_get_scheduler_metadata));
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)