Skip to content

Commit e2384e8

Browse files
committed
Use bitset to check ancestors more efficiently
1 parent 4a24466 commit e2384e8

File tree

2 files changed

+77
-76
lines changed

2 files changed

+77
-76
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 70 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections import defaultdict, deque
77
from collections.abc import Generator, Sequence
88
from functools import cache, reduce
9+
from operator import or_
910
from typing import Literal
1011
from warnings import warn
1112

@@ -15,7 +16,7 @@
1516
from pytensor.compile.mode import get_target_language
1617
from pytensor.configdefaults import config
1718
from pytensor.graph import FunctionGraph, Op
18-
from pytensor.graph.basic import Apply, Variable, ancestors, io_toposort
19+
from pytensor.graph.basic import Apply, Variable, io_toposort
1920
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
2021
from pytensor.graph.features import ReplaceValidate
2122
from pytensor.graph.fg import Output
@@ -661,16 +662,9 @@ def find_fuseable_subgraph(
661662
visited_nodes: set[Apply],
662663
fuseable_clients: FUSEABLE_MAPPING,
663664
unfuseable_clients: UNFUSEABLE_MAPPING,
665+
ancestors_bitset: dict[Apply, int],
664666
toposort_index: dict[Apply, int],
665667
) -> tuple[list[Variable], list[Variable]]:
666-
def variables_depend_on(
667-
variables, depend_on, stop_search_at=None
668-
) -> bool:
669-
return any(
670-
a in depend_on
671-
for a in ancestors(variables, blockers=stop_search_at)
672-
)
673-
674668
for starting_node in toposort_index:
675669
if starting_node in visited_nodes:
676670
continue
@@ -682,7 +676,8 @@ def variables_depend_on(
682676

683677
subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set
684678
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
685-
unfuseable_clients_subgraph: set[Variable] = set()
679+
subgraph_inputs_ancestors_bitset = 0
680+
unfuseable_clients_subgraph_bitset = 0
686681

687682
# If we need to manipulate the maps in place, we'll do a shallow copy later
688683
# For now we query on the original ones
@@ -714,50 +709,32 @@ def variables_depend_on(
714709
if must_become_output:
715710
subgraph_outputs.pop(next_out, None)
716711

717-
required_unfuseable_inputs = [
718-
inp
719-
for inp in next_node.inputs
720-
if next_node in unfuseable_clients_clone.get(inp)
721-
]
722-
new_required_unfuseable_inputs = [
723-
inp
724-
for inp in required_unfuseable_inputs
725-
if inp not in subgraph_inputs
726-
]
727-
728-
must_backtrack = False
729-
if new_required_unfuseable_inputs and subgraph_outputs:
730-
# We need to check that any new inputs required by this node
731-
# do not depend on other outputs of the current subgraph,
732-
# via an unfuseable path.
733-
if variables_depend_on(
734-
[next_out],
735-
depend_on=unfuseable_clients_subgraph,
736-
stop_search_at=subgraph_outputs,
737-
):
738-
must_backtrack = True
712+
# We need to check that any inputs required by this node
713+
# do not depend on other outputs of the current subgraph,
714+
# via an unfuseable path.
715+
must_backtrack = (
716+
ancestors_bitset[next_node]
717+
& unfuseable_clients_subgraph_bitset
718+
)
739719

740720
if not must_backtrack:
741-
implied_unfuseable_clients = {
742-
c
743-
for client in unfuseable_clients_clone.get(next_out)
744-
if not isinstance(client.op, Output)
745-
for c in client.outputs
746-
}
747-
748-
new_implied_unfuseable_clients = (
749-
implied_unfuseable_clients - unfuseable_clients_subgraph
721+
implied_unfuseable_clients_bitset = reduce(
722+
or_,
723+
(
724+
1 << toposort_index[client]
725+
for client in unfuseable_clients_clone.get(next_out)
726+
if not isinstance(client.op, Output)
727+
),
728+
0,
750729
)
751730

752-
if new_implied_unfuseable_clients and subgraph_inputs:
753-
# We need to check that any inputs of the current subgraph
754-
# do not depend on other clients of this node,
755-
# via an unfuseable path.
756-
if variables_depend_on(
757-
subgraph_inputs,
758-
depend_on=new_implied_unfuseable_clients,
759-
):
760-
must_backtrack = True
731+
# We need to check that any inputs of the current subgraph
732+
# do not depend on other clients of this node,
733+
# via an unfuseable path.
734+
must_backtrack = (
735+
subgraph_inputs_ancestors_bitset
736+
& implied_unfuseable_clients_bitset
737+
)
761738

762739
if must_backtrack:
763740
for inp in next_node.inputs:
@@ -798,29 +775,24 @@ def variables_depend_on(
798775
# immediate dependency problems. Update subgraph
799776
# mappings as if it next_node was part of it.
800777
# Useless inputs will be removed by the useless Composite rewrite
801-
for inp in new_required_unfuseable_inputs:
802-
subgraph_inputs[inp] = None
803-
804778
if must_become_output:
805779
subgraph_outputs[next_out] = None
806-
unfuseable_clients_subgraph.update(
807-
new_implied_unfuseable_clients
780+
unfuseable_clients_subgraph_bitset |= (
781+
implied_unfuseable_clients_bitset
808782
)
809783

810-
# Expand through unvisited fuseable ancestors
811-
fuseable_nodes_to_visit.extendleft(
812-
sorted(
813-
(
814-
inp.owner
815-
for inp in next_node.inputs
816-
if (
817-
inp not in required_unfuseable_inputs
818-
and inp.owner not in visited_nodes
819-
)
820-
),
821-
key=toposort_index.get, # type: ignore[arg-type]
822-
)
823-
)
784+
for inp in sorted(
785+
next_node.inputs,
786+
key=lambda x: toposort_index.get(x.owner, -1),
787+
):
788+
if next_node in unfuseable_clients_clone.get(inp, ()):
789+
# input must become an input of the subgraph since it's unfuseable with new node
790+
subgraph_inputs_ancestors_bitset |= (
791+
ancestors_bitset.get(inp.owner, 0)
792+
)
793+
subgraph_inputs[inp] = None
794+
elif inp.owner not in visited_nodes:
795+
fuseable_nodes_to_visit.appendleft(inp.owner)
824796

825797
# Expand through unvisited fuseable clients
826798
fuseable_nodes_to_visit.extend(
@@ -857,6 +829,8 @@ def update_fuseable_mappings_after_fg_replace(
857829
visited_nodes: set[Apply],
858830
fuseable_clients: FUSEABLE_MAPPING,
859831
unfuseable_clients: UNFUSEABLE_MAPPING,
832+
toposort_index: dict[Apply, int],
833+
ancestors_bitset: dict[Apply, int],
860834
starting_nodes: set[Apply],
861835
updated_nodes: set[Apply],
862836
) -> None:
@@ -867,11 +841,25 @@ def update_fuseable_mappings_after_fg_replace(
867841
dropped_nodes = starting_nodes - updated_nodes
868842

869843
# Remove intermediate Composite nodes from mappings
844+
# And compute the ancestors bitset of the new composite node
845+
# As well as the new toposort index for the new node
846+
new_node_ancestor_bitset = 0
847+
new_node_toposort_index = len(toposort_index)
870848
for dropped_node in dropped_nodes:
871849
(dropped_out,) = dropped_node.outputs
872850
fuseable_clients.pop(dropped_out, None)
873851
unfuseable_clients.pop(dropped_out, None)
874852
visited_nodes.remove(dropped_node)
853+
# The new composite ancestor bitset is the union
854+
# of the ancestors of all the dropped nodes
855+
new_node_ancestor_bitset |= ancestors_bitset[dropped_node]
856+
# The new composite node can have the same order as the latest node that was absorbed into it
857+
new_node_toposort_index = max(
858+
new_node_toposort_index, toposort_index[dropped_node]
859+
)
860+
861+
ancestors_bitset[new_composite_node] = new_node_ancestor_bitset
862+
toposort_index[new_composite_node] = new_node_toposort_index
875863

876864
# Update fuseable information for subgraph inputs
877865
for inp in subgraph_inputs:
@@ -903,12 +891,23 @@ def update_fuseable_mappings_after_fg_replace(
903891
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
904892
visited_nodes: set[Apply] = set()
905893
toposort_index = {node: i for i, node in enumerate(fgraph.toposort())}
894+
# Create a bitset for each node of all its ancestors
895+
# This allows to quickly check if a variable depends on a set
896+
ancestors_bitset: dict[Apply, int] = {}
897+
for node, index in toposort_index.items():
898+
node_ancestor_bitset = 1 << index
899+
for inp in node.inputs:
900+
if (inp_node := inp.owner) is not None:
901+
node_ancestor_bitset |= ancestors_bitset[inp_node]
902+
ancestors_bitset[node] = node_ancestor_bitset
903+
906904
while True:
907905
try:
908906
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
909907
visited_nodes=visited_nodes,
910908
fuseable_clients=fuseable_clients,
911909
unfuseable_clients=unfuseable_clients,
910+
ancestors_bitset=ancestors_bitset,
912911
toposort_index=toposort_index,
913912
)
914913
except ValueError:
@@ -927,6 +926,8 @@ def update_fuseable_mappings_after_fg_replace(
927926
visited_nodes=visited_nodes,
928927
fuseable_clients=fuseable_clients,
929928
unfuseable_clients=unfuseable_clients,
929+
toposort_index=toposort_index,
930+
ancestors_bitset=ancestors_bitset,
930931
starting_nodes=starting_nodes,
931932
updated_nodes=fg.apply_nodes,
932933
)

tests/test_printing.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,8 @@ def test_debugprint():
301301
Gemv_op_name = "CGemv" if pytensor.config.blas__ldflags else "Gemv"
302302
exp_res = dedent(
303303
r"""
304-
Composite{(i2 + (i0 - i1))} 4
304+
Composite{(i0 + (i1 - i2))} 4
305+
├─ A
305306
├─ ExpandDims{axis=0} v={0: [0]} 3
306307
"""
307308
f" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2"
@@ -313,17 +314,16 @@ def test_debugprint():
313314
│ ├─ B
314315
│ ├─ <Vector(float64, shape=(?,))>
315316
│ └─ 0.0
316-
├─ D
317-
└─ A
317+
└─ D
318318
319319
Inner graphs:
320320
321-
Composite{(i2 + (i0 - i1))}
321+
Composite{(i0 + (i1 - i2))}
322322
← add 'o0'
323-
├─ i2
324-
└─ sub
325323
├─ i0
326-
└─ i1
324+
└─ sub
325+
├─ i1
326+
└─ i2
327327
"""
328328
).lstrip()
329329

0 commit comments

Comments
 (0)