Skip to content

Commit 14e673f

Browse files
authored
Better coverage for float32 tests (#6780)
* create a failing test * fix the bug * simplify * add float32 test to transforms
1 parent f91dd1c commit 14e673f

File tree

3 files changed

+67
-38
lines changed

3 files changed

+67
-38
lines changed

.github/workflows/tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ jobs:
413413
floatx: [float32]
414414
python-version: ["3.11"]
415415
test-subset:
416-
- tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py
416+
- tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py
417417
fail-fast: false
418418
runs-on: ${{ matrix.os }}
419419
env:

pymc/logprob/transforms.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -958,8 +958,10 @@ class SimplexTransform(RVTransform):
958958
name = "simplex"
959959

960960
def forward(self, value, *inputs):
961+
value = pt.as_tensor(value)
961962
log_value = pt.log(value)
962-
shift = pt.sum(log_value, -1, keepdims=True) / value.shape[-1]
963+
N = value.shape[-1].astype(value.dtype)
964+
shift = pt.sum(log_value, -1, keepdims=True) / N
963965
return log_value[..., :-1] - shift
964966

965967
def backward(self, value, *inputs):
@@ -968,7 +970,9 @@ def backward(self, value, *inputs):
968970
return exp_value_max / pt.sum(exp_value_max, -1, keepdims=True)
969971

970972
def log_jac_det(self, value, *inputs):
973+
value = pt.as_tensor(value)
971974
N = value.shape[-1] + 1
975+
N = N.astype(value.dtype)
972976
sum_value = pt.sum(value, -1, keepdims=True)
973977
value_sum_expanded = value + sum_value
974978
value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros(sum_value.shape)], -1)

tests/distributions/test_transform.py

+61-36
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@
4444

4545
# some transforms (stick breaking) require addition of small slack in order to be numerically
4646
# stable. The minimal addable slack for float32 is higher thus we need to be less strict
47-
tol = 1e-7 if pytensor.config.floatX == "float64" else 1e-6
47+
tol = 1e-7 if pytensor.config.floatX == "float64" else 1e-5
4848

4949

50-
def check_transform(transform, domain, constructor=pt.dscalar, test=0, rv_var=None):
50+
def check_transform(transform, domain, constructor=pt.scalar, test=0, rv_var=None):
5151
x = constructor("x")
5252
x.tag.test_value = test
5353
if rv_var is None:
@@ -57,18 +57,20 @@ def check_transform(transform, domain, constructor=pt.dscalar, test=0, rv_var=No
5757
# FIXME: What's being tested here? That the transformed graph can compile?
5858
forward_f = pytensor.function([x], transform.forward(x, *rv_inputs))
5959
# test transform identity
60-
identity_f = pytensor.function(
61-
[x], transform.backward(transform.forward(x, *rv_inputs), *rv_inputs)
62-
)
60+
z = transform.backward(transform.forward(x, *rv_inputs))
61+
assert z.type == x.type
62+
identity_f = pytensor.function([x], z, *rv_inputs)
6363
for val in domain.vals:
6464
close_to(val, identity_f(val), tol)
6565

6666

6767
def check_vector_transform(transform, domain, rv_var=None):
68-
return check_transform(transform, domain, pt.dvector, test=np.array([0, 0]), rv_var=rv_var)
68+
return check_transform(
69+
transform, domain, pt.vector, test=floatX(np.array([0, 0])), rv_var=rv_var
70+
)
6971

7072

71-
def get_values(transform, domain=R, constructor=pt.dscalar, test=0, rv_var=None):
73+
def get_values(transform, domain=R, constructor=pt.scalar, test=0, rv_var=None):
7274
x = constructor("x")
7375
x.tag.test_value = test
7476
if rv_var is None:
@@ -81,7 +83,7 @@ def get_values(transform, domain=R, constructor=pt.dscalar, test=0, rv_var=None)
8183
def check_jacobian_det(
8284
transform,
8385
domain,
84-
constructor=pt.dscalar,
86+
constructor=pt.scalar,
8587
test=0,
8688
make_comparable=None,
8789
elemwise=False,
@@ -119,22 +121,26 @@ def test_simplex():
119121
check_vector_transform(tr.simplex, Simplex(2))
120122
check_vector_transform(tr.simplex, Simplex(4))
121123

122-
check_transform(tr.simplex, MultiSimplex(3, 2), constructor=pt.dmatrix, test=np.zeros((2, 2)))
124+
check_transform(
125+
tr.simplex, MultiSimplex(3, 2), constructor=pt.matrix, test=floatX(np.zeros((2, 2)))
126+
)
123127

124128

125129
def test_simplex_bounds():
126-
vals = get_values(tr.simplex, Vector(R, 2), pt.dvector, np.array([0, 0]))
130+
vals = get_values(tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0])))
127131

