@@ -458,7 +458,7 @@ def test_new_empty_tensor(self):
458
458
459
459
460
460
class DeformConvTester (OpTester , unittest .TestCase ):
461
- def expected_fn (self , x , weight , offset , bias , stride = 1 , padding = 0 , dilation = 1 ):
461
+ def expected_fn (self , x , weight , offset , mask , bias , stride = 1 , padding = 0 , dilation = 1 ):
462
462
stride_h , stride_w = _pair (stride )
463
463
pad_h , pad_w = _pair (padding )
464
464
dil_h , dil_w = _pair (dilation )
@@ -489,12 +489,17 @@ def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
489
489
c_in = weight_grp * in_c_per_weight_grp + c
490
490
491
491
offset_grp = c_in // in_c_per_offset_grp
492
- offset_idx = 2 * (offset_grp * (weight_h * weight_w ) + di * weight_w + dj )
492
+ mask_idx = offset_grp * (weight_h * weight_w ) + di * weight_w + dj
493
+ offset_idx = 2 * mask_idx
493
494
494
495
pi = stride_h * i - pad_h + dil_h * di + offset [b , offset_idx , i , j ]
495
496
pj = stride_w * j - pad_w + dil_w * dj + offset [b , offset_idx + 1 , i , j ]
496
497
497
- out [b , c_out , i , j ] += (weight [c_out , c , di , dj ] *
498
+ mask_value = 1.0
499
+ if mask is not None :
500
+ mask_value = mask [b , mask_idx , i , j ]
501
+
502
+ out [b , c_out , i , j ] += (mask_value * weight [c_out , c , di , dj ] *
498
503
bilinear_interpolate (x [b , c_in , :, :], pi , pj ))
499
504
out += bias .view (1 , n_out_channels , 1 , 1 )
500
505
return out
@@ -523,6 +528,9 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype):
523
528
offset = torch .randn (batch_sz , n_offset_grps * 2 * weight_h * weight_w , out_h , out_w ,
524
529
device = device , dtype = dtype , requires_grad = True )
525
530
531
+ mask = torch .randn (batch_sz , n_offset_grps * weight_h * weight_w , out_h , out_w ,
532
+ device = device , dtype = dtype , requires_grad = True )
533
+
526
534
weight = torch .randn (n_out_channels , n_in_channels // n_weight_grps , weight_h , weight_w ,
527
535
device = device , dtype = dtype , requires_grad = True )
528
536
@@ -531,31 +539,39 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype):
531
539
if not contiguous :
532
540
x = x .permute (0 , 1 , 3 , 2 ).contiguous ().permute (0 , 1 , 3 , 2 )
533
541
offset = offset .permute (1 , 3 , 0 , 2 ).contiguous ().permute (2 , 0 , 3 , 1 )
542
+ mask = mask .permute (1 , 3 , 0 , 2 ).contiguous ().permute (2 , 0 , 3 , 1 )
534
543
weight = weight .permute (3 , 2 , 0 , 1 ).contiguous ().permute (2 , 3 , 1 , 0 )
535
544
536
- return x , weight , offset , bias , stride , pad , dilation
545
+ return x , weight , offset , mask , bias , stride , pad , dilation
537
546
538
547
def _test_forward (self , device , contiguous , dtype = None ):
539
548
dtype = self .dtype if dtype is None else dtype
540
549
for batch_sz in [0 , 33 ]:
541
550
self ._test_forward_with_batchsize (device , contiguous , batch_sz , dtype )
542
551
543
552
def _test_forward_with_batchsize (self , device , contiguous , batch_sz , dtype ):
544
- x , _ , offset , _ , stride , padding , dilation = self .get_fn_args (device , contiguous , batch_sz , dtype )
553
+ x , _ , offset , mask , _ , stride , padding , dilation = self .get_fn_args (device , contiguous , batch_sz , dtype )
545
554
in_channels = 6
546
555
out_channels = 2
547
556
kernel_size = (3 , 2 )
548
557
groups = 2
558
+ tol = 1e-3 if dtype is torch .half else 1e-5
549
559
550
560
layer = ops .DeformConv2d (in_channels , out_channels , kernel_size , stride = stride , padding = padding ,
551
561
dilation = dilation , groups = groups ).to (device = x .device , dtype = dtype )
552
- res = layer (x , offset )
562
+ res = layer (x , offset , mask )
553
563
554
564
weight = layer .weight .data
555
565
bias = layer .bias .data
556
- expected = self .expected_fn (x , weight , offset , bias , stride = stride , padding = padding , dilation = dilation )
566
+ expected = self .expected_fn (x , weight , offset , mask , bias , stride = stride , padding = padding , dilation = dilation )
567
+
568
+ self .assertTrue (torch .allclose (res .to (expected .dtype ), expected , rtol = tol , atol = tol ),
569
+ '\n res:\n {}\n expected:\n {}' .format (res , expected ))
570
+
571
+ # no modulation test
572
+ res = layer (x , offset )
573
+ expected = self .expected_fn (x , weight , offset , None , bias , stride = stride , padding = padding , dilation = dilation )
557
574
558
- tol = 1e-3 if dtype is torch .half else 1e-5
559
575
self .assertTrue (torch .allclose (res .to (expected .dtype ), expected , rtol = tol , atol = tol ),
560
576
'\n res:\n {}\n expected:\n {}' .format (res , expected ))
561
577
@@ -564,24 +580,45 @@ def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype):
564
580
wrong_offset = torch .rand_like (offset [:, :2 ])
565
581
res = layer (x , wrong_offset )
566
582
583
+ with self .assertRaises (RuntimeError ):
584
+ wrong_mask = torch .rand_like (mask [:, :2 ])
585
+ res = layer (x , offset , wrong_mask )
586
+
567
587
def _test_backward (self , device , contiguous ):
568
588
for batch_sz in [0 , 33 ]:
569
589
self ._test_backward_with_batchsize (device , contiguous , batch_sz )
570
590
571
591
def _test_backward_with_batchsize (self , device , contiguous , batch_sz ):
572
- x , weight , offset , bias , stride , padding , dilation = self .get_fn_args (device , contiguous , batch_sz , self .dtype )
592
+ x , weight , offset , mask , bias , stride , padding , dilation = self .get_fn_args (device , contiguous , batch_sz , self .dtype )
593
+
594
+ def func (x_ , offset_ , mask_ , weight_ , bias_ ):
595
+ return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride ,
596
+ padding = padding , dilation = dilation , mask = mask_ )
573
597
574
- def func (x_ , offset_ , weight_ , bias_ ):
575
- return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride , padding = padding , dilation = dilation )
598
+ gradcheck (func , (x , offset , mask , weight , bias ), nondet_tol = 1e-5 )
599
+
600
+ def func_no_mask (x_ , offset_ , weight_ , bias_ ):
601
+ return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride ,
602
+ padding = padding , dilation = dilation , mask = None )
603
+
604
+ gradcheck (func_no_mask , (x , offset , weight , bias ), nondet_tol = 1e-5 )
605
+
606
+ @torch .jit .script
607
+ def script_func (x_ , offset_ , mask_ , weight_ , bias_ , stride_ , pad_ , dilation_ ):
608
+ # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
609
+ return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride_ ,
610
+ padding = pad_ , dilation = dilation_ , mask = mask_ )
576
611
577
- gradcheck (func , (x , offset , weight , bias ), nondet_tol = 1e-5 )
612
+ gradcheck (lambda z , off , msk , wei , bi : script_func (z , off , msk , wei , bi , stride , padding , dilation ),
613
+ (x , offset , mask , weight , bias ), nondet_tol = 1e-5 )
578
614
579
615
@torch .jit .script
580
- def script_func (x_ , offset_ , weight_ , bias_ , stride_ , pad_ , dilation_ ):
581
- # type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
582
- return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride_ , padding = pad_ , dilation = dilation_ )
616
+ def script_func_no_mask (x_ , offset_ , weight_ , bias_ , stride_ , pad_ , dilation_ ):
617
+ # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
618
+ return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride_ ,
619
+ padding = pad_ , dilation = dilation_ , mask = None )
583
620
584
- gradcheck (lambda z , off , wei , bi : script_func (z , off , wei , bi , stride , padding , dilation ),
621
+ gradcheck (lambda z , off , wei , bi : script_func_no_mask (z , off , wei , bi , stride , padding , dilation ),
585
622
(x , offset , weight , bias ), nondet_tol = 1e-5 )
586
623
587
624
# Test from https://github.com/pytorch/vision/issues/2598
@@ -593,17 +630,19 @@ def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
593
630
init_weight = torch .randn (9 , 9 , 3 , 3 , requires_grad = True )
594
631
img = torch .randn (8 , 9 , 1000 , 110 )
595
632
offset = torch .rand (8 , 2 * 3 * 3 , 1000 , 110 )
633
+ mask = torch .rand (8 , 3 * 3 , 1000 , 110 )
596
634
597
635
if not contiguous :
598
636
img = img .permute (0 , 1 , 3 , 2 ).contiguous ().permute (0 , 1 , 3 , 2 )
599
637
offset = offset .permute (1 , 3 , 0 , 2 ).contiguous ().permute (2 , 0 , 3 , 1 )
638
+ mask = mask .permute (1 , 3 , 0 , 2 ).contiguous ().permute (2 , 0 , 3 , 1 )
600
639
weight = init_weight .permute (3 , 2 , 0 , 1 ).contiguous ().permute (2 , 3 , 1 , 0 )
601
640
else :
602
641
weight = init_weight
603
642
604
643
for d in ["cpu" , "cuda" ]:
605
644
606
- out = ops .deform_conv2d (img .to (d ), offset .to (d ), weight .to (d ), padding = 1 )
645
+ out = ops .deform_conv2d (img .to (d ), offset .to (d ), weight .to (d ), padding = 1 , mask = mask . to ( d ) )
607
646
out .mean ().backward ()
608
647
if true_cpu_grads is None :
609
648
true_cpu_grads = init_weight .grad
0 commit comments