vllm-flash-attn3 / torch-ext /torch_binding.cpp
danieldk's picture
danieldk HF Staff
Convert FA3 to Kernel Hub format
4d1b54e
raw
history blame
3.67 kB
#include <torch/library.h>
#include "pytorch_shim.h"
#include "registration.h"
#include "torch_binding.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fwd(Tensor! q,"
" Tensor k,"
" Tensor v,"
" Tensor? k_new,"
" Tensor? v_new,"
" Tensor? q_v,"
" Tensor!? out,"
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k_new,"
" Tensor? seqused_q,"
" Tensor? seqused_k,"
" int? max_seqlen_q,"
" int? max_seqlen_k,"
" Tensor? page_table,"
" Tensor? kv_batch_idx,"
" Tensor? leftpad_k,"
" Tensor? rotary_cos,"
" Tensor? rotary_sin,"
" Tensor? seqlens_rotary,"
" Tensor? q_descale,"
" Tensor? k_descale,"
" Tensor? v_descale,"
" float softmax_scale,"
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
" float softcap,"
" bool is_rotary_interleaved,"
" Tensor? scheduler_metadata,"
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin,"
" Tensor? s_aux_) -> Tensor[]");
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
ops.def("bwd(Tensor dout,"
" Tensor q,"
" Tensor k,"
" Tensor v,"
" Tensor out,"
" Tensor softmax_lse,"
" Tensor!? dq,"
" Tensor!? dk,"
" Tensor!? dv,"
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
" Tensor? seqused_q,"
" Tensor? seqused_k,"
" int? max_seqlen_q,"
" int? max_seqlen_k,"
" float softmax_scale,"
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
" float softcap,"
" bool deterministic,"
" int sm_margin) -> Tensor[]");
ops.impl("bwd", torch::kCUDA, make_pytorch_shim(&mha_bwd));
ops.def("fwd_combine(Tensor out_partial,"
" Tensor lse_partial,"
" Tensor!? out,"
" ScalarType? out_dtype) -> Tensor[]");
ops.impl("fwd_combine", torch::kCUDA, make_pytorch_shim(&mha_combine));
ops.def("get_scheduler_metadata("
" int batch_size,"
" int max_seqlen_q,"
" int max_seqlen_k,"
" int num_heads,"
" int num_heads_k,"
" int headdim,"
" int headdim_v,"
" ScalarType qkv_dtype,"
" Tensor seqused_k,"
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k_new,"
" Tensor? seqused_q,"
" Tensor? leftpad_k,"
" int? page_size,"
" int max_seqlen_k_new," // 0 means we're not appending new KV
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
" bool has_softcap,"
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin) -> Tensor");
ops.impl("get_scheduler_metadata", torch::kCUDA, make_pytorch_shim(&mha_fwd_get_scheduler_metadata));
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)