Skip to content

Commit 8653d18

Browse files
committed
Enable nested SymbolicRandomVariables
1 parent f6c8aac commit 8653d18

File tree

7 files changed

+52
-18
lines changed

7 files changed

+52
-18
lines changed

pymc/distributions/censored.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ class Censored(Distribution):
8686

8787
@classmethod
8888
def dist(cls, dist, lower, upper, **kwargs):
89-
if not isinstance(dist, TensorVariable) or not isinstance(dist.owner.op, RandomVariable):
89+
if not isinstance(dist, TensorVariable) or not isinstance(
90+
dist.owner.op, (RandomVariable, SymbolicRandomVariable)
91+
):
9092
raise ValueError(
9193
f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"
9294
)

pymc/distributions/mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def dist(cls, w, comp_dists, **kwargs):
191191
# TODO: Allow these to not be a RandomVariable as long as we can call `ndim_supp` on them
192192
# and resize them
193193
if not isinstance(dist, TensorVariable) or not isinstance(
194-
dist.owner.op, RandomVariable
194+
dist.owner.op, (RandomVariable, SymbolicRandomVariable)
195195
):
196196
raise ValueError(
197197
f"Component dist must be a distribution created via the `.dist()` API, got {type(dist)}"

pymc/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1180,7 +1180,7 @@ def dist(cls, n, eta, sd_dist, **kwargs):
11801180
if not (
11811181
isinstance(sd_dist, Variable)
11821182
and sd_dist.owner is not None
1183-
and isinstance(sd_dist.owner.op, RandomVariable)
1183+
and isinstance(sd_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
11841184
and sd_dist.owner.op.ndim_supp < 2
11851185
):
11861186
raise TypeError("sd_dist must be a scalar or vector distribution variable")

pymc/distributions/timeseries.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> at.TensorVari
155155
if not (
156156
isinstance(init_dist, at.TensorVariable)
157157
and init_dist.owner is not None
158-
and isinstance(init_dist.owner.op, RandomVariable)
158+
and isinstance(init_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
159159
# TODO: Lift univariate constraint on init_dist
160160
and init_dist.owner.op.ndim_supp == 0
161161
):
@@ -165,7 +165,7 @@ def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> at.TensorVari
165165
if not (
166166
isinstance(innovation_dist, at.TensorVariable)
167167
and init_dist.owner is not None
168-
and isinstance(init_dist.owner.op, RandomVariable)
168+
and isinstance(init_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
169169
# TODO: Lift univariate constraint on inovvation_dist
170170
and init_dist.owner.op.ndim_supp == 0
171171
):
@@ -434,7 +434,7 @@ def dist(
434434

435435
if init_dist is not None:
436436
if not isinstance(init_dist, TensorVariable) or not isinstance(
437-
init_dist.owner.op, RandomVariable
437+
init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)
438438
):
439439
raise ValueError(
440440
f"Init dist must be a distribution created via the `.dist()` API, "

pymc/printing.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def str_for_potential_or_deterministic(
151151

152152

153153
def _str_for_input_var(var: Variable, formatting: str) -> str:
154+
# Avoid circular import
155+
from pymc.distributions.distribution import SymbolicRandomVariable
156+
154157
def _is_potential_or_determinstic(var: Variable) -> bool:
155158
try:
156159
return var.str_repr.__func__.func is str_for_potential_or_deterministic
@@ -160,7 +163,9 @@ def _is_potential_or_determinstic(var: Variable) -> bool:
160163

161164
if isinstance(var, TensorConstant):
162165
return _str_for_constant(var, formatting)
163-
elif isinstance(var.owner.op, RandomVariable) or _is_potential_or_determinstic(var):
166+
elif isinstance(
167+
var.owner.op, (RandomVariable, SymbolicRandomVariable)
168+
) or _is_potential_or_determinstic(var):
164169
# show the names for RandomVariables, Deterministics, and Potentials, rather
165170
# than the full expression
166171
return _str_for_input_rv(var, formatting)
@@ -194,15 +199,18 @@ def _str_for_constant(var: TensorConstant, formatting: str) -> str:
194199

195200

196201
def _str_for_expression(var: Variable, formatting: str) -> str:
202+
# Avoid circular import
203+
from pymc.distributions.distribution import SymbolicRandomVariable
204+
197205
# construct a string like f(a1, ..., aN) listing all random variables a as arguments
198206
def _expand(x):
199-
if x.owner and (not isinstance(x.owner.op, RandomVariable)):
207+
if x.owner and (not isinstance(x.owner.op, (RandomVariable, SymbolicRandomVariable))):
200208
return reversed(x.owner.inputs)
201209

202210
parents = [
203211
x
204212
for x in walk(nodes=var.owner.inputs, expand=_expand)
205-
if x.owner and isinstance(x.owner.op, RandomVariable)
213+
if x.owner and isinstance(x.owner.op, (RandomVariable, SymbolicRandomVariable))
206214
]
207215
names = [x.name for x in parents]
208216

pymc/tests/distributions/test_mixture.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -599,13 +599,17 @@ def test_list_mvnormals_predictive_sampling_shape(self):
599599
assert prior["mu0"].shape == (n_samples, D)
600600
assert prior["chol_cov_0"].shape == (n_samples, D * (D + 1) // 2)
601601

602-
@pytest.mark.xfail(reason="Nested mixtures not refactored yet")
603602
def test_nested_mixture(self):
604603
if aesara.config.floatX == "float32":
605604
rtol = 1e-4
606605
else:
607606
rtol = 1e-7
608607
nbr = 4
608+
609+
norm_x = generate_normal_mixture_data(
610+
np.r_[0.75, 0.25], np.r_[0.0, 5.0], np.r_[1.0, 1.0], size=1000
611+
)
612+
609613
with Model() as model:
610614
# mixtures components
611615
g_comp = Normal.dist(
@@ -622,7 +626,7 @@ def test_nested_mixture(self):
622626
l_mix = Mixture.dist(w=l_w, comp_dists=l_comp)
623627
# mixture of mixtures
624628
mix_w = Dirichlet("mix_w", a=floatX(np.ones(2)), transform=None, shape=(2,))
625-
mix = Mixture("mix", w=mix_w, comp_dists=[g_mix, l_mix], observed=np.exp(self.norm_x))
629+
mix = Mixture("mix", w=mix_w, comp_dists=[g_mix, l_mix], observed=np.exp(norm_x))
626630

627631
test_point = model.initial_point()
628632

@@ -658,21 +662,23 @@ def mixmixlogp(value, point):
658662
)
659663
return priorlogp, mixmixlogpg
660664

661-
value = np.exp(self.norm_x)[:, None]
665+
value = np.exp(norm_x)[:, None]
662666
priorlogp, mixmixlogpg = mixmixlogp(value, test_point)
663667

664668
# check logp of mixture
665-
assert_allclose(mixmixlogpg, mix.logp_elemwise(test_point), rtol=rtol)
669+
mix_logp_fn = model.compile_logp(vars=[mix], sum=False)
670+
assert_allclose(mixmixlogpg, mix_logp_fn(test_point)[0], rtol=rtol)
666671

667672
# check model logp
668-
assert_allclose(priorlogp + mixmixlogpg.sum(), model.logp(test_point), rtol=rtol)
673+
model_logp_fn = model.compile_logp()
674+
assert_allclose(priorlogp + mixmixlogpg.sum(), model_logp_fn(test_point), rtol=rtol)
669675

670676
# check input and check logp again
671677
test_point["g_w"] = np.asarray([0.1, 0.1, 0.2, 0.6])
672678
test_point["mu_g"] = np.exp(np.random.randn(nbr))
673679
priorlogp, mixmixlogpg = mixmixlogp(value, test_point)
674-
assert_allclose(mixmixlogpg, mix.logp_elemwise(test_point), rtol=rtol)
675-
assert_allclose(priorlogp + mixmixlogpg.sum(), model.logp(test_point), rtol=rtol)
680+
assert_allclose(mixmixlogpg, mix_logp_fn(test_point)[0], rtol=rtol)
681+
assert_allclose(priorlogp + mixmixlogpg.sum(), model_logp_fn(test_point), rtol=rtol)
676682

677683
def test_iterable_single_component_warning(self):
678684
with warnings.catch_warnings():

pymc/tests/test_printing.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
from pymc import Bernoulli, Censored, Mixture
34
from pymc.aesaraf import floatX
45
from pymc.distributions import (
56
DirichletMultinomial,
@@ -48,9 +49,14 @@ def setup_class(self):
4849
# )
4950
nb2 = NegativeBinomial("nb_with_p_n", p=Uniform("nbp"), n=10)
5051

51-
# Symbolic distribution
52+
# SymbolicRV
5253
zip = ZeroInflatedPoisson("zip", 0.5, 5)
5354

55+
# Nested SymbolicRV
56+
comp_1 = ZeroInflatedPoisson.dist(0.5, 5)
57+
comp_2 = Censored.dist(Bernoulli.dist(0.5), -1, 1)
58+
nested_mix = Mixture("nested_mix", [0.5, 0.5], [comp_1, comp_2])
59+
5460
# Expected value of outcome
5561
mu = Deterministic("mu", floatX(alpha + dot(X, b)))
5662

@@ -80,7 +86,7 @@ def setup_class(self):
8086
# add a potential as well
8187
pot = Potential("pot", mu**2)
8288

83-
self.distributions = [alpha, sigma, mu, b, Z, nb2, zip, Y_obs, pot]
89+
self.distributions = [alpha, sigma, mu, b, Z, nb2, zip, nested_mix, Y_obs, pot]
8490
self.deterministics_or_potentials = [mu, pot]
8591
# tuples of (formatting, include_params
8692
self.formats = [("plain", True), ("plain", False), ("latex", True), ("latex", False)]
@@ -93,6 +99,11 @@ def setup_class(self):
9399
r"Z ~ N(f(), f())",
94100
r"nb_with_p_n ~ NB(10, nbp)",
95101
r"zip ~ MarginalMixture(f(), DiracDelta(0), Pois(5))",
102+
(
103+
r"nested_mix ~ MarginalMixture(<constant>, "
104+
r"MarginalMixture(f(), DiracDelta(0), Pois(5)), "
105+
r"Censored(Bern(0.5), -1, 1))"
106+
),
96107
r"Y_obs ~ N(mu, sigma)",
97108
r"pot ~ Potential(f(beta, alpha))",
98109
],
@@ -104,6 +115,7 @@ def setup_class(self):
104115
r"Z ~ N",
105116
r"nb_with_p_n ~ NB",
106117
r"zip ~ MarginalMixture",
118+
r"nested_mix ~ MarginalMixture",
107119
r"Y_obs ~ N",
108120
r"pot ~ Potential",
109121
],
@@ -115,6 +127,11 @@ def setup_class(self):
115127
r"$\text{Z} \sim \operatorname{N}(f(),~f())$",
116128
r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$",
117129
r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\text{\$\operatorname{DiracDelta}(0)\$},~\text{\$\operatorname{Pois}(5)\$})$",
130+
(
131+
r"$\text{nested_mix} \sim \operatorname{MarginalMixture}(\text{<constant>},"
132+
r"~\text{\$\operatorname{MarginalMixture}(f(),~\text{\\$\operatorname{DiracDelta}(0)\\$},~\text{\\$\operatorname{Pois}(5)\\$})\$},"
133+
r"~\text{\$\operatorname{Censored}(\text{\\$\operatorname{Bern}(0.5)\\$},~-1,~1)\$})$"
134+
),
118135
r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$",
119136
r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$",
120137
],
@@ -126,6 +143,7 @@ def setup_class(self):
126143
r"$\text{Z} \sim \operatorname{N}$",
127144
r"$\text{nb_with_p_n} \sim \operatorname{NB}$",
128145
r"$\text{zip} \sim \operatorname{MarginalMixture}$",
146+
r"$\text{nested_mix} \sim \operatorname{MarginalMixture}$",
129147
r"$\text{Y_obs} \sim \operatorname{N}$",
130148
r"$\text{pot} \sim \operatorname{Potential}$",
131149
],

0 commit comments

Comments
 (0)