Skip to content

Commit f768c33

Browse files
committed
Copy on write in FusionOptimizer
1 parent adc74fe commit f768c33

File tree

1 file changed

+61
-21
lines changed

1 file changed

+61
-21
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import itertools
33
import operator
44
import sys
5+
import typing
56
from collections import defaultdict, deque
67
from collections.abc import Generator, Sequence
78
from functools import cache, reduce
@@ -522,6 +523,43 @@ def elemwise_max_operands_fct(node) -> int:
522523
return 1024
523524

524525

526+
class CopyOnWriteDictOfSets:
527+
__slots__ = ("d", "d_copy")
528+
529+
def __init__(self, d: dict[typing.Any, set]):
530+
self.d = d
531+
self.d_copy: dict[typing.Any, set] = {}
532+
533+
def __getitem__(self, key):
534+
try:
535+
return self.d_copy[key]
536+
except KeyError:
537+
return self.d[key]
538+
539+
def get(self, key, default=frozenset()):
540+
try:
541+
return self.d_copy[key]
542+
except KeyError:
543+
try:
544+
return self.d[key]
545+
except KeyError:
546+
return default
547+
548+
def remove_from_key(self, key, value):
549+
try:
550+
self.d_copy[key].remove(value)
551+
except KeyError:
552+
self.d_copy[key] = copied_value = self.d[key].copy()
553+
copied_value.remove(value)
554+
555+
def add_to_key(self, key, value):
556+
try:
557+
self.d_copy[key].add(value)
558+
except KeyError:
559+
self.d_copy[key] = copied_value = self.d[key].copy()
560+
copied_value.add(value)
561+
562+
525563
class FusionOptimizer(GraphRewriter):
526564
"""Graph optimizer that fuses consecutive Elemwise operations."""
527565

@@ -644,15 +682,10 @@ def variables_depend_on(
644682
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
645683
unfuseable_clients_subgraph: set[Variable] = set()
646684

647-
# Shallow cloning of maps so that they can be manipulated in place
648-
fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set)
649-
fuseable_clients_clone.update(
650-
{k: v.copy() for k, v in fuseable_clients.items()}
651-
)
652-
unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set)
653-
unfuseable_clients_clone.update(
654-
{k: v.copy() for k, v in unfuseable_clients.items()}
655-
)
685+
# If we need to manipulate the maps in place, we'll do a shallow copy later
686+
# For now we query on the original ones
687+
fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients)
688+
unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients)
656689

657690
# We now try to expand as much as possible towards the potentially
658691
# fuseable clients and ancestors to detect the largest possible
@@ -682,7 +715,7 @@ def variables_depend_on(
682715
required_unfuseable_inputs = [
683716
inp
684717
for inp in next_node.inputs
685-
if next_node in unfuseable_clients_clone.get(inp, ())
718+
if next_node in unfuseable_clients_clone.get(inp)
686719
]
687720
new_required_unfuseable_inputs = [
688721
inp
@@ -705,7 +738,7 @@ def variables_depend_on(
705738
if not must_backtrack:
706739
implied_unfuseable_clients = {
707740
c
708-
for client in unfuseable_clients_clone.get(next_out, ())
741+
for client in unfuseable_clients_clone.get(next_out)
709742
if not isinstance(client.op, Output)
710743
for c in client.outputs
711744
}
@@ -726,13 +759,15 @@ def variables_depend_on(
726759

727760
if must_backtrack:
728761
for inp in next_node.inputs:
729-
if (
730-
inp.owner in visited_nodes
731-
# next_node could have the same input repeated
732-
and next_node in fuseable_clients_clone[inp]
733-
):
734-
fuseable_clients_clone[inp].remove(next_node)
735-
unfuseable_clients_clone[inp].add(next_node)
762+
if inp.owner in visited_nodes:
763+
if next_node not in fuseable_clients_clone[inp]:
764+
# This can happen when next node has repeated inputs
765+
continue
766+
fuseable_clients_clone.remove_from_key(
767+
inp, next_node
768+
)
769+
unfuseable_clients_clone.add_to_key(inp, next_node)
770+
736771
# This input must become an output of the subgraph,
737772
# because it can't be merged with next_node.
738773
# We will revisit it to make sure this is safe.
@@ -741,8 +776,13 @@ def variables_depend_on(
741776
# need to convert to tuple not to change set size during iteration
742777
for client in tuple(fuseable_clients_clone[next_out]):
743778
if client in visited_nodes:
744-
fuseable_clients_clone[next_out].remove(client)
745-
unfuseable_clients_clone[next_out].add(client)
779+
fuseable_clients_clone.remove_from_key(
780+
next_out, client
781+
)
782+
unfuseable_clients_clone.add_to_key(
783+
next_out, client
784+
)
785+
746786
# next_out must become an input of the subgraph.
747787
# We will revisit any of its clients currently
748788
# in the subgraph to make sure this is safe.
@@ -785,7 +825,7 @@ def variables_depend_on(
785825
sorted(
786826
(
787827
node
788-
for node in fuseable_clients_clone.get(next_out, ())
828+
for node in fuseable_clients_clone.get(next_out)
789829
if node not in visited_nodes
790830
),
791831
key=toposort_index.get, # type: ignore[arg-type]

0 commit comments

Comments
 (0)