Skip to content

Commit cc2361b

Browse files
authored
Merge pull request #209 from iotamudelta/distri_tests
Distribution tests: more tests
2 parents 0974a83 + 34b0864 commit cc2361b

File tree

1 file changed

+0
-33
lines changed

1 file changed

+0
-33
lines changed

test/test_distributions.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,6 @@ def test_repr(self):
724724
dist = Dist(**param)
725725
self.assertTrue(repr(dist).startswith(dist.__class__.__name__))
726726

727-
@skipIfRocm
728727
def test_sample_detached(self):
729728
for Dist, params in EXAMPLES:
730729
for i, param in enumerate(params):
@@ -737,7 +736,6 @@ def test_sample_detached(self):
737736
msg='{} example {}/{}, .sample() is not detached'.format(
738737
Dist.__name__, i + 1, len(params)))
739738

740-
@skipIfRocm
741739
def test_rsample_requires_grad(self):
742740
for Dist, params in EXAMPLES:
743741
for i, param in enumerate(params):
@@ -751,7 +749,6 @@ def test_rsample_requires_grad(self):
751749
msg='{} example {}/{}, .rsample() does not require grad'.format(
752750
Dist.__name__, i + 1, len(params)))
753751

754-
@skipIfRocm
755752
def test_enumerate_support_type(self):
756753
for Dist, params in EXAMPLES:
757754
for i, param in enumerate(params):
@@ -763,7 +760,6 @@ def test_enumerate_support_type(self):
763760
except NotImplementedError:
764761
pass
765762

766-
@skipIfRocm
767763
def test_lazy_property_grad(self):
768764
x = torch.randn(1, requires_grad=True)
769765

@@ -1541,7 +1537,6 @@ def test_normal_sample(self):
15411537
scipy.stats.norm(loc=loc, scale=scale),
15421538
'Normal(mean={}, std={})'.format(loc, scale))
15431539

1544-
@skipIfRocm
15451540
def test_lowrank_multivariate_normal_shape(self):
15461541
mean = torch.randn(5, 3, requires_grad=True)
15471542
mean_no_batch = torch.randn(3, requires_grad=True)
@@ -1590,7 +1585,6 @@ def test_lowrank_multivariate_normal_shape(self):
15901585
(mean_multi_batch, cov_factor_batched, cov_diag_batched))
15911586

15921587
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1593-
@skipIfRocm
15941588
def test_lowrank_multivariate_normal_log_prob(self):
15951589
mean = torch.randn(3, requires_grad=True)
15961590
cov_factor = torch.randn(3, 1, requires_grad=True)
@@ -1624,7 +1618,6 @@ def test_lowrank_multivariate_normal_log_prob(self):
16241618
self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3)
16251619

16261620
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1627-
@skipIfRocm
16281621
def test_lowrank_multivariate_normal_sample(self):
16291622
set_rng_seed(0) # see Note [Randomized statistical tests]
16301623
mean = torch.randn(5, requires_grad=True)
@@ -1637,7 +1630,6 @@ def test_lowrank_multivariate_normal_sample(self):
16371630
'LowRankMultivariateNormal(loc={}, cov_factor={}, cov_diag={})'
16381631
.format(mean, cov_factor, cov_diag), multivariate=True)
16391632

1640-
@skipIfRocm
16411633
def test_lowrank_multivariate_normal_properties(self):
16421634
loc = torch.randn(5)
16431635
cov_factor = torch.randn(5, 2)
@@ -1652,7 +1644,6 @@ def test_lowrank_multivariate_normal_properties(self):
16521644
self.assertEqual(m1.precision_matrix, m2.precision_matrix)
16531645
self.assertEqual(m1.entropy(), m2.entropy())
16541646

1655-
@skipIfRocm
16561647
def test_lowrank_multivariate_normal_moments(self):
16571648
set_rng_seed(0) # see Note [Randomized statistical tests]
16581649
mean = torch.randn(5)
@@ -1665,7 +1656,6 @@ def test_lowrank_multivariate_normal_moments(self):
16651656
empirical_var = samples.var(0)
16661657
self.assertEqual(d.variance, empirical_var, prec=0.02)
16671658

1668-
@skipIfRocm
16691659
def test_multivariate_normal_shape(self):
16701660
mean = torch.randn(5, 3, requires_grad=True)
16711661
mean_no_batch = torch.randn(3, requires_grad=True)
@@ -1713,7 +1703,6 @@ def test_multivariate_normal_shape(self):
17131703
self._gradcheck_log_prob(MultivariateNormal, (mean_no_batch, None, None, scale_tril_batched))
17141704

