|
#include <torch/extension.h>
|
|
#include <cuda_runtime.h>
|
|
#include <cublas_v2.h>
|
|
#include <cudnn.h>
|
|
#include <cmath>
|
|
|
|
|
|
__global__ void matmul_kernel_persistent(
|
|
const float *a_ptr,
|
|
const float *b_ptr,
|
|
float *c_ptr,
|
|
const float *bias_ptr,
|
|
int M, int N, int K,
|
|
int stride_am, int stride_ak,
|
|
int stride_bk, int stride_bn,
|
|
int stride_cm, int stride_cn,
|
|
int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K,
|
|
int GROUP_SIZE_M, int NUM_SMS,
|
|
bool HAS_BIAS)
|
|
{
|
|
int start_pid = blockIdx.x;
|
|
int num_pid_m = (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
|
|
int num_pid_n = (N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N;
|
|
int k_tiles = (K + BLOCK_SIZE_K - 1) / BLOCK_SIZE_K;
|
|
int num_tiles = num_pid_m * num_pid_n;
|
|
|
|
int num_pid_in_group = GROUP_SIZE_M * num_pid_n;
|
|
|
|
for (int tile_id = start_pid; tile_id < num_tiles; tile_id += NUM_SMS)
|
|
{
|
|
int group_id = tile_id / num_pid_in_group;
|
|
int first_pid_m = group_id * GROUP_SIZE_M;
|
|
int group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M);
|
|
int pid_m = first_pid_m + (tile_id % group_size_m);
|
|
int pid_n = (tile_id % num_pid_in_group) / group_size_m;
|
|
|
|
int start_m = pid_m * BLOCK_SIZE_M;
|
|
int start_n = pid_n * BLOCK_SIZE_N;
|
|
|
|
|
|
__shared__ float As[16][16];
|
|
__shared__ float Bs[16][16];
|
|
|
|
float accumulator = 0.0f;
|
|
int tx = threadIdx.x;
|
|
int ty = threadIdx.y;
|
|
|
|
|
|
if (start_m + tx < M && start_n + ty < N)
|
|
{
|
|
|
|
for (int ki = 0; ki < k_tiles; ki++)
|
|
{
|
|
int k_start = ki * BLOCK_SIZE_K;
|
|
|
|
|
|
if (k_start + tx < K && start_m + ty < M)
|
|
{
|
|
As[ty][tx] = a_ptr[(start_m + ty) * stride_am + (k_start + tx) * stride_ak];
|
|
}
|
|
else
|
|
{
|
|
As[ty][tx] = 0.0f;
|
|
}
|
|
|
|
if (k_start + ty < K && start_n + tx < N)
|
|
{
|
|
Bs[ty][tx] = b_ptr[(k_start + ty) * stride_bk + (start_n + tx) * stride_bn];
|
|
}
|
|
else
|
|
{
|
|
Bs[ty][tx] = 0.0f;
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
|
for (int k = 0; k < min(BLOCK_SIZE_K, K - k_start); k++)
|
|
{
|
|
accumulator += As[ty][k] * Bs[k][tx];
|
|
}
|
|
|
|
__syncthreads();
|
|
}
|
|
|
|
|
|
if (HAS_BIAS && bias_ptr != nullptr)
|
|
{
|
|
accumulator += bias_ptr[start_n + tx];
|
|
}
|
|
|
|
|
|
c_ptr[(start_m + ty) * stride_cm + (start_n + tx) * stride_cn] = accumulator;
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
__global__ void log_softmax_kernel(
|
|
const float *input_ptr,
|
|
float *output_ptr,
|
|
int input_row_stride,
|
|
int output_row_stride,
|
|
int n_cols,
|
|
int BLOCK_SIZE)
|
|
{
|
|
int row_idx = blockIdx.x;
|
|
int tid = threadIdx.x;
|
|
|
|
|
|
__shared__ float max_val;
|
|
__shared__ float sum_exp;
|
|
|
|
if (tid == 0)
|
|
{
|
|
max_val = -INFINITY;
|
|
sum_exp = 0.0f;
|
|
}
|
|
__syncthreads();
|
|
|
|
|
|
float thread_max = -INFINITY;
|
|
for (int col = tid; col < n_cols; col += blockDim.x)
|
|
{
|
|
float val = input_ptr[row_idx * input_row_stride + col];
|
|
thread_max = fmaxf(thread_max, val);
|
|
}
|
|
|
|
|
|
__shared__ float sdata[256];
|
|
sdata[tid] = thread_max;
|
|
__syncthreads();
|
|
|
|
for (int s = blockDim.x / 2; s > 0; s >>= 1)
|
|
{
|
|
if (tid < s)
|
|
{
|
|
sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
if (tid == 0)
|
|
{
|
|
max_val = sdata[0];
|
|
}
|
|
__syncthreads();
|
|
|
|
|
|
float thread_sum = 0.0f;
|
|
for (int col = tid; col < n_cols; col += blockDim.x)
|
|
{
|
|
float val = input_ptr[row_idx * input_row_stride + col];
|
|
thread_sum += expf(val - max_val);
|
|
}
|
|
|
|
|
|
sdata[tid] = thread_sum;
|
|
__syncthreads();
|
|
|
|
for (int s = blockDim.x / 2; s > 0; s >>= 1)
|
|
{
|
|
if (tid < s)
|
|
{
|
|
sdata[tid] += sdata[tid + s];
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
if (tid == 0)
|
|
{
|
|
sum_exp = sdata[0];
|
|
}
|
|
__syncthreads();
|
|
|
|
float log_sum_exp = logf(sum_exp);
|
|
|
|
|
|
for (int col = tid; col < n_cols; col += blockDim.x)
|
|
{
|
|
float val = input_ptr[row_idx * input_row_stride + col];
|
|
output_ptr[row_idx * output_row_stride + col] = val - max_val - log_sum_exp;
|
|
}
|
|
}
|
|
|
|
|
|
__global__ void mean_kernel(
|
|
const float *input_ptr,
|
|
float *output_ptr,
|
|
int input_stride0, int input_stride1, int input_stride2,
|
|
int output_stride0, int output_stride1,
|
|
int M, int N, int K,
|
|
int BLOCK_SIZE)
|
|
{
|
|
int pid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
if (pid >= M * K)
|
|
return;
|
|
|
|
int m_idx = pid / K;
|
|
int k_idx = pid % K;
|
|
|
|
float acc = 0.0f;
|
|
for (int n = 0; n < N; n++)
|
|
{
|
|
int input_idx = m_idx * input_stride0 + n * input_stride1 + k_idx * input_stride2;
|
|
acc += input_ptr[input_idx];
|
|
}
|
|
|
|
float mean_val = acc / N;
|
|
int output_idx = m_idx * output_stride0 + k_idx * output_stride1;
|
|
output_ptr[output_idx] = mean_val;
|
|
}
|
|
|
|
|
|
void matmul_persistent_cuda(
|
|
torch::Tensor const &a,
|
|
torch::Tensor const &b,
|
|
torch::Tensor &c,
|
|
torch::Tensor const &bias)
|
|
{
|
|
const int M = a.size(0);
|
|
const int K = a.size(1);
|
|
const int N = b.size(1);
|
|
|
|
|
|
cudaDeviceProp prop;
|
|
cudaGetDeviceProperties(&prop, 0);
|
|
const int NUM_SMS = prop.multiProcessorCount;
|
|
|
|
|
|
const int BLOCK_SIZE_M = 128;
|
|
const int BLOCK_SIZE_N = 128;
|
|
const int BLOCK_SIZE_K = 64;
|
|
const int GROUP_SIZE_M = 8;
|
|
|
|
|
|
const int num_pid_m = (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
|
|
const int num_pid_n = (N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N;
|
|
const int grid_size = min(NUM_SMS, num_pid_m * num_pid_n);
|
|
|
|
dim3 block(16, 16);
|
|
dim3 grid_dim(grid_size);
|
|
|
|
matmul_kernel_persistent<<<grid_dim, block>>>(
|
|
a.data_ptr<float>(),
|
|
b.data_ptr<float>(),
|
|
c.data_ptr<float>(),
|
|
bias.defined() ? bias.data_ptr<float>() : nullptr,
|
|
M, N, K,
|
|
a.stride(0), a.stride(1),
|
|
b.stride(0), b.stride(1),
|
|
c.stride(0), c.stride(1),
|
|
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K,
|
|
GROUP_SIZE_M, NUM_SMS,
|
|
bias.defined());
|
|
}
|
|
|
|
void log_softmax_cuda(
|
|
torch::Tensor const &input,
|
|
torch::Tensor &output)
|
|
{
|
|
const auto original_shape = input.sizes();
|
|
auto input_2d = input.reshape({-1, input.size(-1)}).contiguous();
|
|
auto output_2d = output.reshape({-1, output.size(-1)});
|
|
|
|
const int n_rows = input_2d.size(0);
|
|
const int n_cols = input_2d.size(1);
|
|
|
|
const int BLOCK_SIZE = 256;
|
|
|
|
log_softmax_kernel<<<n_rows, BLOCK_SIZE>>>(
|
|
input_2d.data_ptr<float>(),
|
|
output_2d.data_ptr<float>(),
|
|
input_2d.stride(0),
|
|
output_2d.stride(0),
|
|
n_cols,
|
|
BLOCK_SIZE);
|
|
}
|
|
|
|
void mean_dim_cuda(
|
|
torch::Tensor const &input,
|
|
torch::Tensor &output,
|
|
int dim)
|
|
{
|
|
auto shape = input.sizes().vec();
|
|
|
|
int M = 1;
|
|
for (int i = 0; i < dim; i++)
|
|
{
|
|
M *= shape[i];
|
|
}
|
|
|
|
int N = shape[dim];
|
|
|
|
int K = 1;
|
|
for (int i = dim + 1; i < shape.size(); i++)
|
|
{
|
|
K *= shape[i];
|
|
}
|
|
|
|
auto input_3d = input.reshape({M, N, K});
|
|
auto output_2d = output.reshape({M, K});
|
|
|
|
const int BLOCK_SIZE = 256;
|
|
const int grid_size = (M * K + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
|
|
|
mean_kernel<<<grid_size, BLOCK_SIZE>>>(
|
|
input_3d.data_ptr<float>(),
|
|
output_2d.data_ptr<float>(),
|
|
input_3d.stride(0), input_3d.stride(1), input_3d.stride(2),
|
|
output_2d.stride(0), output_2d.stride(1),
|
|
M, N, K,
|
|
BLOCK_SIZE);
|
|
} |