Skip to content

Commit 9110813

Browse files
jaglinuxpruthvistony
authored andcommitted
ROCm: Fix test_nadam (#1006)
Change the rtol level Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent d9b8ec4 commit 9110813

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

test/test_optim.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
EPOCH_DEPRECATION_WARNING
2020
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
2121
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
22-
parametrize, instantiate_parametrized_tests, gradcheck, skipIfRocm
22+
parametrize, instantiate_parametrized_tests, gradcheck, skipIfRocm, TEST_WITH_ROCM
2323
# load_tests from common_utils is used to automatically filter tests for
2424
# sharding on sandcastle. This line silences flake warnings
2525
load_tests = load_tests
@@ -777,6 +777,8 @@ def test_adadelta_complex(self):
777777
)
778778

779779
def test_nadam(self):
780+
if TEST_WITH_ROCM:
781+
self.rel_tol = 1e-5
780782
self._test_basic_cases(
781783
lambda weight, bias, foreach: optim.NAdam([weight, bias], lr=1e-3, foreach=foreach),
782784
constructor_accepts_foreach=True,

0 commit comments

Comments
 (0)