Skip to content

Commit 46f8227

Browse files
committed
Use more direct imports in rewriting/elemwise.py
1 parent d26374c commit 46f8227

File tree

1 file changed

+38
-31
lines changed

1 file changed

+38
-31
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,15 @@
88
from operator import or_
99
from warnings import warn
1010

11-
import pytensor.scalar.basic as ps
12-
from pytensor import clone_replace, compile
1311
from pytensor.compile.function.types import Supervisor
14-
from pytensor.compile.mode import get_target_language
12+
from pytensor.compile.mode import get_target_language, optdb
1513
from pytensor.configdefaults import config
1614
from pytensor.graph.basic import Apply, Variable
1715
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
1816
from pytensor.graph.features import ReplaceValidate
1917
from pytensor.graph.fg import FunctionGraph, Output
2018
from pytensor.graph.op import Op
19+
from pytensor.graph.replace import clone_replace
2120
from pytensor.graph.rewriting.basic import (
2221
GraphRewriter,
2322
copy_stack_trace,
@@ -30,11 +29,21 @@
3029
from pytensor.graph.rewriting.unify import OpPattern
3130
from pytensor.graph.traversal import toposort
3231
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
33-
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
34-
from pytensor.tensor.basic import (
35-
MakeVector,
36-
constant,
32+
from pytensor.scalar import (
33+
Add,
34+
Composite,
35+
Mul,
36+
ScalarOp,
37+
get_scalar_type,
38+
transfer_type,
39+
upcast_out,
40+
upgrade_to_float,
3741
)
42+
from pytensor.scalar import cast as scalar_cast
43+
from pytensor.scalar import constant as scalar_constant
44+
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
45+
from pytensor.tensor.basic import MakeVector
46+
from pytensor.tensor.basic import constant as tensor_constant
3847
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
3948
from pytensor.tensor.math import add, exp, mul
4049
from pytensor.tensor.rewriting.basic import (
@@ -280,7 +289,7 @@ def create_inplace_node(self, node, inplace_pattern):
280289
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
281290
if hasattr(scalar_op, "make_new_inplace"):
282291
new_scalar_op = scalar_op.make_new_inplace(
283-
ps.transfer_type(
292+
transfer_type(
284293
*[
285294
inplace_pattern.get(i, o.dtype)
286295
for i, o in enumerate(node.outputs)
@@ -289,14 +298,14 @@ def create_inplace_node(self, node, inplace_pattern):
289298
)
290299
else:
291300
new_scalar_op = type(scalar_op)(
292-
ps.transfer_type(
301+
transfer_type(
293302
*[inplace_pattern.get(i, None) for i in range(len(node.outputs))]
294303
)
295304
)
296305
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
297306

298307

299-
compile.optdb.register(
308+
optdb.register(
300309
"inplace_elemwise",
301310
InplaceElemwiseOptimizer(),
302311
"inplace_elemwise_opt", # for historic reason
@@ -428,10 +437,8 @@ def local_useless_dimshuffle_makevector(fgraph, node):
428437
@register_canonicalize
429438
@node_rewriter(
430439
[
431-
elemwise_of(
432-
OpPattern(ps.ScalarOp, output_types_preference=ps.upgrade_to_float)
433-
),
434-
elemwise_of(OpPattern(ps.ScalarOp, output_types_preference=ps.upcast_out)),
440+
elemwise_of(OpPattern(ScalarOp, output_types_preference=upgrade_to_float)),
441+
elemwise_of(OpPattern(ScalarOp, output_types_preference=upcast_out)),
435442
]
436443
)
437444
def local_upcast_elemwise_constant_inputs(fgraph, node):
@@ -452,7 +459,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
452459
changed = False
453460
for i, inp in enumerate(node.inputs):
454461
if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant):
455-
new_inputs[i] = constant(inp.data.astype(output_dtype))
462+
new_inputs[i] = tensor_constant(inp.data.astype(output_dtype))
456463
changed = True
457464

458465
if not changed:
@@ -531,7 +538,7 @@ def add_requirements(self, fgraph):
531538
@staticmethod
532539
def elemwise_to_scalar(inputs, outputs):
533540
replacement = {
534-
inp: ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
541+
inp: get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
535542
}
536543
for node in toposort(outputs, blockers=inputs):
537544
scalar_inputs = [replacement[inp] for inp in node.inputs]
@@ -853,7 +860,7 @@ def elemwise_scalar_op_has_c_code(
853860
scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs)
854861
composite_outputs = Elemwise(
855862
# No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables
856-
ps.Composite(scalar_inputs, scalar_outputs, clone_graph=False)
863+
Composite(scalar_inputs, scalar_outputs, clone_graph=False)
857864
)(*inputs, return_list=True)
858865
assert len(outputs) == len(composite_outputs)
859866
for old_out, composite_out in zip(outputs, composite_outputs):
@@ -913,7 +920,7 @@ def print_profile(stream, prof, level=0):
913920

914921
@register_canonicalize
915922
@register_specialize
916-
@node_rewriter([elemwise_of(ps.Composite)])
923+
@node_rewriter([elemwise_of(Composite)])
917924
def local_useless_composite_outputs(fgraph, node):
918925
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
919926
comp = node.op.scalar_op
@@ -934,7 +941,7 @@ def local_useless_composite_outputs(fgraph, node):
934941
node.outputs
935942
):
936943
used_inputs = [node.inputs[i] for i in used_inputs_idxs]
937-
c = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
944+
c = Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
938945
e = Elemwise(scalar_op=c)(*used_inputs, return_list=True)
939946
return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True))
940947

@@ -948,7 +955,7 @@ def local_careduce_fusion(fgraph, node):
948955

949956
# FIXME: This check is needed because of the faulty logic in the FIXME below!
950957
# Right now, rewrite only works for `Sum`/`Prod`
951-
if not isinstance(car_scalar_op, ps.Add | ps.Mul):
958+
if not isinstance(car_scalar_op, Add | Mul):
952959
return None
953960

954961
elm_node = car_input.owner
@@ -992,19 +999,19 @@ def local_careduce_fusion(fgraph, node):
992999
car_acc_dtype = node.op.acc_dtype
9931000

9941001
scalar_elm_inputs = [
995-
ps.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs
1002+
get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs
9961003
]
9971004

9981005
elm_output = elm_scalar_op(*scalar_elm_inputs)
9991006

10001007
# This input represents the previous value in the `CAReduce` binary reduction
1001-
carried_car_input = ps.get_scalar_type(car_acc_dtype).make_variable()
1008+
carried_car_input = get_scalar_type(car_acc_dtype).make_variable()
10021009

10031010
scalar_fused_output = car_scalar_op(carried_car_input, elm_output)
10041011
if scalar_fused_output.type.dtype != car_acc_dtype:
1005-
scalar_fused_output = ps.cast(scalar_fused_output, car_acc_dtype)
1012+
scalar_fused_output = scalar_cast(scalar_fused_output, car_acc_dtype)
10061013

1007-
fused_scalar_op = ps.Composite(
1014+
fused_scalar_op = Composite(
10081015
inputs=[carried_car_input, *scalar_elm_inputs], outputs=[scalar_fused_output]
10091016
)
10101017

@@ -1025,7 +1032,7 @@ def local_careduce_fusion(fgraph, node):
10251032
return [new_car_op(*elm_inputs)]
10261033

10271034

1028-
@node_rewriter([elemwise_of(ps.Composite)])
1035+
@node_rewriter([elemwise_of(Composite)])
10291036
def local_inline_composite_constants(fgraph, node):
10301037
"""Inline scalar constants in Composite graphs."""
10311038
composite_op = node.op.scalar_op
@@ -1041,7 +1048,7 @@ def local_inline_composite_constants(fgraph, node):
10411048
and "complex" not in outer_inp.type.dtype
10421049
):
10431050
if outer_inp.unique_value is not None:
1044-
inner_replacements[inner_inp] = ps.constant(
1051+
inner_replacements[inner_inp] = scalar_constant(
10451052
outer_inp.unique_value, dtype=inner_inp.dtype
10461053
)
10471054
continue
@@ -1054,7 +1061,7 @@ def local_inline_composite_constants(fgraph, node):
10541061
new_inner_outs = clone_replace(
10551062
composite_op.fgraph.outputs, replace=inner_replacements
10561063
)
1057-
new_composite_op = ps.Composite(new_inner_inputs, new_inner_outs)
1064+
new_composite_op = Composite(new_inner_inputs, new_inner_outs)
10581065
new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs
10591066

10601067
# Some of the inlined constants were broadcasting the output shape
@@ -1095,7 +1102,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
10951102
if other_inps:
10961103
python_op = operator.mul if node.op == mul else operator.add
10971104
folded_inputs = [reference_inp, *other_inps]
1098-
new_inp = constant(
1105+
new_inp = tensor_constant(
10991106
reduce(python_op, (const.data for const in folded_inputs))
11001107
)
11011108
new_constants = [
@@ -1119,7 +1126,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
11191126

11201127

11211128
add_mul_fusion_seqopt = SequenceDB()
1122-
compile.optdb.register(
1129+
optdb.register(
11231130
"add_mul_fusion",
11241131
add_mul_fusion_seqopt,
11251132
"fast_run",
@@ -1140,7 +1147,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
11401147

11411148
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
11421149
fuse_seqopt = SequenceDB()
1143-
compile.optdb.register(
1150+
optdb.register(
11441151
"elemwise_fusion",
11451152
fuse_seqopt,
11461153
"fast_run",
@@ -1271,7 +1278,7 @@ def split_2f1grad_loop(fgraph, node):
12711278
return replacements
12721279

12731280

1274-
compile.optdb["py_only"].register(
1281+
optdb["py_only"].register(
12751282
"split_2f1grad_loop",
12761283
split_2f1grad_loop,
12771284
"fast_compile",

0 commit comments

Comments
 (0)