File size: 3,669 Bytes
4d1b54e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
#include <torch/library.h>
#include "pytorch_shim.h"
#include "registration.h"
#include "torch_binding.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fwd(Tensor! q,"
" Tensor k,"
" Tensor v,"
" Tensor? k_new,"
" Tensor? v_new,"
" Tensor? q_v,"
" Tensor!? out,"
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k_new,"
" Tensor? seqused_q,"
" Tensor? seqused_k,"
" int? max_seqlen_q,"
" int? max_seqlen_k,"
" Tensor? page_table,"
" Tensor? kv_batch_idx,"
" Tensor? leftpad_k,"
" Tensor? rotary_cos,"
" Tensor? rotary_sin,"
" Tensor? seqlens_rotary,"
" Tensor? q_descale,"
" Tensor? k_descale,"
" Tensor? v_descale,"
" float softmax_scale,"
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
" float softcap,"
" bool is_rotary_interleaved,"
" Tensor? scheduler_metadata,"
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin,"
" Tensor? s_aux_) -> Tensor[]");
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
ops.def("bwd(Tensor dout,"
" Tensor q,"
" Tensor k,"
" Tensor v,"
" Tensor out,"
" Tensor softmax_lse,"
" Tensor!? dq,"
" Tensor!? dk,"
" Tensor!? dv,"
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
" Tensor? seqused_q,"
" Tensor? seqused_k,"
" int? max_seqlen_q,"
" int? max_seqlen_k,"
" float softmax_scale,"
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
" float softcap,"
" bool deterministic,"
" int sm_margin) -> Tensor[]");
ops.impl("bwd", torch::kCUDA, make_pytorch_shim(&mha_bwd));
ops.def("fwd_combine(Tensor out_partial,"
" Tensor lse_partial,"
" Tensor!? out,"
" ScalarType? out_dtype) -> Tensor[]");
ops.impl("fwd_combine", torch::kCUDA, make_pytorch_shim(&mha_combine));
ops.def("get_scheduler_metadata("
" int batch_size,"
" int max_seqlen_q,"
" int max_seqlen_k,"
" int num_heads,"
" int num_heads_k,"
" int headdim,"
" int headdim_v,"
" ScalarType qkv_dtype,"
" Tensor seqused_k,"
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k_new,"
" Tensor? seqused_q,"
" Tensor? leftpad_k,"
" int? page_size,"
" int max_seqlen_k_new," // 0 means we're not appending new KV
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
" bool has_softcap,"
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin) -> Tensor");
ops.impl("get_scheduler_metadata", torch::kCUDA, make_pytorch_shim(&mha_fwd_get_scheduler_metadata));
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|