Skip to content

Commit eea9cdc

Browse files
authored
Merge pull request #134 from iotamudelta/distri_tests
Enable distribution tests
2 parents b5cc4eb + 861923d commit eea9cdc

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

test/run_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
'c10d',
4747
'cpp_extensions',
4848
'distributed',
49-
'distributions',
5049
'multiprocessing',
5150
'nccl',
5251
'thd_distributed',

test/test_distributions.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
import torch
3333
from torch._six import inf
34-
from common import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN
34+
from common import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, skipIfRocm
3535
from common_cuda import TEST_CUDA
3636
from torch.autograd import grad, gradcheck
3737
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
@@ -724,6 +724,7 @@ def test_repr(self):
724724
dist = Dist(**param)
725725
self.assertTrue(repr(dist).startswith(dist.__class__.__name__))
726726

727+
@skipIfRocm
727728
def test_sample_detached(self):
728729
for Dist, params in EXAMPLES:
729730
for i, param in enumerate(params):
@@ -736,6 +737,7 @@ def test_sample_detached(self):
736737
msg='{} example {}/{}, .sample() is not detached'.format(
737738
Dist.__name__, i + 1, len(params)))
738739

740+
@skipIfRocm
739741
def test_rsample_requires_grad(self):
740742
for Dist, params in EXAMPLES:
741743
for i, param in enumerate(params):
@@ -749,6 +751,7 @@ def test_rsample_requires_grad(self):
749751
msg='{} example {}/{}, .rsample() does not require grad'.format(
750752
Dist.__name__, i + 1, len(params)))
751753

754+
@skipIfRocm
752755
def test_enumerate_support_type(self):
753756
for Dist, params in EXAMPLES:
754757
for i, param in enumerate(params):
@@ -760,6 +763,7 @@ def test_enumerate_support_type(self):
760763
except NotImplementedError:
761764
pass
762765

766+
@skipIfRocm
763767
def test_lazy_property_grad(self):
764768
x = torch.randn(1, requires_grad=True)
765769

@@ -1129,6 +1133,7 @@ def test_poisson_sample(self):
11291133

11301134
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
11311135
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1136+
@skipIfRocm
11321137
def test_poisson_gpu_sample(self):
11331138
set_rng_seed(1)
11341139
for rate in [0.12, 0.9, 4.0]:
@@ -1536,6 +1541,7 @@ def test_normal_sample(self):
15361541
scipy.stats.norm(loc=loc, scale=scale),
15371542
'Normal(mean={}, std={})'.format(loc, scale))
15381543

1544+
@skipIfRocm
15391545
def test_lowrank_multivariate_normal_shape(self):
15401546
mean = torch.randn(5, 3, requires_grad=True)
15411547
mean_no_batch = torch.randn(3, requires_grad=True)
@@ -1584,6 +1590,7 @@ def test_lowrank_multivariate_normal_shape(self):
15841590
(mean_multi_batch, cov_factor_batched, cov_diag_batched))
15851591

15861592
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1593+
@skipIfRocm
15871594
def test_lowrank_multivariate_normal_log_prob(self):
15881595
mean = torch.randn(3, requires_grad=True)
15891596
cov_factor = torch.randn(3, 1, requires_grad=True)
@@ -1617,6 +1624,7 @@ def test_lowrank_multivariate_normal_log_prob(self):
16171624
self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3)
16181625

16191626
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1627+
@skipIfRocm
16201628
def test_lowrank_multivariate_normal_sample(self):
16211629
set_rng_seed(0) # see Note [Randomized statistical tests]
16221630
mean = torch.randn(5, requires_grad=True)
@@ -1629,6 +1637,7 @@ def test_lowrank_multivariate_normal_sample(self):
16291637
'LowRankMultivariateNormal(loc={}, cov_factor={}, cov_diag={})'
16301638
.format(mean, cov_factor, cov_diag), multivariate=True)
16311639

1640+
@skipIfRocm
16321641
def test_lowrank_multivariate_normal_properties(self):
16331642
loc = torch.randn(5)
16341643
cov_factor = torch.randn(5, 2)
@@ -1643,6 +1652,7 @@ def test_lowrank_multivariate_normal_properties(self):
16431652
self.assertEqual(m1.precision_matrix, m2.precision_matrix)
16441653
self.assertEqual(m1.entropy(), m2.entropy())
16451654

1655+
@skipIfRocm
16461656
def test_lowrank_multivariate_normal_moments(self):
16471657
set_rng_seed(0) # see Note [Randomized statistical tests]
16481658
mean = torch.randn(5)
@@ -1655,6 +1665,7 @@ def test_lowrank_multivariate_normal_moments(self):
16551665
empirical_var = samples.var(0)
16561666
self.assertEqual(d.variance, empirical_var, prec=0.02)
16571667

