|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
#include <cute/tensor.hpp> |
|
|
|
#include "cutlass/fast_math.h" |
|
|
|
#include "utils.h" |
|
|
|
namespace flash { |
|
|
|
using namespace cute; |
|
|
|
template <int kBlockM, int kHeadDim, int NumThreads, typename Element> |
|
struct PackGQAManager { |
|
|
|
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); |
|
static constexpr int kGmemElemsPerStore = kGmemElemsPerLoad; |
|
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); |
|
|
|
|
|
|
|
static constexpr int kBytePerRow = kHeadDim * sizeof(Element); |
|
static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); |
|
static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; |
|
static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); |
|
|
|
|
|
static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); |
|
using GmemCopyAtomCpAsync = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<uint128_t>, Element>; |
|
using GmemLayoutAtom = Layout<Shape <Int<NumThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, |
|
Stride<Int<kGmemThreadsPerRow>, _1>>; |
|
using GmemTiledCopyQCpAsync = decltype( |
|
make_tiled_copy(GmemCopyAtomCpAsync{}, |
|
GmemLayoutAtom{}, |
|
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
using GmemTiledCopyO = decltype( |
|
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{}, |
|
GmemLayoutAtom{}, |
|
Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); |
|
|
|
template <int NumThreadsPerRow=kGmemThreadsPerRow, typename Engine, typename Layout, typename TensorC> |
|
CUTLASS_DEVICE |
|
static auto |
|
compute_ptr(Tensor<Engine, Layout> &tensor, TensorC const &tRows, |
|
cutlass::FastDivmod const &qhead_per_khead_divmod, int const thread_idx, int const m_block) { |
|
|
|
static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size(tRows)), NumThreadsPerRow); |
|
using TensorType = typename Engine::value_type; |
|
Tensor tPrPtr = make_tensor<TensorType const*>(Shape<Int<NumPtrPerThread>>{}); |
|
#pragma unroll |
|
for (int i = 0; i < NumPtrPerThread; ++i) { |
|
int const row = i * NumThreads + get<0>(tRows(thread_idx % NumThreadsPerRow)); |
|
int const idx = m_block * kBlockM + row; |
|
int m_idx, h_idx; |
|
m_idx = qhead_per_khead_divmod.divmod(h_idx, idx); |
|
tPrPtr[i] = &tensor(make_coord(make_coord(h_idx, m_idx))); |
|
} |
|
return tPrPtr; |
|
} |
|
|
|
|
|
template <typename TensormQ, typename TensorsQ> |
|
CUTLASS_DEVICE |
|
static void |
|
load_Q(TensormQ const &mQ, |
|
TensorsQ &sQ, |
|
cutlass::FastDivmod const &qhead_per_khead_divmod, |
|
int const thread_idx, int const seqlen_q, int const m_block |
|
) |
|
{ |
|
GmemTiledCopyQCpAsync gmem_tiled_copy_Q_cp_async; |
|
|
|
auto gmem_thr_copy_Q_cp_async = gmem_tiled_copy_Q_cp_async.get_thread_slice(thread_idx); |
|
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); |
|
Tensor tQcQ = gmem_thr_copy_Q_cp_async.partition_S(cQ); |
|
Tensor tQsQ = gmem_thr_copy_Q_cp_async.partition_D(sQ); |
|
|
|
|
|
|
|
|
|
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ))); |
|
#pragma unroll |
|
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < size<1>(mQ); } |
|
|
|
|
|
|
|
Tensor mQ_0 = mQ(_, _0{}); |
|
Tensor tQcQ_row = tQcQ(_0{}, _, _0{}); |
|
Tensor tPrQPtr = compute_ptr(mQ_0, tQcQ_row, qhead_per_khead_divmod, thread_idx, m_block); |
|
int const qhead_per_khead = qhead_per_khead_divmod.divisor; |
|
#pragma unroll |
|
for (int m = 0; m < size<1>(tQsQ); ++m) { |
|
int idx = m_block * kBlockM + get<0>(tQcQ(_0{}, m, _0{})); |
|
Element const* q_ptr = reinterpret_cast<Element const*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrQPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); |
|
if (idx < seqlen_q * qhead_per_khead) { |
|
|
|
Tensor mQ_cur = make_tensor(make_gmem_ptr(q_ptr), Shape<Int<kHeadDim>>{}); |
|
Tensor mQ_cur_copy = cute::tiled_divide(mQ_cur, Shape<Int<kGmemElemsPerLoad>>{}); |
|
#pragma unroll |
|
for (int k = 0; k < size<2>(tQsQ); ++k) { |
|
int ki = get<1>(tQcQ(_0{}, _0{}, k)) / kGmemElemsPerLoad; |
|
|
|
|
|
cute::copy(gmem_tiled_copy_Q_cp_async.with(tQpQ(k)), mQ_cur_copy(_, ki), tQsQ(_, m, k)); |
|
} |
|
} |
|
} |
|
}; |
|
|
|
template <typename TensormLSE, typename TensorsLSE, typename TiledMma> |
|
CUTLASS_DEVICE |
|
static void |
|
store_LSE(TensormLSE &mLSE, |
|
TensorsLSE const &tLSErLSE, |
|
TiledMma tiled_mma, |
|
cutlass::FastDivmod const &qhead_per_khead_divmod, |
|
int const thread_idx, int const seqlen_o, int const m_block |
|
) |
|
{ |
|
Tensor caccO = cute::make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); |
|
auto thread_mma = tiled_mma.get_thread_slice(thread_idx); |
|
Tensor taccOcO = thread_mma.partition_C(caccO); |
|
Tensor taccOcO_row = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()))(_, _0{}); |
|
CUTE_STATIC_ASSERT_V(size(tLSErLSE) == size(taccOcO_row)); |
|
|
|
|
|
static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); |
|
static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); |
|
static_assert(CUTE_STATIC_V(size(tLSErLSE)) <= kMmaThreadsPerRow); |
|
static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow); |
|
|
|
Tensor tPrLSEPtr = compute_ptr<kMmaThreadsPerRow>(mLSE, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block); |
|
static_assert(CUTE_STATIC_V(size(tPrLSEPtr)) == 1); |
|
int const qhead_per_khead = qhead_per_khead_divmod.divisor; |
|
#pragma unroll |
|
for (int mi = 0; mi < size(tLSErLSE); ++mi) { |
|
int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); |
|
float* ptr_LSE_cur = reinterpret_cast<float*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrLSEPtr[0]), mi % kMmaThreadsPerRow, kMmaThreadsPerRow)); |
|
if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o * qhead_per_khead) { |
|
*ptr_LSE_cur = tLSErLSE(mi); |
|
} |
|
} |
|
}; |
|
|
|
template <typename TensormO, typename TensorrO> |
|
CUTLASS_DEVICE |
|
static void |
|
store_O(TensormO &mO, |
|
TensorrO const &tOrO, |
|
cutlass::FastDivmod const &qhead_per_khead_divmod, |
|
int const thread_idx, int const seqlen_o, int const m_block |
|
) |
|
{ |
|
GmemTiledCopyO gmem_tiled_copy_O; |
|
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); |
|
Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); |
|
Tensor tOcO = gmem_thr_copy_O.partition_D(cO); |
|
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO))); |
|
#pragma unroll |
|
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < size<1>(mO); } |
|
|
|
|
|
|
|
Tensor mO_0 = mO(_, _0{}); |
|
Tensor tOcO_row = tOcO(_0{}, _, _0{}); |
|
Tensor tPrOPtr = compute_ptr(mO_0, tOcO_row, qhead_per_khead_divmod, thread_idx, m_block); |
|
int const qhead_per_khead = qhead_per_khead_divmod.divisor; |
|
#pragma unroll |
|
for (int m = 0; m < size<1>(tOrO); ++m) { |
|
int idx = m_block * kBlockM + get<0>(tOcO(_0{}, m, _0{})); |
|
Element* o_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrOPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); |
|
if (idx < seqlen_o * qhead_per_khead) { |
|
Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape<Int<kHeadDim>>{}); |
|
Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape<Int<kGmemElemsPerStore>>{}); |
|
#pragma unroll |
|
for (int k = 0; k < size<2>(tOrO); ++k) { |
|
int ki = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerStore; |
|
if (tOpO(k)) { |
|
cute::copy(gmem_tiled_copy_O, tOrO(_, m, k), mO_cur_copy(_, ki)); |
|
} |
|
} |
|
} |
|
} |
|
}; |
|
|
|
template <typename TensormO, typename TensorrO, typename TiledMma> |
|
CUTLASS_DEVICE |
|
static void |
|
store_O_direct(TensormO &mO, |
|
TensorrO const &tOrO, |
|
TiledMma tiled_mma, |
|
cutlass::FastDivmod const &qhead_per_khead_divmod, |
|
int const thread_idx, int const seqlen_o, int const m_block |
|
) |
|
{ |
|
static constexpr int kGmemElemsPerStoreDirect = 2; |
|
cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element> gmem_copy_direct; |
|
|
|
Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); |
|
Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{}); |
|
|
|
Tensor caccO = cute::make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); |
|
auto thread_mma = tiled_mma.get_thread_slice(thread_idx); |
|
Tensor taccOcO = thread_mma.partition_C(caccO); |
|
Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); |
|
Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); |
|
Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); |
|
|
|
|
|
static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); |
|
static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); |
|
static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow); |
|
|
|
|
|
|
|
Tensor mO_0 = mO(_, _0{}); |
|
Tensor tPrOPtr = compute_ptr<kMmaThreadsPerRow>(mO_0, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block); |
|
static_assert(CUTE_STATIC_V(size(tPrOPtr)) == 1); |
|
|
|
int const qhead_per_khead = qhead_per_khead_divmod.divisor; |
|
#pragma unroll |
|
for (int m = 0; m < size<1>(tOrO_copy); ++m) { |
|
int row = m_block * kBlockM + get<0>(taccOcO_row(m)); |
|
Element* o_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrOPtr[0]), m % kMmaThreadsPerRow, kMmaThreadsPerRow)); |
|
if (row < seqlen_o * qhead_per_khead) { |
|
Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape<Int<kHeadDim>>{}); |
|
Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape<Int<kGmemElemsPerStoreDirect>>{}); |
|
#pragma unroll |
|
for (int k = 0; k < size<2>(tOrO_copy); ++k) { |
|
int col = get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)); |
|
if (col < size<1>(mO)) { |
|
cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), mO_cur_copy(_, col / kGmemElemsPerStoreDirect)); |
|
} |
|
} |
|
} |
|
} |
|
}; |
|
|
|
}; |
|
|
|
} |
|
|