2
2
import itertools
3
3
import operator
4
4
import sys
5
+ import typing
5
6
from collections import defaultdict , deque
6
7
from collections .abc import Generator , Sequence
7
8
from functools import cache , reduce
@@ -522,6 +523,43 @@ def elemwise_max_operands_fct(node) -> int:
522
523
return 1024
523
524
524
525
526
+ class CopyOnWriteDictOfSets :
527
+ __slots__ = ("d" , "d_copy" )
528
+
529
+ def __init__ (self , d : dict [typing .Any , set ]):
530
+ self .d = d
531
+ self .d_copy : dict [typing .Any , set ] = {}
532
+
533
+ def __getitem__ (self , key ):
534
+ try :
535
+ return self .d_copy [key ]
536
+ except KeyError :
537
+ return self .d [key ]
538
+
539
+ def get (self , key , default = frozenset ()):
540
+ try :
541
+ return self .d_copy [key ]
542
+ except KeyError :
543
+ try :
544
+ return self .d [key ]
545
+ except KeyError :
546
+ return default
547
+
548
+ def remove_from_key (self , key , value ):
549
+ try :
550
+ self .d_copy [key ].remove (value )
551
+ except KeyError :
552
+ self .d_copy [key ] = copied_value = self .d [key ].copy ()
553
+ copied_value .remove (value )
554
+
555
+ def add_to_key (self , key , value ):
556
+ try :
557
+ self .d_copy [key ].add (value )
558
+ except KeyError :
559
+ self .d_copy [key ] = copied_value = self .d [key ].copy ()
560
+ copied_value .add (value )
561
+
562
+
525
563
class FusionOptimizer (GraphRewriter ):
526
564
"""Graph optimizer that fuses consecutive Elemwise operations."""
527
565
@@ -644,15 +682,10 @@ def variables_depend_on(
644
682
subgraph_outputs : dict [Variable , Literal [None ]] = {} # ordered set
645
683
unfuseable_clients_subgraph : set [Variable ] = set ()
646
684
647
- # Shallow cloning of maps so that they can be manipulated in place
648
- fuseable_clients_clone : FUSEABLE_MAPPING = defaultdict (set )
649
- fuseable_clients_clone .update (
650
- {k : v .copy () for k , v in fuseable_clients .items ()}
651
- )
652
- unfuseable_clients_clone : UNFUSEABLE_MAPPING = defaultdict (set )
653
- unfuseable_clients_clone .update (
654
- {k : v .copy () for k , v in unfuseable_clients .items ()}
655
- )
685
+ # If we need to manipulate the maps in place, we'll do a shallow copy later
686
+ # For now we query on the original ones
687
+ fuseable_clients_clone = CopyOnWriteDictOfSets (fuseable_clients )
688
+ unfuseable_clients_clone = CopyOnWriteDictOfSets (unfuseable_clients )
656
689
657
690
# We now try to expand as much as possible towards the potentially
658
691
# fuseable clients and ancestors to detect the largest possible
@@ -682,7 +715,7 @@ def variables_depend_on(
682
715
required_unfuseable_inputs = [
683
716
inp
684
717
for inp in next_node .inputs
685
- if next_node in unfuseable_clients_clone .get (inp , () )
718
+ if next_node in unfuseable_clients_clone .get (inp )
686
719
]
687
720
new_required_unfuseable_inputs = [
688
721
inp
@@ -705,7 +738,7 @@ def variables_depend_on(
705
738
if not must_backtrack :
706
739
implied_unfuseable_clients = {
707
740
c
708
- for client in unfuseable_clients_clone .get (next_out , () )
741
+ for client in unfuseable_clients_clone .get (next_out )
709
742
if not isinstance (client .op , Output )
710
743
for c in client .outputs
711
744
}
@@ -726,13 +759,15 @@ def variables_depend_on(
726
759
727
760
if must_backtrack :
728
761
for inp in next_node .inputs :
729
- if (
730
- inp .owner in visited_nodes
731
- # next_node could have the same input repeated
732
- and next_node in fuseable_clients_clone [inp ]
733
- ):
734
- fuseable_clients_clone [inp ].remove (next_node )
735
- unfuseable_clients_clone [inp ].add (next_node )
762
+ if inp .owner in visited_nodes :
763
+ if next_node not in fuseable_clients_clone [inp ]:
764
+ # This can happen when next node has repeated inputs
765
+ continue
766
+ fuseable_clients_clone .remove_from_key (
767
+ inp , next_node
768
+ )
769
+ unfuseable_clients_clone .add_to_key (inp , next_node )
770
+
736
771
# This input must become an output of the subgraph,
737
772
# because it can't be merged with next_node.
738
773
# We will revisit it to make sure this is safe.
@@ -741,8 +776,13 @@ def variables_depend_on(
741
776
# need to convert to tuple not to change set size during iteration
742
777
for client in tuple (fuseable_clients_clone [next_out ]):
743
778
if client in visited_nodes :
744
- fuseable_clients_clone [next_out ].remove (client )
745
- unfuseable_clients_clone [next_out ].add (client )
779
+ fuseable_clients_clone .remove_from_key (
780
+ next_out , client
781
+ )
782
+ unfuseable_clients_clone .add_to_key (
783
+ next_out , client
784
+ )
785
+
746
786
# next_out must become an input of the subgraph.
747
787
# We will revisit any of its clients currently
748
788
# in the subgraph to make sure this is safe.
@@ -785,7 +825,7 @@ def variables_depend_on(
785
825
sorted (
786
826
(
787
827
node
788
- for node in fuseable_clients_clone .get (next_out , () )
828
+ for node in fuseable_clients_clone .get (next_out )
789
829
if node not in visited_nodes
790
830
),
791
831
key = toposort_index .get , # type: ignore[arg-type]
0 commit comments