@@ -506,10 +506,11 @@ def test_Conv2d_groups_nobias(self):
506
506
output2 = m2 (i2 )
507
507
output2 .backward (grad_output [:, 2 :].contiguous ())
508
508
509
- self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ))
509
+ self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ),
510
+ atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
510
511
self .assertEqual (i .grad .data ,
511
512
torch .cat ([i1 .grad .data , i2 .grad .data ], 1 ),
512
- atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
513
+ atol = 1e-1 if dtype == torch . half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
513
514
self .assertEqual (m .weight .grad .data ,
514
515
torch .cat ([m1 .weight .grad .data , m2 .weight .grad .data ], 0 ),
515
516
atol = 1e-1 if dtype == torch .half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
@@ -545,13 +546,14 @@ def test_Conv2d_groups_nobias_v2(self):
545
546
output2 = m2 (i2 )
546
547
output2 .backward (grad_output [:, 8 :].contiguous ())
547
548
548
- self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ))
549
+ self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ),
550
+ atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
549
551
self .assertEqual (i .grad .data ,
550
552
torch .cat ([i1 .grad .data , i2 .grad .data ], 1 ),
551
- atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
553
+ atol = 1e-1 if dtype == torch . half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
552
554
self .assertEqual (m .weight .grad .data ,
553
555
torch .cat ([m1 .weight .grad .data , m2 .weight .grad .data ], 0 ),
554
- atol = 1e -1 if dtype == torch .half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
556
+ atol = 2e -1 if dtype in [ torch .half , torch . bfloat16 ] else dtype2prec_DONTUSE [dtype ], rtol = 0 )
555
557
556
558
# CPU-only test for group conv3d fast implementation using bmm
557
559
# See: https://github.com/pytorch/pytorch/pull/36355
@@ -2075,7 +2077,8 @@ def test_Conv2d_naive_groups(self, device, dtype):
2075
2077
output2 = m2 (i2 )
2076
2078
output2 .backward (grad_output [:, 2 :].contiguous ())
2077
2079
2078
- self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ))
2080
+ self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ),
2081
+ atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
2079
2082
self .assertEqual (i .grad .data ,
2080
2083
torch .cat ([i1 .grad .data , i2 .grad .data ], 1 ),
2081
2084
atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
@@ -2084,7 +2087,7 @@ def test_Conv2d_naive_groups(self, device, dtype):
2084
2087
atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
2085
2088
self .assertEqual (m .weight .grad .data ,
2086
2089
torch .cat ([m1 .weight .grad .data , m2 .weight .grad .data ], 0 ),
2087
- atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
2090
+ atol = 1e-1 if dtype == torch . half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
2088
2091
2089
2092
@dtypes (torch .double , torch .cdouble )
2090
2093
def test_Conv2d_backward_depthwise (self , device , dtype ):
0 commit comments