File size: 13,969 Bytes
56238f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import torch
import torch.nn as nn
import math

from torch.nn.init import zeros_
from typing import Any
from torch.autograd import Function
from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd

import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
        # triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1),
    ],
    key=['B', 'H', 'W', 'G', 'C', 'K'],
)
@triton.jit
def forward_kernel(
        B: tl.constexpr,
        H: tl.constexpr, # image_size_h
        W: tl.constexpr, # image_size_w
        G: tl.constexpr, # num_channels_per_group
        C: tl.constexpr, # num_groups
        K: tl.constexpr, # kernel size
        input_ptr,   # input features [B, H, W, G, C]
        deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K]
        weights_ptr, # weights [B, H, W, G, K]
        out_ptr, # out [B, H, W, G, C]
        BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
):
    pid = tl.program_id(0)
    wid = pid % W
    hid = pid // W % H
    gid = pid // (W * H) % G
    bid = pid // (W * H * G)

    id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
    common_offset = bid*H*W*G + hid*W*G + wid*G + gid
    batch_base = bid * H * W * G * C

    for block_base in tl.static_range(0, C, BLOCK_SIZE):
        buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
        block_offset = tl.arange(0, BLOCK_SIZE) + block_base
        block_mask = (block_offset < C) & id_mask
        for k in tl.static_range(K):
            deformable_offset = (common_offset * K + k) * 2

            x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
            y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid

            floor_x = x.to(tl.int32)
            floor_y = y.to(tl.int32)
            ceil_x = floor_x + 1
            ceil_y = floor_y + 1

            # load top left
            tl_weight = (ceil_x - x) * (ceil_y - y)
            tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
            tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)

            # load top right
            tr_weight = (x - floor_x) * (ceil_y - y)
            tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
            tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)
            # load bottom left
            bl_weight = (ceil_x - x) * (y - floor_y)
            bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
            bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)
            # load bottom right
            br_weight = (x - floor_x) * (y - floor_y)
            br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
            br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)

            # load dynamic weight and mask
            weights_offset = common_offset*K + k
            weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)



            tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0)
            tl_block_input = tl_block_input * tl_weight

            # load top right
            tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0)
            tr_block_input = tr_block_input * tr_weight
            # load bottom left
            bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0)
            bl_block_input = bl_block_input * bl_weight
            # load bottom right
            br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0)
            br_block_input = br_block_input * br_weight

            # sampled
            sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input

            weighted_sampled_input = sampled_input * weight
            buffer = buffer + weighted_sampled_input
        # store to out_ptr
        tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask)

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=1),
        triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
    ],
    key=['B', 'H', 'W', 'G', 'C', 'K'],
)
@triton.jit
def backward_kernel(
        B: tl.constexpr,
        H: tl.constexpr, # image_size_h
        W: tl.constexpr, # image_size_w
        G: tl.constexpr, # num_groups
        C: tl.constexpr, # num_channels_per_group
        K: tl.constexpr, # kernel size
        input_ptr,   # input features [B, H, W, G, C]
        deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
        weights_ptr, # weights [B, H, W, G, K]
        grad_ptr, # out [B, H, W, G, C]
        grad_input_ptr,   # input features [B, H, W, G, C]
        grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
        grad_weights_ptr, # weights [B, H, W, G, K]
        BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
):

    pid = tl.program_id(0)
    wid = pid % W
    hid = pid // W % H
    gid = pid // (W * H) % G
    bid = pid // (W * H * G)

    id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)

    common_offset = bid*H*W*G + hid*W*G + wid*G + gid
    batch_base = bid * H * W * G * C
    for k in tl.static_range(K):
        # load dynamic weight and mask
        weights_offset = common_offset*K + k
        weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
        dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
        dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
        dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty)
        deformable_offset = (common_offset * K + k)*2
        x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
        y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
        for block_base in tl.static_range(0, C, BLOCK_SIZE):
            block_offset = tl.arange(0, BLOCK_SIZE) + block_base
            block_mask = (block_offset < C) & id_mask
            grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0)
            dods = weight*grad

            floor_x = x.to(tl.int32)
            floor_y = y.to(tl.int32)
            ceil_x = floor_x + 1
            ceil_y = floor_y + 1

            # load top left
            tl_weight = (ceil_x - x) * (ceil_y - y)
            tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset
            tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H))
            tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0)
            tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0)
            dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y)
            dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x)
            dodw = dodw + tl_block_input_dot_grad * tl_weight

            dodtl = dods * tl_weight
            tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl)


            # load top right
            tr_weight = (x - floor_x) * (ceil_y - y)
            tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
            tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0))
            tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0)
            tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0)
            dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y)
            dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x)
            dodw = dodw + tr_block_input_dot_grad*tr_weight

            dodtr = dods * tr_weight
            tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr)


            # load bottom left
            bl_weight = (ceil_x - x) * (y - floor_y)
            bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset
            bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0))
            bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0)
            bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0)
            dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y)
            dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x)
            dodw = dodw + bl_block_input_dot_grad*bl_weight

            dodbl = dods * bl_weight
            tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl)


            # load bottom right
            br_weight = (x - floor_x) * (y - floor_y)
            br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
            br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0))
            br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0)
            br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask

            dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y)
            dody = dody + 1 * br_block_input_dot_grad * (x - floor_x)
            dodw = dodw + br_block_input_dot_grad*br_weight

            dodbr = dods * br_weight
            tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask  & block_mask, val=dodbr)
        dodx = dodx * weight
        dody = dody * weight
        tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask)
        tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask)
        tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask)


