Skip to content

Commit 3888d53

Browse files
committed
Fix bug in Truncated with identity inputs
When using `Deterministic`, variables get wrapped in an `identity` operation. When attempting to define an `icdf`, the logp graph rewrites would remove this useless operation from the graph of the underlying RV and cause a mismatch between explict and implicit inputs of the inner graph of TruncatedRV
1 parent 4e8e986 commit 3888d53

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

pymc/distributions/truncated.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None):
121121
rng=uniform_rng,
122122
size=rv.shape,
123123
).owner.outputs
124-
truncated_rv = icdf(rv, uniform, warn_rvs=False)
124+
# So icdf does not see the random graph of uniform
125+
uniform_type = uniform.type()
126+
truncated_rv = graph_replace(icdf(rv, uniform_type), {uniform_type: uniform})
125127
return TruncatedRV(
126128
base_rv_op=dist.owner.op,
127129
inputs=graph_inputs,

tests/distributions/test_truncated.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pytest
1818
import scipy
1919

20+
from pytensor.scalar import Identity
2021
from pytensor.tensor.random.basic import GeometricRV, NormalRV
2122
from pytensor.tensor.random.type import RandomType
2223

@@ -573,3 +574,14 @@ def maxwell_dist(scale, size):
573574
logp(trunc_x, test_value).eval(),
574575
expected_logp,
575576
)
577+
578+
579+
@pytest.mark.parametrize("dist_op", [icdf_normal, rejection_normal])
580+
def test_truncated_identity_input(dist_op):
581+
# Regression test for https://github.com/pymc-devs/pymc/issues/7312
582+
mu = Exponential.dist(scale=0.5)
583+
mu_identity = mu.copy()
584+
assert isinstance(mu_identity.owner.op.scalar_op, Identity)
585+
586+
rv_out = Truncated.dist(dist=dist_op(mu_identity, 5), lower=0, upper=1)
587+
assert np.ptp(draw(rv_out, draws=500)) < 1

0 commit comments

Comments
 (0)