Skip to content

Commit deeeb91

Browse files
xinyazhangpruthvistony
authored andcommitted
Fix test_Conv2d_groups related errors (SWDEV-416489) (#1269)
* Fix test_Conv2d_groups related errors (SWDEV-416489) This is due to incorrect atol/rtol settings for torch.half and torch.bfloat16 data types. * More adjustment on Navi 32
1 parent 45996e7 commit deeeb91

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

test/nn/test_convolution.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,11 @@ def test_Conv2d_groups_nobias(self):
507507
output2 = m2(i2)
508508
output2.backward(grad_output[:, 2:].contiguous())
509509

510-
self.assertEqual(output, torch.cat([output1, output2], 1))
510+
self.assertEqual(output, torch.cat([output1, output2], 1),
511+
atol=dtype2prec_DONTUSE[dtype], rtol=0)
511512
self.assertEqual(i.grad.data,
512513
torch.cat([i1.grad.data, i2.grad.data], 1),
513-
atol=dtype2prec_DONTUSE[dtype], rtol=0)
514+
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
514515
self.assertEqual(m.weight.grad.data,
515516
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
516517
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
@@ -547,13 +548,14 @@ def test_Conv2d_groups_nobias_v2(self):
547548
output2 = m2(i2)
548549
output2.backward(grad_output[:, 8:].contiguous())
549550

550-
self.assertEqual(output, torch.cat([output1, output2], 1))
551+
self.assertEqual(output, torch.cat([output1, output2], 1),
552+
atol=dtype2prec_DONTUSE[dtype], rtol=0)
551553
self.assertEqual(i.grad.data,
552554
torch.cat([i1.grad.data, i2.grad.data], 1),
553-
atol=dtype2prec_DONTUSE[dtype], rtol=0)
555+
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
554556
self.assertEqual(m.weight.grad.data,
555557
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
556-
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
558+
atol=2e-1 if dtype in [torch.half, torch.bfloat16] else dtype2prec_DONTUSE[dtype], rtol=0)
557559

558560
# CPU-only test for group conv3d fast implementation using bmm
559561
# See: https://github.com/pytorch/pytorch/pull/36355
@@ -2080,7 +2082,8 @@ def test_Conv2d_naive_groups(self, device, dtype):
20802082
output2 = m2(i2)
20812083
output2.backward(grad_output[:, 2:].contiguous())
20822084

2083-
self.assertEqual(output, torch.cat([output1, output2], 1))
2085+
self.assertEqual(output, torch.cat([output1, output2], 1),
2086+
atol=dtype2prec_DONTUSE[dtype], rtol=0)
20842087
self.assertEqual(i.grad.data,
20852088
torch.cat([i1.grad.data, i2.grad.data], 1),
20862089
atol=dtype2prec_DONTUSE[dtype], rtol=0)
@@ -2089,7 +2092,7 @@ def test_Conv2d_naive_groups(self, device, dtype):
20892092
atol=dtype2prec_DONTUSE[dtype], rtol=0)
20902093
self.assertEqual(m.weight.grad.data,
20912094
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
2092-
atol=dtype2prec_DONTUSE[dtype], rtol=0)
2095+
atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0)
20932096

20942097
@dtypes(torch.double, torch.cdouble)
20952098
def test_Conv2d_backward_depthwise(self, device, dtype):

0 commit comments

Comments
 (0)