Skip to content

Commit 9b90343

Browse files
committed
Use bitset to check ancestors more efficiently
1 parent f768c33 commit 9b90343

File tree

2 files changed

+77
-76
lines changed

2 files changed

+77
-76
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 70 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections import defaultdict, deque
77
from collections.abc import Generator, Sequence
88
from functools import cache, reduce
9+
from operator import or_
910
from typing import Literal
1011
from warnings import warn
1112

@@ -29,7 +30,7 @@
2930
)
3031
from pytensor.graph.rewriting.db import SequenceDB
3132
from pytensor.graph.rewriting.unify import OpPattern
32-
from pytensor.graph.traversal import ancestors, toposort
33+
from pytensor.graph.traversal import toposort
3334
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
3435
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
3536
from pytensor.tensor.basic import (
@@ -659,16 +660,9 @@ def find_fuseable_subgraph(
659660
visited_nodes: set[Apply],
660661
fuseable_clients: FUSEABLE_MAPPING,
661662
unfuseable_clients: UNFUSEABLE_MAPPING,
663+
ancestors_bitset: dict[Apply, int],
662664
toposort_index: dict[Apply, int],
663665
) -> tuple[list[Variable], list[Variable]]:
664-
def variables_depend_on(
665-
variables, depend_on, stop_search_at=None
666-
) -> bool:
667-
return any(
668-
a in depend_on
669-
for a in ancestors(variables, blockers=stop_search_at)
670-
)
671-
672666
for starting_node in toposort_index:
673667
if starting_node in visited_nodes:
674668
continue
@@ -680,7 +674,8 @@ def variables_depend_on(
680674

681675
subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set
682676
subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set
683-
unfuseable_clients_subgraph: set[Variable] = set()
677+
subgraph_inputs_ancestors_bitset = 0
678+
unfuseable_clients_subgraph_bitset = 0
684679

685680
# If we need to manipulate the maps in place, we'll do a shallow copy later
686681
# For now we query on the original ones
@@ -712,50 +707,32 @@ def variables_depend_on(
712707
if must_become_output:
713708
subgraph_outputs.pop(next_out, None)
714709

715-
required_unfuseable_inputs = [
716-
inp
717-
for inp in next_node.inputs
718-
if next_node in unfuseable_clients_clone.get(inp)
719-
]
720-
new_required_unfuseable_inputs = [
721-
inp
722-
for inp in required_unfuseable_inputs
723-
if inp not in subgraph_inputs
724-
]
725-
726-
must_backtrack = False
727-
if new_required_unfuseable_inputs and subgraph_outputs:
728-
# We need to check that any new inputs required by this node
729-
# do not depend on other outputs of the current subgraph,
730-
# via an unfuseable path.
731-
if variables_depend_on(
732-
[next_out],
733-
depend_on=unfuseable_clients_subgraph,
734-
stop_search_at=subgraph_outputs,
735-
):
736-
must_backtrack = True
710+
# We need to check that any inputs required by this node
711+
# do not depend on other outputs of the current subgraph,
712+
# via an unfuseable path.
713+
must_backtrack = (
714+
ancestors_bitset[next_node]
715+
& unfuseable_clients_subgraph_bitset
716+
)
737717

738718
if not must_backtrack:
739-
implied_unfuseable_clients = {
740-
c
741-
for client in unfuseable_clients_clone.get(next_out)
742-
if not isinstance(client.op, Output)
743-
for c in client.outputs
744-
}
745-
746-
new_implied_unfuseable_clients = (
747-
implied_unfuseable_clients - unfuseable_clients_subgraph
719+
implied_unfuseable_clients_bitset = reduce(
720+
or_,
721+
(
722+
1 << toposort_index[client]
723+
for client in unfuseable_clients_clone.get(next_out)
724+
if not isinstance(client.op, Output)
725+
),
726+
0,
748727
)
749728

750-
if new_implied_unfuseable_clients and subgraph_inputs:
751-
# We need to check that any inputs of the current subgraph
752-
# do not depend on other clients of this node,
753-
# via an unfuseable path.
754-
if variables_depend_on(
755-
subgraph_inputs,
756-
depend_on=new_implied_unfuseable_clients,
757-
):
758-
must_backtrack = True
729+
# We need to check that any inputs of the current subgraph
730+
# do not depend on other clients of this node,
731+
# via an unfuseable path.
732+
must_backtrack = (
733+
subgraph_inputs_ancestors_bitset
734+
& implied_unfuseable_clients_bitset
735+
)
759736

760737
if must_backtrack:
761738
for inp in next_node.inputs:
@@ -796,29 +773,24 @@ def variables_depend_on(
796773
# immediate dependency problems. Update subgraph
797774
# mappings as if it next_node was part of it.
798775
# Useless inputs will be removed by the useless Composite rewrite
799-
for inp in new_required_unfuseable_inputs:
800-
subgraph_inputs[inp] = None
801-
802776
if must_become_output:
803777
subgraph_outputs[next_out] = None
804-
unfuseable_clients_subgraph.update(
805-
new_implied_unfuseable_clients
778+
unfuseable_clients_subgraph_bitset |= (
779+
implied_unfuseable_clients_bitset
806780
)
807781

808-
# Expand through unvisited fuseable ancestors
809-
fuseable_nodes_to_visit.extendleft(
810-
sorted(
811-
(
812-
inp.owner
813-
for inp in next_node.inputs
814-
if (
815-
inp not in required_unfuseable_inputs
816-
and inp.owner not in visited_nodes
817-
)
818-
),
819-
key=toposort_index.get, # type: ignore[arg-type]
820-
)
821-
)
782+
for inp in sorted(
783+
next_node.inputs,
784+
key=lambda x: toposort_index.get(x.owner, -1),
785+
):
786+
if next_node in unfuseable_clients_clone.get(inp, ()):
787+
# input must become an input of the subgraph since it's unfuseable with new node
788+
subgraph_inputs_ancestors_bitset |= (
789+
ancestors_bitset.get(inp.owner, 0)
790+
)
791+
subgraph_inputs[inp] = None
792+
elif inp.owner not in visited_nodes:
793+
fuseable_nodes_to_visit.appendleft(inp.owner)
822794

823795
# Expand through unvisited fuseable clients
824796
fuseable_nodes_to_visit.extend(
@@ -855,6 +827,8 @@ def update_fuseable_mappings_after_fg_replace(
855827
visited_nodes: set[Apply],
856828
fuseable_clients: FUSEABLE_MAPPING,
857829
unfuseable_clients: UNFUSEABLE_MAPPING,
830+
toposort_index: dict[Apply, int],
831+
ancestors_bitset: dict[Apply, int],
858832
starting_nodes: set[Apply],
859833
updated_nodes: set[Apply],
860834
) -> None:
@@ -865,11 +839,25 @@ def update_fuseable_mappings_after_fg_replace(
865839
dropped_nodes = starting_nodes - updated_nodes
866840

867841
# Remove intermediate Composite nodes from mappings
842+
# And compute the ancestors bitset of the new composite node
843+
# As well as the new toposort index for the new node
844+
new_node_ancestor_bitset = 0
845+
new_node_toposort_index = len(toposort_index)
868846
for dropped_node in dropped_nodes:
869847
(dropped_out,) = dropped_node.outputs
870848
fuseable_clients.pop(dropped_out, None)
871849
unfuseable_clients.pop(dropped_out, None)
872850
visited_nodes.remove(dropped_node)
851+
# The new composite ancestor bitset is the union
852+
# of the ancestors of all the dropped nodes
853+
new_node_ancestor_bitset |= ancestors_bitset[dropped_node]
854+
# The new composite node can have the same order as the latest node that was absorbed into it
855+
new_node_toposort_index = max(
856+
new_node_toposort_index, toposort_index[dropped_node]
857+
)
858+
859+
ancestors_bitset[new_composite_node] = new_node_ancestor_bitset
860+
toposort_index[new_composite_node] = new_node_toposort_index
873861

874862
# Update fuseable information for subgraph inputs
875863
for inp in subgraph_inputs:
@@ -901,12 +889,23 @@ def update_fuseable_mappings_after_fg_replace(
901889
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
902890
visited_nodes: set[Apply] = set()
903891
toposort_index = {node: i for i, node in enumerate(fgraph.toposort())}
892+
# Create a bitset for each node of all its ancestors
893+
# This allows to quickly check if a variable depends on a set
894+
ancestors_bitset: dict[Apply, int] = {}
895+
for node, index in toposort_index.items():
896+
node_ancestor_bitset = 1 << index
897+
for inp in node.inputs:
898+
if (inp_node := inp.owner) is not None:
899+
node_ancestor_bitset |= ancestors_bitset[inp_node]
900+
ancestors_bitset[node] = node_ancestor_bitset
901+
904902
while True:
905903
try:
906904
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
907905
visited_nodes=visited_nodes,
908906
fuseable_clients=fuseable_clients,
909907
unfuseable_clients=unfuseable_clients,
908+
ancestors_bitset=ancestors_bitset,
910909
toposort_index=toposort_index,
911910
)
912911
except ValueError:
@@ -925,6 +924,8 @@ def update_fuseable_mappings_after_fg_replace(
925924
visited_nodes=visited_nodes,
926925
fuseable_clients=fuseable_clients,
927926
unfuseable_clients=unfuseable_clients,
927+
toposort_index=toposort_index,
928+
ancestors_bitset=ancestors_bitset,
928929
starting_nodes=starting_nodes,
929930
updated_nodes=fg.apply_nodes,
930931
)

tests/test_printing.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,8 @@ def test_debugprint():
301301
Gemv_op_name = "CGemv" if pytensor.config.blas__ldflags else "Gemv"
302302
exp_res = dedent(
303303
r"""
304-
Composite{(i2 + (i0 - i1))} 4
304+
Composite{(i0 + (i1 - i2))} 4
305+
├─ A
305306
├─ ExpandDims{axis=0} v={0: [0]} 3
306307
"""
307308
f" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2"
@@ -313,17 +314,16 @@ def test_debugprint():
313314
│ ├─ B
314315
│ ├─ <Vector(float64, shape=(?,))>
315316
│ └─ 0.0
316-
├─ D
317-
└─ A
317+
└─ D
318318
319319
Inner graphs:
320320
321-
Composite{(i2 + (i0 - i1))}
321+
Composite{(i0 + (i1 - i2))}
322322
← add 'o0'
323-
├─ i2
324-
└─ sub
325323
├─ i0
326-
└─ i1
324+
└─ sub
325+
├─ i1
326+
└─ i2
327327
"""
328328
).lstrip()
329329

0 commit comments

Comments
 (0)