kernel
File size: 898 Bytes
a7165c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#pragma once

#include <torch/torch.h>

std::vector<at::Tensor>
mha_fwd(const at::Tensor &q,         // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
        const at::Tensor &k,         // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
        const at::Tensor &v,         // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
        const c10::optional<torch::Tensor> &out_,             // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
        const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
        const double p_dropout,
        const double softmax_scale,
        bool is_causal,
        const int64_t window_size_left,
        const int64_t window_size_right,
        const double softcap,
        const bool return_softmax,
        const c10::optional<at::Generator> gen_);