@@ -507,10 +507,11 @@ def test_Conv2d_groups_nobias(self):
507
507
output2 = m2 (i2 )
508
508
output2 .backward (grad_output [:, 2 :].contiguous ())
509
509
510
- self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ))
510
+ self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ),
511
+ atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
511
512
self .assertEqual (i .grad .data ,
512
513
torch .cat ([i1 .grad .data , i2 .grad .data ], 1 ),
513
- atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
514
+ atol = 1e-1 if dtype == torch . half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
514
515
self .assertEqual (m .weight .grad .data ,
515
516
torch .cat ([m1 .weight .grad .data , m2 .weight .grad .data ], 0 ),
516
517
atol = 1e-1 if dtype == torch .half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
@@ -547,13 +548,14 @@ def test_Conv2d_groups_nobias_v2(self):
547
548
output2 = m2 (i2 )
548
549
output2 .backward (grad_output [:, 8 :].contiguous ())
549
550
550
- self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ))
551
+ self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ),
552
+ atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
551
553
self .assertEqual (i .grad .data ,
552
554
torch .cat ([i1 .grad .data , i2 .grad .data ], 1 ),
553
- atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
555
+ atol = 1e-1 if dtype == torch . half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
554
556
self .assertEqual (m .weight .grad .data ,
555
557
torch .cat ([m1 .weight .grad .data , m2 .weight .grad .data ], 0 ),
556
- atol = 1e -1 if dtype == torch .half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
558
+ atol = 2e -1 if dtype in [ torch .half , torch . bfloat16 ] else dtype2prec_DONTUSE [dtype ], rtol = 0 )
557
559
558
560
# CPU-only test for group conv3d fast implementation using bmm
559
561
# See: https://github.com/pytorch/pytorch/pull/36355
@@ -2073,7 +2075,8 @@ def test_Conv2d_naive_groups(self, device, dtype):
2073
2075
output2 = m2 (i2 )
2074
2076
output2 .backward (grad_output [:, 2 :].contiguous ())
2075
2077
2076
- self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ))
2078
+ self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ),
2079
+ atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
2077
2080
self .assertEqual (i .grad .data ,
2078
2081
torch .cat ([i1 .grad .data , i2 .grad .data ], 1 ),
2079
2082
atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
@@ -2082,7 +2085,7 @@ def test_Conv2d_naive_groups(self, device, dtype):
2082
2085
atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
2083
2086
self .assertEqual (m .weight .grad .data ,
2084
2087
torch .cat ([m1 .weight .grad .data , m2 .weight .grad .data ], 0 ),
2085
- atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
2088
+ atol = 1e-1 if dtype == torch . half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
2086
2089
2087
2090
@dtypes (torch .double , torch .cdouble )
2088
2091
def test_Conv2d_backward_depthwise (self , device , dtype ):
0 commit comments