|
#pragma once |
|
|
|
#include <optional> |
|
#include <vector> |
|
|
|
#include <torch/torch.h> |
|
|
|
std::vector<at::Tensor> |
|
mha_fwd(at::Tensor &q, |
|
const at::Tensor &k, |
|
const at::Tensor &v, |
|
std::optional<const at::Tensor> &k_new_, |
|
std::optional<const at::Tensor> &v_new_, |
|
std::optional<const at::Tensor> &q_v_, |
|
std::optional<at::Tensor> &out_, |
|
std::optional<const at::Tensor> &cu_seqlens_q_, |
|
std::optional<const at::Tensor> &cu_seqlens_k_, |
|
std::optional<const at::Tensor> &cu_seqlens_k_new_, |
|
std::optional<const at::Tensor> &seqused_q_, |
|
std::optional<const at::Tensor> &seqused_k_, |
|
std::optional<int> max_seqlen_q_, |
|
|
|
std::optional<int> max_seqlen_k_, |
|
std::optional<const at::Tensor> &page_table_, |
|
std::optional<const at::Tensor> &kv_batch_idx_, |
|
std::optional<const at::Tensor> &leftpad_k_, |
|
std::optional<const at::Tensor> &rotary_cos_, |
|
std::optional<const at::Tensor> &rotary_sin_, |
|
std::optional<const at::Tensor> &seqlens_rotary_, |
|
std::optional<at::Tensor> &q_descale_, |
|
std::optional<at::Tensor> &k_descale_, |
|
std::optional<at::Tensor> &v_descale_, |
|
float const softmax_scale, |
|
bool is_causal, |
|
int window_size_left, |
|
int window_size_right, |
|
float const softcap, |
|
bool const is_rotary_interleaved, |
|
std::optional<at::Tensor> &scheduler_metadata_, |
|
int num_splits, |
|
std::optional<bool> pack_gqa_, |
|
int const sm_margin, |
|
std::optional<const at::Tensor> &s_aux_ |
|
); |
|
|
|
std::vector<at::Tensor> mha_bwd( |
|
const at::Tensor &dout, |
|
const at::Tensor &q, |
|
const at::Tensor &k, |
|
const at::Tensor &v, |
|
const at::Tensor &out, |
|
const at::Tensor &softmax_lse, |
|
std::optional<at::Tensor> &dq_, |
|
std::optional<at::Tensor> &dk_, |
|
std::optional<at::Tensor> &dv_, |
|
std::optional<const at::Tensor> &cu_seqlens_q_, |
|
std::optional<const at::Tensor> &cu_seqlens_k_, |
|
std::optional<const at::Tensor> &seqused_q_, |
|
std::optional<const at::Tensor> &seqused_k_, |
|
std::optional<int> max_seqlen_q_, |
|
std::optional<int> max_seqlen_k_, |
|
float const softmax_scale, |
|
bool is_causal, |
|
int window_size_left, |
|
int window_size_right, |
|
float const softcap, |
|
bool const deterministic, |
|
int const sm_margin); |
|
|
|
std::vector<at::Tensor> |
|
mha_combine(const at::Tensor &out_partial, |
|
const at::Tensor &lse_partial, |
|
std::optional<at::Tensor> out_, |
|
std::optional<at::ScalarType> out_dtype_ |
|
); |
|
|
|
at::Tensor |
|
mha_fwd_get_scheduler_metadata( |
|
int batch_size, |
|
int max_seqlen_q, |
|
int max_seqlen_k, |
|
int num_heads, |
|
int num_heads_k, |
|
int headdim, |
|
int headdim_v, |
|
at::ScalarType qkv_dtype, |
|
const at::Tensor &seqused_k, |
|
std::optional<const at::Tensor> &cu_seqlens_q_, |
|
std::optional<const at::Tensor> &cu_seqlens_k_, |
|
std::optional<const at::Tensor> &cu_seqlens_k_new_, |
|
std::optional<const at::Tensor> &seqused_q_, |
|
std::optional<const at::Tensor> &leftpad_k_, |
|
std::optional<int> page_size, |
|
int max_seqlen_k_new, |
|
bool is_causal, |
|
int window_size_left, |
|
int window_size_right, |
|
bool has_softcap, |
|
int num_splits, |
|
std::optional<bool> pack_gqa_, |
|
int const sm_margin |
|
); |
|
|
|
|