Skip to content

Commit 54fd083

Browse files
xinyazhangpruthvistony
authored andcommitted
Fix test_Conv2d_groups related errors (SWDEV-416489) (#1269)
* Fix test_Conv2d_groups related errors (SWDEV-416489) This is due to incorrect atol/rtol settings for torch.half and torch.bfloat16 data types. * More adjustment on Navi 32
1 parent a0bd2c9 commit 54fd083

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

test/nn/test_convolution.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,11 @@ def test_Conv2d_groups_nobias(self):
507507
output2 = m2(i2)
508508
output2.backward(grad_output[:, 2:].contiguous())
509509

510-
self.assertEqual(output, torch.cat([output1, output2], 1))
510+
self.assertEqual(output, torch.cat([output1, output2], 1),
511+
atol=dtype2prec_DONTUSE[dtype], rtol=0)
511512
self.assertEqual(i.grad.data,
512513
torch.cat([i1.grad.data, i2.grad.data], 1),
513-
atol=dtype2prec_DONTUSE[dtype], rtol=0)
514+
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
514515
self.assertEqual(m.weight.grad.data,
515516
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
516517
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
@@ -547,13 +548,14 @@ def test_Conv2d_groups_nobias_v2(self):
547548
output2 = m2(i2)
548549
output2.backward(grad_output[:, 8:].contiguous())
549550

550-
self.assertEqual(output, torch.cat([output1, output2], 1))
551+
self.assertEqual(output, torch.cat([output1, output2], 1),
552+
atol=dtype2prec_DONTUSE[dtype], rtol=0)
551553
self.assertEqual(i.grad.data,
552554
torch.cat([i1.grad.data, i2.grad.data], 1),
553-
atol=dtype2prec_DONTUSE[dtype], rtol=0)
555+
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
554556
self.assertEqual(m.weight.grad.data,
555557
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
556-
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
558+
atol=2e-1 if dtype in [torch.half, torch.bfloat16] else dtype2prec_DONTUSE[dtype], rtol=0)
557559

558560
# CPU-only test for group conv3d fast implementation using bmm
559561
# See: https://github.com/pytorch/pytorch/pull/36355
@@ -2073,7 +2075,8 @@ def test_Conv2d_naive_groups(self, device, dtype):
20732075
output2 = m2(i2)
20742076
output2.backward(grad_output[:, 2:].contiguous())
20752077

2076-
self.assertEqual(output, torch.cat([output1, output2], 1))
2078+
self.assertEqual(output, torch.cat([output1, output2], 1),
2079+
atol=dtype2prec_DONTUSE[dtype], rtol=0)
20772080
self.assertEqual(i.grad.data,
20782081
torch.cat([i1.grad.data, i2.grad.data], 1),
20792082
atol=dtype2prec_DONTUSE[dtype], rtol=0)
@@ -2082,7 +2085,7 @@ def test_Conv2d_naive_groups(self, device, dtype):
20822085
atol=dtype2prec_DONTUSE[dtype], rtol=0)
20832086
self.assertEqual(m.weight.grad.data,
20842087
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
2085-
atol=dtype2prec_DONTUSE[dtype], rtol=0)
2088+
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
20862089

20872090
@dtypes(torch.double, torch.cdouble)
20882091
def test_Conv2d_backward_depthwise(self, device, dtype):

0 commit comments

Comments
 (0)