#include #include "pytorch_shim.h" #include "registration.h" #include "torch_binding.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("fwd(Tensor! q," " Tensor k," " Tensor v," " Tensor? k_new," " Tensor? v_new," " Tensor? q_v," " Tensor!? out," " Tensor? cu_seqlens_q," " Tensor? cu_seqlens_k," " Tensor? cu_seqlens_k_new," " Tensor? seqused_q," " Tensor? seqused_k," " int? max_seqlen_q," " int? max_seqlen_k," " Tensor? page_table," " Tensor? kv_batch_idx," " Tensor? leftpad_k," " Tensor? rotary_cos," " Tensor? rotary_sin," " Tensor? seqlens_rotary," " Tensor? q_descale," " Tensor? k_descale," " Tensor? v_descale," " float softmax_scale," " bool is_causal," " int window_size_left," " int window_size_right," " float softcap," " bool is_rotary_interleaved," " Tensor? scheduler_metadata," " int num_splits," " bool? pack_gqa," " int sm_margin," " Tensor? s_aux_) -> Tensor[]"); ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); ops.def("bwd(Tensor dout," " Tensor q," " Tensor k," " Tensor v," " Tensor out," " Tensor softmax_lse," " Tensor!? dq," " Tensor!? dk," " Tensor!? dv," " Tensor? cu_seqlens_q," " Tensor? cu_seqlens_k," " Tensor? seqused_q," " Tensor? seqused_k," " int? max_seqlen_q," " int? max_seqlen_k," " float softmax_scale," " bool is_causal," " int window_size_left," " int window_size_right," " float softcap," " bool deterministic," " int sm_margin) -> Tensor[]"); ops.impl("bwd", torch::kCUDA, make_pytorch_shim(&mha_bwd)); ops.def("fwd_combine(Tensor out_partial," " Tensor lse_partial," " Tensor!? out," " ScalarType? out_dtype) -> Tensor[]"); ops.impl("fwd_combine", torch::kCUDA, make_pytorch_shim(&mha_combine)); ops.def("get_scheduler_metadata(" " int batch_size," " int max_seqlen_q," " int max_seqlen_k," " int num_heads," " int num_heads_k," " int headdim," " int headdim_v," " ScalarType qkv_dtype," " Tensor seqused_k," " Tensor? cu_seqlens_q," " Tensor? cu_seqlens_k," " Tensor? cu_seqlens_k_new," " Tensor? seqused_q," " Tensor? leftpad_k," " int? page_size," " int max_seqlen_k_new," // 0 means we're not appending new KV " bool is_causal," " int window_size_left," " int window_size_right," " bool has_softcap," " int num_splits," " bool? pack_gqa," " int sm_margin) -> Tensor"); ops.impl("get_scheduler_metadata", torch::kCUDA, make_pytorch_shim(&mha_fwd_get_scheduler_metadata)); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)