@@ -236,7 +236,6 @@ def _build_params_dict(self, weight, bias, **kwargs):
236
236
def _build_params_dict_single (self , weight , bias , ** kwargs ):
237
237
return [dict (params = bias , ** kwargs )]
238
238
239
- @skipIfRocm
240
239
def test_sgd (self ):
241
240
self ._test_rosenbrock (
242
241
lambda params : optim .SGD (params , lr = 1e-3 ),
@@ -273,7 +272,6 @@ def test_sgd_sparse(self):
273
272
lambda params : optim .SGD (params , lr = 5e-3 )
274
273
)
275
274
276
- @skipIfRocm
277
275
def test_adam (self ):
278
276
self ._test_rosenbrock (
279
277
lambda params : optim .Adam (params , lr = 1e-2 ),
@@ -311,7 +309,6 @@ def test_sparse_adam(self):
311
309
with self .assertRaisesRegex (ValueError , "Invalid beta parameter at index 0: 1.0" ):
312
310
optim .SparseAdam (None , lr = 1e-2 , betas = (1.0 , 0.0 ))
313
311
314
- @skipIfRocm
315
312
def test_adadelta (self ):
316
313
self ._test_rosenbrock (
317
314
lambda params : optim .Adadelta (params ),
@@ -335,7 +332,6 @@ def test_adadelta(self):
335
332
with self .assertRaisesRegex (ValueError , "Invalid rho value: 1.1" ):
336
333
optim .Adadelta (None , lr = 1e-2 , rho = 1.1 )
337
334
338
- @skipIfRocm
339
335
def test_adagrad (self ):
340
336
self ._test_rosenbrock (
341
337
lambda params : optim .Adagrad (params , lr = 1e-1 ),
@@ -394,7 +390,6 @@ def test_adamax(self):
394
390
with self .assertRaisesRegex (ValueError , "Invalid beta parameter at index 1: 1.0" ):
395
391
optim .Adamax (None , lr = 1e-2 , betas = (0.0 , 1.0 ))
396
392
397
- @skipIfRocm
398
393
def test_rmsprop (self ):
399
394
self ._test_rosenbrock (
400
395
lambda params : optim .RMSprop (params , lr = 1e-2 ),
@@ -419,7 +414,6 @@ def test_rmsprop(self):
419
414
with self .assertRaisesRegex (ValueError , "Invalid momentum value: -1.0" ):
420
415
optim .RMSprop (None , lr = 1e-2 , momentum = - 1.0 )
421
416
422
- @skipIfRocm
423
417
def test_asgd (self ):
424
418
self ._test_rosenbrock (
425
419
lambda params : optim .ASGD (params , lr = 1e-3 ),
@@ -469,7 +463,6 @@ def test_rprop(self):
469
463
with self .assertRaisesRegex (ValueError , "Invalid eta values: 1.0, 0.5" ):
470
464
optim .Rprop (None , lr = 1e-2 , etas = (1.0 , 0.5 ))
471
465
472
- @skipIfRocm
473
466
def test_lbfgs (self ):
474
467
self ._test_rosenbrock (
475
468
lambda params : optim .LBFGS (params ),
0 commit comments