diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index b6c48aeeb2..01b4e1266f 100644 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -106,7 +106,7 @@ def logp(self, value): # only defined for sum(value) == 1 return bound( - sum(logpow(value, a - 1) - gammaln(a), axis=0) + gammaln(sum(a)), + sum(logpow(value, a - 1) - gammaln(a), axis=0) + gammaln(sum(a, axis=0)), k > 1, all(a > 0), all(value >= 0), diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py index 930898671d..f6b9f6a3f5 100644 --- a/pymc3/distributions/transforms.py +++ b/pymc3/distributions/transforms.py @@ -150,32 +150,32 @@ def __init__(self): def forward(self, x): #reverse cumsum x0 = x[:-1] - s = t.extra_ops.cumsum(x0[::-1])[::-1] + x[-1] + s = t.extra_ops.cumsum(x0[::-1], 0)[::-1] + x[-1] z = x0/s Km1 = x.shape[0] - 1 - k = arange(Km1) + k = arange(Km1)[(slice(None), ) + (None, ) * (x.ndim - 1)] eq_share = - t.log(Km1 - k) # logit(1./(Km1 + 1 - k)) y = logit(z) - eq_share return y def backward(self, y): Km1 = y.shape[0] - k = arange(Km1) + k = arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)] eq_share = - t.log(Km1 - k) # logit(1./(Km1 + 1 - k)) z = inverse_logit(y + eq_share) - yl = concatenate([z, [1]]) - yu = concatenate([[1], 1-z]) - S = t.extra_ops.cumprod(yu) + yl = concatenate([z, ones(y[:1].shape)]) + yu = concatenate([ones(y[:1].shape), 1-z]) + S = t.extra_ops.cumprod(yu, 0) x = S * yl return x def jacobian_det(self, y): Km1 = y.shape[0] - k = arange(Km1) + k = arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)] eq_share = -t.log(Km1 - k) #logit(1./(Km1 + 1 - k)) yl = y + eq_share - yu = concatenate([[1], 1-inverse_logit(yl)]) - S = t.extra_ops.cumprod(yu) - return sum(t.log(S[:-1]) - t.log(1+exp(yl)) - t.log(1+exp(-yl))) + yu = concatenate([ones(y[:1].shape), 1-inverse_logit(yl)]) + S = t.extra_ops.cumprod(yu, 0) + return sum(t.log(S[:-1]) - t.log(1+exp(yl)) - t.log(1+exp(-yl)), 0) stick_breaking = StickBreaking() diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 496da87d07..aeb075537d 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -101,6 +101,15 @@ def __init__(self, n): self.dtype = Unit.dtype return +class MultiSimplex(object): + def __init__(self, n_dependent, n_independent): + transposed_vals = list(itertools.product(list(simplex_values(n_dependent)), repeat=n_independent)) + self.vals = list(np.transpose(transposed_vals, (0, 2, 1))) + + self.shape = (n_dependent, n_independent) + self.dtype = Unit.dtype + return + def PdMatrix(n): if n == 1: return PdMatrix1 @@ -403,17 +412,18 @@ def check_lkj(x, n, p, lp): lp, decimal=6, err_msg=str(pt)) def betafn(a): - return scipy.special.gammaln(a).sum() - scipy.special.gammaln(a.sum()) + return scipy.special.gammaln(a).sum(0) - scipy.special.gammaln(a.sum(0)) def logpow(v, p): return np.choose(v==0, [p * np.log(v), 0]) def dirichlet_logpdf(value, a): - return -betafn(a) + logpow(value, a-1).sum() + return (-betafn(a) + logpow(value, a-1).sum(0)).sum() def test_dirichlet(): for n in [2,3]: yield check_dirichlet, n + yield check_dirichlet2D, 2, 2 def check_dirichlet(n): pymc3_matches_scipy( @@ -421,6 +431,12 @@ def check_dirichlet(n): dirichlet_logpdf ) +def check_dirichlet2D(ndep, nind): + pymc3_matches_scipy( + Dirichlet, MultiSimplex(ndep, nind), {'a': Vector(Vector(Rplus, nind), ndep) }, + dirichlet_logpdf + ) + def multinomial_logpdf(value, n, p): if value.sum() == n and all(value >= 0) and all(value <= n) : return scipy.special.gammaln(n+1) - scipy.special.gammaln(value+1).sum() + logpow(p, value).sum() diff --git a/pymc3/tests/test_transforms.py b/pymc3/tests/test_transforms.py index 6b6f3fdabc..c6d86506ab 100644 --- a/pymc3/tests/test_transforms.py +++ b/pymc3/tests/test_transforms.py @@ -2,7 +2,7 @@ import pymc3.distributions.transforms as tr import theano import theano.tensor as t -from .test_distributions import Simplex, Rplusbig, Unit, R, Vector +from .test_distributions import Simplex, Rplusbig, Unit, R, Vector, MultiSimplex from .checks import * from ..theanof import jacobian @@ -30,6 +30,7 @@ def get_values(transform, domain=R, constructor=t.dscalar, test=0): def test_simplex(): check_vector_transform_identity(tr.stick_breaking, Simplex(2)) check_vector_transform_identity(tr.stick_breaking, Simplex(4)) + check_transform_identity(tr.stick_breaking, MultiSimplex(3, 2), constructor=t.dmatrix, test=np.zeros((2, 2))) def test_simplex_bounds(): vals = get_values(tr.stick_breaking, Vector(R, 2), t.dvector, np.array([0,0]))