From 54fc211d725a321a033233aba90ac94d9163eab4 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <MichielCottaar@gmail.com>
Date: Wed, 30 Sep 2015 17:21:21 +0100
Subject: [PATCH 1/4] BUG allow Dirichlet to be multi-dimensional

The resulting variables summed over the first axis will be one.
---
 pymc3/distributions/multivariate.py |  2 +-
 pymc3/distributions/transforms.py   | 20 ++++++++++----------
 2 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py
index b6c48aeeb2..01b4e1266f 100644
--- a/pymc3/distributions/multivariate.py
+++ b/pymc3/distributions/multivariate.py
@@ -106,7 +106,7 @@ def logp(self, value):
 
         # only defined for sum(value) == 1
         return bound(
-            sum(logpow(value, a - 1) - gammaln(a), axis=0) + gammaln(sum(a)),
+            sum(logpow(value, a - 1) - gammaln(a), axis=0) + gammaln(sum(a, axis=0)),
             k > 1,
             all(a > 0),
             all(value >= 0),
diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py
index 930898671d..23255585b5 100644
--- a/pymc3/distributions/transforms.py
+++ b/pymc3/distributions/transforms.py
@@ -150,32 +150,32 @@ def __init__(self):
     def forward(self, x):
         #reverse cumsum
         x0 = x[:-1]
-        s = t.extra_ops.cumsum(x0[::-1])[::-1] + x[-1]
+        s = t.extra_ops.cumsum(x0[::-1], 0)[::-1] + x[-1]
         z = x0/s
         Km1 = x.shape[0] - 1
-        k = arange(Km1)
+        k = arange(Km1)[(slice(None), ) + (None, ) * (x.ndim - 1)]
         eq_share = - t.log(Km1 - k) # logit(1./(Km1 + 1 - k)) 
         y =  logit(z) - eq_share
         return y
 
     def backward(self, y):
         Km1 = y.shape[0]
-        k = arange(Km1)
+        k = arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)]
         eq_share = - t.log(Km1 - k) # logit(1./(Km1 + 1 - k)) 
         z = inverse_logit(y + eq_share)
-        yl = concatenate([z, [1]])
-        yu = concatenate([[1], 1-z])
-        S = t.extra_ops.cumprod(yu)
+        yl = concatenate([z, ones(y[:1].shape)])
+        yu = concatenate([ones(y[:1].shape), 1-z])
+        S = t.extra_ops.cumprod(yu, 0)
         x = S * yl
         return x
 
     def jacobian_det(self, y): 
         Km1 = y.shape[0]
-        k = arange(Km1)
+        k = arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)]
         eq_share =  -t.log(Km1 - k) #logit(1./(Km1 + 1 - k)) 
         yl = y + eq_share
-        yu = concatenate([[1], 1-inverse_logit(yl)])
-        S = t.extra_ops.cumprod(yu)
-        return sum(t.log(S[:-1]) - t.log(1+exp(yl)) - t.log(1+exp(-yl)))
+        yu = concatenate([ones(y[:1].shape), 1-inverse_logit(yl)])
+        S = t.extra_ops.cumprod(yu, 0)
+        return t.log(S[:-1]) - t.log(1+exp(yl)) - t.log(1+exp(-yl))
 
 stick_breaking = StickBreaking()

From 78d5fecfeafa5b18ea81063e4bcdac344e3fa06c Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <MichielCottaar@gmail.com>
Date: Wed, 30 Sep 2015 17:42:11 +0100
Subject: [PATCH 2/4] re-added summing over first dimension to fix test

---
 pymc3/distributions/transforms.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py
index 23255585b5..f6b9f6a3f5 100644
--- a/pymc3/distributions/transforms.py
+++ b/pymc3/distributions/transforms.py
@@ -176,6 +176,6 @@ def jacobian_det(self, y):
         yl = y + eq_share
         yu = concatenate([ones(y[:1].shape), 1-inverse_logit(yl)])
         S = t.extra_ops.cumprod(yu, 0)
-        return t.log(S[:-1]) - t.log(1+exp(yl)) - t.log(1+exp(-yl))
+        return sum(t.log(S[:-1]) - t.log(1+exp(yl)) - t.log(1+exp(-yl)), 0)
 
 stick_breaking = StickBreaking()

From 82fef258fdb58fb84670837df5602a21b8c9216a Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <MichielCottaar@gmail.com>
Date: Sun, 18 Oct 2015 12:02:35 +0100
Subject: [PATCH 3/4] Added test for 2D stick breaking transformation

