@@ -507,10 +507,11 @@ def test_Conv2d_groups_nobias(self):
507
507
output2 = m2 (i2 )
508
508
output2 .backward (grad_output [:, 2 :].contiguous ())
509
509
510
- self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ))
510
+ self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ),
511
+ atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
511
512
self .assertEqual (i .grad .data ,
512
513
torch .cat ([i1 .grad .data , i2 .grad .data ], 1 ),
513
- atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
514
+ atol = 1e-1 if dtype == torch . half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
514
515
self .assertEqual (m .weight .grad .data ,
515
516
torch .cat ([m1 .weight .grad .data , m2 .weight .grad .data ], 0 ),
516
517
atol = 1e-1 if dtype == torch .half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
@@ -547,13 +548,14 @@ def test_Conv2d_groups_nobias_v2(self):
547
548
output2 = m2 (i2 )
548
549
output2 .backward (grad_output [:, 8 :].contiguous ())
549
550
550
- self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ))
551
+ self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ),
552
+ atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
551
553
self .assertEqual (i .grad .data ,
552
554
torch .cat ([i1 .grad .data , i2 .grad .data ], 1 ),
553
- atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
555
+ atol = 1e-1 if dtype == torch . half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
554
556
self .assertEqual (m .weight .grad .data ,
555
557
torch .cat ([m1 .weight .grad .data , m2 .weight .grad .data ], 0 ),
556
- atol = 1e -1 if dtype == torch .half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
558
+ atol = 2e -1 if dtype in [ torch .half , torch . bfloat16 ] else dtype2prec_DONTUSE [dtype ], rtol = 0 )
557
559
558
560
# CPU-only test for group conv3d fast implementation using bmm
559
561
# See: https://github.com/pytorch/pytorch/pull/36355
@@ -2080,7 +2082,8 @@ def test_Conv2d_naive_groups(self, device, dtype):
2080
2082
output2 = m2 (i2 )
2081
2083
output2 .backward (grad_output [:, 2 :].contiguous ())
2082
2084
2083
- self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ))
2085
+ self .assertEqual (output , torch .cat ([output1 , output2 ], 1 ),
2086
+ atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
2084
2087
self .assertEqual (i .grad .data ,
2085
2088
torch .cat ([i1 .grad .data , i2 .grad .data ], 1 ),
2086
2089
atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
@@ -2089,7 +2092,7 @@ def test_Conv2d_naive_groups(self, device, dtype):
2089
2092
atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
2090
2093
self .assertEqual (m .weight .grad .data ,
2091
2094
torch .cat ([m1 .weight .grad .data , m2 .weight .grad .data ], 0 ),
2092
- atol = dtype2prec_DONTUSE [dtype ], rtol = 0 )
2095
+ atol = 1e-1 if dtype == torch . half else dtype2prec_DONTUSE [dtype ], rtol = 0 )
2093
2096
2094
2097
@dtypes (torch .double , torch .cdouble )
2095
2098
def test_Conv2d_backward_depthwise (self , device , dtype ):
0 commit comments