@@ -474,18 +474,18 @@ def _test_forward(self, device, contiguous):
474
474
def _test_backward (self , device , contiguous ):
475
475
x , weight , offset , bias , stride , padding , dilation = self .get_fn_args (device , contiguous )
476
476
477
- def func (x_ , weight_ , offset_ , bias_ ):
478
- return ops .deform_conv2d (x_ , weight_ , offset_ , bias_ , stride = stride , padding = padding , dilation = dilation )
477
+ def func (x_ , offset_ , weight_ , bias_ ):
478
+ return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride , padding = padding , dilation = dilation )
479
479
480
- gradcheck (func , (x , weight , offset , bias ), nondet_tol = 1e-5 )
480
+ gradcheck (func , (x , offset , weight , bias ), nondet_tol = 1e-5 )
481
481
482
482
@torch .jit .script
483
- def script_func (x_ , weight_ , offset_ , bias_ , stride_ , pad_ , dilation_ ):
483
+ def script_func (x_ , offset_ , weight_ , bias_ , stride_ , pad_ , dilation_ ):
484
484
# type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
485
- return ops .deform_conv2d (x_ , weight_ , offset_ , bias_ , stride = stride_ , padding = pad_ , dilation = dilation_ )
485
+ return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride_ , padding = pad_ , dilation = dilation_ )
486
486
487
- gradcheck (lambda z , wei , off , bi : script_func (z , wei , off , bi , stride , padding , dilation ),
488
- (x , weight , offset , bias ), nondet_tol = 1e-5 )
487
+ gradcheck (lambda z , off , wei , bi : script_func (z , off , wei , bi , stride , padding , dilation ),
488
+ (x , offset , weight , bias ), nondet_tol = 1e-5 )
489
489
490
490
491
491
if __name__ == '__main__' :
0 commit comments