Skip to content

Commit 5010f39

Browse files
Fix test_Conv2d_groups related errors (SWDEV-416489) (#1269) (#1344)
* Fix test_Conv2d_groups related errors (SWDEV-416489) This is due to incorrect atol/rtol settings for torch.half and torch.bfloat16 data types. * More adjustment on Navi 32 Co-authored-by: Xinya Zhang <[email protected]>
1 parent e6e5c48 commit 5010f39

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

test/nn/test_convolution.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -506,10 +506,11 @@ def test_Conv2d_groups_nobias(self):
506506
output2 = m2(i2)
507507
output2.backward(grad_output[:, 2:].contiguous())
508508

509-
self.assertEqual(output, torch.cat([output1, output2], 1))
509+
self.assertEqual(output, torch.cat([output1, output2], 1),
510+
atol=dtype2prec_DONTUSE[dtype], rtol=0)
510511
self.assertEqual(i.grad.data,
511512
torch.cat([i1.grad.data, i2.grad.data], 1),
512-
atol=dtype2prec_DONTUSE[dtype], rtol=0)
513+
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
513514
self.assertEqual(m.weight.grad.data,
514515
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
515516
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
@@ -545,13 +546,14 @@ def test_Conv2d_groups_nobias_v2(self):
545546
output2 = m2(i2)
546547
output2.backward(grad_output[:, 8:].contiguous())
547548

548-
self.assertEqual(output, torch.cat([output1, output2], 1))
549+
self.assertEqual(output, torch.cat([output1, output2], 1),
550+
atol=dtype2prec_DONTUSE[dtype], rtol=0)
549551
self.assertEqual(i.grad.data,
550552
torch.cat([i1.grad.data, i2.grad.data], 1),
551-
atol=dtype2prec_DONTUSE[dtype], rtol=0)
553+
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
552554
self.assertEqual(m.weight.grad.data,
553555
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
554-
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
556+
atol=2e-1 if dtype in [torch.half, torch.bfloat16] else dtype2prec_DONTUSE[dtype], rtol=0)
555557

556558
# CPU-only test for group conv3d fast implementation using bmm
557559
# See: https://github.com/pytorch/pytorch/pull/36355
@@ -2076,7 +2078,8 @@ def test_Conv2d_naive_groups(self, device, dtype):
20762078
output2 = m2(i2)
20772079
output2.backward(grad_output[:, 2:].contiguous())
20782080

2079-
self.assertEqual(output, torch.cat([output1, output2], 1))
2081+
self.assertEqual(output, torch.cat([output1, output2], 1),
2082+
atol=dtype2prec_DONTUSE[dtype], rtol=0)
20802083
self.assertEqual(i.grad.data,
20812084
torch.cat([i1.grad.data, i2.grad.data], 1),
20822085
atol=dtype2prec_DONTUSE[dtype], rtol=0)
@@ -2085,7 +2088,7 @@ def test_Conv2d_naive_groups(self, device, dtype):
20852088
atol=dtype2prec_DONTUSE[dtype], rtol=0)
20862089
self.assertEqual(m.weight.grad.data,
20872090
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
2088-
atol=dtype2prec_DONTUSE[dtype], rtol=0)
2091+
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
20892092

20902093
@dtypes(torch.double, torch.cdouble)
20912094
def test_Conv2d_backward_depthwise(self, device, dtype):

0 commit comments

Comments
 (0)