8
8
from operator import or_
9
9
from warnings import warn
10
10
11
- import pytensor .scalar .basic as ps
12
- from pytensor import clone_replace , compile
13
11
from pytensor .compile .function .types import Supervisor
14
- from pytensor .compile .mode import get_target_language
12
+ from pytensor .compile .mode import get_target_language , optdb
15
13
from pytensor .configdefaults import config
16
14
from pytensor .graph .basic import Apply , Variable
17
15
from pytensor .graph .destroyhandler import DestroyHandler , inplace_candidates
18
16
from pytensor .graph .features import ReplaceValidate
19
17
from pytensor .graph .fg import FunctionGraph , Output
20
18
from pytensor .graph .op import Op
19
+ from pytensor .graph .replace import clone_replace
21
20
from pytensor .graph .rewriting .basic import (
22
21
GraphRewriter ,
23
22
copy_stack_trace ,
30
29
from pytensor .graph .rewriting .unify import OpPattern
31
30
from pytensor .graph .traversal import toposort
32
31
from pytensor .graph .utils import InconsistencyError , MethodNotDefined
33
- from pytensor .scalar .math import Grad2F1Loop , _grad_2f1_loop
34
- from pytensor .tensor .basic import (
35
- MakeVector ,
36
- constant ,
32
+ from pytensor .scalar import (
33
+ Add ,
34
+ Composite ,
35
+ Mul ,
36
+ ScalarOp ,
37
+ get_scalar_type ,
38
+ transfer_type ,
39
+ upcast_out ,
40
+ upgrade_to_float ,
37
41
)
42
+ from pytensor .scalar import cast as scalar_cast
43
+ from pytensor .scalar import constant as scalar_constant
44
+ from pytensor .scalar .math import Grad2F1Loop , _grad_2f1_loop
45
+ from pytensor .tensor .basic import MakeVector
46
+ from pytensor .tensor .basic import constant as tensor_constant
38
47
from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
39
48
from pytensor .tensor .math import add , exp , mul
40
49
from pytensor .tensor .rewriting .basic import (
@@ -280,7 +289,7 @@ def create_inplace_node(self, node, inplace_pattern):
280
289
inplace_pattern = {i : o for i , [o ] in inplace_pattern .items ()}
281
290
if hasattr (scalar_op , "make_new_inplace" ):
282
291
new_scalar_op = scalar_op .make_new_inplace (
283
- ps . transfer_type (
292
+ transfer_type (
284
293
* [
285
294
inplace_pattern .get (i , o .dtype )
286
295
for i , o in enumerate (node .outputs )
@@ -289,14 +298,14 @@ def create_inplace_node(self, node, inplace_pattern):
289
298
)
290
299
else :
291
300
new_scalar_op = type (scalar_op )(
292
- ps . transfer_type (
301
+ transfer_type (
293
302
* [inplace_pattern .get (i , None ) for i in range (len (node .outputs ))]
294
303
)
295
304
)
296
305
return type (op )(new_scalar_op , inplace_pattern ).make_node (* node .inputs )
297
306
298
307
299
- compile . optdb .register (
308
+ optdb .register (
300
309
"inplace_elemwise" ,
301
310
InplaceElemwiseOptimizer (),
302
311
"inplace_elemwise_opt" , # for historic reason
@@ -428,10 +437,8 @@ def local_useless_dimshuffle_makevector(fgraph, node):
428
437
@register_canonicalize
429
438
@node_rewriter (
430
439
[
431
- elemwise_of (
432
- OpPattern (ps .ScalarOp , output_types_preference = ps .upgrade_to_float )
433
- ),
434
- elemwise_of (OpPattern (ps .ScalarOp , output_types_preference = ps .upcast_out )),
440
+ elemwise_of (OpPattern (ScalarOp , output_types_preference = upgrade_to_float )),
441
+ elemwise_of (OpPattern (ScalarOp , output_types_preference = upcast_out )),
435
442
]
436
443
)
437
444
def local_upcast_elemwise_constant_inputs (fgraph , node ):
@@ -452,7 +459,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
452
459
changed = False
453
460
for i , inp in enumerate (node .inputs ):
454
461
if inp .type .dtype != output_dtype and isinstance (inp , TensorConstant ):
455
- new_inputs [i ] = constant (inp .data .astype (output_dtype ))
462
+ new_inputs [i ] = tensor_constant (inp .data .astype (output_dtype ))
456
463
changed = True
457
464
458
465
if not changed :
@@ -531,7 +538,7 @@ def add_requirements(self, fgraph):
531
538
@staticmethod
532
539
def elemwise_to_scalar (inputs , outputs ):
533
540
replacement = {
534
- inp : ps . get_scalar_type (inp .type .dtype ).make_variable () for inp in inputs
541
+ inp : get_scalar_type (inp .type .dtype ).make_variable () for inp in inputs
535
542
}
536
543
for node in toposort (outputs , blockers = inputs ):
537
544
scalar_inputs = [replacement [inp ] for inp in node .inputs ]
@@ -853,7 +860,7 @@ def elemwise_scalar_op_has_c_code(
853
860
scalar_inputs , scalar_outputs = self .elemwise_to_scalar (inputs , outputs )
854
861
composite_outputs = Elemwise (
855
862
# No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables
856
- ps . Composite (scalar_inputs , scalar_outputs , clone_graph = False )
863
+ Composite (scalar_inputs , scalar_outputs , clone_graph = False )
857
864
)(* inputs , return_list = True )
858
865
assert len (outputs ) == len (composite_outputs )
859
866
for old_out , composite_out in zip (outputs , composite_outputs ):
@@ -913,7 +920,7 @@ def print_profile(stream, prof, level=0):
913
920
914
921
@register_canonicalize
915
922
@register_specialize
916
- @node_rewriter ([elemwise_of (ps . Composite )])
923
+ @node_rewriter ([elemwise_of (Composite )])
917
924
def local_useless_composite_outputs (fgraph , node ):
918
925
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
919
926
comp = node .op .scalar_op
@@ -934,7 +941,7 @@ def local_useless_composite_outputs(fgraph, node):
934
941
node .outputs
935
942
):
936
943
used_inputs = [node .inputs [i ] for i in used_inputs_idxs ]
937
- c = ps . Composite (inputs = used_inner_inputs , outputs = used_inner_outputs )
944
+ c = Composite (inputs = used_inner_inputs , outputs = used_inner_outputs )
938
945
e = Elemwise (scalar_op = c )(* used_inputs , return_list = True )
939
946
return dict (zip ([node .outputs [i ] for i in used_outputs_idxs ], e , strict = True ))
940
947
@@ -948,7 +955,7 @@ def local_careduce_fusion(fgraph, node):
948
955
949
956
# FIXME: This check is needed because of the faulty logic in the FIXME below!
950
957
# Right now, rewrite only works for `Sum`/`Prod`
951
- if not isinstance (car_scalar_op , ps . Add | ps . Mul ):
958
+ if not isinstance (car_scalar_op , Add | Mul ):
952
959
return None
953
960
954
961
elm_node = car_input .owner
@@ -992,19 +999,19 @@ def local_careduce_fusion(fgraph, node):
992
999
car_acc_dtype = node .op .acc_dtype
993
1000
994
1001
scalar_elm_inputs = [
995
- ps . get_scalar_type (inp .type .dtype ).make_variable () for inp in elm_inputs
1002
+ get_scalar_type (inp .type .dtype ).make_variable () for inp in elm_inputs
996
1003
]
997
1004
998
1005
elm_output = elm_scalar_op (* scalar_elm_inputs )
999
1006
1000
1007
# This input represents the previous value in the `CAReduce` binary reduction
1001
- carried_car_input = ps . get_scalar_type (car_acc_dtype ).make_variable ()
1008
+ carried_car_input = get_scalar_type (car_acc_dtype ).make_variable ()
1002
1009
1003
1010
scalar_fused_output = car_scalar_op (carried_car_input , elm_output )
1004
1011
if scalar_fused_output .type .dtype != car_acc_dtype :
1005
- scalar_fused_output = ps . cast (scalar_fused_output , car_acc_dtype )
1012
+ scalar_fused_output = scalar_cast (scalar_fused_output , car_acc_dtype )
1006
1013
1007
- fused_scalar_op = ps . Composite (
1014
+ fused_scalar_op = Composite (
1008
1015
inputs = [carried_car_input , * scalar_elm_inputs ], outputs = [scalar_fused_output ]
1009
1016
)
1010
1017
@@ -1025,7 +1032,7 @@ def local_careduce_fusion(fgraph, node):
1025
1032
return [new_car_op (* elm_inputs )]
1026
1033
1027
1034
1028
- @node_rewriter ([elemwise_of (ps . Composite )])
1035
+ @node_rewriter ([elemwise_of (Composite )])
1029
1036
def local_inline_composite_constants (fgraph , node ):
1030
1037
"""Inline scalar constants in Composite graphs."""
1031
1038
composite_op = node .op .scalar_op
@@ -1041,7 +1048,7 @@ def local_inline_composite_constants(fgraph, node):
1041
1048
and "complex" not in outer_inp .type .dtype
1042
1049
):
1043
1050
if outer_inp .unique_value is not None :
1044
- inner_replacements [inner_inp ] = ps . constant (
1051
+ inner_replacements [inner_inp ] = scalar_constant (
1045
1052
outer_inp .unique_value , dtype = inner_inp .dtype
1046
1053
)
1047
1054
continue
@@ -1054,7 +1061,7 @@ def local_inline_composite_constants(fgraph, node):
1054
1061
new_inner_outs = clone_replace (
1055
1062
composite_op .fgraph .outputs , replace = inner_replacements
1056
1063
)
1057
- new_composite_op = ps . Composite (new_inner_inputs , new_inner_outs )
1064
+ new_composite_op = Composite (new_inner_inputs , new_inner_outs )
1058
1065
new_outputs = Elemwise (new_composite_op ).make_node (* new_outer_inputs ).outputs
1059
1066
1060
1067
# Some of the inlined constants were broadcasting the output shape
@@ -1095,7 +1102,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
1095
1102
if other_inps :
1096
1103
python_op = operator .mul if node .op == mul else operator .add
1097
1104
folded_inputs = [reference_inp , * other_inps ]
1098
- new_inp = constant (
1105
+ new_inp = tensor_constant (
1099
1106
reduce (python_op , (const .data for const in folded_inputs ))
1100
1107
)
1101
1108
new_constants = [
@@ -1119,7 +1126,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
1119
1126
1120
1127
1121
1128
add_mul_fusion_seqopt = SequenceDB ()
1122
- compile . optdb .register (
1129
+ optdb .register (
1123
1130
"add_mul_fusion" ,
1124
1131
add_mul_fusion_seqopt ,
1125
1132
"fast_run" ,
@@ -1140,7 +1147,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
1140
1147
1141
1148
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
1142
1149
fuse_seqopt = SequenceDB ()
1143
- compile . optdb .register (
1150
+ optdb .register (
1144
1151
"elemwise_fusion" ,
1145
1152
fuse_seqopt ,
1146
1153
"fast_run" ,
@@ -1271,7 +1278,7 @@ def split_2f1grad_loop(fgraph, node):
1271
1278
return replacements
1272
1279
1273
1280
1274
- compile . optdb ["py_only" ].register (
1281
+ optdb ["py_only" ].register (
1275
1282
"split_2f1grad_loop" ,
1276
1283
split_2f1grad_loop ,
1277
1284
"fast_compile" ,
0 commit comments