Skip to content

Commit c7c09ca

Browse files
committed
Reorder weight and offset args in deform_conv2d
We place offset arg before the weight arg, to be more consistent with DeformConv2d.forward(input, offset)
1 parent 6467b58 commit c7c09ca

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

test/test_ops.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -474,18 +474,18 @@ def _test_forward(self, device, contiguous):
474474
def _test_backward(self, device, contiguous):
475475
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous)
476476

477-
def func(x_, weight_, offset_, bias_):
478-
return ops.deform_conv2d(x_, weight_, offset_, bias_, stride=stride, padding=padding, dilation=dilation)
477+
def func(x_, offset_, weight_, bias_):
478+
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)
479479

480-
gradcheck(func, (x, weight, offset, bias), nondet_tol=1e-5)
480+
gradcheck(func, (x, offset, weight, bias), nondet_tol=1e-5)
481481

482482
@torch.jit.script
483-
def script_func(x_, weight_, offset_, bias_, stride_, pad_, dilation_):
483+
def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
484484
# type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
485-
return ops.deform_conv2d(x_, weight_, offset_, bias_, stride=stride_, padding=pad_, dilation=dilation_)
485+
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_)
486486

487-
gradcheck(lambda z, wei, off, bi: script_func(z, wei, off, bi, stride, padding, dilation),
488-
(x, weight, offset, bias), nondet_tol=1e-5)
487+
gradcheck(lambda z, off, wei, bi: script_func(z, off, wei, bi, stride, padding, dilation),
488+
(x, offset, weight, bias), nondet_tol=1e-5)
489489

490490

491491
if __name__ == '__main__':

torchvision/ops/deform_conv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@
88
from torch.jit.annotations import Optional, Tuple
99

1010

11-
def deform_conv2d(input, weight, offset, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1)):
11+
def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1)):
1212
# type: (Tensor, Tensor, Tensor, Optional[Tensor], Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
1313
"""
1414
Performs Deformable Convolution, described in Deformable Convolutional Networks
1515
1616
Arguments:
1717
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
18-
weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]):
19-
convolution weights, split into groups of size (in_channels // groups)
2018
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
2119
out_height, out_width]): offsets to be applied for each position in the
2220
convolution kernel.
21+
weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]):
22+
convolution weights, split into groups of size (in_channels // groups)
2323
bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None
2424
stride (int or Tuple[int, int]): distance between convolution centers. Default: 1
2525
padding (int or Tuple[int, int]): height/width of padding of zeroes around
@@ -105,7 +105,7 @@ def forward(self, input, offset):
105105
out_height, out_width]): offsets to be applied for each position in the
106106
convolution kernel.
107107
"""
108-
return deform_conv2d(input, self.weight, offset, self.bias, stride=self.stride,
108+
return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride,
109109
padding=self.padding, dilation=self.dilation)
110110

111111
def __repr__(self):

0 commit comments

Comments
 (0)