Skip to content

Enable distribution tests #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Sep 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
'c10d',
'cpp_extensions',
'distributed',
'distributions',
'multiprocessing',
'nccl',
'thd_distributed',
Expand Down
38 changes: 37 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

import torch
from torch._six import inf
from common import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN
from common import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, skipIfRocm
from common_cuda import TEST_CUDA
from torch.autograd import grad, gradcheck
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
Expand Down Expand Up @@ -713,6 +713,7 @@ def _check_enumerate_support(self, dist, examples):
actual = dist(param).enumerate_support()
self.assertEqual(actual, expected)

@skipIfRocm
def test_sample_detached(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
Expand All @@ -725,6 +726,7 @@ def test_sample_detached(self):
msg='{} example {}/{}, .sample() is not detached'.format(
Dist.__name__, i + 1, len(params)))

@skipIfRocm
def test_rsample_requires_grad(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
Expand All @@ -738,6 +740,7 @@ def test_rsample_requires_grad(self):
msg='{} example {}/{}, .rsample() does not require grad'.format(
Dist.__name__, i + 1, len(params)))

@skipIfRocm
def test_enumerate_support_type(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
Expand All @@ -749,6 +752,7 @@ def test_enumerate_support_type(self):
except NotImplementedError:
pass

@skipIfRocm
def test_lazy_property_grad(self):
x = torch.randn(1, requires_grad=True)

Expand Down Expand Up @@ -1117,6 +1121,7 @@ def test_poisson_sample(self):

@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
@skipIfRocm
def test_poisson_gpu_sample(self):
set_rng_seed(1)
for rate in [0.12, 0.9, 4.0]:
Expand Down Expand Up @@ -1524,6 +1529,7 @@ def test_normal_sample(self):
scipy.stats.norm(loc=loc, scale=scale),
'Normal(mean={}, std={})'.format(loc, scale))

@skipIfRocm
def test_lowrank_multivariate_normal_shape(self):
mean = torch.randn(5, 3, requires_grad=True)
mean_no_batch = torch.randn(3, requires_grad=True)
Expand Down Expand Up @@ -1572,6 +1578,7 @@ def test_lowrank_multivariate_normal_shape(self):
(mean_multi_batch, cov_factor_batched, cov_diag_batched))

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
@skipIfRocm
def test_lowrank_multivariate_normal_log_prob(self):
mean = torch.randn(3, requires_grad=True)
cov_factor = torch.randn(3, 1, requires_grad=True)
Expand Down Expand Up @@ -1605,6 +1612,7 @@ def test_lowrank_multivariate_normal_log_prob(self):
self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3)

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@skipIfRocm
def test_lowrank_multivariate_normal_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
mean = torch.randn(5, requires_grad=True)
Expand All @@ -1617,6 +1625,7 @@ def test_lowrank_multivariate_normal_sample(self):
'LowRankMultivariateNormal(loc={}, cov_factor={}, cov_diag={})'
.format(mean, cov_factor, cov_diag), multivariate=True)

@skipIfRocm
def test_lowrank_multivariate_normal_properties(self):
loc = torch.randn(5)
cov_factor = torch.randn(5, 2)
Expand All @@ -1631,6 +1640,7 @@ def test_lowrank_multivariate_normal_properties(self):
self.assertEqual(m1.precision_matrix, m2.precision_matrix)
self.assertEqual(m1.entropy(), m2.entropy())

@skipIfRocm
def test_lowrank_multivariate_normal_moments(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
mean = torch.randn(5)
Expand All @@ -1643,6 +1653,7 @@ def test_lowrank_multivariate_normal_moments(self):
empirical_var = samples.var(0)
self.assertEqual(d.variance, empirical_var, prec=0.02)

@skipIfRocm
def test_multivariate_normal_shape(self):
mean = torch.randn(5, 3, requires_grad=True)
mean_no_batch = torch.randn(3, requires_grad=True)
Expand Down Expand Up @@ -1690,6 +1701,7 @@ def test_multivariate_normal_shape(self):
self._gradcheck_log_prob(MultivariateNormal, (mean_no_batch, None, None, scale_tril_batched))

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
@skipIfRocm
def test_multivariate_normal_log_prob(self):
mean = torch.randn(3, requires_grad=True)
tmp = torch.randn(3, 10)
Expand Down Expand Up @@ -1727,6 +1739,7 @@ def test_multivariate_normal_log_prob(self):
self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3)

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@skipIfRocm
def test_multivariate_normal_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
mean = torch.randn(3, requires_grad=True)
Expand All @@ -1748,6 +1761,7 @@ def test_multivariate_normal_sample(self):
'MultivariateNormal(loc={}, scale_tril={})'.format(mean, scale_tril),
multivariate=True)

@skipIfRocm
def test_multivariate_normal_properties(self):
loc = torch.randn(5)
scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
Expand Down Expand Up @@ -1876,6 +1890,7 @@ def ref_log_prob(idx, x, log_prob):

@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@skipIfRocm
def test_gamma_gpu_shape(self):
alpha = torch.tensor(torch.exp(torch.randn(2, 3).cuda()), requires_grad=True)
beta = torch.tensor(torch.exp(torch.randn(2, 3).cuda()), requires_grad=True)
Expand Down Expand Up @@ -1906,6 +1921,7 @@ def test_gamma_sample(self):

@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
@skipIfRocm
def test_gamma_gpu_sample(self):
set_rng_seed(0)
for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
Expand Down Expand Up @@ -2138,6 +2154,7 @@ def test_beta_sample(self):
x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0]
self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x))

