Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import math | |
from torch.nn.init import zeros_ | |
from typing import Any | |
from torch.autograd import Function | |
from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd | |
import triton | |
import triton.language as tl | |
def forward_kernel( | |
B: tl.constexpr, | |
H: tl.constexpr, # image_size_h | |
W: tl.constexpr, # image_size_w | |
G: tl.constexpr, # num_channels_per_group | |
C: tl.constexpr, # num_groups | |
K: tl.constexpr, # kernel size | |
input_ptr, # input features [B, H, W, G, C] | |
deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K] | |
weights_ptr, # weights [B, H, W, G, K] | |
out_ptr, # out [B, H, W, G, C] | |
BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group | |
): | |
pid = tl.program_id(0) | |
wid = pid % W | |
hid = pid // W % H | |
gid = pid // (W * H) % G | |
bid = pid // (W * H * G) | |
id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) | |
common_offset = bid*H*W*G + hid*W*G + wid*G + gid | |
batch_base = bid * H * W * G * C | |
for block_base in tl.static_range(0, C, BLOCK_SIZE): | |
buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) | |
block_offset = tl.arange(0, BLOCK_SIZE) + block_base | |
block_mask = (block_offset < C) & id_mask | |
for k in tl.static_range(K): | |
deformable_offset = (common_offset * K + k) * 2 | |
x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid | |
y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid | |
floor_x = x.to(tl.int32) | |
floor_y = y.to(tl.int32) | |
ceil_x = floor_x + 1 | |
ceil_y = floor_y + 1 | |
# load top left | |
tl_weight = (ceil_x - x) * (ceil_y - y) | |
tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE | |
tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H) | |
# load top right | |
tr_weight = (x - floor_x) * (ceil_y - y) | |
tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE | |
tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0) | |
# load bottom left | |
bl_weight = (ceil_x - x) * (y - floor_y) | |
bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE | |
bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0) | |
# load bottom right | |
br_weight = (x - floor_x) * (y - floor_y) | |
br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE | |
br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0) | |
# load dynamic weight and mask | |
weights_offset = common_offset*K + k | |
weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) | |
tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0) | |
tl_block_input = tl_block_input * tl_weight | |
# load top right | |
tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0) | |
tr_block_input = tr_block_input * tr_weight | |
# load bottom left | |
bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0) | |
bl_block_input = bl_block_input * bl_weight | |
# load bottom right | |
br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0) | |
br_block_input = br_block_input * br_weight | |
# sampled | |
sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input | |
weighted_sampled_input = sampled_input * weight | |
buffer = buffer + weighted_sampled_input | |
# store to out_ptr | |
tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask) | |
def backward_kernel( | |
B: tl.constexpr, | |
H: tl.constexpr, # image_size_h | |
W: tl.constexpr, # image_size_w | |
G: tl.constexpr, # num_groups | |
C: tl.constexpr, # num_channels_per_group | |
K: tl.constexpr, # kernel size | |
input_ptr, # input features [B, H, W, G, C] | |
deformable_ptr, # deformable offsets [B, H, W, G, K, 2] | |
weights_ptr, # weights [B, H, W, G, K] | |
grad_ptr, # out [B, H, W, G, C] | |
grad_input_ptr, # input features [B, H, W, G, C] | |
grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2] | |
grad_weights_ptr, # weights [B, H, W, G, K] | |
BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group | |
): | |
pid = tl.program_id(0) | |
wid = pid % W | |
hid = pid // W % H | |
gid = pid // (W * H) % G | |
bid = pid // (W * H * G) | |
id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) | |
common_offset = bid*H*W*G + hid*W*G + wid*G + gid | |
batch_base = bid * H * W * G * C | |
for k in tl.static_range(K): | |
# load dynamic weight and mask | |
weights_offset = common_offset*K + k | |
weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) | |
dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) | |
dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) | |
dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty) | |
deformable_offset = (common_offset * K + k)*2 | |
x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid | |
y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid | |
for block_base in tl.static_range(0, C, BLOCK_SIZE): | |
block_offset = tl.arange(0, BLOCK_SIZE) + block_base | |
block_mask = (block_offset < C) & id_mask | |
grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0) | |
dods = weight*grad | |
floor_x = x.to(tl.int32) | |
floor_y = y.to(tl.int32) | |
ceil_x = floor_x + 1 | |
ceil_y = floor_y + 1 | |
# load top left | |
tl_weight = (ceil_x - x) * (ceil_y - y) | |
tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset | |
tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)) | |
tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0) | |
tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0) | |
dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y) | |
dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x) | |
dodw = dodw + tl_block_input_dot_grad * tl_weight | |
dodtl = dods * tl_weight | |
tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl) | |
# load top right | |
tr_weight = (x - floor_x) * (ceil_y - y) | |
tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset | |
tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)) | |
tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0) | |
tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0) | |
dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y) | |
dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x) | |
dodw = dodw + tr_block_input_dot_grad*tr_weight | |
dodtr = dods * tr_weight | |
tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr) | |
# load bottom left | |
bl_weight = (ceil_x - x) * (y - floor_y) | |
bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset | |
bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)) | |
bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0) | |
bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0) | |
dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y) | |
dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x) | |
dodw = dodw + bl_block_input_dot_grad*bl_weight | |
dodbl = dods * bl_weight | |
tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl) | |
# load bottom right | |
br_weight = (x - floor_x) * (y - floor_y) | |
br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset | |
br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)) | |
br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0) | |
br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask | |
dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y) | |
dody = dody + 1 * br_block_input_dot_grad * (x - floor_x) | |
dodw = dodw + br_block_input_dot_grad*br_weight | |
dodbr = dods * br_weight | |
tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr) | |
dodx = dodx * weight | |
dody = dody * weight | |
tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask) | |
tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask) | |
tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask) | |
class DCNFunction(Function): | |
def forward(ctx: Any, inputs, deformables, weights) -> Any: | |
B, H, W, G, C = inputs.shape | |
_, _, _, _, K, _ = deformables.shape | |
out = torch.zeros_like(inputs) | |
grid = lambda META: (B * H * W * G,) | |
forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out) | |
ctx.save_for_backward(inputs, deformables, weights) | |
return out | |
def backward(ctx: Any, *grad_outputs: Any) -> Any: | |
grad_output = grad_outputs[0].contiguous() | |
inputs, deformables, weights = ctx.saved_tensors | |
B, H, W, G, C = inputs.shape | |
_, _, _, _, K, _ = deformables.shape | |
grad_inputs = torch.zeros_like(inputs) | |
grad_deformables = torch.zeros_like(deformables) | |
grad_weights = torch.zeros_like(weights) | |
grid = lambda META: (B * H * W * G,) | |
backward_kernel[grid]( | |
B, H, W, G, C, K, | |
inputs, | |
deformables, | |
weights, | |
grad_output, | |
grad_inputs, | |
grad_deformables, | |
grad_weights, | |
) | |
return (grad_inputs, grad_deformables, grad_weights) | |
class MultiScaleDCN(nn.Module): | |
def __init__(self, in_channels, groups, channels, kernels, deformable_biass=True): | |
super().__init__() | |
self.in_channels = in_channels | |
self.groups = groups | |
self.channels = channels | |
self.kernels = kernels | |
self.v = nn.Linear(in_channels, groups * channels, bias=True) | |
self.qk_deformables = nn.Linear(in_channels, groups * kernels * 2, bias=True) | |
self.qk_scales = nn.Linear(in_channels, groups * kernels, bias=False) | |
self.qk_weights = nn.Linear(in_channels, groups*kernels, bias=True) | |
self.out = nn.Linear(groups * channels, in_channels) | |
self.deformables_prior = nn.Parameter(torch.randn((1, 1, 1, 1, kernels, 2)), requires_grad=False) | |
self.deformables_scale = nn.Parameter(torch.ones((1, 1, 1, groups, 1, 1)), requires_grad=True) | |
self.max_scale = 6 | |
self._init_weights() | |
def _init_weights(self): | |
zeros_(self.qk_deformables.weight.data) | |
zeros_(self.qk_scales.weight.data) | |
zeros_(self.qk_deformables.bias.data) | |
zeros_(self.qk_weights.weight.data) | |
zeros_(self.v.bias.data) | |
zeros_(self.out.bias.data) | |
num_prior = int(self.kernels ** 0.5) | |
dx = torch.linspace(-1, 1, num_prior, device="cuda") | |
dy = torch.linspace(-1, 1, num_prior, device="cuda") | |
dxy = torch.meshgrid([dx, dy], indexing="xy") | |
dxy = torch.stack(dxy, dim=-1) | |
dxy = dxy.view(-1, 2) | |
self.deformables_prior.data[..., :num_prior*num_prior, :] = dxy | |
for i in range(self.groups): | |
scale = (i+1)/self.groups - 0.0001 | |
inv_scale = math.log((scale)/(1-scale)) | |
self.deformables_scale.data[..., i, :, :] = inv_scale | |
def forward(self, x): | |
B, H, W, _ = x.shape | |
v = self.v(x).view(B, H, W, self.groups, self.channels) | |
deformables = self.qk_deformables(x).view(B, H, W, self.groups, self.kernels, 2) | |
scale = self.qk_scales(x).view(B, H, W, self.groups, self.kernels, 1) + self.deformables_scale | |
deformables = (deformables + self.deformables_prior ) * scale.sigmoid()*self.max_scale | |
weights = self.qk_weights(x).view(B, H, W, self.groups, self.kernels) | |
out = DCNFunction.apply(v, deformables, weights) | |
out = out.view(B, H, W, -1) | |
out = self.out(out) | |
return out |