17151705
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1716-
@skipIfRocm
17171706
def test_multivariate_normal_log_prob(self):
17181707
mean = torch.randn(3, requires_grad=True)
17191708
tmp = torch.randn(3, 10)
@@ -1751,7 +1740,6 @@ def test_multivariate_normal_log_prob(self):
17511740
self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3)
17521741

17531742
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1754-
@skipIfRocm
17551743
def test_multivariate_normal_sample(self):
17561744
set_rng_seed(0) # see Note [Randomized statistical tests]
17571745
mean = torch.randn(3, requires_grad=True)
@@ -1773,7 +1761,6 @@ def test_multivariate_normal_sample(self):
17731761
'MultivariateNormal(loc={}, scale_tril={})'.format(mean, scale_tril),
17741762
multivariate=True)
17751763

1776-
@skipIfRocm
17771764
def test_multivariate_normal_properties(self):
17781765
loc = torch.randn(5)
17791766
scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
@@ -1902,7 +1889,6 @@ def ref_log_prob(idx, x, log_prob):
19021889

19031890
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
19041891
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1905-
@skipIfRocm
19061892
def test_gamma_gpu_shape(self):
19071893
alpha = torch.tensor(torch.exp(torch.randn(2, 3).cuda()), requires_grad=True)
19081894
beta = torch.tensor(torch.exp(torch.randn(2, 3).cuda()), requires_grad=True)
@@ -2166,7 +2152,6 @@ def test_beta_sample(self):
21662152
x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0]
21672153
self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x))
21682154

2169-
@skipIfRocm
21702155
def test_independent_shape(self):
21712156
for Dist, params in EXAMPLES:
21722157
for i, param in enumerate(params):
@@ -2195,7 +2180,6 @@ def test_independent_shape(self):
21952180
except NotImplementedError:
21962181
pass
21972182

2198-
@skipIfRocm
21992183
def test_cdf_icdf_inverse(self):
22002184
# Tests the invertibility property on the distributions
22012185
for Dist, params in EXAMPLES:
@@ -2215,7 +2199,6 @@ def test_cdf_icdf_inverse(self):
22152199
'icdf(cdf(x)) = {}'.format(actual),
22162200
]))
22172201

2218-
@skipIfRocm
22192202
def test_cdf_log_prob(self):
22202203
# Tests if the differentiation of the CDF gives the PDF at a given value
22212204
for Dist, params in EXAMPLES:
@@ -2607,7 +2590,6 @@ def tearDown(self):
26072590
super(TestCase, self).tearDown()
26082591
Distribution.set_default_validate_args(False)
26092592

2610-
@skipIfRocm
26112593
def test_entropy_shape(self):
26122594
for Dist, params in EXAMPLES:
26132595
for i, param in enumerate(params):
@@ -3161,7 +3143,6 @@ def test_kl_monte_carlo(self):
31613143

31623144
# Multivariate normal has a separate Monte Carlo based test due to the requirement of random generation of
31633145
# positive (semi) definite matrices. n is set to 5, but can be increased during testing.
3164-
@skipIfRocm
31653146
def test_kl_multivariate_normal(self):
31663147
set_rng_seed(0) # see Note [Randomized statistical tests]
31673148
n = 5 # Number of tests for multivariate_normal
@@ -3187,7 +3168,6 @@ def test_kl_multivariate_normal(self):
31873168
'Actual (analytic): {}'.format(actual),
31883169
]))
31893170

3190-
@skipIfRocm
31913171
def test_kl_multivariate_normal_batched(self):
31923172
b = 7 # Number of batches
31933173
loc = [torch.randn(b, 3) for _ in range(0, 2)]
@@ -3199,7 +3179,6 @@ def test_kl_multivariate_normal_batched(self):
31993179
MultivariateNormal(loc[1], scale_tril=scale_tril[1]))
32003180
self.assertEqual(expected_kl, actual_kl)
32013181

3202-
@skipIfRocm
32033182
def test_kl_multivariate_normal_batched_broadcasted(self):
32043183
b = 7 # Number of batches
32053184
loc = [torch.randn(b, 3) for _ in range(0, 2)]
@@ -3212,7 +3191,6 @@ def test_kl_multivariate_normal_batched_broadcasted(self):
32123191
MultivariateNormal(loc[1], scale_tril=scale_tril[1]))
32133192
self.assertEqual(expected_kl, actual_kl)
32143193

