Skip to content

Commit 2285693

Browse files
committed
Add modulation input for DeformConv2D
1 parent 052edce commit 2285693

File tree

8 files changed

+607
-145
lines changed

8 files changed

+607
-145
lines changed

test/test_ops.py

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def test_new_empty_tensor(self):
458458

459459

460460
class DeformConvTester(OpTester, unittest.TestCase):
461-
def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
461+
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
462462
stride_h, stride_w = _pair(stride)
463463
pad_h, pad_w = _pair(padding)
464464
dil_h, dil_w = _pair(dilation)
@@ -489,12 +489,17 @@ def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
489489
c_in = weight_grp * in_c_per_weight_grp + c
490490

491491
offset_grp = c_in // in_c_per_offset_grp
492-
offset_idx = 2 * (offset_grp * (weight_h * weight_w) + di * weight_w + dj)
492+
mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
493+
offset_idx = 2 * mask_idx
493494

494495
pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
495496
pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]
496497

497-
out[b, c_out, i, j] += (weight[c_out, c, di, dj] *
498+
mask_value = 1.0
499+
if mask is not None:
500+
mask_value = mask[b, mask_idx, i, j]
501+
502+
out[b, c_out, i, j] += (mask_value * weight[c_out, c, di, dj] *
498503
bilinear_interpolate(x[b, c_in, :, :], pi, pj))
499504
out += bias.view(1, n_out_channels, 1, 1)
500505
return out
@@ -523,6 +528,9 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype):
523528
offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w,
524529
device=device, dtype=dtype, requires_grad=True)
525530

531+
mask = torch.randn(batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w,
532+
device=device, dtype=dtype, requires_grad=True)
533+
526534
weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
527535
device=device, dtype=dtype, requires_grad=True)
528536

@@ -531,31 +539,39 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype):
531539
if not contiguous:
532540
x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
533541
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
542+
mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
534543
weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
535544

536-
return x, weight, offset, bias, stride, pad, dilation
545+
return x, weight, offset, mask, bias, stride, pad, dilation
537546

538547
def _test_forward(self, device, contiguous, dtype=None):
539548
dtype = self.dtype if dtype is None else dtype
540549
for batch_sz in [0, 33]:
541550
self._test_forward_with_batchsize(device, contiguous, batch_sz, dtype)
542551

543552
def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype):
544-
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
553+
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
545554
in_channels = 6
546555
out_channels = 2
547556
kernel_size = (3, 2)
548557
groups = 2
558+
tol = 1e-3 if dtype is torch.half else 1e-5
549559

550560
layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
551561
dilation=dilation, groups=groups).to(device=x.device, dtype=dtype)
552-
res = layer(x, offset)
562+
res = layer(x, offset, mask)
553563

554564
weight = layer.weight.data
555565
bias = layer.bias.data
556-
expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation)
566+
expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)
567+
568+
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
569+
'\nres:\n{}\nexpected:\n{}'.format(res, expected))
570+
571+
# no modulation test
572+
res = layer(x, offset)
573+
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
557574

558-
tol = 1e-3 if dtype is torch.half else 1e-5
559575
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
560576
'\nres:\n{}\nexpected:\n{}'.format(res, expected))
561577

@@ -564,24 +580,45 @@ def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype):
564580
wrong_offset = torch.rand_like(offset[:, :2])
565581
res = layer(x, wrong_offset)
566582

583+
with self.assertRaises(RuntimeError):
584+
wrong_mask = torch.rand_like(mask[:, :2])
585+
res = layer(x, offset, wrong_mask)
586+
567587
def _test_backward(self, device, contiguous):
568588
for batch_sz in [0, 33]:
569589
self._test_backward_with_batchsize(device, contiguous, batch_sz)
570590

571591
def _test_backward_with_batchsize(self, device, contiguous, batch_sz):
572-
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, self.dtype)
592+
x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, self.dtype)
593+
594+
def func(x_, offset_, mask_, weight_, bias_):
595+
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
596+
padding=padding, dilation=dilation, mask=mask_)
573597

574-
def func(x_, offset_, weight_, bias_):
575-
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)
598+
gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5)
599+
600+
def func_no_mask(x_, offset_, weight_, bias_):
601+
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
602+
padding=padding, dilation=dilation, mask=None)
603+
604+
gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5)
605+
606+
@torch.jit.script
607+
def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
608+
# type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
609+
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
610+
padding=pad_, dilation=dilation_, mask=mask_)
576611

577-
gradcheck(func, (x, offset, weight, bias), nondet_tol=1e-5)
612+
gradcheck(lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
613+
(x, offset, mask, weight, bias), nondet_tol=1e-5)
578614

579615
@torch.jit.script
580-
def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
581-
# type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
582-
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_)
616+
def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
617+
# type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
618+
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
619+
padding=pad_, dilation=dilation_, mask=None)
583620

584-
gradcheck(lambda z, off, wei, bi: script_func(z, off, wei, bi, stride, padding, dilation),
621+
gradcheck(lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
585622
(x, offset, weight, bias), nondet_tol=1e-5)
586623

587624
# Test from https://github.com/pytorch/vision/issues/2598
@@ -593,17 +630,19 @@ def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
593630
init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
594631
img = torch.randn(8, 9, 1000, 110)
595632
offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
633+
mask = torch.rand(8, 3 * 3, 1000, 110)
596634

597635
if not contiguous:
598636
img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
599637
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
638+
mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
600639
weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
601640
else:
602641
weight = init_weight
603642

604643
for d in ["cpu", "cuda"]:
605644

606-
out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1)
645+
out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
607646
out.mean().backward()
608647
if true_cpu_grads is None:
609648
true_cpu_grads = init_weight.grad

0 commit comments

Comments
 (0)