1668+
@skipIfRocm
16581669
def test_multivariate_normal_shape(self):
16591670
mean = torch.randn(5, 3, requires_grad=True)
16601671
mean_no_batch = torch.randn(3, requires_grad=True)
@@ -1702,6 +1713,7 @@ def test_multivariate_normal_shape(self):
17021713
self._gradcheck_log_prob(MultivariateNormal, (mean_no_batch, None, None, scale_tril_batched))
17031714

17041715
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1716+
@skipIfRocm
17051717
def test_multivariate_normal_log_prob(self):
17061718
mean = torch.randn(3, requires_grad=True)
17071719
tmp = torch.randn(3, 10)
@@ -1739,6 +1751,7 @@ def test_multivariate_normal_log_prob(self):
17391751
self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3)
17401752

17411753
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1754+
@skipIfRocm
17421755
def test_multivariate_normal_sample(self):
17431756
set_rng_seed(0) # see Note [Randomized statistical tests]
17441757
mean = torch.randn(3, requires_grad=True)
@@ -1760,6 +1773,7 @@ def test_multivariate_normal_sample(self):
17601773
'MultivariateNormal(loc={}, scale_tril={})'.format(mean, scale_tril),
17611774
multivariate=True)
17621775

1776+
@skipIfRocm
17631777
def test_multivariate_normal_properties(self):
17641778
loc = torch.randn(5)
17651779
scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
@@ -1888,6 +1902,7 @@ def ref_log_prob(idx, x, log_prob):
18881902

18891903
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
18901904
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1905+
@skipIfRocm
18911906
def test_gamma_gpu_shape(self):
18921907
alpha = torch.tensor(torch.exp(torch.randn(2, 3).cuda()), requires_grad=True)
18931908
beta = torch.tensor(torch.exp(torch.randn(2, 3).cuda()), requires_grad=True)
@@ -1918,6 +1933,7 @@ def test_gamma_sample(self):
19181933

19191934
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
19201935
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1936+
@skipIfRocm
19211937
def test_gamma_gpu_sample(self):
19221938
set_rng_seed(0)
19231939
for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
@@ -2150,6 +2166,7 @@ def test_beta_sample(self):
21502166
x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0]
21512167
self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x))
21522168

2169+
@skipIfRocm
21532170
def test_independent_shape(self):
21542171
for Dist, params in EXAMPLES:
21552172
for i, param in enumerate(params):
@@ -2178,6 +2195,7 @@ def test_independent_shape(self):
21782195
except NotImplementedError:
21792196
pass
21802197

2198+
@skipIfRocm
21812199
def test_cdf_icdf_inverse(self):
21822200
# Tests the invertibility property on the distributions
21832201
for Dist, params in EXAMPLES:
@@ -2197,6 +2215,7 @@ def test_cdf_icdf_inverse(self):
21972215
'icdf(cdf(x)) = {}'.format(actual),
21982216
]))
21992217

2218+
@skipIfRocm
22002219
def test_cdf_log_prob(self):
22012220
# Tests if the differentiation of the CDF gives the PDF at a given value
22022221
for Dist, params in EXAMPLES:
@@ -2588,6 +2607,7 @@ def tearDown(self):
25882607
super(TestCase, self).tearDown()
25892608
Distribution.set_default_validate_args(False)
25902609

2610+
@skipIfRocm
25912611
def test_entropy_shape(self):
25922612
for Dist, params in EXAMPLES:
25932613
for i, param in enumerate(params):
@@ -3141,6 +3161,7 @@ def test_kl_monte_carlo(self):
31413161

31423162
# Multivariate normal has a separate Monte Carlo based test due to the requirement of random generation of
31433163
# positive (semi) definite matrices. n is set to 5, but can be increased during testing.
3164+
@skipIfRocm
31443165
def test_kl_multivariate_normal(self):
31453166
set_rng_seed(0) # see Note [Randomized statistical tests]
31463167
n = 5 # Number of tests for multivariate_normal
@@ -3166,6 +3187,7 @@ def test_kl_multivariate_normal(self):
31663187
'Actual (analytic): {}'.format(actual),
31673188
]))
31683189

3190+
@skipIfRocm
31693191
def test_kl_multivariate_normal_batched(self):
31703192
b = 7 # Number of batches
31713193
loc = [torch.randn(b, 3) for _ in range(0, 2)]
@@ -3177,6 +3199,7 @@ def test_kl_multivariate_normal_batched(self):
31773199
MultivariateNormal(loc[1], scale_tril=scale_tril[1]))
31783200
self.assertEqual(expected_kl, actual_kl)
31793201

3202+
@skipIfRocm
31803203
def test_kl_multivariate_normal_batched_broadcasted(self):
31813204
b = 7 # Number of batches
31823205
loc = [torch.randn(b, 3) for _ in range(0, 2)]
@@ -3189,6 +3212,7 @@ def test_kl_multivariate_normal_batched_broadcasted(self):
31893212
MultivariateNormal(loc[1], scale_tril=scale_tril[1]))
31903213
self.assertEqual(expected_kl, actual_kl)
31913214

