|
|
|
|
|
import torch |
|
|
|
import triton |
|
import triton.language as tl |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BLOCK_N': 32}), |
|
triton.Config({'BLOCK_N': 64}), |
|
triton.Config({'BLOCK_N': 128}), |
|
triton.Config({'BLOCK_N': 256}), |
|
triton.Config({'BLOCK_N': 512}), |
|
triton.Config({'BLOCK_N': 1024}), |
|
], |
|
key=['ncols'], |
|
) |
|
@triton.jit |
|
def _swiglu_fwd_kernel( |
|
X, |
|
Y, |
|
OUT, |
|
stride_x_row, |
|
stride_y_row, |
|
stride_out_row, |
|
ncols, |
|
BLOCK_N: tl.constexpr, |
|
): |
|
|
|
row = tl.program_id(0) |
|
start_col = tl.program_id(1) * BLOCK_N |
|
X += row * stride_x_row |
|
Y += row * stride_y_row |
|
OUT += row * stride_out_row |
|
cols = start_col + tl.arange(0, BLOCK_N) |
|
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32) |
|
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32) |
|
out = x * tl.sigmoid(x) * y |
|
tl.store(OUT + cols, out, mask=cols < ncols) |
|
|
|
|
|
def _swiglu_fwd(xy, out=None): |
|
if xy.stride(-1) != 1: |
|
xy = xy.contiguous() |
|
batch_shape = xy.shape[:-1] |
|
xy = xy.reshape(-1, xy.shape[-1]) |
|
x, y = xy.chunk(2, dim=-1) |
|
if out is None: |
|
out = torch.empty_like(x) |
|
else: |
|
out = out.reshape(-1, out.shape[-1]) |
|
assert out.shape == x.shape |
|
assert out.stride(-1) == 1 |
|
M, N = x.shape |
|
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N'])) |
|
with torch.cuda.device(x.device.index): |
|
_swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N) |
|
return out.reshape(*batch_shape, out.shape[-1]) |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BLOCK_N': 32}), |
|
triton.Config({'BLOCK_N': 64}), |
|
triton.Config({'BLOCK_N': 128}), |
|
triton.Config({'BLOCK_N': 256}), |
|
triton.Config({'BLOCK_N': 512}), |
|
triton.Config({'BLOCK_N': 1024}), |
|
], |
|
key=['ncols'], |
|
) |
|
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None}) |
|
@triton.jit |
|
def _swiglu_bwd_kernel( |
|
X, |
|
Y, |
|
DOUT, |
|
OUT, |
|
DX, |
|
DY, |
|
stride_x_row, |
|
stride_y_row, |
|
stride_dout_row, |
|
stride_out_row, |
|
stride_dx_row, |
|
stride_dy_row, |
|
ncols, |
|
BLOCK_N: tl.constexpr, |
|
RECOMPUTE_OUTPUT: tl.constexpr, |
|
): |
|
|
|
row = tl.program_id(0) |
|
start_col = tl.program_id(1) * BLOCK_N |
|
X += row * stride_x_row |
|
Y += row * stride_y_row |
|
DOUT += row * stride_dout_row |
|
if RECOMPUTE_OUTPUT: |
|
OUT += row * stride_out_row |
|
DX += row * stride_dx_row |
|
DY += row * stride_dy_row |
|
cols = start_col + tl.arange(0, BLOCK_N) |
|
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32) |
|
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32) |
|
dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32) |
|
x_sigmoid = tl.sigmoid(x) |
|
dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout |
|
dy = x * x_sigmoid * dout |
|
tl.store(DX + cols, dx, mask=cols < ncols) |
|
tl.store(DY + cols, dy, mask=cols < ncols) |
|
if RECOMPUTE_OUTPUT: |
|
out = x * x_sigmoid * y |
|
tl.store(OUT + cols, out, mask=cols < ncols) |
|
|
|
|
|
def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None): |
|
if xy.stride(-1) != 1: |
|
xy = xy.contiguous() |
|
if dout.stride(-1) != 1: |
|
dout = dout.contiguous() |
|
batch_shape = xy.shape[:-1] |
|
xy = xy.reshape(-1, xy.shape[-1]) |
|
x, y = xy.chunk(2, dim=-1) |
|
dout = dout.reshape(-1, dout.shape[-1]) |
|
assert dout.shape == x.shape |
|
if dxy is None: |
|
dxy = torch.empty_like(xy) |
|
else: |
|
dxy = dxy.reshape(-1, dxy.shape[-1]) |
|
assert dxy.shape == xy.shape |
|
dx, dy = dxy.chunk(2, dim=-1) |
|
assert dx.stride(-1) == 1 |
|
assert dy.stride(-1) == 1 |
|
if recompute_output: |
|
if out is None: |
|
out = torch.empty_like(x) |
|
else: |
|
out = out.reshape(-1, out.shape[-1]) |
|
assert out.shape == x.shape |
|
assert out.stride(-1) == 1 |
|
M, N = x.shape |
|
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N'])) |
|
with torch.cuda.device(x.device.index): |
|
_swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy, |
|
x.stride(0), y.stride(0), dout.stride(0), |
|
out.stride(0) if recompute_output else 0, |
|
dx.stride(0), dy.stride(0), |
|
N) |
|
if not recompute_output: |
|
return dxy.reshape(*batch_shape, dxy.shape[-1]) |
|
else: |
|
return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1]) |
|
|
|
|
|
class SwiGLU(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, xy): |
|
ctx.save_for_backward(xy) |
|
return _swiglu_fwd(xy) |
|
|
|
@staticmethod |
|
def backward(ctx, dout): |
|
xy, = ctx.saved_tensors |
|
return _swiglu_bwd(xy, dout) |
|
|
|
|
|
swiglu = SwiGLU.apply |
|
|