Skip to content

Commit ddc6b65

Browse files
committed
Track specific Elemwise Ops in logprob rewrites
1 parent 7c32d3c commit ddc6b65

File tree

2 files changed

+6
-17
lines changed

2 files changed

+6
-17
lines changed

pymc/logprob/censoring.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from pytensor.graph.rewriting.basic import node_rewriter
4545
from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven
4646
from pytensor.scalar.basic import clip as scalar_clip
47-
from pytensor.tensor.elemwise import Elemwise
47+
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even
4848
from pytensor.tensor.var import TensorConstant
4949

5050
from pymc.logprob.abstract import (
@@ -67,7 +67,7 @@ class MeasurableClip(MeasurableElemwise):
6767
measurable_clip = MeasurableClip(scalar_clip)
6868

6969

70-
@node_rewriter(tracks=[Elemwise])
70+
@node_rewriter(tracks=[clip])
7171
def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableClip]]:
7272
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)
7373

@@ -78,9 +78,6 @@ def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[List[Me
7878
if isinstance(node.op, MeasurableClip):
7979
return None # pragma: no cover
8080

81-
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Clip)):
82-
return None
83-
8481
clipped_var = node.outputs[0]
8582
base_var, lower_bound, upper_bound = node.inputs
8683

@@ -179,7 +176,7 @@ class MeasurableRound(MeasurableElemwise):
179176
valid_scalar_types = (RoundHalfToEven, Floor, Ceil)
180177

181178

182-
@node_rewriter(tracks=[Elemwise])
179+
@node_rewriter(tracks=[ceil, floor, round_half_to_even])
183180
def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableRound]]:
184181

185182
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
@@ -189,12 +186,6 @@ def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[Lis
189186
if isinstance(node.op, MeasurableRound):
190187
return None # pragma: no cover
191188

192-
if not (
193-
isinstance(node.op, Elemwise)
194-
and isinstance(node.op.scalar_op, MeasurableRound.valid_scalar_types)
195-
):
196-
return None
197-
198189
(rounded_var,) = node.outputs
199190
(base_var,) = node.inputs
200191

pymc/logprob/transforms.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from pytensor.graph.op import Op
4949
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
5050
from pytensor.scalar import Add, Exp, Log, Mul
51-
from pytensor.tensor.elemwise import Elemwise
51+
from pytensor.tensor.math import add, exp, log, mul
5252
from pytensor.tensor.rewriting.basic import (
5353
register_specialize,
5454
register_stabilize,
@@ -256,12 +256,9 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
256256
return input_logprob + jacobian
257257

258258

259-
@node_rewriter([Elemwise])
259+
@node_rewriter([exp, log, add, mul])
260260
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
261261
"""Find measurable transformations from Elemwise operators."""
262-
scalar_op = node.op.scalar_op
263-
if not isinstance(scalar_op, MeasurableTransform.valid_scalar_types):
264-
return None
265262

266263
# Node was already converted
267264
if isinstance(node.op, MeasurableVariable):
@@ -311,6 +308,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
311308
# This seems to be the only thing preventing nested rewrites from being erased
312309
measurable_input = assign_custom_measurable_outputs(measurable_input.owner)
313310

311+
scalar_op = node.op.scalar_op
314312
measurable_input_idx = 0
315313
transform_inputs: Tuple[TensorVariable, ...] = (measurable_input,)
316314
transform: RVTransform

0 commit comments

Comments
 (0)