|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
#include <cutlass/cutlass.h> |
|
#include <cutlass/fast_math.h> |
|
#include "cute/tensor.hpp" |
|
|
|
#include "cutlass/gemm/collective/builders/sm90_common.inl" |
|
#include "cutlass/epilogue/collective/builders/sm90_common.inl" |
|
|
|
#include "seqlen.h" |
|
#include "named_barrier.hpp" |
|
#include "pack_gqa.h" |
|
#include "utils.h" |
|
|
|
namespace flash { |
|
|
|
using namespace cute; |
|
|
|
template <class TileShape_MNK_PV_, class ClusterShape_, class Element_, class ArchTag_, |
|
int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false> |
|
struct CollectiveEpilogueFwd { |
|
|
|
using TileShape_MNK_PV = TileShape_MNK_PV_; |
|
using ClusterShape = ClusterShape_; |
|
using Element = Element_; |
|
using ElementPartial = float; |
|
using ArchTag = ArchTag_; |
|
static constexpr int NumEpilogueThreads = NumEpilogueThreads_; |
|
static constexpr bool Varlen = Varlen_; |
|
static constexpr bool PackGQA = PackGQA_; |
|
static constexpr bool Split = Split_; |
|
static constexpr bool Use_smem = !(Split && !Varlen); |
|
static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA; |
|
|
|
static_assert(ArchTag::kMinComputeCapability >= 80); |
|
static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); |
|
static_assert(sizeof(Element) <= 2); |
|
|
|
static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); |
|
static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); |
|
|
|
static constexpr bool LargeHeadDimV = kHeadDimV > 256; |
|
|
|
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; |
|
|
|
|
|
static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element); |
|
static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); |
|
|
|
|
|
|
|
static constexpr int kBytePerRow = kHeadDimV * sizeof(Element); |
|
static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); |
|
static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore; |
|
|
|
static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0); |
|
static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow"); |
|
using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, |
|
Stride<Int<kGmemThreadsPerRow>, _1>>; |
|
static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow"); |
|
using GmemTiledCopyO = decltype( |
|
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{}, |
|
GmemLayoutAtom{}, |
|
Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); |
|
|
|
using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element, |
|
decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>()); |
|
using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{}))); |
|
static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); |
|
static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); |
|
using SmemLayoutAtomO = decltype( |
|
composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{}, |
|
Layout<Shape<_8, Int<kBlockKGmem>>, |
|
Stride<Int<kBlockKGmem>, _1>>{})); |
|
using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{}))); |
|
using SmemLayoutO = std::conditional_t<ArchTag::kMinComputeCapability >= 90, SmemLayoutOTMA, SmemLayoutOSTS>; |
|
|
|
using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; |
|
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>; |
|
using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; |
|
|
|
using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>; |
|
using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>; |
|
|
|
using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>; |
|
using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>; |
|
|
|
using CopyOpR2S = std::conditional_t< |
|
ArchTag::kMinComputeCapability >= 90, |
|
|
|
decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>()), |
|
AutoVectorizingCopyWithAssumedAlignment<128> |
|
>; |
|
using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>; |
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TensorStorage : cute::aligned_struct<128> { |
|
cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0> smem_o; |
|
}; |
|
|
|
using TMA_O = std::conditional_t< |
|
Use_TMA_O, |
|
decltype(make_tma_copy( |
|
GmemTiledCopyOTMA{}, |
|
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}), |
|
SmemLayoutOTMA{}, |
|
select<0, 1>(TileShape_MNK_PV{}), |
|
_1{})), |
|
std::nullptr_t |
|
>; |
|
|
|
|
|
struct Arguments { |
|
Element* ptr_O; |
|
ShapeO const shape_O; |
|
StrideO const stride_O; |
|
ElementPartial* ptr_O_partial; |
|
StrideO const stride_O_partial; |
|
float* ptr_LSE; |
|
StrideLSE const stride_LSE; |
|
float* ptr_LSE_partial; |
|
StrideLSE const stride_LSE_partial; |
|
int32_t const nheads_kv; |
|
int const* cu_seqlens = nullptr; |
|
int const* seqused = nullptr; |
|
}; |
|
|
|
|
|
struct Params { |
|
Element* ptr_O; |
|
ShapeO const shape_O; |
|
StrideO const stride_O; |
|
ShapeOPacked const shape_O_packed; |
|
StrideOPacked const stride_O_packed; |
|
ElementPartial* ptr_O_partial; |
|
StrideO const stride_O_partial; |
|
StrideOPacked const stride_O_partial_packed; |
|
float* ptr_LSE; |
|
StrideLSE const stride_LSE; |
|
ShapeLSEPacked const shape_LSE_packed; |
|
StrideLSEPacked const stride_LSE_packed; |
|
float* ptr_LSE_partial; |
|
StrideLSE const stride_LSE_partial; |
|
StrideLSEPacked const stride_LSE_partial_packed; |
|
cutlass::FastDivmod qhead_per_khead_divmod; |
|
TMA_O tma_store_O; |
|
int const* cu_seqlens = nullptr; |
|
int const* seqused = nullptr; |
|
}; |
|
|
|
static Params |
|
to_underlying_arguments(Arguments const& args) { |
|
Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); |
|
TMA_O tma_store_O = [&]{ |
|
if constexpr (Use_TMA_O) { |
|
return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); |
|
} else { |
|
return nullptr; |
|
} |
|
}(); |
|
|
|
int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv); |
|
auto const shape_O_packed = cute::conditional_return<!PackGQA>( |
|
args.shape_O, |
|
make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) |
|
); |
|
auto const stride_O_packed = cute::conditional_return<!PackGQA>( |
|
args.stride_O, |
|
make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O)) |
|
); |
|
auto const stride_O_partial_packed = cute::conditional_return<!PackGQA>( |
|
args.stride_O_partial, |
|
make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial)) |
|
); |
|
|
|
auto const shape_LSE_packed = cute::conditional_return<!PackGQA>( |
|
select<0, 2, 3, 4>(args.shape_O), |
|
make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) |
|
); |
|
auto const stride_LSE_packed = cute::conditional_return<!PackGQA>( |
|
args.stride_LSE, |
|
make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE)) |
|
); |
|
auto const stride_LSE_partial_packed = cute::conditional_return<!PackGQA>( |
|
args.stride_LSE_partial, |
|
make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial)) |
|
); |
|
return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed, |
|
args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed, |
|
args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed, |
|
args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed, |
|
cutlass::FastDivmod(qhead_per_khead), |
|
tma_store_O, args.cu_seqlens, args.seqused}; |
|
} |
|
|
|
|
|
CUTLASS_DEVICE |
|
static void prefetch_tma_descriptors(Params const& params) { |
|
if constexpr (Use_TMA_O) { |
|
cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor()); |
|
} |
|
} |
|
|
|
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma> |
|
CUTLASS_DEVICE void |
|
store(Params const& params, |
|
FrgTensorO& tOrO, |
|
FrgTensorLSE const& lse, |
|
SharedStorage& shared_storage, |
|
TiledMma tiled_mma, |
|
int thread_idx, |
|
cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord |
|
) { |
|
|
|
auto [m_block, bidh, bidb, split_idx] = block_coord; |
|
int num_splits = get<4>(params.shape_O_packed); |
|
if constexpr (Split && Varlen) { |
|
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; |
|
int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u); |
|
num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; |
|
split_idx &= 0x0000FFFF; |
|
} |
|
bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); |
|
|
|
Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{}); |
|
|
|
|
|
static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4); |
|
|
|
|
|
if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); } |
|
Tensor tOrO_out = make_tensor_like<Element>(tOrO); |
|
flash::convert_type_out(tOrO, tOrO_out); |
|
if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); } |
|
|
|
|
|
|
|
|
|
|
|
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); |
|
|
|
|
|
if constexpr (Use_smem) { |
|
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); |
|
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); |
|
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); |
|
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); |
|
|
|
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); |
|
if constexpr (Use_TMA_O) { |
|
cutlass::arch::fence_view_async_shared(); |
|
cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, |
|
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); |
|
} else { |
|
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); |
|
} |
|
} else { |
|
if constexpr (ArchTag::kMinComputeCapability >= 90) { |
|
#pragma unroll |
|
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { |
|
shared_storage.pipelines.barrier_O.arrive(cta_id); |
|
} |
|
} |
|
} |
|
|
|
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; |
|
bool is_varlen = Varlen && params.cu_seqlens; |
|
int offset_o = seqlen_info.offset; |
|
int seqlen_o = seqlen_info.seqlen; |
|
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); |
|
|
|
|
|
auto thread_mma = tiled_mma.get_thread_slice(thread_idx); |
|
|
|
Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); |
|
static_assert(decltype(size<0, 0>(taccOcO))::value == 2); |
|
static_assert(decltype(size<0, 1>(taccOcO))::value == 2); |
|
Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); |
|
Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); |
|
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); |
|
|
|
using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; |
|
using PackGQApartial_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>; |
|
|
|
Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), |
|
params.shape_LSE_packed, |
|
!is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); |
|
|
|
if (!LargeHeadDimV || warp_group_idx == 0) { |
|
if constexpr (!PackGQA) { |
|
#pragma unroll |
|
for (int mi = 0; mi < size(lse); ++mi) { |
|
int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); |
|
if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } |
|
} |
|
} else { |
|
PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); |
|
} |
|
} |
|
|
|
|
|
if constexpr (Use_TMA_O) { |
|
Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); |
|
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); |
|
auto block_tma_O = params.tma_store_O.get_slice(_0{}); |
|
Tensor tOgO = block_tma_O.partition_D(gO); |
|
Tensor tOsO = block_tma_O.partition_S(sO); |
|
int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0); |
|
if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { |
|
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp, |
|
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); |
|
if (cute::elect_one_sync()) { |
|
cute::copy(params.tma_store_O, tOsO, tOgO); |
|
tma_store_arrive(); |
|
tma_store_wait<0>(); |
|
#pragma unroll |
|
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { |
|
shared_storage.pipelines.barrier_O.arrive(cta_id); |
|
} |
|
} |
|
} |
|
} else { |
|
if (!is_split) { |
|
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); |
|
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); |
|
|
|
GmemTiledCopyO gmem_tiled_copy_O; |
|
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); |
|
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); |
|
|
|
Tensor tOrO = make_fragment_like(tOsO); |
|
cute::copy(gmem_tiled_copy_O, tOsO, tOrO); |
|
if constexpr (ArchTag::kMinComputeCapability >= 90) { |
|
cutlass::arch::fence_view_async_shared(); |
|
#pragma unroll |
|
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { |
|
shared_storage.pipelines.barrier_O.arrive(cta_id); |
|
} |
|
} |
|
if constexpr (!PackGQA) { |
|
|
|
Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); |
|
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOsO))); |
|
#pragma unroll |
|
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } |
|
Tensor tOgO = gmem_thr_copy_O.partition_D(gO); |
|
|
|
flash::copy</*Is_even_MN=*/false, false, false, false>( |
|
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM |
|
); |
|
} else { |
|
|
|
PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); |
|
} |
|
} else { |
|
Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); |
|
Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); |
|
|
|
if constexpr (Use_smem) { |
|
if constexpr (ArchTag::kMinComputeCapability >= 90) { |
|
#pragma unroll |
|
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { |
|
shared_storage.pipelines.barrier_O.arrive(cta_id); |
|
} |
|
} |
|
} |
|
if constexpr (!PackGQA) { |
|
static constexpr int kGmemElemsPerStoreDirect = 2; |
|
cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial> gmem_copy_direct; |
|
|
|
Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); |
|
Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{}); |
|
Tensor tOgO = thread_mma.partition_C(gOpartial); |
|
Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout())); |
|
Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{}); |
|
Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); |
|
#pragma unroll |
|
for (int m = 0; m < size(taccOcO_row); ++m) { |
|
if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) { |
|
#pragma unroll |
|
for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) { |
|
if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) { |
|
cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k)); |
|
} |
|
} |
|
} |
|
} |
|
} else { |
|
PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); |
|
} |
|
} |
|
} |
|
} |
|
|
|
CUTLASS_DEVICE void |
|
store_tail() { |
|
|
|
} |
|
|
|
|
|
CUTLASS_DEVICE void |
|
store_zero( |
|
Params const& params, |
|
int thread_idx, |
|
cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord |
|
) { |
|
static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); |
|
auto [m_block, bidh, bidb, split_idx] = block_coord; |
|
int num_splits = get<4>(params.shape_O_packed); |
|
if constexpr (Split && Varlen) { |
|
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; |
|
int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u); |
|
num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; |
|
split_idx &= 0x0000FFFF; |
|
} |
|
bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); |
|
|
|
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; |
|
bool const is_varlen = Varlen && params.cu_seqlens; |
|
int offset_o = seqlen_info.offset; |
|
int seqlen_o = seqlen_info.seqlen; |
|
int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; |
|
Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), |
|
params.shape_LSE_packed, |
|
!is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); |
|
Tensor gLSE = local_tile(mLSE, Shape<Int<kBlockM>>{}, make_coord(m_block)); |
|
|
|
static_assert(kBlockM <= NumEpilogueThreads); |
|
if (thread_idx < kBlockM) { |
|
const int row = m_block * kBlockM + thread_idx; |
|
if constexpr (!PackGQA) { |
|
if (row < seqlen_o) { mLSE(row) = -INFINITY; } |
|
} else { |
|
if (row < seqlen_o * qhead_per_khead) { |
|
int m_idx, h_idx; |
|
m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row); |
|
|
|
mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
if (!is_split) { |
|
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); |
|
|
|
GmemTiledCopyO gmem_tiled_copy_O; |
|
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); |
|
Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); |
|
if constexpr (!PackGQA) { |
|
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO))); |
|
#pragma unroll |
|
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } |
|
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); |
|
Tensor tOgO = gmem_thr_copy_O.partition_D(gO); |
|
Tensor tOrO = make_fragment_like(tOgO); |
|
cute::clear(tOrO); |
|
|
|
flash::copy</*Is_even_MN=*/false, false, false, false>( |
|
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM |
|
); |
|
} else { |
|
|
|
using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; |
|
Tensor tOrO = make_tensor<Element>(make_shape(Shape<_1, Int<kGmemElemsPerStore>>{}, size<1>(tOcO), size<2>(tOcO))); |
|
cute::clear(tOrO); |
|
PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); |
|
} |
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
} |
|
|