|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
#include <vector> |
|
|
|
inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) { |
|
|
|
if (varlen_q) return true; |
|
|
|
auto round_up = [](int a, int b) { return (a + b - 1) / b * b; }; |
|
float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM)); |
|
float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM)); |
|
return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency; |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) { |
|
|
|
|
|
|
|
if (total_mblocks >= 0.8f * num_SMs) { |
|
int const size_l2 = 50 * 1024 * 1024; |
|
|
|
|
|
if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) { |
|
return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits); |
|
} else { |
|
return 1; |
|
} |
|
} |
|
|
|
if (num_n_blocks <= 4) { return 1; } |
|
max_splits = std::min({max_splits, num_SMs, num_n_blocks}); |
|
float max_efficiency = 0.f; |
|
std::vector<float> efficiency; |
|
efficiency.reserve(max_splits); |
|
for (int num_splits = 1; num_splits <= max_splits; num_splits++) { |
|
float n_waves = float(total_mblocks * num_splits) / num_SMs; |
|
float eff = n_waves / ceil(n_waves); |
|
|
|
if (eff > max_efficiency) { max_efficiency = eff; } |
|
efficiency.push_back(eff); |
|
} |
|
for (int num_splits = 1; num_splits <= max_splits; num_splits++) { |
|
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { |
|
|
|
return num_splits; |
|
} |
|
} |
|
return 1; |
|
} |
|
|