31
31
32
32
import torch
33
33
from torch ._six import inf
34
- from common import TestCase , run_tests , set_rng_seed , TEST_WITH_UBSAN
34
+ from common import TestCase , run_tests , set_rng_seed , TEST_WITH_UBSAN , skipIfRocm
35
35
from common_cuda import TEST_CUDA
36
36
from torch .autograd import grad , gradcheck
37
37
from torch .distributions import (Bernoulli , Beta , Binomial , Categorical ,
@@ -724,6 +724,7 @@ def test_repr(self):
724
724
dist = Dist (** param )
725
725
self .assertTrue (repr (dist ).startswith (dist .__class__ .__name__ ))
726
726
727
+ @skipIfRocm
727
728
def test_sample_detached (self ):
728
729
for Dist , params in EXAMPLES :
729
730
for i , param in enumerate (params ):
@@ -736,6 +737,7 @@ def test_sample_detached(self):
736
737
msg = '{} example {}/{}, .sample() is not detached' .format (
737
738
Dist .__name__ , i + 1 , len (params )))
738
739
740
+ @skipIfRocm
739
741
def test_rsample_requires_grad (self ):
740
742
for Dist , params in EXAMPLES :
741
743
for i , param in enumerate (params ):
@@ -749,6 +751,7 @@ def test_rsample_requires_grad(self):
749
751
msg = '{} example {}/{}, .rsample() does not require grad' .format (
750
752
Dist .__name__ , i + 1 , len (params )))
751
753
754
+ @skipIfRocm
752
755
def test_enumerate_support_type (self ):
753
756
for Dist , params in EXAMPLES :
754
757
for i , param in enumerate (params ):
@@ -760,6 +763,7 @@ def test_enumerate_support_type(self):
760
763
except NotImplementedError :
761
764
pass
762
765
766
+ @skipIfRocm
763
767
def test_lazy_property_grad (self ):
764
768
x = torch .randn (1 , requires_grad = True )
765
769
@@ -1129,6 +1133,7 @@ def test_poisson_sample(self):
1129
1133
1130
1134
@unittest .skipIf (not TEST_CUDA , "CUDA not found" )
1131
1135
@unittest .skipIf (not TEST_NUMPY , "Numpy not found" )
1136
+ @skipIfRocm
1132
1137
def test_poisson_gpu_sample (self ):
1133
1138
set_rng_seed (1 )
1134
1139
for rate in [0.12 , 0.9 , 4.0 ]:
@@ -1536,6 +1541,7 @@ def test_normal_sample(self):
1536
1541
scipy .stats .norm (loc = loc , scale = scale ),
1537
1542
'Normal(mean={}, std={})' .format (loc , scale ))
1538
1543
1544
+ @skipIfRocm
1539
1545
def test_lowrank_multivariate_normal_shape (self ):
1540
1546
mean = torch .randn (5 , 3 , requires_grad = True )
1541
1547
mean_no_batch = torch .randn (3 , requires_grad = True )
@@ -1584,6 +1590,7 @@ def test_lowrank_multivariate_normal_shape(self):
1584
1590
(mean_multi_batch , cov_factor_batched , cov_diag_batched ))
1585
1591
1586
1592
@unittest .skipIf (not TEST_NUMPY , "Numpy not found" )
1593
+ @skipIfRocm
1587
1594
def test_lowrank_multivariate_normal_log_prob (self ):
1588
1595
mean = torch .randn (3 , requires_grad = True )
1589
1596
cov_factor = torch .randn (3 , 1 , requires_grad = True )
@@ -1617,6 +1624,7 @@ def test_lowrank_multivariate_normal_log_prob(self):
1617
1624
self .assertAlmostEqual (0.0 , (batched_prob - unbatched_prob ).abs ().max (), places = 3 )
1618
1625
1619
1626
@unittest .skipIf (not TEST_NUMPY , "NumPy not found" )
1627
+ @skipIfRocm
1620
1628
def test_lowrank_multivariate_normal_sample (self ):
1621
1629
set_rng_seed (0 ) # see Note [Randomized statistical tests]
1622
1630
mean = torch .randn (5 , requires_grad = True )
@@ -1629,6 +1637,7 @@ def test_lowrank_multivariate_normal_sample(self):
1629
1637
'LowRankMultivariateNormal(loc={}, cov_factor={}, cov_diag={})'
1630
1638
.format (mean , cov_factor , cov_diag ), multivariate = True )
1631
1639
1640
+ @skipIfRocm
1632
1641
def test_lowrank_multivariate_normal_properties (self ):
1633
1642
loc = torch .randn (5 )
1634
1643
cov_factor = torch .randn (5 , 2 )
@@ -1643,6 +1652,7 @@ def test_lowrank_multivariate_normal_properties(self):
1643
1652
self .assertEqual (m1 .precision_matrix , m2 .precision_matrix )
1644
1653
self .assertEqual (m1 .entropy (), m2 .entropy ())
1645
1654
1655
+ @skipIfRocm
1646
1656
def test_lowrank_multivariate_normal_moments (self ):
1647
1657
set_rng_seed (0 ) # see Note [Randomized statistical tests]
1648
1658
mean = torch .randn (5 )
@@ -1655,6 +1665,7 @@ def test_lowrank_multivariate_normal_moments(self):
1655
1665
empirical_var = samples .var (0 )
1656
1666
self .assertEqual (d .variance , empirical_var , prec = 0.02 )
1657
1667
1668
+ @skipIfRocm
1658
1669
def test_multivariate_normal_shape (self ):
1659
1670
mean = torch .randn (5 , 3 , requires_grad = True )
1660
1671
mean_no_batch = torch .randn (3 , requires_grad = True )
@@ -1702,6 +1713,7 @@ def test_multivariate_normal_shape(self):
1702
1713
self ._gradcheck_log_prob (MultivariateNormal , (mean_no_batch , None , None , scale_tril_batched ))
1703
1714
1704
1715
@unittest .skipIf (not TEST_NUMPY , "Numpy not found" )
1716
+ @skipIfRocm
1705
1717
def test_multivariate_normal_log_prob (self ):
1706
1718
mean = torch .randn (3 , requires_grad = True )
1707
1719
tmp = torch .randn (3 , 10 )
@@ -1739,6 +1751,7 @@ def test_multivariate_normal_log_prob(self):
1739
1751
self .assertAlmostEqual (0.0 , (batched_prob - unbatched_prob ).abs ().max (), places = 3 )
1740
1752
1741
1753
@unittest .skipIf (not TEST_NUMPY , "NumPy not found" )
1754
+ @skipIfRocm
1742
1755
def test_multivariate_normal_sample (self ):
1743
1756
set_rng_seed (0 ) # see Note [Randomized statistical tests]
1744
1757
mean = torch .randn (3 , requires_grad = True )
@@ -1760,6 +1773,7 @@ def test_multivariate_normal_sample(self):
1760
1773
'MultivariateNormal(loc={}, scale_tril={})' .format (mean , scale_tril ),
1761
1774
multivariate = True )
1762
1775
1776
+ @skipIfRocm
1763
1777
def test_multivariate_normal_properties (self ):
1764
1778
loc = torch .randn (5 )
1765
1779
scale_tril = transform_to (constraints .lower_cholesky )(torch .randn (5 , 5 ))
@@ -1888,6 +1902,7 @@ def ref_log_prob(idx, x, log_prob):
1888
1902
1889
1903
@unittest .skipIf (not TEST_CUDA , "CUDA not found" )
1890
1904
@unittest .skipIf (not TEST_NUMPY , "NumPy not found" )
1905
+ @skipIfRocm
1891
1906
def test_gamma_gpu_shape (self ):
1892
1907
alpha = torch .tensor (torch .exp (torch .randn (2 , 3 ).cuda ()), requires_grad = True )
1893
1908
beta = torch .tensor (torch .exp (torch .randn (2 , 3 ).cuda ()), requires_grad = True )
@@ -1918,6 +1933,7 @@ def test_gamma_sample(self):
1918
1933
1919
1934
@unittest .skipIf (not TEST_CUDA , "CUDA not found" )
1920
1935
@unittest .skipIf (not TEST_NUMPY , "Numpy not found" )
1936
+ @skipIfRocm
1921
1937
def test_gamma_gpu_sample (self ):
1922
1938
set_rng_seed (0 )
1923
1939
for alpha , beta in product ([0.1 , 1.0 , 5.0 ], [0.1 , 1.0 , 10.0 ]):
@@ -2150,6 +2166,7 @@ def test_beta_sample(self):
2150
2166
x = Beta (Tensor ([1e-6 ]), Tensor ([1e-6 ])).sample ()[0 ]
2151
2167
self .assertTrue (np .isfinite (x ) and x > 0 , 'Invalid Beta.sample(): {}' .format (x ))
2152
2168
2169
+ @skipIfRocm
2153
2170
def test_independent_shape (self ):
2154
2171
for Dist , params in EXAMPLES :
2155
2172
for i , param in enumerate (params ):
@@ -2178,6 +2195,7 @@ def test_independent_shape(self):
2178
2195
except NotImplementedError :
2179
2196
pass
2180
2197
2198
+ @skipIfRocm
2181
2199
def test_cdf_icdf_inverse (self ):
2182
2200
# Tests the invertibility property on the distributions
2183
2201
for Dist , params in EXAMPLES :
@@ -2197,6 +2215,7 @@ def test_cdf_icdf_inverse(self):
2197
2215
'icdf(cdf(x)) = {}' .format (actual ),
2198
2216
]))
2199
2217
2218
+ @skipIfRocm
2200
2219
def test_cdf_log_prob (self ):
2201
2220
# Tests if the differentiation of the CDF gives the PDF at a given value
2202
2221
for Dist , params in EXAMPLES :
@@ -2588,6 +2607,7 @@ def tearDown(self):
2588
2607
super (TestCase , self ).tearDown ()
2589
2608
Distribution .set_default_validate_args (False )
2590
2609
2610
+ @skipIfRocm
2591
2611
def test_entropy_shape (self ):
2592
2612
for Dist , params in EXAMPLES :
2593
2613
for i , param in enumerate (params ):
@@ -3141,6 +3161,7 @@ def test_kl_monte_carlo(self):
3141
3161
3142
3162
# Multivariate normal has a separate Monte Carlo based test due to the requirement of random generation of
3143
3163
# positive (semi) definite matrices. n is set to 5, but can be increased during testing.
3164
+ @skipIfRocm
3144
3165
def test_kl_multivariate_normal (self ):
3145
3166
set_rng_seed (0 ) # see Note [Randomized statistical tests]
3146
3167
n = 5 # Number of tests for multivariate_normal
@@ -3166,6 +3187,7 @@ def test_kl_multivariate_normal(self):
3166
3187
'Actual (analytic): {}' .format (actual ),
3167
3188
]))
3168
3189
3190
+ @skipIfRocm
3169
3191
def test_kl_multivariate_normal_batched (self ):
3170
3192
b = 7 # Number of batches
3171
3193
loc = [torch .randn (b , 3 ) for _ in range (0 , 2 )]
@@ -3177,6 +3199,7 @@ def test_kl_multivariate_normal_batched(self):
3177
3199
MultivariateNormal (loc [1 ], scale_tril = scale_tril [1 ]))
3178
3200
self .assertEqual (expected_kl , actual_kl )
3179
3201
3202
+ @skipIfRocm
3180
3203
def test_kl_multivariate_normal_batched_broadcasted (self ):
3181
3204
b = 7 # Number of batches
3182
3205
loc = [torch .randn (b , 3 ) for _ in range (0 , 2 )]
@@ -3189,6 +3212,7 @@ def test_kl_multivariate_normal_batched_broadcasted(self):
3189
3212
MultivariateNormal (loc [1 ], scale_tril = scale_tril [1 ]))
3190
3213
self .assertEqual (expected_kl , actual_kl )
3191
3214
3215
+ @skipIfRocm
3192
3216
def test_kl_lowrank_multivariate_normal (self ):
3193
3217
set_rng_seed (0 ) # see Note [Randomized statistical tests]
3194
3218
n = 5 # Number of tests for lowrank_multivariate_normal
@@ -3229,6 +3253,7 @@ def test_kl_lowrank_multivariate_normal(self):
3229
3253
'Actual (analytic): {}' .format (actual_full_lowrank ),
3230
3254
]))
3231
3255
3256
+ @skipIfRocm
3232
3257
def test_kl_lowrank_multivariate_normal_batched (self ):
3233
3258
b = 7 # Number of batches
3234
3259
loc = [torch .randn (b , 3 ) for _ in range (0 , 2 )]
@@ -3264,6 +3289,7 @@ def test_kl_edgecases(self):
3264
3289
self .assertEqual (kl_divergence (Bernoulli (1 ), Bernoulli (1 )), 0 )
3265
3290
self .assertEqual (kl_divergence (Categorical (torch .tensor ([0. , 1. ])), Categorical (torch .tensor ([0. , 1. ]))), 0 )
3266
3291
3292
+ @skipIfRocm
3267
3293
def test_kl_shape (self ):
3268
3294
for Dist , params in EXAMPLES :
3269
3295
for i , param in enumerate (params ):
@@ -3279,6 +3305,7 @@ def test_kl_shape(self):
3279
3305
'Actual {}' .format (kl .shape ),
3280
3306
]))
3281
3307
3308
+ @skipIfRocm
3282
3309
def test_entropy_monte_carlo (self ):
3283
3310
set_rng_seed (0 ) # see Note [Randomized statistical tests]
3284
3311
for Dist , params in EXAMPLES :
@@ -3322,6 +3349,7 @@ def test_entropy_exponential_family(self):
3322
3349
3323
3350
3324
3351
class TestConstraints (TestCase ):
3352
+ @skipIfRocm
3325
3353
def test_params_contains (self ):
3326
3354
for Dist , params in EXAMPLES :
3327
3355
for i , param in enumerate (params ):
@@ -3345,6 +3373,7 @@ def test_params_contains(self):
3345
3373
Dist .__name__ , i + 1 , len (params ), name , value )
3346
3374
self .assertTrue (constraint .check (value ).all (), msg = message )
3347
3375
3376
+ @skipIfRocm
3348
3377
def test_support_contains (self ):
3349
3378
for Dist , params in EXAMPLES :
3350
3379
self .assertIsInstance (Dist .support , Constraint )
@@ -3624,6 +3653,7 @@ def setUp(self):
3624
3653
)
3625
3654
]
3626
3655
3656
+ @skipIfRocm
3627
3657
def test_mean (self ):
3628
3658
for pytorch_dist , scipy_dist in self .distribution_pairs :
3629
3659
if isinstance (pytorch_dist , (Cauchy , HalfCauchy )):
@@ -3634,6 +3664,7 @@ def test_mean(self):
3634
3664
else :
3635
3665
self .assertEqual (pytorch_dist .mean , scipy_dist .mean (), allow_inf = True , message = pytorch_dist )
3636
3666
3667
+ @skipIfRocm
3637
3668
def test_variance_stddev (self ):
3638
3669
for pytorch_dist , scipy_dist in self .distribution_pairs :
3639
3670
if isinstance (pytorch_dist , (Cauchy , HalfCauchy )):
@@ -3649,6 +3680,7 @@ def test_variance_stddev(self):
3649
3680
self .assertEqual (pytorch_dist .variance , scipy_dist .var (), allow_inf = True , message = pytorch_dist )
3650
3681
self .assertEqual (pytorch_dist .stddev , scipy_dist .var () ** 0.5 , message = pytorch_dist )
3651
3682
3683
+ @skipIfRocm
3652
3684
def test_cdf (self ):
3653
3685
for pytorch_dist , scipy_dist in self .distribution_pairs :
3654
3686
samples = pytorch_dist .sample ((5 ,))
@@ -3658,6 +3690,7 @@ def test_cdf(self):
3658
3690
continue
3659
3691
self .assertEqual (cdf , scipy_dist .cdf (samples ), message = pytorch_dist )
3660
3692
3693
+ @skipIfRocm
3661
3694
def test_icdf (self ):
3662
3695
for pytorch_dist , scipy_dist in self .distribution_pairs :
3663
3696
samples = torch .rand ((5 ,) + pytorch_dist .batch_shape )
@@ -3963,6 +3996,7 @@ def test_biject_to(self):
3963
3996
self .assertEqual (j .shape , x .shape [:x .dim () - t .event_dim ])
3964
3997
3965
3998
@unittest .skipIf (not TEST_CUDA , "CUDA not found" )
3999
+ @skipIfRocm
3966
4000
def test_biject_to_cuda (self ):
3967
4001
for constraint in self .get_constraints (is_cuda = True ):
3968
4002
try :
@@ -3995,6 +4029,7 @@ def test_transform_to(self):
3995
4029
self .assertEqual (y , y2 , message = "Error in transform_to({}) pseudoinverse" .format (constraint ))
3996
4030
3997
4031
@unittest .skipIf (not TEST_CUDA , "CUDA not found" )
4032
+ @skipIfRocm
3998
4033
def test_transform_to_cuda (self ):
3999
4034
for constraint in self .get_constraints (is_cuda = True ):
4000
4035
t = transform_to (constraint )
@@ -4012,6 +4047,7 @@ def setUp(self):
4012
4047
super (TestCase , self ).setUp ()
4013
4048
Distribution .set_default_validate_args (True )
4014
4049
4050
+ @skipIfRocm
4015
4051
def test_valid (self ):
4016
4052
for Dist , params in EXAMPLES :
4017
4053
for i , param in enumerate (params ):
0 commit comments