128132
close_to(vals.sum(axis=1), 1, tol)
129133
close_to_logical(vals > 0, True, tol)
130134
close_to_logical(vals < 1, True, tol)
131135

132-
check_jacobian_det(tr.simplex, Vector(R, 2), pt.dvector, np.array([0, 0]), lambda x: x[:-1])
136+
check_jacobian_det(
137+
tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), lambda x: x[:-1]
138+
)
133139

134140

135141
def test_simplex_accuracy():
136-
val = np.array([-30])
137-
x = pt.dvector("x")
142+
val = floatX(np.array([-30]))
143+
x = pt.vector("x")
138144
x.tag.test_value = val
139145
identity_f = pytensor.function([x], tr.simplex.forward(x, tr.simplex.backward(x, x)))
140146
close_to(val, identity_f(val), tol)
@@ -148,28 +154,39 @@ def test_sum_to_1():
148154
tr.SumTo1(2)
149155

150156
check_jacobian_det(
151-
tr.univariate_sum_to_1, Vector(Unit, 2), pt.dvector, np.array([0, 0]), lambda x: x[:-1]
157+
tr.univariate_sum_to_1,
158+
Vector(Unit, 2),
159+
pt.vector,
160+
floatX(np.array([0, 0])),
161+
lambda x: x[:-1],
152162
)
153163
check_jacobian_det(
154-
tr.multivariate_sum_to_1, Vector(Unit, 2), pt.dvector, np.array([0, 0]), lambda x: x[:-1]
164+
tr.multivariate_sum_to_1,
165+
Vector(Unit, 2),
166+
pt.vector,
167+
floatX(np.array([0, 0])),
168+
lambda x: x[:-1],
155169
)
156170

157171

158172
def test_log():
159173
check_transform(tr.log, Rplusbig)
160174

161175
check_jacobian_det(tr.log, Rplusbig, elemwise=True)
162-
check_jacobian_det(tr.log, Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=True)
176+
check_jacobian_det(tr.log, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True)
163177

164178
vals = get_values(tr.log)
165179
close_to_logical(vals > 0, True, tol)
166180

167181

182+
@pytest.mark.skipif(
183+
pytensor.config.floatX == "float32", reason="Test is designed for 64bit precision"
184+
)
168185
def test_log_exp_m1():
169186
check_transform(tr.log_exp_m1, Rplusbig)
170187

171188
check_jacobian_det(tr.log_exp_m1, Rplusbig, elemwise=True)
172-
check_jacobian_det(tr.log_exp_m1, Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=True)
189+
check_jacobian_det(tr.log_exp_m1, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True)
173190

174191
vals = get_values(tr.log_exp_m1)
175192
close_to_logical(vals > 0, True, tol)
@@ -179,7 +196,7 @@ def test_logodds():
179196
check_transform(tr.logodds, Unit)
180197

181198
check_jacobian_det(tr.logodds, Unit, elemwise=True)
182-
check_jacobian_det(tr.logodds, Vector(Unit, 2), pt.dvector, [0.5, 0.5], elemwise=True)
199+
check_jacobian_det(tr.logodds, Vector(Unit, 2), pt.vector, [0.5, 0.5], elemwise=True)
183200

184201
vals = get_values(tr.logodds)
185202
close_to_logical(vals > 0, True, tol)
@@ -191,7 +208,7 @@ def test_lowerbound():
191208
check_transform(trans, Rplusbig)
192209

