6
6
from collections import defaultdict , deque
7
7
from collections .abc import Generator , Sequence
8
8
from functools import cache , reduce
9
+ from operator import or_
9
10
from typing import Literal
10
11
from warnings import warn
11
12
15
16
from pytensor .compile .mode import get_target_language
16
17
from pytensor .configdefaults import config
17
18
from pytensor .graph import FunctionGraph , Op
18
- from pytensor .graph .basic import Apply , Variable , ancestors , io_toposort
19
+ from pytensor .graph .basic import Apply , Variable , io_toposort
19
20
from pytensor .graph .destroyhandler import DestroyHandler , inplace_candidates
20
21
from pytensor .graph .features import ReplaceValidate
21
22
from pytensor .graph .fg import Output
@@ -661,16 +662,9 @@ def find_fuseable_subgraph(
661
662
visited_nodes : set [Apply ],
662
663
fuseable_clients : FUSEABLE_MAPPING ,
663
664
unfuseable_clients : UNFUSEABLE_MAPPING ,
665
+ ancestors_bitset : dict [Apply , int ],
664
666
toposort_index : dict [Apply , int ],
665
667
) -> tuple [list [Variable ], list [Variable ]]:
666
- def variables_depend_on (
667
- variables , depend_on , stop_search_at = None
668
- ) -> bool :
669
- return any (
670
- a in depend_on
671
- for a in ancestors (variables , blockers = stop_search_at )
672
- )
673
-
674
668
for starting_node in toposort_index :
675
669
if starting_node in visited_nodes :
676
670
continue
@@ -682,7 +676,8 @@ def variables_depend_on(
682
676
683
677
subgraph_inputs : dict [Variable , Literal [None ]] = {} # ordered set
684
678
subgraph_outputs : dict [Variable , Literal [None ]] = {} # ordered set
685
- unfuseable_clients_subgraph : set [Variable ] = set ()
679
+ subgraph_inputs_ancestors_bitset = 0
680
+ unfuseable_clients_subgraph_bitset = 0
686
681
687
682
# If we need to manipulate the maps in place, we'll do a shallow copy later
688
683
# For now we query on the original ones
@@ -714,50 +709,32 @@ def variables_depend_on(
714
709
if must_become_output :
715
710
subgraph_outputs .pop (next_out , None )
716
711
717
- required_unfuseable_inputs = [
718
- inp
719
- for inp in next_node .inputs
720
- if next_node in unfuseable_clients_clone .get (inp )
721
- ]
722
- new_required_unfuseable_inputs = [
723
- inp
724
- for inp in required_unfuseable_inputs
725
- if inp not in subgraph_inputs
726
- ]
727
-
728
- must_backtrack = False
729
- if new_required_unfuseable_inputs and subgraph_outputs :
730
- # We need to check that any new inputs required by this node
731
- # do not depend on other outputs of the current subgraph,
732
- # via an unfuseable path.
733
- if variables_depend_on (
734
- [next_out ],
735
- depend_on = unfuseable_clients_subgraph ,
736
- stop_search_at = subgraph_outputs ,
737
- ):
738
- must_backtrack = True
712
+ # We need to check that any inputs required by this node
713
+ # do not depend on other outputs of the current subgraph,
714
+ # via an unfuseable path.
715
+ must_backtrack = (
716
+ ancestors_bitset [next_node ]
717
+ & unfuseable_clients_subgraph_bitset
718
+ )
739
719
740
720
if not must_backtrack :
741
- implied_unfuseable_clients = {
742
- c
743
- for client in unfuseable_clients_clone .get (next_out )
744
- if not isinstance (client .op , Output )
745
- for c in client .outputs
746
- }
747
-
748
- new_implied_unfuseable_clients = (
749
- implied_unfuseable_clients - unfuseable_clients_subgraph
721
+ implied_unfuseable_clients_bitset = reduce (
722
+ or_ ,
723
+ (
724
+ 1 << toposort_index [client ]
725
+ for client in unfuseable_clients_clone .get (next_out )
726
+ if not isinstance (client .op , Output )
727
+ ),
728
+ 0 ,
750
729
)
751
730
752
- if new_implied_unfuseable_clients and subgraph_inputs :
753
- # We need to check that any inputs of the current subgraph
754
- # do not depend on other clients of this node,
755
- # via an unfuseable path.
756
- if variables_depend_on (
757
- subgraph_inputs ,
758
- depend_on = new_implied_unfuseable_clients ,
759
- ):
760
- must_backtrack = True
731
+ # We need to check that any inputs of the current subgraph
732
+ # do not depend on other clients of this node,
733
+ # via an unfuseable path.
734
+ must_backtrack = (
735
+ subgraph_inputs_ancestors_bitset
736
+ & implied_unfuseable_clients_bitset
737
+ )
761
738
762
739
if must_backtrack :
763
740
for inp in next_node .inputs :
@@ -798,29 +775,24 @@ def variables_depend_on(
798
775
# immediate dependency problems. Update subgraph
799
776
# mappings as if it next_node was part of it.
800
777
# Useless inputs will be removed by the useless Composite rewrite
801
- for inp in new_required_unfuseable_inputs :
802
- subgraph_inputs [inp ] = None
803
-
804
778
if must_become_output :
805
779
subgraph_outputs [next_out ] = None
806
- unfuseable_clients_subgraph . update (
807
- new_implied_unfuseable_clients
780
+ unfuseable_clients_subgraph_bitset |= (
781
+ implied_unfuseable_clients_bitset
808
782
)
809
783
810
- # Expand through unvisited fuseable ancestors
811
- fuseable_nodes_to_visit .extendleft (
812
- sorted (
813
- (
814
- inp .owner
815
- for inp in next_node .inputs
816
- if (
817
- inp not in required_unfuseable_inputs
818
- and inp .owner not in visited_nodes
819
- )
820
- ),
821
- key = toposort_index .get , # type: ignore[arg-type]
822
- )
823
- )
784
+ for inp in sorted (
785
+ next_node .inputs ,
786
+ key = lambda x : toposort_index .get (x .owner , - 1 ),
787
+ ):
788
+ if next_node in unfuseable_clients_clone .get (inp , ()):
789
+ # input must become an input of the subgraph since it's unfuseable with new node
790
+ subgraph_inputs_ancestors_bitset |= (
791
+ ancestors_bitset .get (inp .owner , 0 )
792
+ )
793
+ subgraph_inputs [inp ] = None
794
+ elif inp .owner not in visited_nodes :
795
+ fuseable_nodes_to_visit .appendleft (inp .owner )
824
796
825
797
# Expand through unvisited fuseable clients
826
798
fuseable_nodes_to_visit .extend (
@@ -857,6 +829,8 @@ def update_fuseable_mappings_after_fg_replace(
857
829
visited_nodes : set [Apply ],
858
830
fuseable_clients : FUSEABLE_MAPPING ,
859
831
unfuseable_clients : UNFUSEABLE_MAPPING ,
832
+ toposort_index : dict [Apply , int ],
833
+ ancestors_bitset : dict [Apply , int ],
860
834
starting_nodes : set [Apply ],
861
835
updated_nodes : set [Apply ],
862
836
) -> None :
@@ -867,11 +841,25 @@ def update_fuseable_mappings_after_fg_replace(
867
841
dropped_nodes = starting_nodes - updated_nodes
868
842
869
843
# Remove intermediate Composite nodes from mappings
844
+ # And compute the ancestors bitset of the new composite node
845
+ # As well as the new toposort index for the new node
846
+ new_node_ancestor_bitset = 0
847
+ new_node_toposort_index = len (toposort_index )
870
848
for dropped_node in dropped_nodes :
871
849
(dropped_out ,) = dropped_node .outputs
872
850
fuseable_clients .pop (dropped_out , None )
873
851
unfuseable_clients .pop (dropped_out , None )
874
852
visited_nodes .remove (dropped_node )
853
+ # The new composite ancestor bitset is the union
854
+ # of the ancestors of all the dropped nodes
855
+ new_node_ancestor_bitset |= ancestors_bitset [dropped_node ]
856
+ # The new composite node can have the same order as the latest node that was absorbed into it
857
+ new_node_toposort_index = max (
858
+ new_node_toposort_index , toposort_index [dropped_node ]
859
+ )
860
+
861
+ ancestors_bitset [new_composite_node ] = new_node_ancestor_bitset
862
+ toposort_index [new_composite_node ] = new_node_toposort_index
875
863
876
864
# Update fuseable information for subgraph inputs
877
865
for inp in subgraph_inputs :
@@ -903,12 +891,23 @@ def update_fuseable_mappings_after_fg_replace(
903
891
fuseable_clients , unfuseable_clients = initialize_fuseable_mappings (fg = fg )
904
892
visited_nodes : set [Apply ] = set ()
905
893
toposort_index = {node : i for i , node in enumerate (fgraph .toposort ())}
894
+ # Create a bitset for each node of all its ancestors
895
+ # This allows to quickly check if a variable depends on a set
896
+ ancestors_bitset : dict [Apply , int ] = {}
897
+ for node , index in toposort_index .items ():
898
+ node_ancestor_bitset = 1 << index
899
+ for inp in node .inputs :
900
+ if (inp_node := inp .owner ) is not None :
901
+ node_ancestor_bitset |= ancestors_bitset [inp_node ]
902
+ ancestors_bitset [node ] = node_ancestor_bitset
903
+
906
904
while True :
907
905
try :
908
906
subgraph_inputs , subgraph_outputs = find_fuseable_subgraph (
909
907
visited_nodes = visited_nodes ,
910
908
fuseable_clients = fuseable_clients ,
911
909
unfuseable_clients = unfuseable_clients ,
910
+ ancestors_bitset = ancestors_bitset ,
912
911
toposort_index = toposort_index ,
913
912
)
914
913
except ValueError :
@@ -927,6 +926,8 @@ def update_fuseable_mappings_after_fg_replace(
927
926
visited_nodes = visited_nodes ,
928
927
fuseable_clients = fuseable_clients ,
929
928
unfuseable_clients = unfuseable_clients ,
929
+ toposort_index = toposort_index ,
930
+ ancestors_bitset = ancestors_bitset ,
930
931
starting_nodes = starting_nodes ,
931
932
updated_nodes = fg .apply_nodes ,
932
933
)
0 commit comments