Skip to content

Commit 7e0e9cd

Browse files
authored
ROCm: Fix test_nadam (#1006)
Change the rtol level Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent cbf022a commit 7e0e9cd

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

test/test_optim.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
_LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR, ChainedScheduler, \
2121
EPOCH_DEPRECATION_WARNING
2222
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
23-
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
23+
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, TEST_WITH_UBSAN, load_tests, \
2424
skipIfRocm
2525
# load_tests from common_utils is used to automatically filter tests for
2626
# sharding on sandcastle. This line silences flake warnings
@@ -663,6 +663,8 @@ def test_adadelta_complex(self):
663663
)
664664

665665
def test_nadam(self):
666+
if TEST_WITH_ROCM:
667+
self.rel_tol = 1e-5
666668
for optimizer in [optim.NAdam, optim_mt.NAdam]:
667669
self._test_basic_cases(
668670
lambda weight, bias: optimizer([weight, bias], lr=1e-3)

0 commit comments

Comments
 (0)