@@ -506,10 +506,11 @@ def test_Conv2d_groups_nobias(self):
506
506
output2 = m2 (i2 )
507
507
output2 .backward (grad_output [:, 2 :].contiguous ())
508
508
509
- self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ))
509
+ self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ),
510
+ atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
510
511
self .assertEqual (i .grad .data ,
511
512
torch .cat ([i1 .grad .data , i2 .grad .data ], 1 ),
512
- atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
513
+ atol = 1e-1 if dtype == torch . half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
513
514
self .assertEqual (m .weight .grad .data ,
514
515
torch .cat ([m1 .weight .grad .data , m2 .weight .grad .data ], 0 ),
515
516
atol = 1e-1 if dtype == torch .half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
@@ -545,13 +546,14 @@ def test_Conv2d_groups_nobias_v2(self):
545
546
output2 = m2 (i2 )
546
547
output2 .backward (grad_output [:, 8 :].contiguous ())
547
548
548
- self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ))
549
+ self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ),
550
+ atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
549
551
self .assertEqual (i .grad .data ,
550
552
torch .cat ([i1 .grad .data , i2 .grad .data ], 1 ),
551
- atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
553
+ atol = 1e-1 if dtype == torch . half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
552
554
self .assertEqual (m .weight .grad .data ,
553
555
torch .cat ([m1 .weight .grad .data , m2 .weight .grad .data ], 0 ),
554
- atol = 1e -1 if dtype == torch .half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
556
+ atol = 2e -1 if dtype in [ torch .half , torch . bfloat16 ] else dtype2prec_DONTUSE [dtype ], rtol = 0 )
555
557
556
558
# CPU-only test for group conv3d fast implementation using bmm
557
559
# See: https://github.com/pytorch/pytorch/pull/36355
@@ -2076,7 +2078,8 @@ def test_Conv2d_naive_groups(self, device, dtype):
2076
2078
output2 = m2 (i2 )
2077
2079
output2 .backward (grad_output [:, 2 :].contiguous ())
2078
2080
2079
- self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ))
2081
+ self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ),
2082
+ atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
2080
2083
self .assertEqual (i .grad .data ,
2081
2084
torch .cat ([i1 .grad .data , i2 .grad .data ], 1 ),
2082
2085
atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
@@ -2085,7 +2088,7 @@ def test_Conv2d_naive_groups(self, device, dtype):
2085
2088
atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
2086
2089
self .assertEqual (m .weight .grad .data ,
2087
2090
torch .cat ([m1 .weight .grad .data , m2 .weight .grad .data ], 0 ),
2088
- atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
2091
+ atol = 1e-1 if dtype == torch . half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
2089
2092
2090
2093
@dtypes (torch .double , torch .cdouble )
2091
2094
def test_Conv2d_backward_depthwise (self , device , dtype ):
0 commit comments