44
44
from pytensor .graph .rewriting .basic import node_rewriter
45
45
from pytensor .scalar .basic import Ceil , Clip , Floor , RoundHalfToEven
46
46
from pytensor .scalar .basic import clip as scalar_clip
47
- from pytensor .tensor .elemwise import Elemwise
47
+ from pytensor .tensor .math import ceil , clip , floor , round_half_to_even
48
48
from pytensor .tensor .var import TensorConstant
49
49
50
50
from pymc .logprob .abstract import (
@@ -67,7 +67,7 @@ class MeasurableClip(MeasurableElemwise):
67
67
measurable_clip = MeasurableClip (scalar_clip )
68
68
69
69
70
- @node_rewriter (tracks = [Elemwise ])
70
+ @node_rewriter (tracks = [clip ])
71
71
def find_measurable_clips (fgraph : FunctionGraph , node : Node ) -> Optional [List [MeasurableClip ]]:
72
72
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)
73
73
@@ -78,9 +78,6 @@ def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[List[Me
78
78
if isinstance (node .op , MeasurableClip ):
79
79
return None # pragma: no cover
80
80
81
- if not (isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , Clip )):
82
- return None
83
-
84
81
clipped_var = node .outputs [0 ]
85
82
base_var , lower_bound , upper_bound = node .inputs
86
83
@@ -179,7 +176,7 @@ class MeasurableRound(MeasurableElemwise):
179
176
valid_scalar_types = (RoundHalfToEven , Floor , Ceil )
180
177
181
178
182
- @node_rewriter (tracks = [Elemwise ])
179
+ @node_rewriter (tracks = [ceil , floor , round_half_to_even ])
183
180
def find_measurable_roundings (fgraph : FunctionGraph , node : Node ) -> Optional [List [MeasurableRound ]]:
184
181
185
182
rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
@@ -189,12 +186,6 @@ def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[Lis
189
186
if isinstance (node .op , MeasurableRound ):
190
187
return None # pragma: no cover
191
188
192
- if not (
193
- isinstance (node .op , Elemwise )
194
- and isinstance (node .op .scalar_op , MeasurableRound .valid_scalar_types )
195
- ):
196
- return None
197
-
198
189
(rounded_var ,) = node .outputs
199
190
(base_var ,) = node .inputs
200
191
0 commit comments