wangshuai6
init
56238f0
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from typing import Any
from torch.autograd import Function
from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
# triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1),
],
key=['B', 'H', 'W', 'G', 'C', 'K'],
)
@triton.jit
def forward_kernel(
B: tl.constexpr,
H: tl.constexpr, # image_size_h
W: tl.constexpr, # image_size_w
G: tl.constexpr, # num_channels_per_group
C: tl.constexpr, # num_groups
K: tl.constexpr, # kernel size
input_ptr, # input features [B, H, W, G, C]
deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K]
weights_ptr, # weights [B, H, W, G, K]
out_ptr, # out [B, H, W, G, C]
BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
):
pid = tl.program_id(0)
wid = pid % W
hid = pid // W % H
gid = pid // (W * H) % G
bid = pid // (W * H * G)
id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
common_offset = bid*H*W*G + hid*W*G + wid*G + gid
batch_base = bid * H * W * G * C
for block_base in tl.static_range(0, C, BLOCK_SIZE):
buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
block_offset = tl.arange(0, BLOCK_SIZE) + block_base
block_mask = (block_offset < C) & id_mask
for k in tl.static_range(K):
deformable_offset = (common_offset * K + k) * 2
x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
floor_x = x.to(tl.int32)
floor_y = y.to(tl.int32)
ceil_x = floor_x + 1
ceil_y = floor_y + 1
# load top left
tl_weight = (ceil_x - x) * (ceil_y - y)
tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)
# load top right
tr_weight = (x - floor_x) * (ceil_y - y)
tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)
# load bottom left
bl_weight = (ceil_x - x) * (y - floor_y)
bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)
# load bottom right
br_weight = (x - floor_x) * (y - floor_y)
br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)
# load dynamic weight and mask
weights_offset = common_offset*K + k
weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0)
tl_block_input = tl_block_input * tl_weight
# load top right
tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0)
tr_block_input = tr_block_input * tr_weight
# load bottom left
bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0)
bl_block_input = bl_block_input * bl_weight
# load bottom right
br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0)
br_block_input = br_block_input * br_weight
# sampled
sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input
weighted_sampled_input = sampled_input * weight
buffer = buffer + weighted_sampled_input
# store to out_ptr
tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask)
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=1),
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
],
key=['B', 'H', 'W', 'G', 'C', 'K'],
)
@triton.jit
def backward_kernel(
B: tl.constexpr,
H: tl.constexpr, # image_size_h
W: tl.constexpr, # image_size_w
G: tl.constexpr, # num_groups
C: tl.constexpr, # num_channels_per_group
K: tl.constexpr, # kernel size
input_ptr, # input features [B, H, W, G, C]
deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
weights_ptr, # weights [B, H, W, G, K]
grad_ptr, # out [B, H, W, G, C]
grad_input_ptr, # input features [B, H, W, G, C]
grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
grad_weights_ptr, # weights [B, H, W, G, K]
BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
):
pid = tl.program_id(0)
wid = pid % W
hid = pid // W % H
gid = pid // (W * H) % G
bid = pid // (W * H * G)
id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
common_offset = bid*H*W*G + hid*W*G + wid*G + gid
batch_base = bid * H * W * G * C
for k in tl.static_range(K):
# load dynamic weight and mask
weights_offset = common_offset*K + k
weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty)
deformable_offset = (common_offset * K + k)*2
x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
for block_base in tl.static_range(0, C, BLOCK_SIZE):
block_offset = tl.arange(0, BLOCK_SIZE) + block_base
block_mask = (block_offset < C) & id_mask
grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0)
dods = weight*grad
floor_x = x.to(tl.int32)
floor_y = y.to(tl.int32)
ceil_x = floor_x + 1
ceil_y = floor_y + 1
# load top left
tl_weight = (ceil_x - x) * (ceil_y - y)
tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset
tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H))
tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0)
tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0)
dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y)
dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x)
dodw = dodw + tl_block_input_dot_grad * tl_weight
dodtl = dods * tl_weight
tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl)
# load top right
tr_weight = (x - floor_x) * (ceil_y - y)
tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0))
tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0)
tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0)
dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y)
dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x)
dodw = dodw + tr_block_input_dot_grad*tr_weight
dodtr = dods * tr_weight
tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr)
# load bottom left
bl_weight = (ceil_x - x) * (y - floor_y)
bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset
bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0))
bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0)
bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0)
dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y)
dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x)
dodw = dodw + bl_block_input_dot_grad*bl_weight
dodbl = dods * bl_weight
tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl)
# load bottom right
br_weight = (x - floor_x) * (y - floor_y)
br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0))
br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0)
br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask
dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y)
dody = dody + 1 * br_block_input_dot_grad * (x - floor_x)
dodw = dodw + br_block_input_dot_grad*br_weight
dodbr = dods * br_weight
tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr)
dodx = dodx * weight
dody = dody * weight
tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask)
tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask)
tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask)
class DCNFunction(Function):
@staticmethod
@custom_fwd
def forward(ctx: Any, inputs, deformables, weights) -> Any:
B, H, W, G, C = inputs.shape
_, _, _, _, K, _ = deformables.shape
out = torch.zeros_like(inputs)
grid = lambda META: (B * H * W * G,)
forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out)
ctx.save_for_backward(inputs, deformables, weights)
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, *grad_outputs: Any) -> Any:
grad_output = grad_outputs[0].contiguous()
inputs, deformables, weights = ctx.saved_tensors
B, H, W, G, C = inputs.shape
_, _, _, _, K, _ = deformables.shape
grad_inputs = torch.zeros_like(inputs)
grad_deformables = torch.zeros_like(deformables)
grad_weights = torch.zeros_like(weights)
grid = lambda META: (B * H * W * G,)
backward_kernel[grid](
B, H, W, G, C, K,
inputs,
deformables,
weights,
grad_output,
grad_inputs,
grad_deformables,
grad_weights,
)
return (grad_inputs, grad_deformables, grad_weights)
class MultiScaleDCN(nn.Module):
def __init__(self, in_channels, groups, channels, kernels, deformable_biass=True):
super().__init__()
self.in_channels = in_channels
self.groups = groups
self.channels = channels
self.kernels = kernels
self.v = nn.Linear(in_channels, groups * channels, bias=True)
self.qk_deformables = nn.Linear(in_channels, groups * kernels * 2, bias=True)
self.qk_scales = nn.Linear(in_channels, groups * kernels, bias=False)
self.qk_weights = nn.Linear(in_channels, groups*kernels, bias=True)
self.out = nn.Linear(groups * channels, in_channels)
self.deformables_prior = nn.Parameter(torch.randn((1, 1, 1, 1, kernels, 2)), requires_grad=False)
self.deformables_scale = nn.Parameter(torch.ones((1, 1, 1, groups, 1, 1)), requires_grad=True)
self.max_scale = 6
self._init_weights()
def _init_weights(self):
zeros_(self.qk_deformables.weight.data)
zeros_(self.qk_scales.weight.data)
zeros_(self.qk_deformables.bias.data)
zeros_(self.qk_weights.weight.data)
zeros_(self.v.bias.data)
zeros_(self.out.bias.data)
num_prior = int(self.kernels ** 0.5)
dx = torch.linspace(-1, 1, num_prior, device="cuda")
dy = torch.linspace(-1, 1, num_prior, device="cuda")
dxy = torch.meshgrid([dx, dy], indexing="xy")
dxy = torch.stack(dxy, dim=-1)
dxy = dxy.view(-1, 2)
self.deformables_prior.data[..., :num_prior*num_prior, :] = dxy
for i in range(self.groups):
scale = (i+1)/self.groups - 0.0001
inv_scale = math.log((scale)/(1-scale))
self.deformables_scale.data[..., i, :, :] = inv_scale
def forward(self, x):
B, H, W, _ = x.shape
v = self.v(x).view(B, H, W, self.groups, self.channels)
deformables = self.qk_deformables(x).view(B, H, W, self.groups, self.kernels, 2)
scale = self.qk_scales(x).view(B, H, W, self.groups, self.kernels, 1) + self.deformables_scale
deformables = (deformables + self.deformables_prior ) * scale.sigmoid()*self.max_scale
weights = self.qk_weights(x).view(B, H, W, self.groups, self.kernels)
out = DCNFunction.apply(v, deformables, weights)
out = out.view(B, H, W, -1)
out = self.out(out)
return out