Skip to content

Commit 52b8685

Browse files
pedrofreirefmassa
authored andcommitted
Add Deformable Convolution operation. (#1586)
* Add Deformable Convolution operation. This adds the deformable convolution operation, as described in Deformable Convolutional Networks (https://arxiv.org/abs/1703.06211). - The code is based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp ; the whole code was modified and refactored to remove redundancies and increase clarity, and to adapt it to torchvision. - The CPU part is a direct copy of the CUDA code; it might make sense to do follow-up adjustments in the CPU code to simplify it / optimize it, or to reuse functionality between CPU and CUDA.. - We also add tests (with a non-trivial set of parameters); they can be made more robust by randomizing the parameters and executing multiple times. * Update DeformConv to be more consistent w/ Conv2d * rename some variables and arguments to match Conv2d; * add optional bias; * add weight, offset and bias as module parameters; * remove the n_parallel_imgs parameter; * Fix __repr__; * etc.. Initialization of weight and bias is the same as in Conv2d, and initialization of offsets to zero is the same as in the paper. This also includes some other small unrelated fixes/improvements. * Apply clang-format in DeformConv files. * Import Optional type annotation * Remove offset param from DeformConv2d module - We pass the offset in the forward of DeformConv2d, instead of having an internal parameter. This adds some complexity to creating the module (e.g. now you have to worry about the output size, to create the offset), but it gives more flexibility. - We also use make_tuple for tuple creation, in an attempt to fix error w/ older compilers. * Replace abs by std::abs Old gcc versions were giving wrong results here, because they would resolve abs as int -> int, thus causing undesired truncation. Replacing abs by std::abs should allow for correct overloading of abs as float -> float. * Reorder declarations for clarity * Reorder weight and offset args in deform_conv2d We place offset arg before the weight arg, to be more consistent with DeformConv2d.forward(input, offset) * Replace abs by std::abs in DeformConv_cuda
1 parent 5b1716a commit 52b8685

File tree

9 files changed

+2555
-39
lines changed

9 files changed

+2555
-39
lines changed

test/test_ops.py

Lines changed: 157 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from __future__ import division
2+
import math
3+
import unittest
4+
25
import numpy as np
6+
37
import torch
8+
from torch import Tensor
49
from torch.autograd import gradcheck
5-
10+
from torch.jit.annotations import Tuple
11+
from torch.nn.modules.utils import _pair
612
from torchvision import ops
713

8-
from itertools import product
9-
import unittest
10-
1114

12-
class RoIOpTester(object):
15+
class OpTester(object):
1316
@classmethod
1417
def setUpClass(cls):
1518
cls.dtype = torch.float64
@@ -42,6 +45,14 @@ def test_backward_cuda_contiguous(self):
4245
def test_backward_cuda_non_contiguous(self):
4346
self._test_backward(device=torch.device('cuda'), contiguous=False)
4447

48+
def _test_forward(self, device, contiguous):
49+
pass
50+
51+
def _test_backward(self, device, contiguous):
52+
pass
53+
54+
55+
class RoIOpTester(OpTester):
4556
def _test_forward(self, device, contiguous):
4657
pool_size = 5
4758
# n_channels % (pool_size ** 2) == 0 required for PS opeartions.
@@ -79,7 +90,6 @@ def func(z):
7990

8091
self.assertTrue(gradcheck(func, (x,)))
8192
self.assertTrue(gradcheck(script_func, (x,)))
82-
return
8393

8494
def fn(*args, **kwargs):
8595
pass
@@ -98,7 +108,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
98108
def get_script_fn(self, rois, pool_size):
99109
@torch.jit.script
100110
def script_fn(input, rois, pool_size):
101-
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
111+
# type: (Tensor, Tensor, int) -> Tensor
102112
return ops.roi_pool(input, rois, pool_size, 1.0)[0]
103113
return lambda x: script_fn(x, rois, pool_size)
104114

@@ -137,7 +147,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
137147
def get_script_fn(self, rois, pool_size):
138148
@torch.jit.script
139149
def script_fn(input, rois, pool_size):
140-
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
150+
# type: (Tensor, Tensor, int) -> Tensor
141151
return ops.ps_roi_pool(input, rois, pool_size, 1.0)[0]
142152
return lambda x: script_fn(x, rois, pool_size)
143153

@@ -174,29 +184,35 @@ def get_slice(k, block):
174184
return y
175185

176186

177-
def bilinear_interpolate(data, height, width, y, x):
178-
if y < -1.0 or y > height or x < -1.0 or x > width:
179-
return 0.
187+
def bilinear_interpolate(data, y, x, snap_border=False):
188+
height, width = data.shape
180189

181-
y = min(max(0, y), height - 1)
182-
x = min(max(0, x), width - 1)
190+
if snap_border:
191+
if -1 < y <= 0:
192+
y = 0
193+
elif height - 1 <= y < height:
194+
y = height - 1
183195

184-
y_low = int(y)
185-
y_high = min(y_low + 1, height - 1)
196+
if -1 < x <= 0:
197+
x = 0
198+
elif width - 1 <= x < width:
199+
x = width - 1
186200

187-
x_low = int(x)
188-
x_high = min(x_low + 1, width - 1)
201+
y_low = int(math.floor(y))
202+
x_low = int(math.floor(x))
203+
y_high = y_low + 1
204+
x_high = x_low + 1
189205

190206
wy_h = y - y_low
191-
wy_l = 1 - wy_h
192-
193207
wx_h = x - x_low
208+
wy_l = 1 - wy_h
194209
wx_l = 1 - wx_h
195210

196211
val = 0
197-
for wx, x in zip((wx_l, wx_h), (x_low, x_high)):
198-
for wy, y in zip((wy_l, wy_h), (y_low, y_high)):
199-
val += wx * wy * data[y * width + x]
212+
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
213+
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
214+
if 0 <= yp < height and 0 <= xp < width:
215+
val += wx * wy * data[yp, xp]
200216
return val
201217

202218

@@ -208,7 +224,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
208224
def get_script_fn(self, rois, pool_size):
209225
@torch.jit.script
210226
def script_fn(input, rois, pool_size):
211-
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
227+
# type: (Tensor, Tensor, int) -> Tensor
212228
return ops.roi_align(input, rois, pool_size, 1.0)[0]
213229
return lambda x: script_fn(x, rois, pool_size)
214230

@@ -242,12 +258,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
242258
y = start_h + (iy + 0.5) * bin_h / grid_h
243259
for ix in range(0, grid_w):
244260
x = start_w + (ix + 0.5) * bin_w / grid_w
245-
val += bilinear_interpolate(
246-
in_data[batch_idx, channel, :, :].flatten(),
247-
in_data.size(-2),
248-
in_data.size(-1),
249-
y, x
250-
)
261+
val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
251262
val /= grid_h * grid_w
252263

253264
out_data[r, channel, i, j] = val
@@ -262,7 +273,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
262273
def get_script_fn(self, rois, pool_size):
263274
@torch.jit.script
264275
def script_fn(input, rois, pool_size):
265-
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
276+
# type: (Tensor, Tensor, int) -> Tensor
266277
return ops.ps_roi_align(input, rois, pool_size, 1.0)[0]
267278
return lambda x: script_fn(x, rois, pool_size)
268279

@@ -298,12 +309,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1,
298309
y = start_h + (iy + 0.5) * bin_h / grid_h
299310
for ix in range(0, grid_w):
300311
x = start_w + (ix + 0.5) * bin_w / grid_w
301-
val += bilinear_interpolate(
302-
in_data[batch_idx, c_in, :, :].flatten(),
303-
in_data.size(-2),
304-
in_data.size(-1),
305-
y, x
306-
)
312+
val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
307313
val /= grid_h * grid_w
308314

309315
out_data[r, c_out, i, j] = val
@@ -376,5 +382,120 @@ def test_new_empty_tensor(self):
376382
assert out.dtype == input.dtype
377383

378384

385+
class DeformConvTester(OpTester, unittest.TestCase):
386+
def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
387+
stride_h, stride_w = _pair(stride)
388+
pad_h, pad_w = _pair(padding)
389+
dil_h, dil_w = _pair(dilation)
390+
weight_h, weight_w = weight.shape[-2:]
391+
392+
n_batches, n_in_channels, in_h, in_w = x.shape
393+
n_out_channels = weight.shape[0]
394+
395+
out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
396+
out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
397+
398+
n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
399+
in_c_per_offset_grp = n_in_channels // n_offset_grps
400+
401+
n_weight_grps = n_in_channels // weight.shape[1]
402+
in_c_per_weight_grp = weight.shape[1]
403+
out_c_per_weight_grp = n_out_channels // n_weight_grps
404+
405+
out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
406+
for b in range(n_batches):
407+
for c_out in range(n_out_channels):
408+
for i in range(out_h):
409+
for j in range(out_w):
410+
for di in range(weight_h):
411+
for dj in range(weight_w):
412+
for c in range(in_c_per_weight_grp):
413+
weight_grp = c_out // out_c_per_weight_grp
414+
c_in = weight_grp * in_c_per_weight_grp + c
415+
416+
offset_grp = c_in // in_c_per_offset_grp
417+
offset_idx = 2 * (offset_grp * (weight_h * weight_w) + di * weight_w + dj)
418+
419+
pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
420+
pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]
421+
422+
out[b, c_out, i, j] += (weight[c_out, c, di, dj] *
423+
bilinear_interpolate(x[b, c_in, :, :], pi, pj))
424+
out += bias.view(1, n_out_channels, 1, 1)
425+
return out
426+
427+
def get_fn_args(self, device, contiguous):
428+
batch_sz = 1
429+
n_in_channels = 6
430+
n_out_channels = 2
431+
n_weight_grps = 2
432+
n_offset_grps = 3
433+
434+
stride = (2, 1)
435+
pad = (1, 0)
436+
dilation = (2, 1)
437+
438+
stride_h, stride_w = stride
439+
pad_h, pad_w = pad
440+
dil_h, dil_w = dilation
441+
weight_h, weight_w = (3, 2)
442+
in_h, in_w = (5, 4)
443+
444+
out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
445+
out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
446+
447+
x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=self.dtype, requires_grad=True)
448+
449+
offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w,
450+
device=device, dtype=self.dtype, requires_grad=True)
451+
452+
weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
453+
device=device, dtype=self.dtype, requires_grad=True)
454+
455+
bias = torch.randn(n_out_channels, device=device, dtype=self.dtype, requires_grad=True)
456+
457+
if not contiguous:
458+
x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
459+
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
460+
weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
461+
462+
return x, weight, offset, bias, stride, pad, dilation
463+
464+
def _test_forward(self, device, contiguous):
465+
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous)
466+
in_channels = 6
467+
out_channels = 2
468+
kernel_size = (3, 2)
469+
groups = 2
470+
offset_groups = 3
471+
472+
layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
473+
dilation=dilation, groups=groups, offset_groups=offset_groups).to(device=x.device,
474+
dtype=x.dtype)
475+
res = layer(x, offset)
476+
477+
weight = layer.weight.data
478+
bias = layer.bias.data
479+
expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation)
480+
481+
self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(res, expected))
482+
483+
def _test_backward(self, device, contiguous):
484+
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous)
485+
486+
def func(x_, offset_, weight_, bias_):
487+
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)
488+
489+
gradcheck(func, (x, offset, weight, bias), nondet_tol=1e-5)
490+
491+
@torch.jit.script
492+
def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
493+
# type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
494+
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_)
495+
496+
gradcheck(lambda z, off, wei, bi: script_func(z, off, wei, bi, stride, padding, dilation),
497+
(x, offset, weight, bias), nondet_tol=1e-5)
498+
499+
379500
if __name__ == '__main__':
380501
unittest.main()

0 commit comments

Comments
 (0)