Skip to content

Commit 4a24466

Browse files
committed
Copy on write in FusionOptimizer
1 parent f3ba7bc commit 4a24466

File tree

1 file changed

+61
-21
lines changed

1 file changed

+61
-21
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import itertools
33
import operator
44
import sys
5+
import typing
56
from collections import defaultdict, deque
67
from collections.abc import Generator, Sequence
78
from functools import cache, reduce
@@ -520,6 +521,43 @@ def elemwise_max_operands_fct(node) -> int:
520521
return 1024
521522

522523

524+
class CopyOnWriteDictOfSets:
525+
__slots__ = ("d", "d_copy")
526+
527+
def __init__(self, d: dict[typing.Any, set]):
528+
self.d = d
529+
self.d_copy: dict[typing.Any, set] = {}
530+
531+
def __getitem__(self, key):
532+
try:
533+
return self.d_copy[key]
534+
except KeyError:
535+
return self.d[key]
536+
537+
def get(self, key, default=frozenset()):
538+
try:
539+
return self.d_copy[key]
540+
except KeyError:
541+
try:
542+
return self.d[key]
543+
except KeyError:
544+
return default
545+
546+
def remove_from_key(self, key, value):
547+
try:
548+
self.d_copy[key].remove(value)
549+
except KeyError:
550+
self.d_copy[key] = copied_value = self.d[key].copy()
551+
copied_value.remove(value)
552+
553+
def add_to_key(self, key, value):
554+
try:
555+
self.d_copy[key].add(value)
556+
except KeyError:
557+
self.d_copy[key] = copied_value = self.d[key].copy()
558+
copied_value.add(value)
559+
560+
523561
class FusionOptimizer(GraphRewriter):
524562
"""Graph optimizer that fuses consecutive Elemwise operations."""
525563

@@ -646,15 +684,10 @@ def variables_depend_on(
646684
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
647685
unfuseable_clients_subgraph: set[Variable] = set()
648686

649-
# Shallow cloning of maps so that they can be manipulated in place
650-
fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set)
651-
fuseable_clients_clone.update(
652-
{k: v.copy() for k, v in fuseable_clients.items()}
653-
)
654-
unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set)
655-
unfuseable_clients_clone.update(
656-
{k: v.copy() for k, v in unfuseable_clients.items()}
657-
)
687+
# If we need to manipulate the maps in place, we'll do a shallow copy later
688+
# For now we query on the original ones
689+
fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients)
690+
unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients)
658691

659692
# We now try to expand as much as possible towards the potentially
660693
# fuseable clients and ancestors to detect the largest possible
@@ -684,7 +717,7 @@ def variables_depend_on(
684717
required_unfuseable_inputs = [
685718
inp
686719
for inp in next_node.inputs
687-
if next_node in unfuseable_clients_clone.get(inp, ())
720+
if next_node in unfuseable_clients_clone.get(inp)
688721
]
689722
new_required_unfuseable_inputs = [
690723
inp
@@ -707,7 +740,7 @@ def variables_depend_on(
707740
if not must_backtrack:
708741
implied_unfuseable_clients = {
709742
c
710-
for client in unfuseable_clients_clone.get(next_out, ())
743+
for client in unfuseable_clients_clone.get(next_out)
711744
if not isinstance(client.op, Output)
712745
for c in client.outputs
713746
}
@@ -728,13 +761,15 @@ def variables_depend_on(
728761

729762
if must_backtrack:
730763
for inp in next_node.inputs:
731-
if (
732-
inp.owner in visited_nodes
733-
# next_node could have the same input repeated
734-
and next_node in fuseable_clients_clone[inp]
735-
):
736-
fuseable_clients_clone[inp].remove(next_node)
737-
unfuseable_clients_clone[inp].add(next_node)
764+
if inp.owner in visited_nodes:
765+
if next_node not in fuseable_clients_clone[inp]:
766+
# This can happen when next node has repeated inputs
767+
continue
768+
fuseable_clients_clone.remove_from_key(
769+
inp, next_node
770+
)
771+
unfuseable_clients_clone.add_to_key(inp, next_node)
772+
738773
# This input must become an output of the subgraph,
739774
# because it can't be merged with next_node.
740775
# We will revisit it to make sure this is safe.
@@ -743,8 +778,13 @@ def variables_depend_on(
743778
# need to convert to tuple not to change set size during iteration
744779
for client in tuple(fuseable_clients_clone[next_out]):
745780
if client in visited_nodes:
746-
fuseable_clients_clone[next_out].remove(client)
747-
unfuseable_clients_clone[next_out].add(client)
781+
fuseable_clients_clone.remove_from_key(
782+
next_out, client
783+
)
784+
unfuseable_clients_clone.add_to_key(
785+
next_out, client
786+
)
787+
748788
# next_out must become an input of the subgraph.
749789
# We will revisit any of its clients currently
750790
# in the subgraph to make sure this is safe.
@@ -787,7 +827,7 @@ def variables_depend_on(
787827
sorted(
788828
(
789829
node
790-
for node in fuseable_clients_clone.get(next_out, ())
830+
for node in fuseable_clients_clone.get(next_out)
791831
if node not in visited_nodes
792832
),
793833
key=toposort_index.get, # type: ignore[arg-type]

0 commit comments

Comments
 (0)