Skip to content

Allow multi-dimensional dirichet (correct pull) #844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

MichielCottaar
Copy link
Contributor

This alters the StickBreaking transformation to allow for multi-dimensional Dirichlet object. This should resolve #792 . The first dimension is treated as special and the sum along this dimension of the resultant variables are consistenlty one. Both the log(p) and the jacobian of the transformation are now multi-dimensional array with summation only over the first dimension, which means that a useful logp_elemwiset is created.

z = x0/s
Km1 = x.shape[0] - 1
k = arange(Km1)
k = arange(Km1)[(slice(None), ) + (None, ) * (x.ndim - 1)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly is this doing? Might be good to leave a note

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ensures that k (and hence eq_share) has the same number of dimensions as x, which is necessary for it to be subtracted from logit(z).

@jsalvatier
Copy link
Member

Thanks for this Michiel! Super helpful.

Looks good to me. Wan we also add a test for 2-d variables?

@MichielCottaar
Copy link
Contributor Author

The most obvious test to add seems to me to be to compare the log(p) of a dirichlet of shape (3, 5) with 5 individual dirichlets of shape (3,) and ensure that they are the same (where 3 and 5 are just example numbers).
However, I don't see an obvious way to extend the test_dirichlet in test_distributions to do this. Do you have any suggestions, where I could add a new test for this or how I could extend the test_dirichlet to test a multi-dimensional case.

@jsalvatier
Copy link
Member

Since most of the change was in the transform, I would add a test in https://github.com/pymc-devs/pymc3/blob/master/pymc3/tests/test_transforms.py.

Maybe just add a line to https://github.com/pymc-devs/pymc3/blob/master/pymc3/tests/test_transforms.py#L30

Where you pass in a different Domain, instead of Simplex make new "multi-simplex" space that has a bunch simplexes. The simplex domain is defined here: https://github.com/pymc-devs/pymc3/blob/master/pymc3/tests/test_distributions.py#L88

You can probably copy this and modify it slightly. You can also look at the other domains defined nearby such as PdMatrix or R.

@twiecki
Copy link
Member

twiecki commented Oct 16, 2015

@MichielCottaar Any updates on this?

The main goal of this test is to show that the variables are independent
across the second dimension.
Note that betafn and dirichlet_logpdf had to be changed so that
internally they only sum over the first dimension, and only in the final
summation is the log(p) summed over all dimensions.
@MichielCottaar
Copy link
Contributor Author

I've added both the test suggested by John, as well as a test to compute the log(p) of a 2D dirichlet. Unfortunately the second test takes about a minute (mainly because the parameter space explored is big).

@twiecki
Copy link
Member

twiecki commented Oct 18, 2015

Great, thanks @MichielCottaar! I'm not worried about the length of the test.

twiecki added a commit that referenced this pull request Oct 18, 2015
@twiecki
Copy link
Member

twiecki commented Oct 18, 2015

Merged with 4a62e8c.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

stick breaking transform does not work for 2D vars
3 participants