193210
check_jacobian_det(trans, Rplusbig, elemwise=True)
194-
check_jacobian_det(trans, Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=True)
211+
check_jacobian_det(trans, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True)
195212

196213
vals = get_values(trans)
197214
close_to_logical(vals > 0, True, tol)
@@ -202,7 +219,7 @@ def test_upperbound():
202219
check_transform(trans, Rminusbig)
203220

204221
check_jacobian_det(trans, Rminusbig, elemwise=True)
205-
check_jacobian_det(trans, Vector(Rminusbig, 2), pt.dvector, [-1, -1], elemwise=True)
222+
check_jacobian_det(trans, Vector(Rminusbig, 2), pt.vector, [-1, -1], elemwise=True)
206223

207224
vals = get_values(trans)
208225
close_to_logical(vals < 0, True, tol)
@@ -234,7 +251,7 @@ def test_interval_near_boundary():
234251
pm.Uniform("x", initval=x0, lower=lb, upper=ub)
235252

236253
log_prob = model.point_logps()
237-
np.testing.assert_allclose(list(log_prob.values()), np.array([-52.68]))
254+
np.testing.assert_allclose(list(log_prob.values()), floatX(np.array([-52.68])))
238255

239256

240257
def test_circular():
@@ -257,19 +274,19 @@ def test_ordered():
257274
tr.Ordered(2)
258275

259276
check_jacobian_det(
260-
tr.univariate_ordered, Vector(R, 2), pt.dvector, np.array([0, 0]), elemwise=False
277+
tr.univariate_ordered, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), elemwise=False
261278
)
262279
check_jacobian_det(
263-
tr.multivariate_ordered, Vector(R, 2), pt.dvector, np.array([0, 0]), elemwise=False
280+
tr.multivariate_ordered, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), elemwise=False
264281
)
265282

266-
vals = get_values(tr.univariate_ordered, Vector(R, 3), pt.dvector, np.zeros(3))
283+
vals = get_values(tr.univariate_ordered, Vector(R, 3), pt.vector, floatX(np.zeros(3)))
267284
close_to_logical(np.diff(vals) >= 0, True, tol)
268285

269286

270287
def test_chain_values():
271288
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
272-
vals = get_values(chain_tranf, Vector(R, 5), pt.dvector, np.zeros(5))
289+
vals = get_values(chain_tranf, Vector(R, 5), pt.vector, floatX(np.zeros(5)))
273290
close_to_logical(np.diff(vals) >= 0, True, tol)
274291

275292

@@ -281,7 +298,7 @@ def test_chain_vector_transform():
281298
@pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.")
282299
def test_chain_jacob_det():
283300
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
284-
check_jacobian_det(chain_tranf, Vector(R, 4), pt.dvector, np.zeros(4), elemwise=False)
301+
check_jacobian_det(chain_tranf, Vector(R, 4), pt.vector, floatX(np.zeros(4)), elemwise=False)
285302

286303

287304
class TestElementWiseLogp(SeededTest):
@@ -432,7 +449,7 @@ def transform_params(*inputs):
432449
[
433450
(0.0, 1.0, 2.0, 2),
434451
(-10, 0, 200, (2, 3)),
435-
(np.zeros(3), np.ones(3), np.ones(3), (4, 3)),
452+
(floatX(np.zeros(3)), floatX(np.ones(3)), floatX(np.ones(3)), (4, 3)),
436453
],
437454
)
438455
def test_triangular(self, lower, c, upper, size):
@@ -449,7 +466,8 @@ def transform_params(*inputs):
449466
self.check_transform_elementwise_logp(model)
450467

451468
@pytest.mark.parametrize(
452-
"mu,kappa,size", [(0.0, 1.0, 2), (-0.5, 5.5, (2, 3)), (np.zeros(3), np.ones(3), (4, 3))]
469+
"mu,kappa,size",
470+
[(0.0, 1.0, 2), (-0.5, 5.5, (2, 3)), (floatX(np.zeros(3)), floatX(np.ones(3)), (4, 3))],
453471
)
454472
def test_vonmises(self, mu, kappa, size):
455473
model = self.build_model(
@@ -549,7 +567,9 @@ def transform_params(*inputs):
549567
)
550568
self.check_vectortransform_elementwise_logp(model)
551569

