|
#include <torch/library.h> |
|
|
|
#include "pytorch_shim.h" |
|
#include "registration.h" |
|
#include "torch_binding.h" |
|
|
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
|
ops.def("fwd(Tensor! q," |
|
" Tensor k," |
|
" Tensor v," |
|
" Tensor? k_new," |
|
" Tensor? v_new," |
|
" Tensor? q_v," |
|
" Tensor!? out," |
|
" Tensor? cu_seqlens_q," |
|
" Tensor? cu_seqlens_k," |
|
" Tensor? cu_seqlens_k_new," |
|
" Tensor? seqused_q," |
|
" Tensor? seqused_k," |
|
" int? max_seqlen_q," |
|
" int? max_seqlen_k," |
|
" Tensor? page_table," |
|
" Tensor? kv_batch_idx," |
|
" Tensor? leftpad_k," |
|
" Tensor? rotary_cos," |
|
" Tensor? rotary_sin," |
|
" Tensor? seqlens_rotary," |
|
" Tensor? q_descale," |
|
" Tensor? k_descale," |
|
" Tensor? v_descale," |
|
" float softmax_scale," |
|
" bool is_causal," |
|
" int window_size_left," |
|
" int window_size_right," |
|
" float softcap," |
|
" bool is_rotary_interleaved," |
|
" Tensor? scheduler_metadata," |
|
" int num_splits," |
|
" bool? pack_gqa," |
|
" int sm_margin," |
|
" Tensor? s_aux_) -> Tensor[]"); |
|
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); |
|
|
|
ops.def("bwd(Tensor dout," |
|
" Tensor q," |
|
" Tensor k," |
|
" Tensor v," |
|
" Tensor out," |
|
" Tensor softmax_lse," |
|
" Tensor!? dq," |
|
" Tensor!? dk," |
|
" Tensor!? dv," |
|
" Tensor? cu_seqlens_q," |
|
" Tensor? cu_seqlens_k," |
|
" Tensor? seqused_q," |
|
" Tensor? seqused_k," |
|
" int? max_seqlen_q," |
|
" int? max_seqlen_k," |
|
" float softmax_scale," |
|
" bool is_causal," |
|
" int window_size_left," |
|
" int window_size_right," |
|
" float softcap," |
|
" bool deterministic," |
|
" int sm_margin) -> Tensor[]"); |
|
ops.impl("bwd", torch::kCUDA, make_pytorch_shim(&mha_bwd)); |
|
|
|
ops.def("fwd_combine(Tensor out_partial," |
|
" Tensor lse_partial," |
|
" Tensor!? out," |
|
" ScalarType? out_dtype) -> Tensor[]"); |
|
ops.impl("fwd_combine", torch::kCUDA, make_pytorch_shim(&mha_combine)); |
|
|
|
ops.def("get_scheduler_metadata(" |
|
" int batch_size," |
|
" int max_seqlen_q," |
|
" int max_seqlen_k," |
|
" int num_heads," |
|
" int num_heads_k," |
|
" int headdim," |
|
" int headdim_v," |
|
" ScalarType qkv_dtype," |
|
" Tensor seqused_k," |
|
" Tensor? cu_seqlens_q," |
|
" Tensor? cu_seqlens_k," |
|
" Tensor? cu_seqlens_k_new," |
|
" Tensor? seqused_q," |
|
" Tensor? leftpad_k," |
|
" int? page_size," |
|
" int max_seqlen_k_new," |
|
" bool is_causal," |
|
" int window_size_left," |
|
" int window_size_right," |
|
" bool has_softcap," |
|
" int num_splits," |
|
" bool? pack_gqa," |
|
" int sm_margin) -> Tensor"); |
|
ops.impl("get_scheduler_metadata", torch::kCUDA, make_pytorch_shim(&mha_fwd_get_scheduler_metadata)); |
|
} |
|
|
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
|
|