Skip to content

Commit 22b4446

Browse files
committed
Avoid missing_rvs warning when using RandomWalk
1 parent 6007e84 commit 22b4446

File tree

2 files changed

+45
-12
lines changed

2 files changed

+45
-12
lines changed

pymc/distributions/timeseries.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pytensor
2222
import pytensor.tensor as at
2323

24-
from pytensor.graph.basic import Node
24+
from pytensor.graph.basic import Node, ancestors
2525
from pytensor.graph.replace import clone_replace
2626
from pytensor.tensor import TensorVariable
2727
from pytensor.tensor.random.op import RandomVariable
@@ -33,7 +33,7 @@
3333
_moment,
3434
moment,
3535
)
36-
from pymc.distributions.logprob import ignore_logprob, logp
36+
from pymc.distributions.logprob import ignore_logprob, logp, reconsider_logprob
3737
from pymc.distributions.multivariate import MvNormal, MvStudentT
3838
from pymc.distributions.shape_utils import (
3939
_change_dist_size,
@@ -106,6 +106,15 @@ def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> at.TensorVari
106106
"init_dist and innovation_dist must have the same support dimensionality"
107107
)
108108

109+
# We need to check this, because we clone the variables when we ignore their logprob next
110+
if init_dist in ancestors([innovation_dist]) or innovation_dist in ancestors([init_dist]):
111+
raise ValueError("init_dist and innovation_dist must be completely independent")
112+
113+
# PyMC should not be concerned that these variables don't have values, as they will be
114+
# accounted for in the logp of RandomWalk
115+
init_dist = ignore_logprob(init_dist)
116+
innovation_dist = ignore_logprob(innovation_dist)
117+
109118
steps = cls.get_steps(
110119
innovation_dist=innovation_dist,
111120
steps=steps,
@@ -225,12 +234,14 @@ def random_walk_moment(op, rv, init_dist, innovation_dist, steps):
225234

226235

227236
@_logprob.register(RandomWalkRV)
228-
def random_walk_logp(op, values, *inputs, **kwargs):
237+
def random_walk_logp(op, values, init_dist, innovation_dist, steps, **kwargs):
229238
# Although we can derive the logprob of random walks, it does not collapse
230239
# what we consider the core dimension of steps. We do it manually here.
231240
(value,) = values
232241
# Recreate RV and obtain inner graph
233-
rv_node = op.make_node(*inputs)
242+
rv_node = op.make_node(
243+
reconsider_logprob(init_dist), reconsider_logprob(innovation_dist), steps
244+
)
234245
rv = clone_replace(
235246
op.inner_outputs, replace={u: v for u, v in zip(op.inner_inputs, rv_node.inputs)}
236247
)[op.default_output]

pymc/tests/distributions/test_timeseries.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747
from pymc.tests.distributions.util import assert_moment_is_expected
4848
from pymc.tests.helpers import select_by_precision
4949

50+
# Turn all warnings into errors for this module
51+
# Ignoring NumPy deprecation warning tracked in https://github.com/pymc-devs/pytensor/issues/146
52+
pytestmark = pytest.mark.filterwarnings("error", "ignore: NumPy will stop allowing conversion")
53+
5054

5155
class TestRandomWalk:
5256
def test_dists_types(self):
@@ -92,6 +96,14 @@ def test_dists_not_registered_check(self):
9296
):
9397
RandomWalk("rw", init_dist=init_dist, innovation_dist=innovation, steps=5)
9498

99+
def test_dists_independent_check(self):
100+
init_dist = Normal.dist()
101+
innovation_dist = Normal.dist(init_dist)
102+
with pytest.raises(
103+
ValueError, match="init_dist and innovation_dist must be completely independent"
104+
):
105+
RandomWalk.dist(init_dist=init_dist, innovation_dist=innovation_dist)
106+
95107
@pytest.mark.parametrize(
96108
"init_dist, innovation_dist, steps, size, shape, "
97109
"init_dist_size, innovation_dist_size, rw_shape",
@@ -423,15 +435,18 @@ def test_mvgaussian_with_chol_cov_rv(self, param):
423435
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
424436
)
425437
# pylint: enable=unpacking-non-sequence
426-
if param == "chol":
427-
mv = MvGaussianRandomWalk("mv", mu, chol=chol, shape=(10, 7, 3))
428-
elif param == "cov":
429-
mv = MvGaussianRandomWalk("mv", mu, cov=pm.math.dot(chol, chol.T), shape=(10, 7, 3))
430-
else:
431-
raise ValueError
438+
with pytest.warns(UserWarning, match="Initial distribution not specified"):
439+
if param == "chol":
440+
mv = MvGaussianRandomWalk("mv", mu, chol=chol, shape=(10, 7, 3))
441+
elif param == "cov":
442+
mv = MvGaussianRandomWalk(
443+
"mv", mu, cov=pm.math.dot(chol, chol.T), shape=(10, 7, 3)
444+
)
445+
else:
446+
raise ValueError
432447
assert draw(mv, draws=5).shape == (5, 10, 7, 3)
433448

434-
@pytest.mark.parametrize("param", ["cov", "chol", "tau"])
449+
@pytest.mark.parametrize("param", ["scale", "chol", "tau"])
435450
def test_mvstudentt(self, param):
436451
x = MvStudentTRandomWalk.dist(
437452
nu=4,
@@ -853,7 +868,13 @@ def sde_fn(x, k, d, s):
853868
with Model() as t0:
854869
init_dist = pm.Normal.dist(0, 10, shape=(batch_size,))
855870
y = EulerMaruyama(
856-
"y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, init_dist=init_dist, **kwargs
871+
"y",
872+
dt=0.02,
873+
sde_fn=sde_fn,
874+
sde_pars=sde_pars,
875+
init_dist=init_dist,
876+
initval="prior",
877+
**kwargs,
857878
)
858879

859880
y_eval = draw(y, draws=2, random_seed=numpy_rng)
@@ -873,6 +894,7 @@ def sde_fn(x, k, d, s):
873894
sde_fn=sde_fn,
874895
sde_pars=sde_pars_slice,
875896
init_dist=init_dist,
897+
initval="prior",
876898
**kwargs,
877899
)
878900

0 commit comments

Comments
 (0)