3215+
@skipIfRocm
31923216
def test_kl_lowrank_multivariate_normal(self):
31933217
set_rng_seed(0) # see Note [Randomized statistical tests]
31943218
n = 5 # Number of tests for lowrank_multivariate_normal
@@ -3229,6 +3253,7 @@ def test_kl_lowrank_multivariate_normal(self):
32293253
'Actual (analytic): {}'.format(actual_full_lowrank),
32303254
]))
32313255

3256+
@skipIfRocm
32323257
def test_kl_lowrank_multivariate_normal_batched(self):
32333258
b = 7 # Number of batches
32343259
loc = [torch.randn(b, 3) for _ in range(0, 2)]
@@ -3264,6 +3289,7 @@ def test_kl_edgecases(self):
32643289
self.assertEqual(kl_divergence(Bernoulli(1), Bernoulli(1)), 0)
32653290
self.assertEqual(kl_divergence(Categorical(torch.tensor([0., 1.])), Categorical(torch.tensor([0., 1.]))), 0)
32663291

3292+
@skipIfRocm
32673293
def test_kl_shape(self):
32683294
for Dist, params in EXAMPLES:
32693295
for i, param in enumerate(params):
@@ -3279,6 +3305,7 @@ def test_kl_shape(self):
32793305
'Actual {}'.format(kl.shape),
32803306
]))
32813307

3308+
@skipIfRocm
32823309
def test_entropy_monte_carlo(self):
32833310
set_rng_seed(0) # see Note [Randomized statistical tests]
32843311
for Dist, params in EXAMPLES:
@@ -3322,6 +3349,7 @@ def test_entropy_exponential_family(self):
33223349

33233350

33243351
class TestConstraints(TestCase):
3352+
@skipIfRocm
33253353
def test_params_contains(self):
33263354
for Dist, params in EXAMPLES:
33273355
for i, param in enumerate(params):
@@ -3345,6 +3373,7 @@ def test_params_contains(self):
33453373
Dist.__name__, i + 1, len(params), name, value)
33463374
self.assertTrue(constraint.check(value).all(), msg=message)
33473375

3376+
@skipIfRocm
33483377
def test_support_contains(self):
33493378
for Dist, params in EXAMPLES:
33503379
self.assertIsInstance(Dist.support, Constraint)
@@ -3624,6 +3653,7 @@ def setUp(self):
36243653
)
36253654
]
36263655

3656+
@skipIfRocm
36273657
def test_mean(self):
36283658
for pytorch_dist, scipy_dist in self.distribution_pairs:
36293659
if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
@@ -3634,6 +3664,7 @@ def test_mean(self):
36343664
else:
36353665
self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), allow_inf=True, message=pytorch_dist)
36363666

3667+
@skipIfRocm
36373668
def test_variance_stddev(self):
36383669
for pytorch_dist, scipy_dist in self.distribution_pairs:
36393670
if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
@@ -3649,6 +3680,7 @@ def test_variance_stddev(self):
36493680
self.assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=True, message=pytorch_dist)
36503681
self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, message=pytorch_dist)
36513682

3683+
@skipIfRocm
36523684
def test_cdf(self):
36533685
for pytorch_dist, scipy_dist in self.distribution_pairs:
36543686
samples = pytorch_dist.sample((5,))
@@ -3658,6 +3690,7 @@ def test_cdf(self):
36583690
continue
36593691
self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist)
36603692

3693+
@skipIfRocm
36613694
def test_icdf(self):
36623695
for pytorch_dist, scipy_dist in self.distribution_pairs:
36633696
samples = torch.rand((5,) + pytorch_dist.batch_shape)
@@ -3963,6 +3996,7 @@ def test_biject_to(self):
39633996
self.assertEqual(j.shape, x.shape[:x.dim() - t.event_dim])
39643997

39653998
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
3999+
@skipIfRocm
39664000
def test_biject_to_cuda(self):
39674001
for constraint in self.get_constraints(is_cuda=True):
39684002
try:
@@ -3995,6 +4029,7 @@ def test_transform_to(self):
39954029
self.assertEqual(y, y2, message="Error in transform_to({}) pseudoinverse".format(constraint))
39964030

39974031
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
4032+
@skipIfRocm
39984033
def test_transform_to_cuda(self):
39994034
for constraint in self.get_constraints(is_cuda=True):
40004035
t = transform_to(constraint)
@@ -4012,6 +4047,7 @@ def setUp(self):
40124047
super(TestCase, self).setUp()
40134048
Distribution.set_default_validate_args(True)
40144049

4050+
@skipIfRocm
40154051
def test_valid(self):
40164052
for Dist, params in EXAMPLES:
40174053
for i, param in enumerate(params):

0 commit comments

Comments
 (0)