From 54fc211d725a321a033233aba90ac94d9163eab4 Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <MichielCottaar@gmail.com> Date: Wed, 30 Sep 2015 17:21:21 +0100 Subject: [PATCH 1/4] BUG allow Dirichlet to be multi-dimensional The resulting variables summed over the first axis will be one. --- pymc3/distributions/multivariate.py | 2 +- pymc3/distributions/transforms.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index b6c48aeeb2..01b4e1266f 100644 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -106,7 +106,7 @@ def logp(self, value): # only defined for sum(value) == 1 return bound( - sum(logpow(value, a - 1) - gammaln(a), axis=0) + gammaln(sum(a)), + sum(logpow(value, a - 1) - gammaln(a), axis=0) + gammaln(sum(a, axis=0)), k > 1, all(a > 0), all(value >= 0), diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py index 930898671d..23255585b5 100644 --- a/pymc3/distributions/transforms.py +++ b/pymc3/distributions/transforms.py @@ -150,32 +150,32 @@ def __init__(self): def forward(self, x): #reverse cumsum x0 = x[:-1] - s = t.extra_ops.cumsum(x0[::-1])[::-1] + x[-1] + s = t.extra_ops.cumsum(x0[::-1], 0)[::-1] + x[-1] z = x0/s Km1 = x.shape[0] - 1 - k = arange(Km1) + k = arange(Km1)[(slice(None), ) + (None, ) * (x.ndim - 1)] eq_share = - t.log(Km1 - k) # logit(1./(Km1 + 1 - k)) y = logit(z) - eq_share return y def backward(self, y): Km1 = y.shape[0] - k = arange(Km1) + k = arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)] eq_share = - t.log(Km1 - k) # logit(1./(Km1 + 1 - k)) z = inverse_logit(y + eq_share) - yl = concatenate([z, [1]]) - yu = concatenate([[1], 1-z]) - S = t.extra_ops.cumprod(yu) + yl = concatenate([z, ones(y[:1].shape)]) + yu = concatenate([ones(y[:1].shape), 1-z]) + S = t.extra_ops.cumprod(yu, 0) x = S * yl return x def jacobian_det(self, y): Km1 = y.shape[0] - k = arange(Km1) + k = arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)] eq_share = -t.log(Km1 - k) #logit(1./(Km1 + 1 - k)) yl = y + eq_share - yu = concatenate([[1], 1-inverse_logit(yl)]) - S = t.extra_ops.cumprod(yu) - return sum(t.log(S[:-1]) - t.log(1+exp(yl)) - t.log(1+exp(-yl))) + yu = concatenate([ones(y[:1].shape), 1-inverse_logit(yl)]) + S = t.extra_ops.cumprod(yu, 0) + return t.log(S[:-1]) - t.log(1+exp(yl)) - t.log(1+exp(-yl)) stick_breaking = StickBreaking() From 78d5fecfeafa5b18ea81063e4bcdac344e3fa06c Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <MichielCottaar@gmail.com> Date: Wed, 30 Sep 2015 17:42:11 +0100 Subject: [PATCH 2/4] re-added summing over first dimension to fix test --- pymc3/distributions/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py index 23255585b5..f6b9f6a3f5 100644 --- a/pymc3/distributions/transforms.py +++ b/pymc3/distributions/transforms.py @@ -176,6 +176,6 @@ def jacobian_det(self, y): yl = y + eq_share yu = concatenate([ones(y[:1].shape), 1-inverse_logit(yl)]) S = t.extra_ops.cumprod(yu, 0) - return t.log(S[:-1]) - t.log(1+exp(yl)) - t.log(1+exp(-yl)) + return sum(t.log(S[:-1]) - t.log(1+exp(yl)) - t.log(1+exp(-yl)), 0) stick_breaking = StickBreaking() From 82fef258fdb58fb84670837df5602a21b8c9216a Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <MichielCottaar@gmail.com> Date: Sun, 18 Oct 2015 12:02:35 +0100 Subject: [PATCH 3/4] Added test for 2D stick breaking transformation --- pymc3/tests/test_distributions.py | 9 +++++++++ pymc3/tests/test_transforms.py | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 496da87d07..c509dedff4 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -101,6 +101,15 @@ def __init__(self, n): self.dtype = Unit.dtype return +class MultiSimplex(object): + def __init__(self, n_dependent, n_independent): + transposed_vals = list(itertools.product(list(simplex_values(n_dependent)), repeat=n_independent)) + self.vals = list(np.transpose(transposed_vals, (0, 2, 1))) + + self.shape = (n_dependent, n_independent) + self.dtype = Unit.dtype + return + def PdMatrix(n): if n == 1: return PdMatrix1 diff --git a/pymc3/tests/test_transforms.py b/pymc3/tests/test_transforms.py index 6b6f3fdabc..c6d86506ab 100644 --- a/pymc3/tests/test_transforms.py +++ b/pymc3/tests/test_transforms.py @@ -2,7 +2,7 @@ import pymc3.distributions.transforms as tr import theano import theano.tensor as t -from .test_distributions import Simplex, Rplusbig, Unit, R, Vector +from .test_distributions import Simplex, Rplusbig, Unit, R, Vector, MultiSimplex from .checks import * from ..theanof import jacobian @@ -30,6 +30,7 @@ def get_values(transform, domain=R, constructor=t.dscalar, test=0): def test_simplex(): check_vector_transform_identity(tr.stick_breaking, Simplex(2)) check_vector_transform_identity(tr.stick_breaking, Simplex(4)) + check_transform_identity(tr.stick_breaking, MultiSimplex(3, 2), constructor=t.dmatrix, test=np.zeros((2, 2))) def test_simplex_bounds(): vals = get_values(tr.stick_breaking, Vector(R, 2), t.dvector, np.array([0,0])) From e83ed18ba7398145ff96371098fdf299fd498139 Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <MichielCottaar@gmail.com> Date: Sun, 18 Oct 2015 12:19:38 +0100 Subject: [PATCH 4/4] tested log(p) of 2D dirichlet The main goal of this test is to show that the variables are independent across the second dimension. Note that betafn and dirichlet_logpdf had to be changed so that internally they only sum over the first dimension, and only in the final summation is the log(p) summed over all dimensions. --- pymc3/tests/test_distributions.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index c509dedff4..aeb075537d 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -412,17 +412,18 @@ def check_lkj(x, n, p, lp): lp, decimal=6, err_msg=str(pt)) def betafn(a): - return scipy.special.gammaln(a).sum() - scipy.special.gammaln(a.sum()) + return scipy.special.gammaln(a).sum(0) - scipy.special.gammaln(a.sum(0)) def logpow(v, p): return np.choose(v==0, [p * np.log(v), 0]) def dirichlet_logpdf(value, a): - return -betafn(a) + logpow(value, a-1).sum() + return (-betafn(a) + logpow(value, a-1).sum(0)).sum() def test_dirichlet(): for n in [2,3]: yield check_dirichlet, n + yield check_dirichlet2D, 2, 2 def check_dirichlet(n): pymc3_matches_scipy( @@ -430,6 +431,12 @@ def check_dirichlet(n): dirichlet_logpdf ) +def check_dirichlet2D(ndep, nind): + pymc3_matches_scipy( + Dirichlet, MultiSimplex(ndep, nind), {'a': Vector(Vector(Rplus, nind), ndep) }, + dirichlet_logpdf + ) + def multinomial_logpdf(value, n, p): if value.sum() == n and all(value >= 0) and all(value <= n) : return scipy.special.gammaln(n+1) - scipy.special.gammaln(value+1).sum() + logpow(p, value).sum()