3215-
@skipIfRocm
32163194
def test_kl_lowrank_multivariate_normal(self):
32173195
set_rng_seed(0) # see Note [Randomized statistical tests]
32183196
n = 5 # Number of tests for lowrank_multivariate_normal
@@ -3253,7 +3231,6 @@ def test_kl_lowrank_multivariate_normal(self):
32533231
'Actual (analytic): {}'.format(actual_full_lowrank),
32543232
]))
32553233

3256-
@skipIfRocm
32573234
def test_kl_lowrank_multivariate_normal_batched(self):
32583235
b = 7 # Number of batches
32593236
loc = [torch.randn(b, 3) for _ in range(0, 2)]
@@ -3289,7 +3266,6 @@ def test_kl_edgecases(self):
32893266
self.assertEqual(kl_divergence(Bernoulli(1), Bernoulli(1)), 0)
32903267
self.assertEqual(kl_divergence(Categorical(torch.tensor([0., 1.])), Categorical(torch.tensor([0., 1.]))), 0)
32913268

3292-
@skipIfRocm
32933269
def test_kl_shape(self):
32943270
for Dist, params in EXAMPLES:
32953271
for i, param in enumerate(params):
@@ -3305,7 +3281,6 @@ def test_kl_shape(self):
33053281
'Actual {}'.format(kl.shape),
33063282
]))
33073283

3308-
@skipIfRocm
33093284
def test_entropy_monte_carlo(self):
33103285
set_rng_seed(0) # see Note [Randomized statistical tests]
33113286
for Dist, params in EXAMPLES:
@@ -3349,7 +3324,6 @@ def test_entropy_exponential_family(self):
33493324

33503325

33513326
class TestConstraints(TestCase):
3352-
@skipIfRocm
33533327
def test_params_contains(self):
33543328
for Dist, params in EXAMPLES:
33553329
for i, param in enumerate(params):
@@ -3373,7 +3347,6 @@ def test_params_contains(self):
33733347
Dist.__name__, i + 1, len(params), name, value)
33743348
self.assertTrue(constraint.check(value).all(), msg=message)
33753349

3376-
@skipIfRocm
33773350
def test_support_contains(self):
33783351
for Dist, params in EXAMPLES:
33793352
self.assertIsInstance(Dist.support, Constraint)
@@ -3653,7 +3626,6 @@ def setUp(self):
36533626
)
36543627
]
36553628

3656-
@skipIfRocm
36573629
def test_mean(self):
36583630
for pytorch_dist, scipy_dist in self.distribution_pairs:
36593631
if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
@@ -3664,7 +3636,6 @@ def test_mean(self):
36643636
else:
36653637
self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), allow_inf=True, message=pytorch_dist)
36663638

3667-
@skipIfRocm
36683639
def test_variance_stddev(self):
36693640
for pytorch_dist, scipy_dist in self.distribution_pairs:
36703641
if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
@@ -3680,7 +3651,6 @@ def test_variance_stddev(self):
36803651
self.assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=True, message=pytorch_dist)
36813652
self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, message=pytorch_dist)
36823653

3683-
@skipIfRocm
36843654
def test_cdf(self):
36853655
for pytorch_dist, scipy_dist in self.distribution_pairs:
36863656
samples = pytorch_dist.sample((5,))
@@ -3690,7 +3660,6 @@ def test_cdf(self):
36903660
continue
36913661
self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist)
36923662

3693-
@skipIfRocm
36943663
def test_icdf(self):
36953664
for pytorch_dist, scipy_dist in self.distribution_pairs:
36963665
samples = torch.rand((5,) + pytorch_dist.batch_shape)
@@ -3996,7 +3965,6 @@ def test_biject_to(self):
39963965
self.assertEqual(j.shape, x.shape[:x.dim() - t.event_dim])
39973966

39983967
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
3999-
@skipIfRocm
40003968
def test_biject_to_cuda(self):
40013969
for constraint in self.get_constraints(is_cuda=True):
40023970
try:
@@ -4047,7 +4015,6 @@ def setUp(self):
40474015
super(TestCase, self).setUp()
40484016
Distribution.set_default_validate_args(True)
40494017

4050-
@skipIfRocm
40514018
def test_valid(self):
40524019
for Dist, params in EXAMPLES:
40534020
for i, param in enumerate(params):

0 commit comments

Comments
 (0)