44
44
45
45
# some transforms (stick breaking) require addition of small slack in order to be numerically
46
46
# stable. The minimal addable slack for float32 is higher thus we need to be less strict
47
- tol = 1e-7 if pytensor .config .floatX == "float64" else 1e-6
47
+ tol = 1e-7 if pytensor .config .floatX == "float64" else 1e-5
48
48
49
49
50
- def check_transform (transform , domain , constructor = pt .dscalar , test = 0 , rv_var = None ):
50
+ def check_transform (transform , domain , constructor = pt .scalar , test = 0 , rv_var = None ):
51
51
x = constructor ("x" )
52
52
x .tag .test_value = test
53
53
if rv_var is None :
@@ -57,18 +57,20 @@ def check_transform(transform, domain, constructor=pt.dscalar, test=0, rv_var=No
57
57
# FIXME: What's being tested here? That the transformed graph can compile?
58
58
forward_f = pytensor .function ([x ], transform .forward (x , * rv_inputs ))
59
59
# test transform identity
60
- identity_f = pytensor . function (
61
- [ x ], transform . backward ( transform . forward ( x , * rv_inputs ), * rv_inputs )
62
- )
60
+ z = transform . backward ( transform . forward ( x , * rv_inputs ))
61
+ assert z . type == x . type
62
+ identity_f = pytensor . function ([ x ], z , * rv_inputs )
63
63
for val in domain .vals :
64
64
close_to (val , identity_f (val ), tol )
65
65
66
66
67
67
def check_vector_transform (transform , domain , rv_var = None ):
68
- return check_transform (transform , domain , pt .dvector , test = np .array ([0 , 0 ]), rv_var = rv_var )
68
+ return check_transform (
69
+ transform , domain , pt .vector , test = floatX (np .array ([0 , 0 ])), rv_var = rv_var
70
+ )
69
71
70
72
71
- def get_values (transform , domain = R , constructor = pt .dscalar , test = 0 , rv_var = None ):
73
+ def get_values (transform , domain = R , constructor = pt .scalar , test = 0 , rv_var = None ):
72
74
x = constructor ("x" )
73
75
x .tag .test_value = test
74
76
if rv_var is None :
@@ -81,7 +83,7 @@ def get_values(transform, domain=R, constructor=pt.dscalar, test=0, rv_var=None)
81
83
def check_jacobian_det (
82
84
transform ,
83
85
domain ,
84
- constructor = pt .dscalar ,
86
+ constructor = pt .scalar ,
85
87
test = 0 ,
86
88
make_comparable = None ,
87
89
elemwise = False ,
@@ -119,22 +121,26 @@ def test_simplex():
119
121
check_vector_transform (tr .simplex , Simplex (2 ))
120
122
check_vector_transform (tr .simplex , Simplex (4 ))
121
123
122
- check_transform (tr .simplex , MultiSimplex (3 , 2 ), constructor = pt .dmatrix , test = np .zeros ((2 , 2 )))
124
+ check_transform (
125
+ tr .simplex , MultiSimplex (3 , 2 ), constructor = pt .matrix , test = floatX (np .zeros ((2 , 2 )))
126
+ )
123
127
124
128
125
129
def test_simplex_bounds ():
126
- vals = get_values (tr .simplex , Vector (R , 2 ), pt .dvector , np .array ([0 , 0 ]))
130
+ vals = get_values (tr .simplex , Vector (R , 2 ), pt .vector , floatX ( np .array ([0 , 0 ]) ))
127
131
128
132
close_to (vals .sum (axis = 1 ), 1 , tol )
129
133
close_to_logical (vals > 0 , True , tol )
130
134
close_to_logical (vals < 1 , True , tol )
131
135
132
- check_jacobian_det (tr .simplex , Vector (R , 2 ), pt .dvector , np .array ([0 , 0 ]), lambda x : x [:- 1 ])
136
+ check_jacobian_det (
137
+ tr .simplex , Vector (R , 2 ), pt .vector , floatX (np .array ([0 , 0 ])), lambda x : x [:- 1 ]
138
+ )
133
139
134
140
135
141
def test_simplex_accuracy ():
136
- val = np .array ([- 30 ])
137
- x = pt .dvector ("x" )
142
+ val = floatX ( np .array ([- 30 ]) )
143
+ x = pt .vector ("x" )
138
144
x .tag .test_value = val
139
145
identity_f = pytensor .function ([x ], tr .simplex .forward (x , tr .simplex .backward (x , x )))
140
146
close_to (val , identity_f (val ), tol )
@@ -148,28 +154,39 @@ def test_sum_to_1():
148
154
tr .SumTo1 (2 )
149
155
150
156
check_jacobian_det (
151
- tr .univariate_sum_to_1 , Vector (Unit , 2 ), pt .dvector , np .array ([0 , 0 ]), lambda x : x [:- 1 ]
157
+ tr .univariate_sum_to_1 ,
158
+ Vector (Unit , 2 ),
159
+ pt .vector ,
160
+ floatX (np .array ([0 , 0 ])),
161
+ lambda x : x [:- 1 ],
152
162
)
153
163
check_jacobian_det (
154
- tr .multivariate_sum_to_1 , Vector (Unit , 2 ), pt .dvector , np .array ([0 , 0 ]), lambda x : x [:- 1 ]
164
+ tr .multivariate_sum_to_1 ,
165
+ Vector (Unit , 2 ),
166
+ pt .vector ,
167
+ floatX (np .array ([0 , 0 ])),
168
+ lambda x : x [:- 1 ],
155
169
)
156
170
157
171
158
172
def test_log ():
159
173
check_transform (tr .log , Rplusbig )
160
174
161
175
check_jacobian_det (tr .log , Rplusbig , elemwise = True )
162
- check_jacobian_det (tr .log , Vector (Rplusbig , 2 ), pt .dvector , [0 , 0 ], elemwise = True )
176
+ check_jacobian_det (tr .log , Vector (Rplusbig , 2 ), pt .vector , [0 , 0 ], elemwise = True )
163
177
164
178
vals = get_values (tr .log )
165
179
close_to_logical (vals > 0 , True , tol )
166
180
167
181
182
+ @pytest .mark .skipif (
183
+ pytensor .config .floatX == "float32" , reason = "Test is designed for 64bit precision"
184
+ )
168
185
def test_log_exp_m1 ():
169
186
check_transform (tr .log_exp_m1 , Rplusbig )
170
187
171
188
check_jacobian_det (tr .log_exp_m1 , Rplusbig , elemwise = True )
172
- check_jacobian_det (tr .log_exp_m1 , Vector (Rplusbig , 2 ), pt .dvector , [0 , 0 ], elemwise = True )
189
+ check_jacobian_det (tr .log_exp_m1 , Vector (Rplusbig , 2 ), pt .vector , [0 , 0 ], elemwise = True )
173
190
174
191
vals = get_values (tr .log_exp_m1 )
175
192
close_to_logical (vals > 0 , True , tol )
@@ -179,7 +196,7 @@ def test_logodds():
179
196
check_transform (tr .logodds , Unit )
180
197
181
198
check_jacobian_det (tr .logodds , Unit , elemwise = True )
182
- check_jacobian_det (tr .logodds , Vector (Unit , 2 ), pt .dvector , [0.5 , 0.5 ], elemwise = True )
199
+ check_jacobian_det (tr .logodds , Vector (Unit , 2 ), pt .vector , [0.5 , 0.5 ], elemwise = True )
183
200
184
201
vals = get_values (tr .logodds )
185
202
close_to_logical (vals > 0 , True , tol )
@@ -191,7 +208,7 @@ def test_lowerbound():
191
208
check_transform (trans , Rplusbig )
192
209
193
210
check_jacobian_det (trans , Rplusbig , elemwise = True )
194
- check_jacobian_det (trans , Vector (Rplusbig , 2 ), pt .dvector , [0 , 0 ], elemwise = True )
211
+ check_jacobian_det (trans , Vector (Rplusbig , 2 ), pt .vector , [0 , 0 ], elemwise = True )
195
212
196
213
vals = get_values (trans )
197
214
close_to_logical (vals > 0 , True , tol )
@@ -202,7 +219,7 @@ def test_upperbound():
202
219
check_transform (trans , Rminusbig )
203
220
204
221
check_jacobian_det (trans , Rminusbig , elemwise = True )
205
- check_jacobian_det (trans , Vector (Rminusbig , 2 ), pt .dvector , [- 1 , - 1 ], elemwise = True )
222
+ check_jacobian_det (trans , Vector (Rminusbig , 2 ), pt .vector , [- 1 , - 1 ], elemwise = True )
206
223
207
224
vals = get_values (trans )
208
225
close_to_logical (vals < 0 , True , tol )
@@ -234,7 +251,7 @@ def test_interval_near_boundary():
234
251
pm .Uniform ("x" , initval = x0 , lower = lb , upper = ub )
235
252
236
253
log_prob = model .point_logps ()
237
- np .testing .assert_allclose (list (log_prob .values ()), np .array ([- 52.68 ]))
254
+ np .testing .assert_allclose (list (log_prob .values ()), floatX ( np .array ([- 52.68 ]) ))
238
255
239
256
240
257
def test_circular ():
@@ -257,19 +274,19 @@ def test_ordered():
257
274
tr .Ordered (2 )
258
275
259
276
check_jacobian_det (
260
- tr .univariate_ordered , Vector (R , 2 ), pt .dvector , np .array ([0 , 0 ]), elemwise = False
277
+ tr .univariate_ordered , Vector (R , 2 ), pt .vector , floatX ( np .array ([0 , 0 ]) ), elemwise = False
261
278
)
262
279
check_jacobian_det (
263
- tr .multivariate_ordered , Vector (R , 2 ), pt .dvector , np .array ([0 , 0 ]), elemwise = False
280
+ tr .multivariate_ordered , Vector (R , 2 ), pt .vector , floatX ( np .array ([0 , 0 ]) ), elemwise = False
264
281
)
265
282
266
- vals = get_values (tr .univariate_ordered , Vector (R , 3 ), pt .dvector , np .zeros (3 ))
283
+ vals = get_values (tr .univariate_ordered , Vector (R , 3 ), pt .vector , floatX ( np .zeros (3 ) ))
267
284
close_to_logical (np .diff (vals ) >= 0 , True , tol )
268
285
269
286
270
287
def test_chain_values ():
271
288
chain_tranf = tr .Chain ([tr .logodds , tr .univariate_ordered ])
272
- vals = get_values (chain_tranf , Vector (R , 5 ), pt .dvector , np .zeros (5 ))
289
+ vals = get_values (chain_tranf , Vector (R , 5 ), pt .vector , floatX ( np .zeros (5 ) ))
273
290
close_to_logical (np .diff (vals ) >= 0 , True , tol )
274
291
275
292
@@ -281,7 +298,7 @@ def test_chain_vector_transform():
281
298
@pytest .mark .xfail (reason = "Fails due to precision issue. Values just close to expected." )
282
299
def test_chain_jacob_det ():
283
300
chain_tranf = tr .Chain ([tr .logodds , tr .univariate_ordered ])
284
- check_jacobian_det (chain_tranf , Vector (R , 4 ), pt .dvector , np .zeros (4 ), elemwise = False )
301
+ check_jacobian_det (chain_tranf , Vector (R , 4 ), pt .vector , floatX ( np .zeros (4 ) ), elemwise = False )
285
302
286
303
287
304
class TestElementWiseLogp (SeededTest ):
@@ -432,7 +449,7 @@ def transform_params(*inputs):
432
449
[
433
450
(0.0 , 1.0 , 2.0 , 2 ),
434
451
(- 10 , 0 , 200 , (2 , 3 )),
435
- (np .zeros (3 ), np .ones (3 ), np .ones (3 ), (4 , 3 )),
452
+ (floatX ( np .zeros (3 )), floatX ( np .ones (3 )), floatX ( np .ones (3 ) ), (4 , 3 )),
436
453
],
437
454
)
438
455
def test_triangular (self , lower , c , upper , size ):
@@ -449,7 +466,8 @@ def transform_params(*inputs):
449
466
self .check_transform_elementwise_logp (model )
450
467
451
468
@pytest .mark .parametrize (
452
- "mu,kappa,size" , [(0.0 , 1.0 , 2 ), (- 0.5 , 5.5 , (2 , 3 )), (np .zeros (3 ), np .ones (3 ), (4 , 3 ))]
469
+ "mu,kappa,size" ,
470
+ [(0.0 , 1.0 , 2 ), (- 0.5 , 5.5 , (2 , 3 )), (floatX (np .zeros (3 )), floatX (np .ones (3 )), (4 , 3 ))],
453
471
)
454
472
def test_vonmises (self , mu , kappa , size ):
455
473
model = self .build_model (
@@ -549,7 +567,9 @@ def transform_params(*inputs):
549
567
)
550
568
self .check_vectortransform_elementwise_logp (model )
551
569
552
- @pytest .mark .parametrize ("mu,kappa,size" , [(0.0 , 1.0 , (2 ,)), (np .zeros (3 ), np .ones (3 ), (4 , 3 ))])
570
+ @pytest .mark .parametrize (
571
+ "mu,kappa,size" , [(0.0 , 1.0 , (2 ,)), (floatX (np .zeros (3 )), floatX (np .ones (3 )), (4 , 3 ))]
572
+ )
553
573
def test_vonmises_ordered (self , mu , kappa , size ):
554
574
initval = np .sort (np .abs (np .random .rand (* size )))
555
575
model = self .build_model (
@@ -566,7 +586,12 @@ def test_vonmises_ordered(self, mu, kappa, size):
566
586
[
567
587
(0.0 , 1.0 , (2 ,), tr .simplex ),
568
588
(0.5 , 5.5 , (2 , 3 ), tr .simplex ),
569
- (np .zeros (3 ), np .ones (3 ), (4 , 3 ), tr .Chain ([tr .univariate_sum_to_1 , tr .logodds ])),
589
+ (
590
+ floatX (np .zeros (3 )),
591
+ floatX (np .ones (3 )),
592
+ (4 , 3 ),
593
+ tr .Chain ([tr .univariate_sum_to_1 , tr .logodds ]),
594
+ ),
570
595
],
571
596
)
572
597
def test_uniform_other (self , lower , upper , size , transform ):
@@ -583,8 +608,8 @@ def test_uniform_other(self, lower, upper, size, transform):
583
608
@pytest .mark .parametrize (
584
609
"mu,cov,size,shape" ,
585
610
[
586
- (np .zeros (2 ), np .diag (np .ones (2 )), None , (2 ,)),
587
- (np .zeros (3 ), np .diag (np .ones (3 )), (4 ,), (4 , 3 )),
611
+ (floatX ( np .zeros (2 )), floatX ( np .diag (np .ones (2 ) )), None , (2 ,)),
612
+ (floatX ( np .zeros (3 )), floatX ( np .diag (np .ones (3 ) )), (4 ,), (4 , 3 )),
588
613
],
589
614
)
590
615
def test_mvnormal_ordered (self , mu , cov , size , shape ):
@@ -643,7 +668,7 @@ def test_2d_univariate_ordered():
643
668
)
644
669
645
670
log_p = model .compile_logp (sum = False )(
646
- {"x_1d_ordered__" : np .zeros ((4 ,)), "x_2d_ordered__" : np .zeros ((10 , 4 ))}
671
+ {"x_1d_ordered__" : floatX ( np .zeros ((4 ,))) , "x_2d_ordered__" : floatX ( np .zeros ((10 , 4 ) ))}
647
672
)
648
673
np .testing .assert_allclose (np .tile (log_p [0 ], (10 , 1 )), log_p [1 ])
649
674
@@ -667,7 +692,7 @@ def test_2d_multivariate_ordered():
667
692
)
668
693
669
694
log_p = model .compile_logp (sum = False )(
670
- {"x_1d_ordered__" : np .zeros ((2 ,)), "x_2d_ordered__" : np .zeros ((2 , 2 ))}
695
+ {"x_1d_ordered__" : floatX ( np .zeros ((2 ,))) , "x_2d_ordered__" : floatX ( np .zeros ((2 , 2 ) ))}
671
696
)
672
697
np .testing .assert_allclose (log_p [0 ], log_p [1 ])
673
698
@@ -690,7 +715,7 @@ def test_2d_univariate_sum_to_1():
690
715
)
691
716
692
717
log_p = model .compile_logp (sum = False )(
693
- {"x_1d_sumto1__" : np .zeros (3 ), "x_2d_sumto1__" : np .zeros ((10 , 3 ))}
718
+ {"x_1d_sumto1__" : floatX ( np .zeros (3 )) , "x_2d_sumto1__" : floatX ( np .zeros ((10 , 3 ) ))}
694
719
)
695
720
np .testing .assert_allclose (np .tile (log_p [0 ], (10 , 1 )), log_p [1 ])
696
721
@@ -712,6 +737,6 @@ def test_2d_multivariate_sum_to_1():
712
737
)
713
738
714
739
log_p = model .compile_logp (sum = False )(
715
- {"x_1d_sumto1__" : np .zeros (1 ), "x_2d_sumto1__" : np .zeros ((2 , 1 ))}
740
+ {"x_1d_sumto1__" : floatX ( np .zeros (1 )) , "x_2d_sumto1__" : floatX ( np .zeros ((2 , 1 ) ))}
716
741
)
717
742
np .testing .assert_allclose (log_p [0 ], log_p [1 ])
0 commit comments