Skip to content

Commit 034da48

Browse files
Fix test_Conv2d_groups related errors (SWDEV-416489) (#1269) (#1281)
* Fix test_Conv2d_groups related errors (SWDEV-416489) This is due to incorrect atol/rtol settings for torch.half and torch.bfloat16 data types. * More adjustment on Navi 32 Co-authored-by: Xinya Zhang <[email protected]>
1 parent 4c8bc42 commit 034da48

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

test/nn/test_convolution.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -506,10 +506,11 @@ def test_Conv2d_groups_nobias(self):
506506
output2 = m2(i2)
507507
output2.backward(grad_output[:, 2:].contiguous())
508508

509-
self.assertEqual(output, torch.cat([output1, output2], 1))
509+
self.assertEqual(output, torch.cat([output1, output2], 1),
510+
atol=dtype2prec_DONTUSE[dtype], rtol=0)
510511
self.assertEqual(i.grad.data,
511512
torch.cat([i1.grad.data, i2.grad.data], 1),
512-
atol=dtype2prec_DONTUSE[dtype], rtol=0)
513+
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
513514
self.assertEqual(m.weight.grad.data,
514515
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
515516
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
@@ -545,13 +546,14 @@ def test_Conv2d_groups_nobias_v2(self):
545546
output2 = m2(i2)
546547
output2.backward(grad_output[:, 8:].contiguous())
547548

548-
self.assertEqual(output, torch.cat([output1, output2], 1))
549+
self.assertEqual(output, torch.cat([output1, output2], 1),
550+
atol=dtype2prec_DONTUSE[dtype], rtol=0)
549551
self.assertEqual(i.grad.data,
550552
torch.cat([i1.grad.data, i2.grad.data], 1),
551-
atol=dtype2prec_DONTUSE[dtype], rtol=0)
553+
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
552554
self.assertEqual(m.weight.grad.data,
553555
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
554-
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
556+
atol=2e-1 if dtype in [torch.half, torch.bfloat16] else dtype2prec_DONTUSE[dtype], rtol=0)
555557

556558
# CPU-only test for group conv3d fast implementation using bmm
557559
# See: https://github.com/pytorch/pytorch/pull/36355
@@ -2075,7 +2077,8 @@ def test_Conv2d_naive_groups(self, device, dtype):
20752077
output2 = m2(i2)
20762078
output2.backward(grad_output[:, 2:].contiguous())
20772079

2078-
self.assertEqual(output, torch.cat([output1, output2], 1))
2080+
self.assertEqual(output, torch.cat([output1, output2], 1),
2081+
atol=dtype2prec_DONTUSE[dtype], rtol=0)
20792082
self.assertEqual(i.grad.data,
20802083
torch.cat([i1.grad.data, i2.grad.data], 1),
20812084
atol=dtype2prec_DONTUSE[dtype], rtol=0)
@@ -2084,7 +2087,7 @@ def test_Conv2d_naive_groups(self, device, dtype):
20842087
atol=dtype2prec_DONTUSE[dtype], rtol=0)
20852088
self.assertEqual(m.weight.grad.data,
20862089
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
2087-
atol=dtype2prec_DONTUSE[dtype], rtol=0)
2090+
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
20882091

20892092
@dtypes(torch.double, torch.cdouble)
20902093
def test_Conv2d_backward_depthwise(self, device, dtype):

0 commit comments

Comments
 (0)