552-
@pytest.mark.parametrize("mu,kappa,size", [(0.0, 1.0, (2,)), (np.zeros(3), np.ones(3), (4, 3))])
570+
@pytest.mark.parametrize(
571+
"mu,kappa,size", [(0.0, 1.0, (2,)), (floatX(np.zeros(3)), floatX(np.ones(3)), (4, 3))]
572+
)
553573
def test_vonmises_ordered(self, mu, kappa, size):
554574
initval = np.sort(np.abs(np.random.rand(*size)))
555575
model = self.build_model(
@@ -566,7 +586,12 @@ def test_vonmises_ordered(self, mu, kappa, size):
566586
[
567587
(0.0, 1.0, (2,), tr.simplex),
568588
(0.5, 5.5, (2, 3), tr.simplex),
569-
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.univariate_sum_to_1, tr.logodds])),
589+
(
590+
floatX(np.zeros(3)),
591+
floatX(np.ones(3)),
592+
(4, 3),
593+
tr.Chain([tr.univariate_sum_to_1, tr.logodds]),
594+
),
570595
],
571596
)
572597
def test_uniform_other(self, lower, upper, size, transform):
@@ -583,8 +608,8 @@ def test_uniform_other(self, lower, upper, size, transform):
583608
@pytest.mark.parametrize(
584609
"mu,cov,size,shape",
585610
[
586-
(np.zeros(2), np.diag(np.ones(2)), None, (2,)),
587-
(np.zeros(3), np.diag(np.ones(3)), (4,), (4, 3)),
611+
(floatX(np.zeros(2)), floatX(np.diag(np.ones(2))), None, (2,)),
612+
(floatX(np.zeros(3)), floatX(np.diag(np.ones(3))), (4,), (4, 3)),
588613
],
589614
)
590615
def test_mvnormal_ordered(self, mu, cov, size, shape):
@@ -643,7 +668,7 @@ def test_2d_univariate_ordered():
643668
)
644669

645670
log_p = model.compile_logp(sum=False)(
646-
{"x_1d_ordered__": np.zeros((4,)), "x_2d_ordered__": np.zeros((10, 4))}
671+
{"x_1d_ordered__": floatX(np.zeros((4,))), "x_2d_ordered__": floatX(np.zeros((10, 4)))}
647672
)
648673
np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1])
649674

@@ -667,7 +692,7 @@ def test_2d_multivariate_ordered():
667692
)
668693

669694
log_p = model.compile_logp(sum=False)(
670-
{"x_1d_ordered__": np.zeros((2,)), "x_2d_ordered__": np.zeros((2, 2))}
695+
{"x_1d_ordered__": floatX(np.zeros((2,))), "x_2d_ordered__": floatX(np.zeros((2, 2)))}
671696
)
672697
np.testing.assert_allclose(log_p[0], log_p[1])
673698

@@ -690,7 +715,7 @@ def test_2d_univariate_sum_to_1():
690715
)
691716

692717
log_p = model.compile_logp(sum=False)(
693-
{"x_1d_sumto1__": np.zeros(3), "x_2d_sumto1__": np.zeros((10, 3))}
718+
{"x_1d_sumto1__": floatX(np.zeros(3)), "x_2d_sumto1__": floatX(np.zeros((10, 3)))}
694719
)
695720
np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1])
696721

@@ -712,6 +737,6 @@ def test_2d_multivariate_sum_to_1():
712737
)
713738

714739
log_p = model.compile_logp(sum=False)(
715-
{"x_1d_sumto1__": np.zeros(1), "x_2d_sumto1__": np.zeros((2, 1))}
740+
{"x_1d_sumto1__": floatX(np.zeros(1)), "x_2d_sumto1__": floatX(np.zeros((2, 1)))}
716741
)
717742
np.testing.assert_allclose(log_p[0], log_p[1])

0 commit comments

Comments
 (0)