Skip to content

Commit 9fd3af7

Browse files
Smit-createricardoV94
authored andcommitted
Add stabilization rewrite for log_diff_exp
Co-authored-by: Ricardo Vieira <[email protected]> .rewrite
1 parent b1332b2 commit 9fd3af7

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

pytensor/tensor/rewriting/math.py

+8
Original file line numberDiff line numberDiff line change
@@ -3604,6 +3604,14 @@ def local_reciprocal_1_plus_exp(fgraph, node):
36043604
)
36053605
register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp")
36063606

3607+
# log(exp(a) - exp(b)) -> a + log1mexp(b - a)
3608+
logdiffexp_to_log1mexpdiff = PatternNodeRewriter(
3609+
(log, (sub, (exp, "x"), (exp, "y"))),
3610+
(add, "x", (log1mexp, (sub, "y", "x"))),
3611+
allow_multiple_clients=True,
3612+
)
3613+
register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff")
3614+
36073615

36083616
# log(sigmoid(x) / (1 - sigmoid(x))) -> x
36093617
# i.e logit(sigmoid(x)) -> x

tests/tensor/rewriting/test_math.py

+39
Original file line numberDiff line numberDiff line change
@@ -4136,3 +4136,42 @@ def test_log1mexp_stabilization():
41364136
f(np.array([-0.8, -0.6], dtype=config.floatX)),
41374137
np.log(1 - np.exp([-0.8, -0.6])),
41384138
)
4139+
4140+
4141+
def test_logdiffexp():
4142+
rng = np.random.default_rng(3559)
4143+
mode = Mode("py").including("stabilize").excluding("fusion")
4144+
4145+
x = fmatrix("x")
4146+
y = fmatrix("y")
4147+
f = function([x, y], log(exp(x) - exp(y)), mode=mode)
4148+
4149+
graph = f.maker.fgraph.toposort()
4150+
assert (
4151+
len(
4152+
[
4153+
node
4154+
for node in graph
4155+
if isinstance(node.op, Elemwise)
4156+
and isinstance(node.op.scalar_op, (aes.Exp, aes.Log))
4157+
]
4158+
)
4159+
== 0
4160+
)
4161+
assert (
4162+
len(
4163+
[
4164+
node
4165+
for node in graph
4166+
if isinstance(node.op, Elemwise)
4167+
and isinstance(node.op.scalar_op, aes.Log1mexp)
4168+
]
4169+
)
4170+
== 1
4171+
)
4172+
4173+
y_test = rng.normal(size=(3, 2)).astype("float32")
4174+
x_test = rng.normal(size=(3, 2)).astype("float32") + y_test.max()
4175+
np.testing.assert_almost_equal(
4176+
f(x_test, y_test), np.log(np.exp(x_test) - np.exp(y_test))
4177+
)

0 commit comments

Comments
 (0)