Skip to content

Commit e205cb0

Browse files
author
Pedro Freire
committed
Add Deformable Convolution operation.
This adds the deformable convolution operation, as described in Deformable Convolutional Networks (https://arxiv.org/abs/1703.06211). - The code is based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp ; the whole code was modified and refactored to remove redundancies and increase clarity, and to adapt it to torchvision. - The CPU part is a direct copy of the CUDA code; it might make sense to do follow-up adjustments in the CPU code to simplify it / optimize it, or to reuse functionality between CPU and CUDA.. - We also add tests (with a non-trivial set of parameters); they can be made more robust by randomizing the parameters and executing multiple times.
1 parent 4b2f8da commit e205cb0

File tree

9 files changed

+1777
-39
lines changed

9 files changed

+1777
-39
lines changed

test/test_ops.py

Lines changed: 143 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from __future__ import division
2+
import math
3+
from typing import Tuple
4+
import unittest
5+
26
import numpy as np
7+
38
import torch
49
from torch.autograd import gradcheck
5-
10+
from torch.nn.modules.utils import _pair
11+
from torch import Tensor
612
from torchvision import ops
713

8-
from itertools import product
9-
import unittest
10-
1114

12-
class RoIOpTester(object):
15+
class OpTester(object):
1316
@classmethod
1417
def setUpClass(cls):
1518
cls.dtype = torch.float64
@@ -42,6 +45,14 @@ def test_backward_cuda_contiguous(self):
4245
def test_backward_cuda_non_contiguous(self):
4346
self._test_backward(device=torch.device('cuda'), contiguous=False)
4447

48+
def _test_forward(self, device, contiguous):
49+
pass
50+
51+
def _test_backward(self, device, contiguous):
52+
pass
53+
54+
55+
class RoIOpTester(OpTester):
4556
def _test_forward(self, device, contiguous):
4657
pool_size = 5
4758
# n_channels % (pool_size ** 2) == 0 required for PS opeartions.
@@ -79,7 +90,6 @@ def func(z):
7990

8091
self.assertTrue(gradcheck(func, (x,)))
8192
self.assertTrue(gradcheck(script_func, (x,)))
82-
return
8393

8494
def fn(*args, **kwargs):
8595
pass
@@ -98,7 +108,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
98108
def get_script_fn(self, rois, pool_size):
99109
@torch.jit.script
100110
def script_fn(input, rois, pool_size):
101-
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
111+
# type: (Tensor, Tensor, int) -> Tensor
102112
return ops.roi_pool(input, rois, pool_size, 1.0)[0]
103113
return lambda x: script_fn(x, rois, pool_size)
104114

@@ -137,7 +147,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
137147
def get_script_fn(self, rois, pool_size):
138148
@torch.jit.script
139149
def script_fn(input, rois, pool_size):
140-
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
150+
# type: (Tensor, Tensor, int) -> Tensor
141151
return ops.ps_roi_pool(input, rois, pool_size, 1.0)[0]
142152
return lambda x: script_fn(x, rois, pool_size)
143153

@@ -174,29 +184,35 @@ def get_slice(k, block):
174184
return y
175185

176186

177-
def bilinear_interpolate(data, height, width, y, x):
178-
if y < -1.0 or y > height or x < -1.0 or x > width:
179-
return 0.
187+
def bilinear_interpolate(data, y, x, snap_border=False):
188+
height, width = data.shape
180189

181-
y = min(max(0, y), height - 1)
182-
x = min(max(0, x), width - 1)
190+
if snap_border:
191+
if -1 < y <= 0:
192+
y = 0
193+
elif height - 1 <= y < height:
194+
y = height - 1
183195

184-
y_low = int(y)
185-
y_high = min(y_low + 1, height - 1)
196+
if -1 < x <= 0:
197+
x = 0
198+
elif width - 1 <= x < width:
199+
x = width - 1
186200

187-
x_low = int(x)
188-
x_high = min(x_low + 1, width - 1)
201+
y_low = int(math.floor(y))
202+
x_low = int(math.floor(x))
203+
y_high = y_low + 1
204+
x_high = x_low + 1
189205

190206
wy_h = y - y_low
191-
wy_l = 1 - wy_h
192-
193207
wx_h = x - x_low
208+
wy_l = 1 - wy_h
194209
wx_l = 1 - wx_h
195210

196211
val = 0
197-
for wx, x in zip((wx_l, wx_h), (x_low, x_high)):
198-
for wy, y in zip((wy_l, wy_h), (y_low, y_high)):
199-
val += wx * wy * data[y * width + x]
212+
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
213+
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
214+
if 0 <= yp < height and 0 <= xp < width:
215+
val += wx * wy * data[yp, xp]
200216
return val
201217

202218

@@ -208,7 +224,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
208224
def get_script_fn(self, rois, pool_size):
209225
@torch.jit.script
210226
def script_fn(input, rois, pool_size):
211-
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
227+
# type: (Tensor, Tensor, int) -> Tensor
212228
return ops.roi_align(input, rois, pool_size, 1.0)[0]
213229
return lambda x: script_fn(x, rois, pool_size)
214230

@@ -242,12 +258,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
242258
y = start_h + (iy + 0.5) * bin_h / grid_h
243259
for ix in range(0, grid_w):
244260
x = start_w + (ix + 0.5) * bin_w / grid_w
245-
val += bilinear_interpolate(
246-
in_data[batch_idx, channel, :, :].flatten(),
247-
in_data.size(-2),
248-
in_data.size(-1),
249-
y, x
250-
)
261+
val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
251262
val /= grid_h * grid_w
252263

253264
out_data[r, channel, i, j] = val
@@ -262,7 +273,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
262273
def get_script_fn(self, rois, pool_size):
263274
@torch.jit.script
264275
def script_fn(input, rois, pool_size):
265-
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
276+
# type: (Tensor, Tensor, int) -> Tensor
266277
return ops.ps_roi_align(input, rois, pool_size, 1.0)[0]
267278
return lambda x: script_fn(x, rois, pool_size)
268279

@@ -298,12 +309,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1,
298309
y = start_h + (iy + 0.5) * bin_h / grid_h
299310
for ix in range(0, grid_w):
300311
x = start_w + (ix + 0.5) * bin_w / grid_w
301-
val += bilinear_interpolate(
302-
in_data[batch_idx, c_in, :, :].flatten(),
303-
in_data.size(-2),
304-
in_data.size(-1),
305-
y, x
306-
)
312+
val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
307313
val /= grid_h * grid_w
308314

309315
out_data[r, c_out, i, j] = val
@@ -367,5 +373,106 @@ def test_nms_cuda(self):
367373
self.assertTrue(torch.allclose(r_cpu, r_cuda.cpu()), err_msg.format(iou))
368374

369375

376+
class DeformConvTester(OpTester, unittest.TestCase):
377+
def expected_fn(self, x, offsets, weights, *args, stride=1, pad=0, dilation=1):
378+
stride_h, stride_w = _pair(stride)
379+
pad_h, pad_w = _pair(pad)
380+
dil_h, dil_w = _pair(dilation)
381+
weights_h, weights_w = weights.shape[-2:]
382+
383+
n_batches, n_in_channels, in_h, in_w = x.shape
384+
n_out_channels = weights.shape[0]
385+
386+
out_h = (in_h + 2 * pad_h - (dil_h * (weights_h - 1) + 1)) // stride_h + 1
387+
out_w = (in_w + 2 * pad_w - (dil_w * (weights_w - 1) + 1)) // stride_w + 1
388+
389+
n_offset_grps = offsets.shape[1] // (2 * weights_h * weights_w)
390+
in_c_per_offset_grp = n_in_channels // n_offset_grps
391+
392+
n_weight_grps = n_in_channels // weights.shape[1]
393+
in_c_per_weight_grp = weights.shape[1]
394+
out_c_per_weight_grp = n_out_channels // n_weight_grps
395+
396+
out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
397+
for b in range(n_batches):
398+
for c_out in range(n_out_channels):
399+
for i in range(out_h):
400+
for j in range(out_w):
401+
for di in range(weights_h):
402+
for dj in range(weights_w):
403+
for c in range(in_c_per_weight_grp):
404+
weight_grp = c_out // out_c_per_weight_grp
405+
c_in = weight_grp * in_c_per_weight_grp + c
406+
407+
offset_grp = c_in // in_c_per_offset_grp
408+
offset_idx = 2 * (offset_grp * (weights_h * weights_w) + di * weights_w + dj)
409+
410+
pi = stride_h * i - pad_h + dil_h * di + offsets[b, offset_idx, i, j]
411+
pj = stride_w * j - pad_w + dil_w * dj + offsets[b, offset_idx + 1, i, j]
412+
413+
out[b, c_out, i, j] += (weights[c_out, c, di, dj] *
414+
bilinear_interpolate(x[b, c_in, :, :], pi, pj))
415+
return out
416+
417+
def get_fn_args(self, device, contiguous):
418+
batch_sz = 1
419+
n_in_channels = 6
420+
n_out_channels = 2
421+
n_weight_grps = 2
422+
n_offset_grps = 3
423+
424+
stride = (2, 1)
425+
pad = (1, 0)
426+
dilation = (2, 1)
427+
428+
stride_h, stride_w = stride
429+
pad_h, pad_w = pad
430+
dil_h, dil_w = dilation
431+
weight_h, weight_w = (3, 2)
432+
in_h, in_w = (5, 4)
433+
434+
out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
435+
out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
436+
437+
x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=self.dtype, requires_grad=True)
438+
439+
offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w,
440+
device=device, dtype=self.dtype, requires_grad=True)
441+
442+
weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
443+
device=device, dtype=self.dtype, requires_grad=True)
444+
445+
if not contiguous:
446+
x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
447+
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
448+
weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
449+
450+
return x, offset, weight, stride, pad, dilation
451+
452+
def _test_forward(self, device, contiguous):
453+
x, offset, weight, stride, pad, dilation = self.get_fn_args(device, contiguous)
454+
455+
res = ops.DeformConv(stride=stride, pad=pad, dilation=dilation)(x, offset, weight)
456+
expected = self.expected_fn(x, offset, weight, stride=stride, pad=pad, dilation=dilation)
457+
458+
self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(x, res, expected))
459+
460+
def _test_backward(self, device, contiguous):
461+
x, offset, weight, stride, pad, dilation = self.get_fn_args(device, contiguous)
462+
463+
def func(x_, offset_, weight_):
464+
return ops.deform_conv(x_, offset_, weight_, stride=stride, pad=pad, dilation=dilation)
465+
466+
gradcheck(func, (x, offset, weight), nondet_tol=1e-5)
467+
468+
@torch.jit.script
469+
def script_func(x_, offset_, weight_, stride_, pad_, dilation_):
470+
# type: (Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
471+
return ops.deform_conv(x_, offset_, weight_, stride=stride_, pad=pad_, dilation=dilation_)
472+
473+
gradcheck(lambda z, off, wei: script_func(z, off, wei, stride, pad, dilation),
474+
(x, offset, weight), nondet_tol=1e-5)
475+
476+
370477
if __name__ == '__main__':
371478
unittest.main()

0 commit comments

Comments
 (0)