11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import abc
14
15
import warnings
15
16
17
+ from abc import ABCMeta
16
18
from typing import Optional
17
19
18
20
import aesara
37
39
from pymc .distributions .shape_utils import (
38
40
_change_dist_size ,
39
41
change_dist_size ,
42
+ get_support_shape ,
40
43
get_support_shape_1d ,
41
44
to_tuple ,
42
45
)
@@ -69,94 +72,156 @@ class RandomWalk(Distribution):
69
72
70
73
rv_type = RandomWalkRV
71
74
72
- def __new__ (cls , * args , steps = None , ** kwargs ):
73
- steps = get_support_shape_1d (
74
- support_shape = steps ,
75
+ def __new__ (cls , * args , innovation_dist , steps = None , ** kwargs ):
76
+ steps = cls .get_steps (
77
+ innovation_dist = innovation_dist ,
78
+ steps = steps ,
75
79
shape = None , # Shape will be checked in `cls.dist`
76
- dims = kwargs .get ("dims" , None ),
77
- observed = kwargs .get ("observed" , None ),
78
- support_shape_offset = 1 ,
80
+ dims = kwargs .get ("dims" ),
81
+ observed = kwargs .get ("observed" ),
79
82
)
80
- return super ().__new__ (cls , * args , steps = steps , ** kwargs )
83
+
84
+ return super ().__new__ (cls , * args , innovation_dist = innovation_dist , steps = steps , ** kwargs )
81
85
82
86
@classmethod
83
87
def dist (cls , init_dist , innovation_dist , steps = None , ** kwargs ) -> at .TensorVariable :
84
- steps = get_support_shape_1d (
85
- support_shape = steps ,
86
- shape = kwargs .get ("shape" ),
87
- support_shape_offset = 1 ,
88
- )
89
- if steps is None :
90
- raise ValueError ("Must specify steps or shape parameter" )
91
- steps = at .as_tensor_variable (intX (steps ))
92
-
93
88
if not (
94
89
isinstance (init_dist , at .TensorVariable )
95
90
and init_dist .owner is not None
96
91
and isinstance (init_dist .owner .op , (RandomVariable , SymbolicRandomVariable ))
97
- # TODO: Lift univariate constraint on init_dist
98
- and init_dist .owner .op .ndim_supp == 0
99
92
):
100
- raise TypeError ("init_dist must be a univariate distribution variable" )
93
+ raise TypeError ("init_dist must be a distribution variable" )
101
94
check_dist_not_registered (init_dist )
102
95
103
96
if not (
104
97
isinstance (innovation_dist , at .TensorVariable )
105
98
and innovation_dist .owner is not None
106
99
and isinstance (innovation_dist .owner .op , (RandomVariable , SymbolicRandomVariable ))
107
- and innovation_dist .owner .op .ndim_supp == 0
108
100
):
109
- raise TypeError ("innovation_dist must be a univariate distribution variable" )
101
+ raise TypeError ("innovation_dist must be a distribution variable" )
110
102
check_dist_not_registered (innovation_dist )
111
103
104
+ if init_dist .owner .op .ndim_supp != innovation_dist .owner .op .ndim_supp :
105
+ raise TypeError (
106
+ "init_dist and innovation_dist must have the same support dimensionality"
107
+ )
108
+
109
+ steps = cls .get_steps (
110
+ innovation_dist = innovation_dist ,
111
+ steps = steps ,
112
+ shape = kwargs .get ("shape" ),
113
+ dims = None ,
114
+ observed = None ,
115
+ )
116
+ if steps is None :
117
+ raise ValueError ("Must specify steps or shape parameter" )
118
+ steps = at .as_tensor_variable (intX (steps ))
119
+
112
120
return super ().dist ([init_dist , innovation_dist , steps ], ** kwargs )
113
121
122
+ @classmethod
123
+ def get_steps (cls , innovation_dist , steps , shape , dims , observed ):
124
+ # We need to know the ndim_supp of the innovation_dist
125
+ if not (
126
+ isinstance (innovation_dist , at .TensorVariable )
127
+ and innovation_dist .owner is not None
128
+ and isinstance (innovation_dist .owner .op , (RandomVariable , SymbolicRandomVariable ))
129
+ ):
130
+ raise TypeError ("innovation_dist must be a distribution variable" )
131
+
132
+ dist_ndim_supp = innovation_dist .owner .op .ndim_supp
133
+ dist_shape = tuple (innovation_dist .shape )
134
+ support_shape = None
135
+ if steps is not None :
136
+ support_shape = (steps ,) + (dist_shape [len (dist_shape ) - dist_ndim_supp :])
137
+ support_shape = get_support_shape (
138
+ support_shape = support_shape ,
139
+ shape = shape ,
140
+ dims = dims ,
141
+ observed = observed ,
142
+ support_shape_offset = 1 ,
143
+ ndim_supp = dist_ndim_supp + 1 ,
144
+ )
145
+ if support_shape is not None :
146
+ steps = support_shape [- dist_ndim_supp - 1 ]
147
+ return steps
148
+
114
149
@classmethod
115
150
def rv_op (cls , init_dist , innovation_dist , steps , size = None ):
116
151
if not steps .ndim == 0 or not steps .dtype .startswith ("int" ):
117
152
raise ValueError ("steps must be an integer scalar (ndim=0)." )
118
153
154
+ dist_ndim_supp = init_dist .owner .op .ndim_supp
155
+ init_dist_shape = tuple (init_dist .shape )
156
+ init_dist_batch_shape = init_dist_shape [: len (init_dist_shape ) - dist_ndim_supp ]
157
+ innovation_dist_shape = tuple (innovation_dist .shape )
158
+ innovation_batch_shape = innovation_dist_shape [
159
+ : len (innovation_dist_shape ) - dist_ndim_supp
160
+ ]
161
+
162
+ ndim_supp = dist_ndim_supp + 1
163
+
119
164
# If not explicit, size is determined by the shapes of the input distributions
120
165
if size is None :
121
- size = at .broadcast_shape (init_dist , at .atleast_1d (innovation_dist )[..., 0 ])
122
- innovation_size = tuple (size ) + (steps ,)
166
+ size = at .broadcast_shape (
167
+ init_dist_batch_shape , innovation_batch_shape , arrays_are_shapes = True
168
+ )
123
169
124
- # Resize input distributions
125
- init_dist = change_dist_size (init_dist , size )
126
- innovation_dist = change_dist_size (innovation_dist , innovation_size )
170
+ # Resize input distributions. We will size them to (T, B, S) in order
171
+ # to safely take random draws. We later swap the steps dimension so
172
+ # that the final distribution will follow (B, T, S)
173
+ # init_dist must have shape (1, B, S)
174
+ init_dist = change_dist_size (init_dist , (1 , * size ))
175
+ # innovation_dist must have shape (T-1, B, S)
176
+ innovation_dist = change_dist_size (innovation_dist , (steps , * size ))
127
177
128
178
# Create SymbolicRV
129
179
init_dist_ , innovation_dist_ , steps_ = (
130
180
init_dist .type (),
131
181
innovation_dist .type (),
132
182
steps .type (),
133
183
)
134
- grw_ = at .concatenate ([init_dist_ [..., None ], innovation_dist_ ], axis = - 1 )
135
- grw_ = at .cumsum (grw_ , axis = - 1 )
184
+ # Aeppl can only infer the logp of a dimshuffled variables, if the dimshuffle is
185
+ # done directly on top of a RandomVariable. Because of this we dimshuffle the
186
+ # distributions and only then concatenate them, instead of the other way around.
187
+ # shape = (B, 1, S)
188
+ init_dist_dimswapped_ = at .moveaxis (init_dist_ , 0 , - ndim_supp )
189
+ # shape = (B, T-1, S)
190
+ innovation_dist_dimswapped_ = at .moveaxis (innovation_dist_ , 0 , - ndim_supp )
191
+ # shape = (B, T, S)
192
+ grw_ = at .concatenate ([init_dist_dimswapped_ , innovation_dist_dimswapped_ ], axis = - ndim_supp )
193
+ grw_ = at .cumsum (grw_ , axis = - ndim_supp )
136
194
return RandomWalkRV (
137
195
[init_dist_ , innovation_dist_ , steps_ ],
138
196
# We pass steps_ through just so we can keep a reference to it, even though
139
197
# it's no longer needed at this point
140
198
[grw_ , steps_ ],
141
- ndim_supp = 1 ,
199
+ ndim_supp = ndim_supp ,
142
200
)(init_dist , innovation_dist , steps )
143
201
144
202
145
203
@_change_dist_size .register (RandomWalkRV )
146
204
def change_random_walk_size (op , dist , new_size , expand ):
147
205
init_dist , innovation_dist , steps = dist .owner .inputs
148
206
if expand :
149
- old_size = init_dist .shape
207
+ old_shape = tuple (dist .shape )
208
+ old_size = old_shape [: len (old_shape ) - op .ndim_supp ]
150
209
new_size = tuple (new_size ) + tuple (old_size )
151
210
return RandomWalk .rv_op (init_dist , innovation_dist , steps , size = new_size )
152
211
153
212
154
213
@_moment .register (RandomWalkRV )
155
214
def random_walk_moment (op , rv , init_dist , innovation_dist , steps ):
156
- grw_moment = at .zeros_like (rv )
157
- grw_moment = at .set_subtensor (grw_moment [..., 0 ], moment (init_dist ))
158
- grw_moment = at .set_subtensor (grw_moment [..., 1 :], moment (innovation_dist ))
159
- return at .cumsum (grw_moment , axis = - 1 )
215
+ # shape = (1, B, S)
216
+ init_moment = moment (init_dist )
217
+ # shape = (T-1, B, S)
218
+ innovation_moment = moment (innovation_dist )
219
+ # shape = (T, B, S)
220
+ grw_moment = at .concatenate ([init_moment , innovation_moment ], axis = 0 )
221
+ grw_moment = at .cumsum (grw_moment , axis = 0 )
222
+ # shape = (B, T, S)
223
+ grw_moment = at .moveaxis (grw_moment , 0 , - op .ndim_supp )
224
+ return grw_moment
160
225
161
226
162
227
@_logprob .register (RandomWalkRV )
@@ -173,7 +238,25 @@ def random_walk_logp(op, values, *inputs, **kwargs):
173
238
return logp (rv , value ).sum (axis = - 1 )
174
239
175
240
176
- class GaussianRandomWalk :
241
+ class PredefinedRandomWalk (ABCMeta ):
242
+ """Base class for predefined RandomWalk distributions"""
243
+
244
+ def __new__ (cls , name , * args , ** kwargs ):
245
+ init_dist , innovation_dist , kwargs = cls .get_dists (* args , ** kwargs )
246
+ return RandomWalk (name , init_dist = init_dist , innovation_dist = innovation_dist , ** kwargs )
247
+
248
+ @classmethod
249
+ def dist (cls , * args , ** kwargs ) -> at .TensorVariable :
250
+ init_dist , innovation_dist , kwargs = cls .get_dists (* args , ** kwargs )
251
+ return RandomWalk .dist (init_dist = init_dist , innovation_dist = innovation_dist , ** kwargs )
252
+
253
+ @classmethod
254
+ @abc .abstractmethod
255
+ def get_dists (cls , * args , ** kwargs ):
256
+ pass
257
+
258
+
259
+ class GaussianRandomWalk (PredefinedRandomWalk ):
177
260
r"""Random Walk with Normal innovations.
178
261
179
262
Parameters
@@ -186,40 +269,22 @@ class GaussianRandomWalk:
186
269
Unnamed univariate distribution of the initial value. Unnamed refers to distributions
187
270
created with the ``.dist()`` API.
188
271
189
- .. warning:: init will be cloned, rendering them independent of the ones passed as input.
272
+ .. warning:: init_dist will be cloned, rendering them independent of the ones passed as input.
190
273
191
274
steps : int, optional
192
275
Number of steps in Gaussian Random Walk (steps > 0). Only needed if shape is not
193
276
provided.
194
277
"""
195
278
196
- def __new__ (cls , name , mu = 0.0 , sigma = 1.0 , * , init_dist = None , steps = None , ** kwargs ):
197
- init_dist , innovation_dist , kwargs = cls .get_dists (
198
- mu = mu , sigma = sigma , init_dist = init_dist , ** kwargs
199
- )
200
- return RandomWalk (
201
- name , init_dist = init_dist , innovation_dist = innovation_dist , steps = steps , ** kwargs
202
- )
203
-
204
- @classmethod
205
- def dist (cls , mu = 0.0 , sigma = 1.0 , * , init_dist = None , steps = None , ** kwargs ) -> at .TensorVariable :
206
- init_dist , innovation_dist , kwargs = cls .get_dists (
207
- mu = mu , sigma = sigma , init_dist = init_dist , ** kwargs
208
- )
209
- return RandomWalk .dist (
210
- init_dist = init_dist , innovation_dist = innovation_dist , steps = steps , ** kwargs
211
- )
212
-
213
279
@classmethod
214
- def get_dists (cls , * , mu , sigma , init_dist , ** kwargs ):
280
+ def get_dists (cls , mu = 0.0 , sigma = 1.0 , * , init_dist = None , ** kwargs ):
215
281
if "init" in kwargs :
216
282
warnings .warn (
217
283
"init parameter is now called init_dist. Using init will raise an error in a future release." ,
218
284
FutureWarning ,
219
285
)
220
286
init_dist = kwargs .pop ("init" )
221
287
222
- # If no scalar distribution is passed then initialize with a Normal of same mu and sigma
223
288
if init_dist is None :
224
289
warnings .warn (
225
290
"Initial distribution not specified, defaulting to `Normal.dist(0, 100)`."
@@ -228,11 +293,9 @@ def get_dists(cls, *, mu, sigma, init_dist, **kwargs):
228
293
)
229
294
init_dist = Normal .dist (0 , 100 )
230
295
231
- # Add one dimension to the right, so that mu and sigma broadcast safely along
232
- # the steps dimension
233
296
mu = at .as_tensor_variable (mu )
234
297
sigma = at .as_tensor_variable (sigma )
235
- innovation_dist = Normal .dist (mu = mu [..., None ], sigma = sigma [..., None ] )
298
+ innovation_dist = Normal .dist (mu = mu , sigma = sigma )
236
299
237
300
return init_dist , innovation_dist , kwargs
238
301
0 commit comments