---
 pymc3/tests/test_distributions.py | 9 +++++++++
 pymc3/tests/test_transforms.py    | 3 ++-
 2 files changed, 11 insertions(+), 1 deletion(-)

diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py
index 496da87d07..c509dedff4 100644
--- a/pymc3/tests/test_distributions.py
+++ b/pymc3/tests/test_distributions.py
@@ -101,6 +101,15 @@ def __init__(self, n):
         self.dtype = Unit.dtype
         return
 
+class MultiSimplex(object):
+    def __init__(self, n_dependent, n_independent):
+        transposed_vals = list(itertools.product(list(simplex_values(n_dependent)), repeat=n_independent))
+        self.vals = list(np.transpose(transposed_vals, (0, 2, 1)))
+
+        self.shape = (n_dependent, n_independent)
+        self.dtype = Unit.dtype
+        return
+
 def PdMatrix(n):
     if n == 1:
         return PdMatrix1 
diff --git a/pymc3/tests/test_transforms.py b/pymc3/tests/test_transforms.py
index 6b6f3fdabc..c6d86506ab 100644
--- a/pymc3/tests/test_transforms.py
+++ b/pymc3/tests/test_transforms.py
@@ -2,7 +2,7 @@
 import pymc3.distributions.transforms as tr
 import theano
 import theano.tensor as t 
-from .test_distributions import Simplex, Rplusbig, Unit, R, Vector
+from .test_distributions import Simplex, Rplusbig, Unit, R, Vector, MultiSimplex
 
 from .checks import *
 from ..theanof import jacobian
@@ -30,6 +30,7 @@ def get_values(transform, domain=R, constructor=t.dscalar, test=0):
 def test_simplex():
     check_vector_transform_identity(tr.stick_breaking, Simplex(2))
     check_vector_transform_identity(tr.stick_breaking, Simplex(4))
+    check_transform_identity(tr.stick_breaking, MultiSimplex(3, 2), constructor=t.dmatrix, test=np.zeros((2, 2)))
     
 def test_simplex_bounds():
     vals = get_values(tr.stick_breaking, Vector(R, 2), t.dvector, np.array([0,0]))

From e83ed18ba7398145ff96371098fdf299fd498139 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <MichielCottaar@gmail.com>
Date: Sun, 18 Oct 2015 12:19:38 +0100
Subject: [PATCH 4/4] tested log(p) of 2D dirichlet

The main goal of this test is to show that the variables are independent
across the second dimension.
Note that betafn and dirichlet_logpdf had to be changed so that
internally they only sum over the first dimension, and only in the final
summation is the log(p) summed over all dimensions.
---
 pymc3/tests/test_distributions.py | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py
index c509dedff4..aeb075537d 100644
--- a/pymc3/tests/test_distributions.py
+++ b/pymc3/tests/test_distributions.py
@@ -412,17 +412,18 @@ def check_lkj(x, n, p, lp):
                         lp, decimal=6, err_msg=str(pt))
 
 def betafn(a):
-    return scipy.special.gammaln(a).sum() - scipy.special.gammaln(a.sum())
+    return scipy.special.gammaln(a).sum(0) - scipy.special.gammaln(a.sum(0))
 
 def logpow(v, p):
     return np.choose(v==0, [p * np.log(v), 0])
 
 def dirichlet_logpdf(value, a):
-    return -betafn(a) + logpow(value, a-1).sum()
+    return (-betafn(a) + logpow(value, a-1).sum(0)).sum()
 
 def test_dirichlet():
     for n in [2,3]:
         yield check_dirichlet, n
+    yield check_dirichlet2D, 2, 2
 
 def check_dirichlet(n):
         pymc3_matches_scipy(
@@ -430,6 +431,12 @@ def check_dirichlet(n):
                 dirichlet_logpdf
                 )
 
+def check_dirichlet2D(ndep, nind):
+        pymc3_matches_scipy(
+                Dirichlet, MultiSimplex(ndep, nind), {'a': Vector(Vector(Rplus, nind), ndep) },
+                dirichlet_logpdf
+                )
+
 def multinomial_logpdf(value, n, p):
     if value.sum() == n and all(value >= 0) and all(value <= n) :
         return scipy.special.gammaln(n+1) - scipy.special.gammaln(value+1).sum() + logpow(p, value).sum()