6
6
from collections import defaultdict , deque
7
7
from collections .abc import Generator , Sequence
8
8
from functools import cache , reduce
9
+ from operator import or_
9
10
from typing import Literal
10
11
from warnings import warn
11
12
29
30
)
30
31
from pytensor .graph .rewriting .db import SequenceDB
31
32
from pytensor .graph .rewriting .unify import OpPattern
32
- from pytensor .graph .traversal import ancestors , toposort
33
+ from pytensor .graph .traversal import toposort
33
34
from pytensor .graph .utils import InconsistencyError , MethodNotDefined
34
35
from pytensor .scalar .math import Grad2F1Loop , _grad_2f1_loop
35
36
from pytensor .tensor .basic import (
@@ -659,16 +660,9 @@ def find_fuseable_subgraph(
659
660
visited_nodes : set [Apply ],
660
661
fuseable_clients : FUSEABLE_MAPPING ,
661
662
unfuseable_clients : UNFUSEABLE_MAPPING ,
663
+ ancestors_bitset : dict [Apply , int ],
662
664
toposort_index : dict [Apply , int ],
663
665
) -> tuple [list [Variable ], list [Variable ]]:
664
- def variables_depend_on (
665
- variables , depend_on , stop_search_at = None
666
- ) -> bool :
667
- return any (
668
- a in depend_on
669
- for a in ancestors (variables , blockers = stop_search_at )
670
- )
671
-
672
666
for starting_node in toposort_index :
673
667
if starting_node in visited_nodes :
674
668
continue
@@ -680,7 +674,8 @@ def variables_depend_on(
680
674
681
675
subgraph_inputs : dict [Variable , Literal [None ]] = {} # ordered set
682
676
subgraph_outputs : dict [Variable , Literal [None ]] = {} # ordered set
683
- unfuseable_clients_subgraph : set [Variable ] = set ()
677
+ subgraph_inputs_ancestors_bitset = 0
678
+ unfuseable_clients_subgraph_bitset = 0
684
679
685
680
# If we need to manipulate the maps in place, we'll do a shallow copy later
686
681
# For now we query on the original ones
@@ -712,50 +707,32 @@ def variables_depend_on(
712
707
if must_become_output :
713
708
subgraph_outputs .pop (next_out , None )
714
709
715
- required_unfuseable_inputs = [
716
- inp
717
- for inp in next_node .inputs
718
- if next_node in unfuseable_clients_clone .get (inp )
719
- ]
720
- new_required_unfuseable_inputs = [
721
- inp
722
- for inp in required_unfuseable_inputs
723
- if inp not in subgraph_inputs
724
- ]
725
-
726
- must_backtrack = False
727
- if new_required_unfuseable_inputs and subgraph_outputs :
728
- # We need to check that any new inputs required by this node
729
- # do not depend on other outputs of the current subgraph,
730
- # via an unfuseable path.
731
- if variables_depend_on (
732
- [next_out ],
733
- depend_on = unfuseable_clients_subgraph ,
734
- stop_search_at = subgraph_outputs ,
735
- ):
736
- must_backtrack = True
710
+ # We need to check that any inputs required by this node
711
+ # do not depend on other outputs of the current subgraph,
712
+ # via an unfuseable path.
713
+ must_backtrack = (
714
+ ancestors_bitset [next_node ]
715
+ & unfuseable_clients_subgraph_bitset
716
+ )
737
717
738
718
if not must_backtrack :
739
- implied_unfuseable_clients = {
740
- c
741
- for client in unfuseable_clients_clone .get (next_out )
742
- if not isinstance (client .op , Output )
743
- for c in client .outputs
744
- }
745
-
746
- new_implied_unfuseable_clients = (
747
- implied_unfuseable_clients - unfuseable_clients_subgraph
719
+ implied_unfuseable_clients_bitset = reduce (
720
+ or_ ,
721
+ (
722
+ 1 << toposort_index [client ]
723
+ for client in unfuseable_clients_clone .get (next_out )
724
+ if not isinstance (client .op , Output )
725
+ ),
726
+ 0 ,
748
727
)
749
728
750
- if new_implied_unfuseable_clients and subgraph_inputs :
751
- # We need to check that any inputs of the current subgraph
752
- # do not depend on other clients of this node,
753
- # via an unfuseable path.
754
- if variables_depend_on (
755
- subgraph_inputs ,
756
- depend_on = new_implied_unfuseable_clients ,
757
- ):
758
- must_backtrack = True
729
+ # We need to check that any inputs of the current subgraph
730
+ # do not depend on other clients of this node,
731
+ # via an unfuseable path.
732
+ must_backtrack = (
733
+ subgraph_inputs_ancestors_bitset
734
+ & implied_unfuseable_clients_bitset
735
+ )
759
736
760
737
if must_backtrack :
761
738
for inp in next_node .inputs :
@@ -796,29 +773,24 @@ def variables_depend_on(
796
773
# immediate dependency problems. Update subgraph
797
774
# mappings as if it next_node was part of it.
798
775
# Useless inputs will be removed by the useless Composite rewrite
799
- for inp in new_required_unfuseable_inputs :
800
- subgraph_inputs [inp ] = None
801
-
802
776
if must_become_output :
803
777
subgraph_outputs [next_out ] = None
804
- unfuseable_clients_subgraph . update (
805
- new_implied_unfuseable_clients
778
+ unfuseable_clients_subgraph_bitset |= (
779
+ implied_unfuseable_clients_bitset
806
780
)
807
781
808
- # Expand through unvisited fuseable ancestors
809
- fuseable_nodes_to_visit .extendleft (
810
- sorted (
811
- (
812
- inp .owner
813
- for inp in next_node .inputs
814
- if (
815
- inp not in required_unfuseable_inputs
816
- and inp .owner not in visited_nodes
817
- )
818
- ),
819
- key = toposort_index .get , # type: ignore[arg-type]
820
- )
821
- )
782
+ for inp in sorted (
783
+ next_node .inputs ,
784
+ key = lambda x : toposort_index .get (x .owner , - 1 ),
785
+ ):
786
+ if next_node in unfuseable_clients_clone .get (inp , ()):
787
+ # input must become an input of the subgraph since it's unfuseable with new node
788
+ subgraph_inputs_ancestors_bitset |= (
789
+ ancestors_bitset .get (inp .owner , 0 )
790
+ )
791
+ subgraph_inputs [inp ] = None
792
+ elif inp .owner not in visited_nodes :
793
+ fuseable_nodes_to_visit .appendleft (inp .owner )
822
794
823
795
# Expand through unvisited fuseable clients
824
796
fuseable_nodes_to_visit .extend (
@@ -855,6 +827,8 @@ def update_fuseable_mappings_after_fg_replace(
855
827
visited_nodes : set [Apply ],
856
828
fuseable_clients : FUSEABLE_MAPPING ,
857
829
unfuseable_clients : UNFUSEABLE_MAPPING ,
830
+ toposort_index : dict [Apply , int ],
831
+ ancestors_bitset : dict [Apply , int ],
858
832
starting_nodes : set [Apply ],
859
833
updated_nodes : set [Apply ],
860
834
) -> None :
@@ -865,11 +839,25 @@ def update_fuseable_mappings_after_fg_replace(
865
839
dropped_nodes = starting_nodes - updated_nodes
866
840
867
841
# Remove intermediate Composite nodes from mappings
842
+ # And compute the ancestors bitset of the new composite node
843
+ # As well as the new toposort index for the new node
844
+ new_node_ancestor_bitset = 0
845
+ new_node_toposort_index = len (toposort_index )
868
846
for dropped_node in dropped_nodes :
869
847
(dropped_out ,) = dropped_node .outputs
870
848
fuseable_clients .pop (dropped_out , None )
871
849
unfuseable_clients .pop (dropped_out , None )
872
850
visited_nodes .remove (dropped_node )
851
+ # The new composite ancestor bitset is the union
852
+ # of the ancestors of all the dropped nodes
853
+ new_node_ancestor_bitset |= ancestors_bitset [dropped_node ]
854
+ # The new composite node can have the same order as the latest node that was absorbed into it
855
+ new_node_toposort_index = max (
856
+ new_node_toposort_index , toposort_index [dropped_node ]
857
+ )
858
+
859
+ ancestors_bitset [new_composite_node ] = new_node_ancestor_bitset
860
+ toposort_index [new_composite_node ] = new_node_toposort_index
873
861
874
862
# Update fuseable information for subgraph inputs
875
863
for inp in subgraph_inputs :
@@ -901,12 +889,23 @@ def update_fuseable_mappings_after_fg_replace(
901
889
fuseable_clients , unfuseable_clients = initialize_fuseable_mappings (fg = fg )
902
890
visited_nodes : set [Apply ] = set ()
903
891
toposort_index = {node : i for i , node in enumerate (fgraph .toposort ())}
892
+ # Create a bitset for each node of all its ancestors
893
+ # This allows to quickly check if a variable depends on a set
894
+ ancestors_bitset : dict [Apply , int ] = {}
895
+ for node , index in toposort_index .items ():
896
+ node_ancestor_bitset = 1 << index
897
+ for inp in node .inputs :
898
+ if (inp_node := inp .owner ) is not None :
899
+ node_ancestor_bitset |= ancestors_bitset [inp_node ]
900
+ ancestors_bitset [node ] = node_ancestor_bitset
901
+
904
902
while True :
905
903
try :
906
904
subgraph_inputs , subgraph_outputs = find_fuseable_subgraph (
907
905
visited_nodes = visited_nodes ,
908
906
fuseable_clients = fuseable_clients ,
909
907
unfuseable_clients = unfuseable_clients ,
908
+ ancestors_bitset = ancestors_bitset ,
910
909
toposort_index = toposort_index ,
911
910
)
912
911
except ValueError :
@@ -925,6 +924,8 @@ def update_fuseable_mappings_after_fg_replace(
925
924
visited_nodes = visited_nodes ,
926
925
fuseable_clients = fuseable_clients ,
927
926
unfuseable_clients = unfuseable_clients ,
927
+ toposort_index = toposort_index ,
928
+ ancestors_bitset = ancestors_bitset ,
928
929
starting_nodes = starting_nodes ,
929
930
updated_nodes = fg .apply_nodes ,
930
931
)
0 commit comments