class DCNFunction(Function):
    @staticmethod
    @custom_fwd
    def forward(ctx: Any, inputs, deformables, weights) -> Any:
        B, H, W, G, C = inputs.shape
        _, _, _, _, K, _ = deformables.shape
        out = torch.zeros_like(inputs)
        grid = lambda META: (B * H * W * G,)
        forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out)
        ctx.save_for_backward(inputs, deformables, weights)
        return out

    @staticmethod
    @custom_bwd
    def backward(ctx: Any, *grad_outputs: Any) -> Any:
        grad_output = grad_outputs[0].contiguous()
        inputs, deformables, weights = ctx.saved_tensors
        B, H, W, G, C = inputs.shape
        _, _, _, _, K, _ = deformables.shape
        grad_inputs = torch.zeros_like(inputs)
        grad_deformables = torch.zeros_like(deformables)
        grad_weights = torch.zeros_like(weights)
        grid = lambda META: (B * H * W * G,)
        backward_kernel[grid](
            B, H, W, G, C, K,
            inputs,
            deformables,
            weights,
            grad_output,
            grad_inputs,
            grad_deformables,
            grad_weights,
        )
        return (grad_inputs, grad_deformables, grad_weights)


class MultiScaleDCN(nn.Module):
    def __init__(self, in_channels, groups, channels, kernels, deformable_biass=True):
        super().__init__()
        self.in_channels = in_channels
        self.groups = groups
        self.channels = channels
        self.kernels = kernels
        self.v = nn.Linear(in_channels, groups * channels, bias=True)
        self.qk_deformables = nn.Linear(in_channels, groups * kernels * 2, bias=True)
        self.qk_scales = nn.Linear(in_channels, groups * kernels, bias=False)
        self.qk_weights = nn.Linear(in_channels, groups*kernels, bias=True)
        self.out = nn.Linear(groups * channels, in_channels)
        self.deformables_prior = nn.Parameter(torch.randn((1, 1, 1, 1, kernels, 2)), requires_grad=False)
        self.deformables_scale = nn.Parameter(torch.ones((1, 1, 1, groups, 1, 1)), requires_grad=True)
        self.max_scale = 6
        self._init_weights()
    def _init_weights(self):
        zeros_(self.qk_deformables.weight.data)
        zeros_(self.qk_scales.weight.data)
        zeros_(self.qk_deformables.bias.data)
        zeros_(self.qk_weights.weight.data)
        zeros_(self.v.bias.data)
        zeros_(self.out.bias.data)
        num_prior = int(self.kernels ** 0.5)
        dx = torch.linspace(-1, 1, num_prior, device="cuda")
        dy = torch.linspace(-1, 1, num_prior, device="cuda")
        dxy = torch.meshgrid([dx, dy], indexing="xy")
        dxy = torch.stack(dxy, dim=-1)
        dxy = dxy.view(-1, 2)
        self.deformables_prior.data[..., :num_prior*num_prior, :] = dxy
        for i in range(self.groups):
           scale = (i+1)/self.groups - 0.0001
           inv_scale = math.log((scale)/(1-scale))
           self.deformables_scale.data[..., i, :, :] = inv_scale
    def forward(self, x):
        B, H, W, _ = x.shape
        v = self.v(x).view(B, H, W, self.groups, self.channels)
        deformables = self.qk_deformables(x).view(B, H, W, self.groups, self.kernels, 2)
        scale = self.qk_scales(x).view(B, H, W, self.groups, self.kernels, 1) + self.deformables_scale
        deformables = (deformables + self.deformables_prior ) * scale.sigmoid()*self.max_scale
        weights = self.qk_weights(x).view(B, H, W, self.groups, self.kernels)
        out = DCNFunction.apply(v, deformables, weights)
        out = out.view(B, H, W, -1)
        out = self.out(out)
        return out