Skip to content

Commit ae0336d

Browse files
authored
Merge pull request #101 from jithunnair-amd/enable_optim_unit_tests
Enable test_optim unit tests …
2 parents 18d0351 + b13d164 commit ae0336d

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

test/run_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
'multiprocessing',
5353
'nccl',
5454
'nn',
55-
'optim',
5655
'sparse',
5756
'torch',
5857
'utils',

test/test_optim.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def _build_params_dict(self, weight, bias, **kwargs):
236236
def _build_params_dict_single(self, weight, bias, **kwargs):
237237
return [dict(params=bias, **kwargs)]
238238

239+
@unittest.skipIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")
239240
def test_sgd(self):
240241
self._test_rosenbrock(
241242
lambda params: optim.SGD(params, lr=1e-3),
@@ -272,6 +273,7 @@ def test_sgd_sparse(self):
272273
lambda params: optim.SGD(params, lr=5e-3)
273274
)
274275

276+
@unittest.skipIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")
275277
def test_adam(self):
276278
self._test_rosenbrock(
277279
lambda params: optim.Adam(params, lr=1e-2),
@@ -309,6 +311,7 @@ def test_sparse_adam(self):
309311
with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
310312
optim.SparseAdam(None, lr=1e-2, betas=(1.0, 0.0))
311313

314+
@unittest.skipIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")
312315
def test_adadelta(self):
313316
self._test_rosenbrock(
314317
lambda params: optim.Adadelta(params),
@@ -332,6 +335,7 @@ def test_adadelta(self):
332335
with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"):
333336
optim.Adadelta(None, lr=1e-2, rho=1.1)
334337

338+
@unittest.skipIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")
335339
def test_adagrad(self):
336340
self._test_rosenbrock(
337341
lambda params: optim.Adagrad(params, lr=1e-1),
@@ -365,6 +369,7 @@ def test_adagrad_sparse(self):
365369
lambda params: optim.Adagrad(params, lr=1e-1)
366370
)
367371

372+
@unittest.skipIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")
368373
def test_adamax(self):
369374
self._test_rosenbrock(
370375
lambda params: optim.Adamax(params, lr=1e-1),
@@ -389,6 +394,7 @@ def test_adamax(self):
389394
with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 1: 1.0"):
390395
optim.Adamax(None, lr=1e-2, betas=(0.0, 1.0))
391396

397+
@unittest.skipIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")
392398
def test_rmsprop(self):
393399
self._test_rosenbrock(
394400
lambda params: optim.RMSprop(params, lr=1e-2),
@@ -413,6 +419,7 @@ def test_rmsprop(self):
413419
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"):
414420
optim.RMSprop(None, lr=1e-2, momentum=-1.0)
415421

422+
@unittest.skipIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")
416423
def test_asgd(self):
417424
self._test_rosenbrock(
418425
lambda params: optim.ASGD(params, lr=1e-3),
@@ -462,6 +469,7 @@ def test_rprop(self):
462469
with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"):
463470
optim.Rprop(None, lr=1e-2, etas=(1.0, 0.5))
464471

472+
@unittest.skipIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")
465473
def test_lbfgs(self):
466474
self._test_rosenbrock(
467475
lambda params: optim.LBFGS(params),

0 commit comments

Comments
 (0)