Skip to content

Commit 2c60f5c

Browse files
committed
Fix arguments order of deform_conv2d
1 parent 0707337 commit 2c60f5c

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

test/test_ops.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -567,35 +567,31 @@ def _test_backward_with_batchsize(self, device, contiguous, batch_sz):
567567
x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz)
568568

569569
def func(x_, offset_, mask_, weight_, bias_):
570-
return ops.deform_conv2d(x_, offset_, mask_,
571-
weight_, bias_, stride=stride,
572-
padding=padding, dilation=dilation)
570+
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
571+
padding=padding, dilation=dilation, mask=mask_)
573572

574573
gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5)
575574

576575
def func_no_mask(x_, offset_, weight_, bias_):
577-
return ops.deform_conv2d(x_, offset_, None,
578-
weight_, bias_, stride=stride,
579-
padding=padding, dilation=dilation)
576+
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
577+
padding=padding, dilation=dilation, mask=None)
580578

581579
gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5)
582580

583581
@torch.jit.script
584582
def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
585583
# type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
586-
return ops.deform_conv2d(x_, offset_, mask_,
587-
weight_, bias_, stride=stride_,
588-
padding=pad_, dilation=dilation_)
584+
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
585+
padding=pad_, dilation=dilation_, mask=mask_)
589586

590587
gradcheck(lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
591588
(x, offset, mask, weight, bias), nondet_tol=1e-5)
592589

593590
@torch.jit.script
594591
def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
595592
# type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
596-
return ops.deform_conv2d(x_, offset_, None,
597-
weight_, bias_, stride=stride_,
598-
padding=pad_, dilation=dilation_)
593+
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
594+
padding=pad_, dilation=dilation_, mask=None)
599595

600596
gradcheck(lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
601597
(x, offset, weight, bias), nondet_tol=1e-5)
@@ -621,7 +617,7 @@ def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
621617

622618
for d in ["cpu", "cuda"]:
623619

624-
out = ops.deform_conv2d(img.to(d), offset.to(d), mask.to(d), weight.to(d), padding=1)
620+
out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
625621
out.mean().backward()
626622
if true_cpu_grads is None:
627623
true_cpu_grads = init_weight.grad

torchvision/ops/deform_conv.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
def deform_conv2d(
1313
input: Tensor,
1414
offset: Tensor,
15-
mask: Optional[Tensor],
1615
weight: Tensor,
1716
bias: Optional[Tensor] = None,
1817
stride: Tuple[int, int] = (1, 1),
1918
padding: Tuple[int, int] = (0, 0),
2019
dilation: Tuple[int, int] = (1, 1),
20+
mask: Optional[Tensor] = None,
2121
) -> Tensor:
2222
"""
2323
Performs Deformable Convolution, described in Deformable Convolutional Networks
@@ -27,16 +27,16 @@ def deform_conv2d(
2727
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
2828
out_height, out_width]): offsets to be applied for each position in the
2929
convolution kernel.
30-
mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width,
31-
out_height, out_width]): masks to be applied for each position in the
32-
convolution kernel.
3330
weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]):
3431
convolution weights, split into groups of size (in_channels // groups)
3532
bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None
3633
stride (int or Tuple[int, int]): distance between convolution centers. Default: 1
3734
padding (int or Tuple[int, int]): height/width of padding of zeroes around
3835
each image. Default: 0
3936
dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1
37+
mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width,
38+
out_height, out_width]): masks to be applied for each position in the
39+
convolution kernel.
4040
4141
Returns:
4242
output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution
@@ -51,7 +51,7 @@ def deform_conv2d(
5151
>>> # and kernel size of 3, without padding, the output size is 8
5252
>>> offset = torch.rand(4, 2 * kh * kw, 8, 8)
5353
>>> mask = torch.rand(4, kh * kw, 8, 8)
54-
>>> out = deform_conv2d(input, offset, mask, weight)
54+
>>> out = deform_conv2d(input, offset, weight, mask=mask)
5555
>>> print(out.shape)
5656
>>> # returns
5757
>>> torch.Size([4, 5, 8, 8])
@@ -158,8 +158,8 @@ def forward(self, input: Tensor, offset: Tensor, mask: Tensor = None) -> Tensor:
158158
out_height, out_width]): masks to be applied for each position in the
159159
convolution kernel.
160160
"""
161-
return deform_conv2d(input, offset, mask, self.weight, self.bias, stride=self.stride,
162-
padding=self.padding, dilation=self.dilation)
161+
return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride,
162+
padding=self.padding, dilation=self.dilation, mask=mask)
163163

164164
def __repr__(self) -> str:
165165
s = self.__class__.__name__ + '('

0 commit comments

Comments
 (0)