-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Allow multi-dimensional dirichet (correct pull) #844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The resulting variables summed over the first axis will be one.
z = x0/s | ||
Km1 = x.shape[0] - 1 | ||
k = arange(Km1) | ||
k = arange(Km1)[(slice(None), ) + (None, ) * (x.ndim - 1)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What exactly is this doing? Might be good to leave a note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This ensures that k (and hence eq_share) has the same number of dimensions as x, which is necessary for it to be subtracted from logit(z).
Thanks for this Michiel! Super helpful. Looks good to me. Wan we also add a test for 2-d variables? |
The most obvious test to add seems to me to be to compare the log(p) of a dirichlet of shape (3, 5) with 5 individual dirichlets of shape (3,) and ensure that they are the same (where 3 and 5 are just example numbers). |
Since most of the change was in the transform, I would add a test in https://github.com/pymc-devs/pymc3/blob/master/pymc3/tests/test_transforms.py. Maybe just add a line to https://github.com/pymc-devs/pymc3/blob/master/pymc3/tests/test_transforms.py#L30 Where you pass in a different You can probably copy this and modify it slightly. You can also look at the other domains defined nearby such as |
@MichielCottaar Any updates on this? |
The main goal of this test is to show that the variables are independent across the second dimension. Note that betafn and dirichlet_logpdf had to be changed so that internally they only sum over the first dimension, and only in the final summation is the log(p) summed over all dimensions.
I've added both the test suggested by John, as well as a test to compute the log(p) of a 2D dirichlet. Unfortunately the second test takes about a minute (mainly because the parameter space explored is big). |
Great, thanks @MichielCottaar! I'm not worried about the length of the test. |
Merged with 4a62e8c. |
This alters the StickBreaking transformation to allow for multi-dimensional Dirichlet object. This should resolve #792 . The first dimension is treated as special and the sum along this dimension of the resultant variables are consistenlty one. Both the log(p) and the jacobian of the transformation are now multi-dimensional array with summation only over the first dimension, which means that a useful logp_elemwiset is created.