@skipIfRocm
def test_independent_shape(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
Expand Down Expand Up @@ -2166,6 +2183,7 @@ def test_independent_shape(self):
except NotImplementedError:
pass

@skipIfRocm
def test_cdf_icdf_inverse(self):
# Tests the invertibility property on the distributions
for Dist, params in EXAMPLES:
Expand All @@ -2185,6 +2203,7 @@ def test_cdf_icdf_inverse(self):
'icdf(cdf(x)) = {}'.format(actual),
]))

@skipIfRocm
def test_cdf_log_prob(self):
# Tests if the differentiation of the CDF gives the PDF at a given value
for Dist, params in EXAMPLES:
Expand Down Expand Up @@ -2576,6 +2595,7 @@ def tearDown(self):
super(TestCase, self).tearDown()
Distribution.set_default_validate_args(False)

@skipIfRocm
def test_entropy_shape(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
Expand Down Expand Up @@ -3129,6 +3149,7 @@ def test_kl_monte_carlo(self):

# Multivariate normal has a separate Monte Carlo based test due to the requirement of random generation of
# positive (semi) definite matrices. n is set to 5, but can be increased during testing.
@skipIfRocm
def test_kl_multivariate_normal(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
n = 5 # Number of tests for multivariate_normal
Expand All @@ -3154,6 +3175,7 @@ def test_kl_multivariate_normal(self):
'Actual (analytic): {}'.format(actual),
]))

@skipIfRocm
def test_kl_multivariate_normal_batched(self):
b = 7 # Number of batches
loc = [torch.randn(b, 3) for _ in range(0, 2)]
Expand All @@ -3165,6 +3187,7 @@ def test_kl_multivariate_normal_batched(self):
MultivariateNormal(loc[1], scale_tril=scale_tril[1]))
self.assertEqual(expected_kl, actual_kl)

@skipIfRocm
def test_kl_multivariate_normal_batched_broadcasted(self):
b = 7 # Number of batches
loc = [torch.randn(b, 3) for _ in range(0, 2)]
Expand All @@ -3177,6 +3200,7 @@ def test_kl_multivariate_normal_batched_broadcasted(self):
MultivariateNormal(loc[1], scale_tril=scale_tril[1]))
self.assertEqual(expected_kl, actual_kl)

@skipIfRocm
def test_kl_lowrank_multivariate_normal(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
n = 5 # Number of tests for lowrank_multivariate_normal
Expand Down Expand Up @@ -3217,6 +3241,7 @@ def test_kl_lowrank_multivariate_normal(self):
'Actual (analytic): {}'.format(actual_full_lowrank),
]))

@skipIfRocm
def test_kl_lowrank_multivariate_normal_batched(self):
b = 7 # Number of batches
loc = [torch.randn(b, 3) for _ in range(0, 2)]
Expand Down Expand Up @@ -3252,6 +3277,7 @@ def test_kl_edgecases(self):
self.assertEqual(kl_divergence(Bernoulli(1), Bernoulli(1)), 0)
self.assertEqual(kl_divergence(Categorical(torch.tensor([0., 1.])), Categorical(torch.tensor([0., 1.]))), 0)

@skipIfRocm
def test_kl_shape(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
Expand All @@ -3267,6 +3293,7 @@ def test_kl_shape(self):
'Actual {}'.format(kl.shape),
]))

@skipIfRocm
def test_entropy_monte_carlo(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for Dist, params in EXAMPLES:
Expand Down Expand Up @@ -3310,6 +3337,7 @@ def test_entropy_exponential_family(self):


class TestConstraints(TestCase):
@skipIfRocm
def test_params_contains(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
Expand All @@ -3333,6 +3361,7 @@ def test_params_contains(self):
Dist.__name__, i + 1, len(params), name, value)
self.assertTrue(constraint.check(value).all(), msg=message)

@skipIfRocm
def test_support_contains(self):
for Dist, params in EXAMPLES:
self.assertIsInstance(Dist.support, Constraint)
Expand Down Expand Up @@ -3612,6 +3641,7 @@ def setUp(self):
)
]

@skipIfRocm
def test_mean(self):
for pytorch_dist, scipy_dist in self.distribution_pairs:
if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
Expand All @@ -3622,6 +3652,7 @@ def test_mean(self):
else:
self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), allow_inf=True, message=pytorch_dist)

@skipIfRocm
def test_variance_stddev(self):
for pytorch_dist, scipy_dist in self.distribution_pairs:
if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
Expand All @@ -3637,6 +3668,7 @@ def test_variance_stddev(self):
self.assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=True, message=pytorch_dist)
self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, message=pytorch_dist)

@skipIfRocm
def test_cdf(self):
for pytorch_dist, scipy_dist in self.distribution_pairs:
samples = pytorch_dist.sample((5,))
Expand All @@ -3646,6 +3678,7 @@ def test_cdf(self):
continue
self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist)

@skipIfRocm
def test_icdf(self):
for pytorch_dist, scipy_dist in self.distribution_pairs:
samples = torch.rand((5,) + pytorch_dist.batch_shape)
Expand Down Expand Up @@ -3951,6 +3984,7 @@ def test_biject_to(self):
self.assertEqual(j.shape, x.shape[:x.dim() - t.event_dim])

@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@skipIfRocm
def test_biject_to_cuda(self):
for constraint in self.get_constraints(is_cuda=True):
try:
Expand Down Expand Up @@ -3983,6 +4017,7 @@ def test_transform_to(self):
self.assertEqual(y, y2, message="Error in transform_to({}) pseudoinverse".format(constraint))

@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@skipIfRocm
def test_transform_to_cuda(self):
for constraint in self.get_constraints(is_cuda=True):
t = transform_to(constraint)
Expand All @@ -4000,6 +4035,7 @@ def setUp(self):
super(TestCase, self).setUp()
Distribution.set_default_validate_args(True)

@skipIfRocm
def test_valid(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
Expand Down