2
2
import itertools
3
3
import operator
4
4
import sys
5
+ import typing
5
6
from collections import defaultdict , deque
6
7
from collections .abc import Generator , Sequence
7
8
from functools import cache , reduce
@@ -520,6 +521,43 @@ def elemwise_max_operands_fct(node) -> int:
520
521
return 1024
521
522
522
523
524
+ class CopyOnWriteDictOfSets :
525
+ __slots__ = ("d" , "d_copy" )
526
+
527
+ def __init__ (self , d : dict [typing .Any , set ]):
528
+ self .d = d
529
+ self .d_copy : dict [typing .Any , set ] = {}
530
+
531
+ def __getitem__ (self , key ):
532
+ try :
533
+ return self .d_copy [key ]
534
+ except KeyError :
535
+ return self .d [key ]
536
+
537
+ def get (self , key , default = frozenset ()):
538
+ try :
539
+ return self .d_copy [key ]
540
+ except KeyError :
541
+ try :
542
+ return self .d [key ]
543
+ except KeyError :
544
+ return default
545
+
546
+ def remove_from_key (self , key , value ):
547
+ try :
548
+ self .d_copy [key ].remove (value )
549
+ except KeyError :
550
+ self .d_copy [key ] = copied_value = self .d [key ].copy ()
551
+ copied_value .remove (value )
552
+
553
+ def add_to_key (self , key , value ):
554
+ try :
555
+ self .d_copy [key ].add (value )
556
+ except KeyError :
557
+ self .d_copy [key ] = copied_value = self .d [key ].copy ()
558
+ copied_value .add (value )
559
+
560
+
523
561
class FusionOptimizer (GraphRewriter ):
524
562
"""Graph optimizer that fuses consecutive Elemwise operations."""
525
563
@@ -646,15 +684,10 @@ def variables_depend_on(
646
684
subgraph_outputs : dict [Variable , Literal [None ]] = {} # ordered set
647
685
unfuseable_clients_subgraph : set [Variable ] = set ()
648
686
649
- # Shallow cloning of maps so that they can be manipulated in place
650
- fuseable_clients_clone : FUSEABLE_MAPPING = defaultdict (set )
651
- fuseable_clients_clone .update (
652
- {k : v .copy () for k , v in fuseable_clients .items ()}
653
- )
654
- unfuseable_clients_clone : UNFUSEABLE_MAPPING = defaultdict (set )
655
- unfuseable_clients_clone .update (
656
- {k : v .copy () for k , v in unfuseable_clients .items ()}
657
- )
687
+ # If we need to manipulate the maps in place, we'll do a shallow copy later
688
+ # For now we query on the original ones
689
+ fuseable_clients_clone = CopyOnWriteDictOfSets (fuseable_clients )
690
+ unfuseable_clients_clone = CopyOnWriteDictOfSets (unfuseable_clients )
658
691
659
692
# We now try to expand as much as possible towards the potentially
660
693
# fuseable clients and ancestors to detect the largest possible
@@ -684,7 +717,7 @@ def variables_depend_on(
684
717
required_unfuseable_inputs = [
685
718
inp
686
719
for inp in next_node .inputs
687
- if next_node in unfuseable_clients_clone .get (inp , () )
720
+ if next_node in unfuseable_clients_clone .get (inp )
688
721
]
689
722
new_required_unfuseable_inputs = [
690
723
inp
@@ -707,7 +740,7 @@ def variables_depend_on(
707
740
if not must_backtrack :
708
741
implied_unfuseable_clients = {
709
742
c
710
- for client in unfuseable_clients_clone .get (next_out , () )
743
+ for client in unfuseable_clients_clone .get (next_out )
711
744
if not isinstance (client .op , Output )
712
745
for c in client .outputs
713
746
}
@@ -728,13 +761,15 @@ def variables_depend_on(
728
761
729
762
if must_backtrack :
730
763
for inp in next_node .inputs :
731
- if (
732
- inp .owner in visited_nodes
733
- # next_node could have the same input repeated
734
- and next_node in fuseable_clients_clone [inp ]
735
- ):
736
- fuseable_clients_clone [inp ].remove (next_node )
737
- unfuseable_clients_clone [inp ].add (next_node )
764
+ if inp .owner in visited_nodes :
765
+ if next_node not in fuseable_clients_clone [inp ]:
766
+ # This can happen when next node has repeated inputs
767
+ continue
768
+ fuseable_clients_clone .remove_from_key (
769
+ inp , next_node
770
+ )
771
+ unfuseable_clients_clone .add_to_key (inp , next_node )
772
+
738
773
# This input must become an output of the subgraph,
739
774
# because it can't be merged with next_node.
740
775
# We will revisit it to make sure this is safe.
@@ -743,8 +778,13 @@ def variables_depend_on(
743
778
# need to convert to tuple not to change set size during iteration
744
779
for client in tuple (fuseable_clients_clone [next_out ]):
745
780
if client in visited_nodes :
746
- fuseable_clients_clone [next_out ].remove (client )
747
- unfuseable_clients_clone [next_out ].add (client )
781
+ fuseable_clients_clone .remove_from_key (
782
+ next_out , client
783
+ )
784
+ unfuseable_clients_clone .add_to_key (
785
+ next_out , client
786
+ )
787
+
748
788
# next_out must become an input of the subgraph.
749
789
# We will revisit any of its clients currently
750
790
# in the subgraph to make sure this is safe.
@@ -787,7 +827,7 @@ def variables_depend_on(
787
827
sorted (
788
828
(
789
829
node
790
- for node in fuseable_clients_clone .get (next_out , () )
830
+ for node in fuseable_clients_clone .get (next_out )
791
831
if node not in visited_nodes
792
832
),
793
833
key = toposort_index .get , # type: ignore[arg-type]
0 commit comments