Skip to content

Commit c4a818a

Browse files
committed
Extend RandomWalk to handle multivariate cases
Also create abstract class for predefined RandomWalks
1 parent 042b906 commit c4a818a

File tree

2 files changed

+404
-177
lines changed

2 files changed

+404
-177
lines changed

pymc/distributions/timeseries.py

+121-58
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import abc
1415
import warnings
1516

17+
from abc import ABCMeta
1618
from typing import Optional
1719

1820
import aesara
@@ -37,6 +39,7 @@
3739
from pymc.distributions.shape_utils import (
3840
_change_dist_size,
3941
change_dist_size,
42+
get_support_shape,
4043
get_support_shape_1d,
4144
to_tuple,
4245
)
@@ -69,94 +72,156 @@ class RandomWalk(Distribution):
6972

7073
rv_type = RandomWalkRV
7174

72-
def __new__(cls, *args, steps=None, **kwargs):
73-
steps = get_support_shape_1d(
74-
support_shape=steps,
75+
def __new__(cls, *args, innovation_dist, steps=None, **kwargs):
76+
steps = cls.get_steps(
77+
innovation_dist=innovation_dist,
78+
steps=steps,
7579
shape=None, # Shape will be checked in `cls.dist`
76-
dims=kwargs.get("dims", None),
77-
observed=kwargs.get("observed", None),
78-
support_shape_offset=1,
80+
dims=kwargs.get("dims"),
81+
observed=kwargs.get("observed"),
7982
)
80-
return super().__new__(cls, *args, steps=steps, **kwargs)
83+
84+
return super().__new__(cls, *args, innovation_dist=innovation_dist, steps=steps, **kwargs)
8185

8286
@classmethod
8387
def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> at.TensorVariable:
84-
steps = get_support_shape_1d(
85-
support_shape=steps,
86-
shape=kwargs.get("shape"),
87-
support_shape_offset=1,
88-
)
89-
if steps is None:
90-
raise ValueError("Must specify steps or shape parameter")
91-
steps = at.as_tensor_variable(intX(steps))
92-
9388
if not (
9489
isinstance(init_dist, at.TensorVariable)
9590
and init_dist.owner is not None
9691
and isinstance(init_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
97-
# TODO: Lift univariate constraint on init_dist
98-
and init_dist.owner.op.ndim_supp == 0
9992
):
100-
raise TypeError("init_dist must be a univariate distribution variable")
93+
raise TypeError("init_dist must be a distribution variable")
10194
check_dist_not_registered(init_dist)
10295

10396
if not (
10497
isinstance(innovation_dist, at.TensorVariable)
10598
and innovation_dist.owner is not None
10699
and isinstance(innovation_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
107-
and innovation_dist.owner.op.ndim_supp == 0
108100
):
109-
raise TypeError("innovation_dist must be a univariate distribution variable")
101+
raise TypeError("innovation_dist must be a distribution variable")
110102
check_dist_not_registered(innovation_dist)
111103

104+
if init_dist.owner.op.ndim_supp != innovation_dist.owner.op.ndim_supp:
105+
raise TypeError(
106+
"init_dist and innovation_dist must have the same support dimensionality"
107+
)
108+
109+
steps = cls.get_steps(
110+
innovation_dist=innovation_dist,
111+
steps=steps,
112+
shape=kwargs.get("shape"),
113+
dims=None,
114+
observed=None,
115+
)
116+
if steps is None:
117+
raise ValueError("Must specify steps or shape parameter")
118+
steps = at.as_tensor_variable(intX(steps))
119+
112120
return super().dist([init_dist, innovation_dist, steps], **kwargs)
113121

122+
@classmethod
123+
def get_steps(cls, innovation_dist, steps, shape, dims, observed):
124+
# We need to know the ndim_supp of the innovation_dist
125+
if not (
126+
isinstance(innovation_dist, at.TensorVariable)
127+
and innovation_dist.owner is not None
128+
and isinstance(innovation_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
129+
):
130+
raise TypeError("innovation_dist must be a distribution variable")
131+
132+
dist_ndim_supp = innovation_dist.owner.op.ndim_supp
133+
dist_shape = tuple(innovation_dist.shape)
134+
support_shape = None
135+
if steps is not None:
136+
support_shape = (steps,) + (dist_shape[len(dist_shape) - dist_ndim_supp :])
137+
support_shape = get_support_shape(
138+
support_shape=support_shape,
139+
shape=shape,
140+
dims=dims,
141+
observed=observed,
142+
support_shape_offset=1,
143+
ndim_supp=dist_ndim_supp + 1,
144+
)
145+
if support_shape is not None:
146+
steps = support_shape[-dist_ndim_supp - 1]
147+
return steps
148+
114149
@classmethod
115150
def rv_op(cls, init_dist, innovation_dist, steps, size=None):
116151
if not steps.ndim == 0 or not steps.dtype.startswith("int"):
117152
raise ValueError("steps must be an integer scalar (ndim=0).")
118153

154+
dist_ndim_supp = init_dist.owner.op.ndim_supp
155+
init_dist_shape = tuple(init_dist.shape)
156+
init_dist_batch_shape = init_dist_shape[: len(init_dist_shape) - dist_ndim_supp]
157+
innovation_dist_shape = tuple(innovation_dist.shape)
158+
innovation_batch_shape = innovation_dist_shape[
159+
: len(innovation_dist_shape) - dist_ndim_supp
160+
]
161+
162+
ndim_supp = dist_ndim_supp + 1
163+
119164
# If not explicit, size is determined by the shapes of the input distributions
120165
if size is None:
121-
size = at.broadcast_shape(init_dist, at.atleast_1d(innovation_dist)[..., 0])
122-
innovation_size = tuple(size) + (steps,)
166+
size = at.broadcast_shape(
167+
init_dist_batch_shape, innovation_batch_shape, arrays_are_shapes=True
168+
)
123169

124-
# Resize input distributions
125-
init_dist = change_dist_size(init_dist, size)
126-
innovation_dist = change_dist_size(innovation_dist, innovation_size)
170+
# Resize input distributions. We will size them to (T, B, S) in order
171+
# to safely take random draws. We later swap the steps dimension so
172+
# that the final distribution will follow (B, T, S)
173+
# init_dist must have shape (1, B, S)
174+
init_dist = change_dist_size(init_dist, (1, *size))
175+
# innovation_dist must have shape (T-1, B, S)
176+
innovation_dist = change_dist_size(innovation_dist, (steps, *size))
127177

128178
# Create SymbolicRV
129179
init_dist_, innovation_dist_, steps_ = (
130180
init_dist.type(),
131181
innovation_dist.type(),
132182
steps.type(),
133183
)
134-
grw_ = at.concatenate([init_dist_[..., None], innovation_dist_], axis=-1)
135-
grw_ = at.cumsum(grw_, axis=-1)
184+
# Aeppl can only infer the logp of a dimshuffled variables, if the dimshuffle is
185+
# done directly on top of a RandomVariable. Because of this we dimshuffle the
186+
# distributions and only then concatenate them, instead of the other way around.
187+
# shape = (B, 1, S)
188+
init_dist_dimswapped_ = at.moveaxis(init_dist_, 0, -ndim_supp)
189+
# shape = (B, T-1, S)
190+
innovation_dist_dimswapped_ = at.moveaxis(innovation_dist_, 0, -ndim_supp)
191+
# shape = (B, T, S)
192+
grw_ = at.concatenate([init_dist_dimswapped_, innovation_dist_dimswapped_], axis=-ndim_supp)
193+
grw_ = at.cumsum(grw_, axis=-ndim_supp)
136194
return RandomWalkRV(
137195
[init_dist_, innovation_dist_, steps_],
138196
# We pass steps_ through just so we can keep a reference to it, even though
139197
# it's no longer needed at this point
140198
[grw_, steps_],
141-
ndim_supp=1,
199+
ndim_supp=ndim_supp,
142200
)(init_dist, innovation_dist, steps)
143201

144202

145203
@_change_dist_size.register(RandomWalkRV)
146204
def change_random_walk_size(op, dist, new_size, expand):
147205
init_dist, innovation_dist, steps = dist.owner.inputs
148206
if expand:
149-
old_size = init_dist.shape
207+
old_shape = tuple(dist.shape)
208+
old_size = old_shape[: len(old_shape) - op.ndim_supp]
150209
new_size = tuple(new_size) + tuple(old_size)
151210
return RandomWalk.rv_op(init_dist, innovation_dist, steps, size=new_size)
152211

153212

154213
@_moment.register(RandomWalkRV)
155214
def random_walk_moment(op, rv, init_dist, innovation_dist, steps):
156-
grw_moment = at.zeros_like(rv)
157-
grw_moment = at.set_subtensor(grw_moment[..., 0], moment(init_dist))
158-
grw_moment = at.set_subtensor(grw_moment[..., 1:], moment(innovation_dist))
159-
return at.cumsum(grw_moment, axis=-1)
215+
# shape = (1, B, S)
216+
init_moment = moment(init_dist)
217+
# shape = (T-1, B, S)
218+
innovation_moment = moment(innovation_dist)
219+
# shape = (T, B, S)
220+
grw_moment = at.concatenate([init_moment, innovation_moment], axis=0)
221+
grw_moment = at.cumsum(grw_moment, axis=0)
222+
# shape = (B, T, S)
223+
grw_moment = at.moveaxis(grw_moment, 0, -op.ndim_supp)
224+
return grw_moment
160225

161226

162227
@_logprob.register(RandomWalkRV)
@@ -173,7 +238,25 @@ def random_walk_logp(op, values, *inputs, **kwargs):
173238
return logp(rv, value).sum(axis=-1)
174239

175240

176-
class GaussianRandomWalk:
241+
class PredefinedRandomWalk(ABCMeta):
242+
"""Base class for predefined RandomWalk distributions"""
243+
244+
def __new__(cls, name, *args, **kwargs):
245+
init_dist, innovation_dist, kwargs = cls.get_dists(*args, **kwargs)
246+
return RandomWalk(name, init_dist=init_dist, innovation_dist=innovation_dist, **kwargs)
247+
248+
@classmethod
249+
def dist(cls, *args, **kwargs) -> at.TensorVariable:
250+
init_dist, innovation_dist, kwargs = cls.get_dists(*args, **kwargs)
251+
return RandomWalk.dist(init_dist=init_dist, innovation_dist=innovation_dist, **kwargs)
252+
253+
@classmethod
254+
@abc.abstractmethod
255+
def get_dists(cls, *args, **kwargs):
256+
pass
257+
258+
259+
class GaussianRandomWalk(PredefinedRandomWalk):
177260
r"""Random Walk with Normal innovations.
178261
179262
Parameters
@@ -186,40 +269,22 @@ class GaussianRandomWalk:
186269
Unnamed univariate distribution of the initial value. Unnamed refers to distributions
187270
created with the ``.dist()`` API.
188271
189-
.. warning:: init will be cloned, rendering them independent of the ones passed as input.
272+
.. warning:: init_dist will be cloned, rendering them independent of the ones passed as input.
190273
191274
steps : int, optional
192275
Number of steps in Gaussian Random Walk (steps > 0). Only needed if shape is not
193276
provided.
194277
"""
195278

196-
def __new__(cls, name, mu=0.0, sigma=1.0, *, init_dist=None, steps=None, **kwargs):
197-
init_dist, innovation_dist, kwargs = cls.get_dists(
198-
mu=mu, sigma=sigma, init_dist=init_dist, **kwargs
199-
)
200-
return RandomWalk(
201-
name, init_dist=init_dist, innovation_dist=innovation_dist, steps=steps, **kwargs
202-
)
203-
204-
@classmethod
205-
def dist(cls, mu=0.0, sigma=1.0, *, init_dist=None, steps=None, **kwargs) -> at.TensorVariable:
206-
init_dist, innovation_dist, kwargs = cls.get_dists(
207-
mu=mu, sigma=sigma, init_dist=init_dist, **kwargs
208-
)
209-
return RandomWalk.dist(
210-
init_dist=init_dist, innovation_dist=innovation_dist, steps=steps, **kwargs
211-
)
212-
213279
@classmethod
214-
def get_dists(cls, *, mu, sigma, init_dist, **kwargs):
280+
def get_dists(cls, mu=0.0, sigma=1.0, *, init_dist=None, **kwargs):
215281
if "init" in kwargs:
216282
warnings.warn(
217283
"init parameter is now called init_dist. Using init will raise an error in a future release.",
218284
FutureWarning,
219285
)
220286
init_dist = kwargs.pop("init")
221287

222-
# If no scalar distribution is passed then initialize with a Normal of same mu and sigma
223288
if init_dist is None:
224289
warnings.warn(
225290
"Initial distribution not specified, defaulting to `Normal.dist(0, 100)`."
@@ -228,11 +293,9 @@ def get_dists(cls, *, mu, sigma, init_dist, **kwargs):
228293
)
229294
init_dist = Normal.dist(0, 100)
230295

231-
# Add one dimension to the right, so that mu and sigma broadcast safely along
232-
# the steps dimension
233296
mu = at.as_tensor_variable(mu)
234297
sigma = at.as_tensor_variable(sigma)
235-
innovation_dist = Normal.dist(mu=mu[..., None], sigma=sigma[..., None])
298+
innovation_dist = Normal.dist(mu=mu, sigma=sigma)
236299

237300
return init_dist, innovation_dist, kwargs
238301

0 commit comments

Comments
 (0)