Skip to content

Commit 514832e

Browse files
committed
Use bitset to check ancestors more efficiently
1 parent d73debf commit 514832e

File tree

2 files changed

+77
-76
lines changed

2 files changed

+77
-76
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 70 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections import defaultdict, deque
77
from collections.abc import Generator, Sequence
88
from functools import cache, reduce
9+
from operator import or_
910
from typing import Literal
1011
from warnings import warn
1112

@@ -29,7 +30,7 @@
2930
)
3031
from pytensor.graph.rewriting.db import SequenceDB
3132
from pytensor.graph.rewriting.unify import OpPattern
32-
from pytensor.graph.traversal import ancestors, toposort
33+
from pytensor.graph.traversal import toposort
3334
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
3435
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
3536
from pytensor.tensor.basic import (
@@ -663,16 +664,9 @@ def find_fuseable_subgraph(
663664
visited_nodes: set[Apply],
664665
fuseable_clients: FUSEABLE_MAPPING,
665666
unfuseable_clients: UNFUSEABLE_MAPPING,
667+
ancestors_bitset: dict[Apply, int],
666668
toposort_index: dict[Apply, int],
667669
) -> tuple[list[Variable], list[Variable]]:
668-
def variables_depend_on(
669-
variables, depend_on, stop_search_at=None
670-
) -> bool:
671-
return any(
672-
a in depend_on
673-
for a in ancestors(variables, blockers=stop_search_at)
674-
)
675-
676670
for starting_node in toposort_index:
677671
if starting_node in visited_nodes:
678672
continue
@@ -684,7 +678,8 @@ def variables_depend_on(
684678

685679
subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set
686680
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
687-
unfuseable_clients_subgraph: set[Variable] = set()
681+
subgraph_inputs_ancestors_bitset = 0
682+
unfuseable_clients_subgraph_bitset = 0
688683

689684
# If we need to manipulate the maps in place, we'll do a shallow copy later
690685
# For now we query on the original ones
@@ -716,50 +711,32 @@ def variables_depend_on(
716711
if must_become_output:
717712
subgraph_outputs.pop(next_out, None)
718713

719-
required_unfuseable_inputs = [
720-
inp
721-
for inp in next_node.inputs
722-
if next_node in unfuseable_clients_clone.get(inp)
723-
]
724-
new_required_unfuseable_inputs = [
725-
inp
726-
for inp in required_unfuseable_inputs
727-
if inp not in subgraph_inputs
728-
]
729-
730-
must_backtrack = False
731-
if new_required_unfuseable_inputs and subgraph_outputs:
732-
# We need to check that any new inputs required by this node
733-
# do not depend on other outputs of the current subgraph,
734-
# via an unfuseable path.
735-
if variables_depend_on(
736-
[next_out],
737-
depend_on=unfuseable_clients_subgraph,
738-
stop_search_at=subgraph_outputs,
739-
):
740-
must_backtrack = True
714+
# We need to check that any inputs required by this node
715+
# do not depend on other outputs of the current subgraph,
716+
# via an unfuseable path.
717+
must_backtrack = (
718+
ancestors_bitset[next_node]
719+
& unfuseable_clients_subgraph_bitset
720+
)
741721

742722
if not must_backtrack:
743-
implied_unfuseable_clients = {
744-
c
745-
for client in unfuseable_clients_clone.get(next_out)
746-
if not isinstance(client.op, Output)
747-
for c in client.outputs
748-
}
749-
750-
new_implied_unfuseable_clients = (
751-
implied_unfuseable_clients - unfuseable_clients_subgraph
723+
implied_unfuseable_clients_bitset = reduce(
724+
or_,
725+
(
726+
1 << toposort_index[client]
727+
for client in unfuseable_clients_clone.get(next_out)
728+
if not isinstance(client.op, Output)
729+
),
730+
0,
752731
)
753732

754-
if new_implied_unfuseable_clients and subgraph_inputs:
755-
# We need to check that any inputs of the current subgraph
756-
# do not depend on other clients of this node,
757-
# via an unfuseable path.
758-
if variables_depend_on(
759-
subgraph_inputs,
760-
depend_on=new_implied_unfuseable_clients,
761-
):
762-
must_backtrack = True
733+
# We need to check that any inputs of the current subgraph
734+
# do not depend on other clients of this node,
735+
# via an unfuseable path.
736+
must_backtrack = (
737+
subgraph_inputs_ancestors_bitset
738+
& implied_unfuseable_clients_bitset
739+
)
763740

764741
if must_backtrack:
765742
for inp in next_node.inputs:
@@ -800,29 +777,24 @@ def variables_depend_on(
800777
# immediate dependency problems. Update subgraph
801778
# mappings as if it next_node was part of it.
802779
# Useless inputs will be removed by the useless Composite rewrite
803-
for inp in new_required_unfuseable_inputs:
804-
subgraph_inputs[inp] = None
805-
806780
if must_become_output:
807781
subgraph_outputs[next_out] = None
808-
unfuseable_clients_subgraph.update(
809-
new_implied_unfuseable_clients
782+
unfuseable_clients_subgraph_bitset |= (
783+
implied_unfuseable_clients_bitset
810784
)
811785

812-
# Expand through unvisited fuseable ancestors
813-
fuseable_nodes_to_visit.extendleft(
814-
sorted(
815-
(
816-
inp.owner
817-
for inp in next_node.inputs
818-
if (
819-
inp not in required_unfuseable_inputs
820-
and inp.owner not in visited_nodes
821-
)
822-
),
823-
key=toposort_index.get, # type: ignore[arg-type]
824-
)
825-
)
786+
for inp in sorted(
787+
next_node.inputs,
788+
key=lambda x: toposort_index.get(x.owner, -1),
789+
):
790+
if next_node in unfuseable_clients_clone.get(inp, ()):
791+
# input must become an input of the subgraph since it's unfuseable with new node
792+
subgraph_inputs_ancestors_bitset |= (
793+
ancestors_bitset.get(inp.owner, 0)
794+
)
795+
subgraph_inputs[inp] = None
796+
elif inp.owner not in visited_nodes:
797+
fuseable_nodes_to_visit.appendleft(inp.owner)
826798

827799
# Expand through unvisited fuseable clients
828800
fuseable_nodes_to_visit.extend(
@@ -859,6 +831,8 @@ def update_fuseable_mappings_after_fg_replace(
859831
visited_nodes: set[Apply],
860832
fuseable_clients: FUSEABLE_MAPPING,
861833
unfuseable_clients: UNFUSEABLE_MAPPING,
834+
toposort_index: dict[Apply, int],
835+
ancestors_bitset: dict[Apply, int],
862836
starting_nodes: set[Apply],
863837
updated_nodes: set[Apply],
864838
) -> None:
@@ -869,11 +843,25 @@ def update_fuseable_mappings_after_fg_replace(
869843
dropped_nodes = starting_nodes - updated_nodes
870844

871845
# Remove intermediate Composite nodes from mappings
846+
# And compute the ancestors bitset of the new composite node
847+
# As well as the new toposort index for the new node
848+
new_node_ancestor_bitset = 0
849+
new_node_toposort_index = len(toposort_index)
872850
for dropped_node in dropped_nodes:
873851
(dropped_out,) = dropped_node.outputs
874852
fuseable_clients.pop(dropped_out, None)
875853
unfuseable_clients.pop(dropped_out, None)
876854
visited_nodes.remove(dropped_node)
855+
# The new composite ancestor bitset is the union
856+
# of the ancestors of all the dropped nodes
857+
new_node_ancestor_bitset |= ancestors_bitset[dropped_node]
858+
# The new composite node can have the same order as the latest node that was absorbed into it
859+
new_node_toposort_index = max(
860+
new_node_toposort_index, toposort_index[dropped_node]
861+
)
862+
863+
ancestors_bitset[new_composite_node] = new_node_ancestor_bitset
864+
toposort_index[new_composite_node] = new_node_toposort_index
877865

878866
# Update fuseable information for subgraph inputs
879867
for inp in subgraph_inputs:
@@ -905,12 +893,23 @@ def update_fuseable_mappings_after_fg_replace(
905893
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
906894
visited_nodes: set[Apply] = set()
907895
toposort_index = {node: i for i, node in enumerate(fgraph.toposort())}
896+
# Create a bitset for each node of all its ancestors
897+
# This allows to quickly check if a variable depends on a set
898+
ancestors_bitset: dict[Apply, int] = {}
899+
for node, index in toposort_index.items():
900+
node_ancestor_bitset = 1 << index
901+
for inp in node.inputs:
902+
if (inp_node := inp.owner) is not None:
903+
node_ancestor_bitset |= ancestors_bitset[inp_node]
904+
ancestors_bitset[node] = node_ancestor_bitset
905+
908906
while True:
909907
try:
910908
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
911909
visited_nodes=visited_nodes,
912910
fuseable_clients=fuseable_clients,
913911
unfuseable_clients=unfuseable_clients,
912+
ancestors_bitset=ancestors_bitset,
914913
toposort_index=toposort_index,
915914
)
916915
except ValueError:
@@ -929,6 +928,8 @@ def update_fuseable_mappings_after_fg_replace(
929928
visited_nodes=visited_nodes,
930929
fuseable_clients=fuseable_clients,
931930
unfuseable_clients=unfuseable_clients,
931+
toposort_index=toposort_index,
932+
ancestors_bitset=ancestors_bitset,
932933
starting_nodes=starting_nodes,
933934
updated_nodes=fg.apply_nodes,
934935
)

tests/test_printing.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,8 @@ def test_debugprint():
301301
Gemv_op_name = "CGemv" if pytensor.config.blas__ldflags else "Gemv"
302302
exp_res = dedent(
303303
r"""
304-
Composite{(i2 + (i0 - i1))} 4
304+
Composite{(i0 + (i1 - i2))} 4
305+
├─ A
305306
├─ ExpandDims{axis=0} v={0: [0]} 3
306307
"""
307308
f" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2"
@@ -313,17 +314,16 @@ def test_debugprint():
313314
│ ├─ B
314315
│ ├─ <Vector(float64, shape=(?,))>
315316
│ └─ 0.0
316-
├─ D
317-
└─ A
317+
└─ D
318318
319319
Inner graphs:
320320
321-
Composite{(i2 + (i0 - i1))}
321+
Composite{(i0 + (i1 - i2))}
322322
← add 'o0'
323-
├─ i2
324-
└─ sub
325323
├─ i0
326-
└─ i1
324+
└─ sub
325+
├─ i1
326+
└─ i2
327327
"""
328328
).lstrip()
329329

0 commit comments

Comments
 (0)