File size: 1,406 Bytes
a7165c8 b0d3c12 a7165c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
#include <torch/library.h>
#include "registration.h"
#include "torch_binding.h"
// TODO: Add all of the functions listed
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// m.doc() = "FlashAttention";
// m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass");
// m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass (variable length)");
// m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass");
// m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass (variable length)");
// m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache");
// }
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("mha_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
ops.impl("mha_fwd", torch::kCUDA, &mha_fwd);
ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
ops.impl("mha_varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|