@@ -236,6 +236,7 @@ def _build_params_dict(self, weight, bias, **kwargs):
236
236
def _build_params_dict_single (self , weight , bias , ** kwargs ):
237
237
return [dict (params = bias , ** kwargs )]
238
238
239
+ @unittest .skipIf (TEST_WITH_ROCM , "test doesn't currently work on the ROCm stack" )
239
240
def test_sgd (self ):
240
241
self ._test_rosenbrock (
241
242
lambda params : optim .SGD (params , lr = 1e-3 ),
@@ -272,6 +273,7 @@ def test_sgd_sparse(self):
272
273
lambda params : optim .SGD (params , lr = 5e-3 )
273
274
)
274
275
276
+ @unittest .skipIf (TEST_WITH_ROCM , "test doesn't currently work on the ROCm stack" )
275
277
def test_adam (self ):
276
278
self ._test_rosenbrock (
277
279
lambda params : optim .Adam (params , lr = 1e-2 ),
@@ -309,6 +311,7 @@ def test_sparse_adam(self):
309
311
with self .assertRaisesRegex (ValueError , "Invalid beta parameter at index 0: 1.0" ):
310
312
optim .SparseAdam (None , lr = 1e-2 , betas = (1.0 , 0.0 ))
311
313
314
+ @unittest .skipIf (TEST_WITH_ROCM , "test doesn't currently work on the ROCm stack" )
312
315
def test_adadelta (self ):
313
316
self ._test_rosenbrock (
314
317
lambda params : optim .Adadelta (params ),
@@ -332,6 +335,7 @@ def test_adadelta(self):
332
335
with self .assertRaisesRegex (ValueError , "Invalid rho value: 1.1" ):
333
336
optim .Adadelta (None , lr = 1e-2 , rho = 1.1 )
334
337
338
+ @unittest .skipIf (TEST_WITH_ROCM , "test doesn't currently work on the ROCm stack" )
335
339
def test_adagrad (self ):
336
340
self ._test_rosenbrock (
337
341
lambda params : optim .Adagrad (params , lr = 1e-1 ),
@@ -365,6 +369,7 @@ def test_adagrad_sparse(self):
365
369
lambda params : optim .Adagrad (params , lr = 1e-1 )
366
370
)
367
371
372
+ @unittest .skipIf (TEST_WITH_ROCM , "test doesn't currently work on the ROCm stack" )
368
373
def test_adamax (self ):
369
374
self ._test_rosenbrock (
370
375
lambda params : optim .Adamax (params , lr = 1e-1 ),
@@ -389,6 +394,7 @@ def test_adamax(self):
389
394
with self .assertRaisesRegex (ValueError , "Invalid beta parameter at index 1: 1.0" ):
390
395
optim .Adamax (None , lr = 1e-2 , betas = (0.0 , 1.0 ))
391
396
397
+ @unittest .skipIf (TEST_WITH_ROCM , "test doesn't currently work on the ROCm stack" )
392
398
def test_rmsprop (self ):
393
399
self ._test_rosenbrock (
394
400
lambda params : optim .RMSprop (params , lr = 1e-2 ),
@@ -413,6 +419,7 @@ def test_rmsprop(self):
413
419
with self .assertRaisesRegex (ValueError , "Invalid momentum value: -1.0" ):
414
420
optim .RMSprop (None , lr = 1e-2 , momentum = - 1.0 )
415
421
422
+ @unittest .skipIf (TEST_WITH_ROCM , "test doesn't currently work on the ROCm stack" )
416
423
def test_asgd (self ):
417
424
self ._test_rosenbrock (
418
425
lambda params : optim .ASGD (params , lr = 1e-3 ),
@@ -462,6 +469,7 @@ def test_rprop(self):
462
469
with self .assertRaisesRegex (ValueError , "Invalid eta values: 1.0, 0.5" ):
463
470
optim .Rprop (None , lr = 1e-2 , etas = (1.0 , 0.5 ))
464
471
472
+ @unittest .skipIf (TEST_WITH_ROCM , "test doesn't currently work on the ROCm stack" )
465
473
def test_lbfgs (self ):
466
474
self ._test_rosenbrock (
467
475
lambda params : optim .LBFGS (params ),
0 commit comments