Skip to content

Support logp derivation of power(base, rv) #6962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 26, 2023
25 changes: 23 additions & 2 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
cleanup_ir_rewrites_db,
measurable_ir_rewrites_db,
)
from pymc.logprob.utils import check_potential_measurability
from pymc.logprob.utils import CheckParameterValue, check_potential_measurability


class TransformedVariable(Op):
Expand Down Expand Up @@ -617,6 +617,21 @@ def measurable_special_exp_to_exp(fgraph, node):
return [1 / (1 + pt.exp(-inp))]


@node_rewriter([pow])
def measurable_power_exponent_to_exp(fgraph, node):
"""Convert power(base, rv) of `MeasurableVariable`s to exp(log(base) * rv) form."""
base, inp_exponent = node.inputs

# When the base is measurable we have `power(rv, exponent)`, which should be handled by `PowerTransform` and needs no further rewrite.
# Here we change only the cases where exponent is measurable `power(base, rv)` which is not supported by the `PowerTransform`
if check_potential_measurability([base], fgraph.preserve_rv_mappings.rv_values.keys()):
return None

base = CheckParameterValue("base >= 0")(base, pt.all(pt.ge(base, 0.0)))

return [pt.exp(pt.log(base) * inp_exponent)]


@node_rewriter(
[
exp,
Expand Down Expand Up @@ -693,7 +708,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
try:
(power,) = other_inputs
power = pt.get_underlying_scalar_constant_value(power).item()
# Power needs to be a constant
# Power needs to be a constant, if not then proceed to the other case power(base, rv)
except NotScalarConstantError:
return None
transform_inputs = (measurable_input, power)
Expand Down Expand Up @@ -769,6 +784,12 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
"transform",
)

measurable_ir_rewrites_db.register(
"measurable_power_expotent_to_exp",
measurable_power_exponent_to_exp,
"basic",
"transform",
)

measurable_ir_rewrites_db.register(
"find_measurable_transforms",
Expand Down
56 changes: 56 additions & 0 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
TransformValuesMapping,
TransformValuesRewrite,
)
from pymc.logprob.utils import ParameterValueError
from pymc.testing import Rplusbig, Vector, assert_no_rvs
from tests.distributions.test_transform import check_jacobian_det

Expand Down Expand Up @@ -1159,6 +1160,61 @@ def test_special_log_exp_transforms(transform):
assert equal_computations([logp_test], [logp_ref])


def test_measurable_power_exponent_with_constant_base():
# test power(2, rv) = exp2(rv)
# test negative base fails
x_rv_pow = pt.pow(2, pt.random.normal())
x_rv_exp2 = pt.exp2(pt.random.normal())

x_vv_pow = x_rv_pow.clone()
x_vv_exp2 = x_rv_exp2.clone()

x_logp_fn_pow = pytensor.function([x_vv_pow], pt.sum(logp(x_rv_pow, x_vv_pow)))
x_logp_fn_exp2 = pytensor.function([x_vv_exp2], pt.sum(logp(x_rv_exp2, x_vv_exp2)))

np.testing.assert_allclose(x_logp_fn_pow(0.1), x_logp_fn_exp2(0.1))

with pytest.raises(ParameterValueError, match="base >= 0"):
x_rv_neg = pt.pow(-2, pt.random.normal())
x_vv_neg = x_rv_neg.clone()
logp(x_rv_neg, x_vv_neg)


def test_measurable_power_exponent_with_variable_base():
# test with RV when logp(<0) we raise error
base_rv = pt.random.normal([2])
x_raw_rv = pt.random.normal()
x_rv = pt.power(base_rv, x_raw_rv)

x_rv.name = "x"
base_rv.name = "base"
base_vv = base_rv.clone()
x_vv = x_rv.clone()

res = conditional_logp({base_rv: base_vv, x_rv: x_vv})
x_logp = res[x_vv]
logp_vals_fn = pytensor.function([base_vv, x_vv], x_logp)

with pytest.raises(ParameterValueError, match="base >= 0"):
logp_vals_fn(np.array([-2]), np.array([2]))


def test_base_exponent_non_measurable():
# test dual sources of measuravility fails
base_rv = pt.random.normal([2])
x_raw_rv = pt.random.normal()
x_rv = pt.power(base_rv, x_raw_rv)
x_rv.name = "x"

x_vv = x_rv.clone()

with pytest.raises(
RuntimeError,
match="The logprob terms of the following value variables could not be derived: {x}",
):
conditional_logp({x_rv: x_vv})


@pytest.mark.parametrize("shift", [1.5, np.array([-0.5, 1, 0.3])])
@pytest.mark.parametrize("scale", [2.0, np.array([1.5, 3.3, 1.0])])
def test_multivariate_rv_transform(shift, scale):
Expand Down