Skip to content

Allow multi-dimensional dirichet (correct pull) #844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def logp(self, value):

# only defined for sum(value) == 1
return bound(
sum(logpow(value, a - 1) - gammaln(a), axis=0) + gammaln(sum(a)),
sum(logpow(value, a - 1) - gammaln(a), axis=0) + gammaln(sum(a, axis=0)),
k > 1,
all(a > 0),
all(value >= 0),
Expand Down
20 changes: 10 additions & 10 deletions pymc3/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,32 +150,32 @@ def __init__(self):
def forward(self, x):
#reverse cumsum
x0 = x[:-1]
s = t.extra_ops.cumsum(x0[::-1])[::-1] + x[-1]
s = t.extra_ops.cumsum(x0[::-1], 0)[::-1] + x[-1]
z = x0/s
Km1 = x.shape[0] - 1
k = arange(Km1)
k = arange(Km1)[(slice(None), ) + (None, ) * (x.ndim - 1)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly is this doing? Might be good to leave a note

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ensures that k (and hence eq_share) has the same number of dimensions as x, which is necessary for it to be subtracted from logit(z).

eq_share = - t.log(Km1 - k) # logit(1./(Km1 + 1 - k))
y = logit(z) - eq_share
return y

def backward(self, y):
Km1 = y.shape[0]
k = arange(Km1)
k = arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)]
eq_share = - t.log(Km1 - k) # logit(1./(Km1 + 1 - k))
z = inverse_logit(y + eq_share)
yl = concatenate([z, [1]])
yu = concatenate([[1], 1-z])
S = t.extra_ops.cumprod(yu)
yl = concatenate([z, ones(y[:1].shape)])
yu = concatenate([ones(y[:1].shape), 1-z])
S = t.extra_ops.cumprod(yu, 0)
x = S * yl
return x

def jacobian_det(self, y):
Km1 = y.shape[0]
k = arange(Km1)
k = arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)]
eq_share = -t.log(Km1 - k) #logit(1./(Km1 + 1 - k))
yl = y + eq_share
yu = concatenate([[1], 1-inverse_logit(yl)])
S = t.extra_ops.cumprod(yu)
return sum(t.log(S[:-1]) - t.log(1+exp(yl)) - t.log(1+exp(-yl)))
yu = concatenate([ones(y[:1].shape), 1-inverse_logit(yl)])
S = t.extra_ops.cumprod(yu, 0)
return sum(t.log(S[:-1]) - t.log(1+exp(yl)) - t.log(1+exp(-yl)), 0)

stick_breaking = StickBreaking()
20 changes: 18 additions & 2 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ def __init__(self, n):
self.dtype = Unit.dtype
return

class MultiSimplex(object):
def __init__(self, n_dependent, n_independent):
transposed_vals = list(itertools.product(list(simplex_values(n_dependent)), repeat=n_independent))
self.vals = list(np.transpose(transposed_vals, (0, 2, 1)))

self.shape = (n_dependent, n_independent)
self.dtype = Unit.dtype
return

def PdMatrix(n):
if n == 1:
return PdMatrix1
Expand Down Expand Up @@ -403,24 +412,31 @@ def check_lkj(x, n, p, lp):
lp, decimal=6, err_msg=str(pt))

def betafn(a):
return scipy.special.gammaln(a).sum() - scipy.special.gammaln(a.sum())
return scipy.special.gammaln(a).sum(0) - scipy.special.gammaln(a.sum(0))

def logpow(v, p):
return np.choose(v==0, [p * np.log(v), 0])

def dirichlet_logpdf(value, a):
return -betafn(a) + logpow(value, a-1).sum()
return (-betafn(a) + logpow(value, a-1).sum(0)).sum()

def test_dirichlet():
for n in [2,3]:
yield check_dirichlet, n
yield check_dirichlet2D, 2, 2

def check_dirichlet(n):
pymc3_matches_scipy(
Dirichlet, Simplex(n), {'a': Vector(Rplus, n) },
dirichlet_logpdf
)

def check_dirichlet2D(ndep, nind):
pymc3_matches_scipy(
Dirichlet, MultiSimplex(ndep, nind), {'a': Vector(Vector(Rplus, nind), ndep) },
dirichlet_logpdf
)

def multinomial_logpdf(value, n, p):
if value.sum() == n and all(value >= 0) and all(value <= n) :
return scipy.special.gammaln(n+1) - scipy.special.gammaln(value+1).sum() + logpow(p, value).sum()
Expand Down
3 changes: 2 additions & 1 deletion pymc3/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pymc3.distributions.transforms as tr
import theano
import theano.tensor as t
from .test_distributions import Simplex, Rplusbig, Unit, R, Vector
from .test_distributions import Simplex, Rplusbig, Unit, R, Vector, MultiSimplex

from .checks import *
from ..theanof import jacobian
Expand Down Expand Up @@ -30,6 +30,7 @@ def get_values(transform, domain=R, constructor=t.dscalar, test=0):
def test_simplex():
check_vector_transform_identity(tr.stick_breaking, Simplex(2))
check_vector_transform_identity(tr.stick_breaking, Simplex(4))
check_transform_identity(tr.stick_breaking, MultiSimplex(3, 2), constructor=t.dmatrix, test=np.zeros((2, 2)))

def test_simplex_bounds():
vals = get_values(tr.stick_breaking, Vector(R, 2), t.dvector, np.array([0,0]))
Expand Down