|
#include <torch/library.h> |
|
|
|
#include "registration.h" |
|
#include "torch_binding.h" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
|
ops.def("mha_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]"); |
|
ops.impl("mha_fwd", torch::kCUDA, &mha_fwd); |
|
|
|
ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]"); |
|
ops.impl("mha_varlen_fwd", torch::kCUDA, &mha_varlen_fwd); |
|
} |
|
|
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
|
|