6
6
from collections import defaultdict , deque
7
7
from collections .abc import Generator , Sequence
8
8
from functools import cache , reduce
9
+ from operator import or_
9
10
from typing import Literal
10
11
from warnings import warn
11
12
29
30
)
30
31
from pytensor .graph .rewriting .db import SequenceDB
31
32
from pytensor .graph .rewriting .unify import OpPattern
32
- from pytensor .graph .traversal import ancestors , toposort
33
+ from pytensor .graph .traversal import toposort
33
34
from pytensor .graph .utils import InconsistencyError , MethodNotDefined
34
35
from pytensor .scalar .math import Grad2F1Loop , _grad_2f1_loop
35
36
from pytensor .tensor .basic import (
@@ -663,16 +664,9 @@ def find_fuseable_subgraph(
663
664
visited_nodes : set [Apply ],
664
665
fuseable_clients : FUSEABLE_MAPPING ,
665
666
unfuseable_clients : UNFUSEABLE_MAPPING ,
667
+ ancestors_bitset : dict [Apply , int ],
666
668
toposort_index : dict [Apply , int ],
667
669
) -> tuple [list [Variable ], list [Variable ]]:
668
- def variables_depend_on (
669
- variables , depend_on , stop_search_at = None
670
- ) -> bool :
671
- return any (
672
- a in depend_on
673
- for a in ancestors (variables , blockers = stop_search_at )
674
- )
675
-
676
670
for starting_node in toposort_index :
677
671
if starting_node in visited_nodes :
678
672
continue
@@ -684,7 +678,8 @@ def variables_depend_on(
684
678
685
679
subgraph_inputs : dict [Variable , Literal [None ]] = {} # ordered set
686
680
subgraph_outputs : dict [Variable , Literal [None ]] = {} # ordered set
687
- unfuseable_clients_subgraph : set [Variable ] = set ()
681
+ subgraph_inputs_ancestors_bitset = 0
682
+ unfuseable_clients_subgraph_bitset = 0
688
683
689
684
# If we need to manipulate the maps in place, we'll do a shallow copy later
690
685
# For now we query on the original ones
@@ -716,50 +711,32 @@ def variables_depend_on(
716
711
if must_become_output :
717
712
subgraph_outputs .pop (next_out , None )
718
713
719
- required_unfuseable_inputs = [
720
- inp
721
- for inp in next_node .inputs
722
- if next_node in unfuseable_clients_clone .get (inp )
723
- ]
724
- new_required_unfuseable_inputs = [
725
- inp
726
- for inp in required_unfuseable_inputs
727
- if inp not in subgraph_inputs
728
- ]
729
-
730
- must_backtrack = False
731
- if new_required_unfuseable_inputs and subgraph_outputs :
732
- # We need to check that any new inputs required by this node
733
- # do not depend on other outputs of the current subgraph,
734
- # via an unfuseable path.
735
- if variables_depend_on (
736
- [next_out ],
737
- depend_on = unfuseable_clients_subgraph ,
738
- stop_search_at = subgraph_outputs ,
739
- ):
740
- must_backtrack = True
714
+ # We need to check that any inputs required by this node
715
+ # do not depend on other outputs of the current subgraph,
716
+ # via an unfuseable path.
717
+ must_backtrack = (
718
+ ancestors_bitset [next_node ]
719
+ & unfuseable_clients_subgraph_bitset
720
+ )
741
721
742
722
if not must_backtrack :
743
- implied_unfuseable_clients = {
744
- c
745
- for client in unfuseable_clients_clone .get (next_out )
746
- if not isinstance (client .op , Output )
747
- for c in client .outputs
748
- }
749
-
750
- new_implied_unfuseable_clients = (
751
- implied_unfuseable_clients - unfuseable_clients_subgraph
723
+ implied_unfuseable_clients_bitset = reduce (
724
+ or_ ,
725
+ (
726
+ 1 << toposort_index [client ]
727
+ for client in unfuseable_clients_clone .get (next_out )
728
+ if not isinstance (client .op , Output )
729
+ ),
730
+ 0 ,
752
731
)
753
732
754
- if new_implied_unfuseable_clients and subgraph_inputs :
755
- # We need to check that any inputs of the current subgraph
756
- # do not depend on other clients of this node,
757
- # via an unfuseable path.
758
- if variables_depend_on (
759
- subgraph_inputs ,
760
- depend_on = new_implied_unfuseable_clients ,
761
- ):
762
- must_backtrack = True
733
+ # We need to check that any inputs of the current subgraph
734
+ # do not depend on other clients of this node,
735
+ # via an unfuseable path.
736
+ must_backtrack = (
737
+ subgraph_inputs_ancestors_bitset
738
+ & implied_unfuseable_clients_bitset
739
+ )
763
740
764
741
if must_backtrack :
765
742
for inp in next_node .inputs :
@@ -800,29 +777,24 @@ def variables_depend_on(
800
777
# immediate dependency problems. Update subgraph
801
778
# mappings as if it next_node was part of it.
802
779
# Useless inputs will be removed by the useless Composite rewrite
803
- for inp in new_required_unfuseable_inputs :
804
- subgraph_inputs [inp ] = None
805
-
806
780
if must_become_output :
807
781
subgraph_outputs [next_out ] = None
808
- unfuseable_clients_subgraph . update (
809
- new_implied_unfuseable_clients
782
+ unfuseable_clients_subgraph_bitset |= (
783
+ implied_unfuseable_clients_bitset
810
784
)
811
785
812
- # Expand through unvisited fuseable ancestors
813
- fuseable_nodes_to_visit .extendleft (
814
- sorted (
815
- (
816
- inp .owner
817
- for inp in next_node .inputs
818
- if (
819
- inp not in required_unfuseable_inputs
820
- and inp .owner not in visited_nodes
821
- )
822
- ),
823
- key = toposort_index .get , # type: ignore[arg-type]
824
- )
825
- )
786
+ for inp in sorted (
787
+ next_node .inputs ,
788
+ key = lambda x : toposort_index .get (x .owner , - 1 ),
789
+ ):
790
+ if next_node in unfuseable_clients_clone .get (inp , ()):
791
+ # input must become an input of the subgraph since it's unfuseable with new node
792
+ subgraph_inputs_ancestors_bitset |= (
793
+ ancestors_bitset .get (inp .owner , 0 )
794
+ )
795
+ subgraph_inputs [inp ] = None
796
+ elif inp .owner not in visited_nodes :
797
+ fuseable_nodes_to_visit .appendleft (inp .owner )
826
798
827
799
# Expand through unvisited fuseable clients
828
800
fuseable_nodes_to_visit .extend (
@@ -859,6 +831,8 @@ def update_fuseable_mappings_after_fg_replace(
859
831
visited_nodes : set [Apply ],
860
832
fuseable_clients : FUSEABLE_MAPPING ,
861
833
unfuseable_clients : UNFUSEABLE_MAPPING ,
834
+ toposort_index : dict [Apply , int ],
835
+ ancestors_bitset : dict [Apply , int ],
862
836
starting_nodes : set [Apply ],
863
837
updated_nodes : set [Apply ],
864
838
) -> None :
@@ -869,11 +843,25 @@ def update_fuseable_mappings_after_fg_replace(
869
843
dropped_nodes = starting_nodes - updated_nodes
870
844
871
845
# Remove intermediate Composite nodes from mappings
846
+ # And compute the ancestors bitset of the new composite node
847
+ # As well as the new toposort index for the new node
848
+ new_node_ancestor_bitset = 0
849
+ new_node_toposort_index = len (toposort_index )
872
850
for dropped_node in dropped_nodes :
873
851
(dropped_out ,) = dropped_node .outputs
874
852
fuseable_clients .pop (dropped_out , None )
875
853
unfuseable_clients .pop (dropped_out , None )
876
854
visited_nodes .remove (dropped_node )
855
+ # The new composite ancestor bitset is the union
856
+ # of the ancestors of all the dropped nodes
857
+ new_node_ancestor_bitset |= ancestors_bitset [dropped_node ]
858
+ # The new composite node can have the same order as the latest node that was absorbed into it
859
+ new_node_toposort_index = max (
860
+ new_node_toposort_index , toposort_index [dropped_node ]
861
+ )
862
+
863
+ ancestors_bitset [new_composite_node ] = new_node_ancestor_bitset
864
+ toposort_index [new_composite_node ] = new_node_toposort_index
877
865
878
866
# Update fuseable information for subgraph inputs
879
867
for inp in subgraph_inputs :
@@ -905,12 +893,23 @@ def update_fuseable_mappings_after_fg_replace(
905
893
fuseable_clients , unfuseable_clients = initialize_fuseable_mappings (fg = fg )
906
894
visited_nodes : set [Apply ] = set ()
907
895
toposort_index = {node : i for i , node in enumerate (fgraph .toposort ())}
896
+ # Create a bitset for each node of all its ancestors
897
+ # This allows to quickly check if a variable depends on a set
898
+ ancestors_bitset : dict [Apply , int ] = {}
899
+ for node , index in toposort_index .items ():
900
+ node_ancestor_bitset = 1 << index
901
+ for inp in node .inputs :
902
+ if (inp_node := inp .owner ) is not None :
903
+ node_ancestor_bitset |= ancestors_bitset [inp_node ]
904
+ ancestors_bitset [node ] = node_ancestor_bitset
905
+
908
906
while True :
909
907
try :
910
908
subgraph_inputs , subgraph_outputs = find_fuseable_subgraph (
911
909
visited_nodes = visited_nodes ,
912
910
fuseable_clients = fuseable_clients ,
913
911
unfuseable_clients = unfuseable_clients ,
912
+ ancestors_bitset = ancestors_bitset ,
914
913
toposort_index = toposort_index ,
915
914
)
916
915
except ValueError :
@@ -929,6 +928,8 @@ def update_fuseable_mappings_after_fg_replace(
929
928
visited_nodes = visited_nodes ,
930
929
fuseable_clients = fuseable_clients ,
931
930
unfuseable_clients = unfuseable_clients ,
931
+ toposort_index = toposort_index ,
932
+ ancestors_bitset = ancestors_bitset ,
932
933
starting_nodes = starting_nodes ,
933
934
updated_nodes = fg .apply_nodes ,
934
935
)
0 commit comments