Skip to content

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Sep 20, 2025

FusionOptimizer can be one of the slower rewrites during compilation. This PR speedups it up by a factor of 4-3x in the benchmarked graphs.

Each commit provides a substantial speedup (except maybe for 2-> 3).

The main speedups come from:

  1. Reducing number of inner graph clonings when creating CompositeOp
  2. Reducing number of toposort computation
  3. Using bitsets and bitflags to efficiently compute multiset ancestor dependencies (to ask: do these variables depend on these others?)

The logic for finding valid fused kernels is also more clear now imo, avoiding the need for backtracking.

Benchmark per commit
HEAD is now at 774792356 Benchmark another FusionOptimizer graph
----------------------------------------------------------------------------------------------------------- benchmark: 2 tests -----------------------------------------------------------------------------------------------------------
Name (time in ms)                                                         Min                 Max                Mean             StdDev              Median                IQR            Outliers      OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rewrite_benchmark[deep_small_kernels-20-expected_n_repl0]        23.0197 (1.0)       54.2629 (1.0)       36.1126 (1.0)      12.1105 (1.99)      41.8545 (1.0)      18.0821 (2.25)          2;0  27.6912 (1.0)           7           5
test_rewrite_benchmark[large_fuseable_graph-25-expected_n_repl1]     485.6301 (21.10)    503.2956 (9.28)     496.1585 (13.74)     6.0741 (1.0)      496.9919 (11.87)     8.0230 (1.0)           2;0   2.0155 (0.07)          7           5
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

HEAD is now at c7b287f39 Short-circuit `as_scalar` common cases faster
---------------------------------------------------------------------------------------------------------- benchmark: 2 tests ----------------------------------------------------------------------------------------------------------
Name (time in ms)                                                         Min                 Max                Mean            StdDev              Median               IQR            Outliers      OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rewrite_benchmark[deep_small_kernels-20-expected_n_repl0]        23.2032 (1.0)       30.3229 (1.0)       27.0646 (1.0)      3.2161 (1.0)       28.9586 (1.0)      5.9681 (1.0)           3;0  36.9486 (1.0)           7           5
test_rewrite_benchmark[large_fuseable_graph-25-expected_n_repl1]     487.7418 (21.02)    499.8394 (16.48)    494.1633 (18.26)    4.6597 (1.45)     494.5370 (17.08)    8.1017 (1.36)          3;0   2.0236 (0.05)          7           5
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

HEAD is now at 598d9fcb9 Speedup supports c_code
---------------------------------------------------------------------------------------------------------- benchmark: 2 tests ----------------------------------------------------------------------------------------------------------
Name (time in ms)                                                         Min                 Max                Mean            StdDev              Median               IQR            Outliers      OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rewrite_benchmark[deep_small_kernels-20-expected_n_repl0]        22.3723 (1.0)       30.8797 (1.0)       26.4816 (1.0)      3.6947 (1.0)       28.3610 (1.0)      6.6708 (1.0)           3;0  37.7620 (1.0)           7           5
test_rewrite_benchmark[large_fuseable_graph-25-expected_n_repl1]     489.3477 (21.87)    504.9612 (16.35)    495.2749 (18.70)    6.1049 (1.65)     492.2781 (17.36)    9.8775 (1.48)          1;0   2.0191 (0.05)          7           5
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

HEAD is now at 41707795b Speedup FusionOptimizer.elemwise_to_scalar
----------------------------------------------------------------------------------------------------------- benchmark: 2 tests ----------------------------------------------------------------------------------------------------------
Name (time in ms)                                                         Min                 Max                Mean             StdDev              Median               IQR            Outliers      OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rewrite_benchmark[deep_small_kernels-20-expected_n_repl0]        19.5081 (1.0)       27.7007 (1.0)       21.9665 (1.0)       3.4888 (1.0)       20.2496 (1.0)      5.1938 (1.0)           2;0  45.5238 (1.0)           7           5
test_rewrite_benchmark[large_fuseable_graph-25-expected_n_repl1]     419.0247 (21.48)    617.6116 (22.30)    461.7518 (21.02)    69.2312 (19.84)    439.2438 (21.69)    9.9747 (1.92)          1;2   2.1657 (0.05)          7           5
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

HEAD is now at 34de75ff5 Avoid double cloning of Composite Ops created by FusionOptimizer
----------------------------------------------------------------------------------------------------------- benchmark: 2 tests -----------------------------------------------------------------------------------------------------------
Name (time in ms)                                                         Min                 Max                Mean             StdDev              Median                IQR            Outliers      OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rewrite_benchmark[deep_small_kernels-20-expected_n_repl0]        17.0792 (1.0)       24.1214 (1.0)       19.3538 (1.0)       3.1298 (1.0)       17.6328 (1.0)       5.0325 (1.0)           2;0  51.6693 (1.0)           7           5
test_rewrite_benchmark[large_fuseable_graph-25-expected_n_repl1]     381.0516 (22.31)    455.6597 (18.89)    407.9326 (21.08)    27.3474 (8.74)     398.1366 (22.58)    37.5527 (7.46)          1;0   2.4514 (0.05)          7           5
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

HEAD is now at 543a8a23a Do not recompute toposort in every iteration of FusionOptimizer
----------------------------------------------------------------------------------------------------------- benchmark: 2 tests -----------------------------------------------------------------------------------------------------------
Name (time in ms)                                                         Min                 Max                Mean             StdDev              Median                IQR            Outliers      OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rewrite_benchmark[deep_small_kernels-20-expected_n_repl0]        14.3812 (1.0)       21.8006 (1.0)       16.3498 (1.0)       3.1517 (1.0)       14.5783 (1.0)       4.2066 (1.0)           2;0  61.1627 (1.0)           7           5
test_rewrite_benchmark[large_fuseable_graph-25-expected_n_repl1]     244.2117 (16.98)    279.3751 (12.82)    261.6797 (16.01)    11.5636 (3.67)     264.0385 (18.11)    14.4312 (3.43)          2;0   3.8215 (0.06)          7           5
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

HEAD is now at dd75569a1 Cleanup FusionOptimizer code
----------------------------------------------------------------------------------------------------------- benchmark: 2 tests -----------------------------------------------------------------------------------------------------------
Name (time in ms)                                                         Min                 Max                Mean             StdDev              Median                IQR            Outliers      OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rewrite_benchmark[deep_small_kernels-20-expected_n_repl0]        13.8777 (1.0)       21.8236 (1.0)       15.9468 (1.0)       3.5019 (1.0)       13.9371 (1.0)       4.7736 (1.0)           2;0  62.7085 (1.0)           7           5
test_rewrite_benchmark[large_fuseable_graph-25-expected_n_repl1]     243.7067 (17.56)    276.4169 (12.67)    257.8971 (16.17)    12.8118 (3.66)     256.3974 (18.40)    22.7781 (4.77)          3;0   3.8775 (0.06)          7           5
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

HEAD is now at 68ca3cf36 Copy on write in FusionOptimizer
----------------------------------------------------------------------------------------------------------- benchmark: 2 tests -----------------------------------------------------------------------------------------------------------
Name (time in ms)                                                         Min                 Max                Mean             StdDev              Median                IQR            Outliers      OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rewrite_benchmark[deep_small_kernels-20-expected_n_repl0]        14.1523 (1.0)       20.7984 (1.0)       16.0264 (1.0)       2.8238 (1.0)       14.4556 (1.0)       3.9848 (1.0)           2;0  62.3970 (1.0)           7           5
test_rewrite_benchmark[large_fuseable_graph-25-expected_n_repl1]     195.4694 (13.81)    264.7950 (12.73)    221.6116 (13.83)    22.3072 (7.90)     213.4827 (14.77)    19.7748 (4.96)          2;1   4.5124 (0.07)          7           5
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

HEAD is now at bb3e54c57 Use bitset to check ancestors more efficiently
------------------------------------------------------------------------------------------------------------ benchmark: 2 tests -----------------------------------------------------------------------------------------------------------
Name (time in ms)                                                         Min                 Max                Mean             StdDev              Median                IQR            Outliers       OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rewrite_benchmark[deep_small_kernels-20-expected_n_repl0]         8.1186 (1.0)       14.0557 (1.0)        9.9162 (1.0)       2.6646 (1.0)        8.5550 (1.0)       4.1075 (1.0)           2;0  100.8451 (1.0)           7           5
test_rewrite_benchmark[large_fuseable_graph-25-expected_n_repl1]     152.0318 (18.73)    176.6330 (12.57)    163.5286 (16.49)    10.0701 (3.78)     165.0426 (19.29)    18.2174 (4.44)          2;0    6.1151 (0.06)          7           5
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

HEAD is now at c392faec9 Avoid backtracking in FusionOptimizer
----------------------------------------------------------------------------------------------------------- benchmark: 2 tests ----------------------------------------------------------------------------------------------------------
Name (time in ms)                                                         Min                 Max                Mean            StdDev              Median               IQR            Outliers       OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rewrite_benchmark[deep_small_kernels-20-expected_n_repl0]         6.9413 (1.0)       14.3935 (1.0)        8.9253 (1.0)      3.2081 (1.0)        7.1780 (1.0)      4.3681 (1.0)           2;0  112.0412 (1.0)           7           5
test_rewrite_benchmark[large_fuseable_graph-25-expected_n_repl1]     140.8090 (20.29)    155.4400 (10.80)    149.6261 (16.76)    5.7609 (1.80)     151.9854 (21.17)    9.8091 (2.25)          3;0    6.6833 (0.06)          7           5
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

@ricardoV94 ricardoV94 force-pushed the faster_fusion_optimizer_based_on_edges branch 4 times, most recently from 19748ac to 97cef0b Compare September 20, 2025 11:10
Copy link

codecov bot commented Sep 20, 2025

Codecov Report

❌ Patch coverage is 97.04142% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.64%. Comparing base (96122d1) to head (e5e58b2).
⚠️ Report is 11 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/elemwise.py 97.14% 2 Missing and 2 partials ⚠️
pytensor/scalar/basic.py 96.55% 1 Missing ⚠️

❌ Your patch status has failed because the patch coverage (97.04%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1615   +/-   ##
=======================================
  Coverage   81.64%   81.64%           
=======================================
  Files         231      231           
  Lines       52997    52952   -45     
  Branches     9395     9388    -7     
=======================================
- Hits        43267    43235   -32     
+ Misses       7282     7273    -9     
+ Partials     2448     2444    -4     
Files with missing lines Coverage Δ
pytensor/scalar/basic.py 80.57% <96.55%> (-0.01%) ⬇️
pytensor/tensor/rewriting/elemwise.py 93.62% <97.14%> (+1.01%) ⬆️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 marked this pull request as ready for review September 20, 2025 14:55
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR optimizes the FusionOptimizer to improve compilation performance by a factor of 3-4x on benchmarked graphs. The optimization focuses on reducing redundant operations and using more efficient data structures for graph analysis.

  • Implemented bitset-based ancestor dependency tracking for faster subgraph convexity checks
  • Eliminated redundant graph cloning and toposort computations during fusion analysis
  • Streamlined the fusion algorithm to avoid backtracking by using a more direct approach

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
pytensor/tensor/rewriting/elemwise.py Complete rewrite of FusionOptimizer logic with bitset-based dependency tracking and elimination of redundant operations
pytensor/scalar/basic.py Performance optimizations for scalar type creation, graph cleanup, and C code validation
tests/tensor/rewriting/test_elemwise.py Updated benchmarks with new test cases and expected fusion counts
tests/test_printing.py Updated expected test output reflecting changes in Composite operation ordering
pytensor/tensor/conv/abstract_conv.py Removed type checking comment for scipy import

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@ricardoV94 ricardoV94 force-pushed the faster_fusion_optimizer_based_on_edges branch 8 times, most recently from 986ba6f to eb010b7 Compare September 23, 2025 06:08
@ricardoV94
Copy link
Member Author

Tests are passing and conflicts sorted

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few small changes in the review so far.

I also want to have a bigger discussion about the bitset algorithm being used in the FusionOptimizer. I think it's obviously better: about 100 lines of code shorter, and faster. But I also think it's highly non-obvious what's going on, and I think it introduces a future maintenance burden and the possibility for subtle bugs.

I think one thing that would help would be to factor out things into some helpers. I'm thinking about something like a BitSet class, and something like a BitSetTraverser class, so that the code inside the fusion optimizer becomes more readable, like:

while traverser.queue:
    node_bitflag, node, is_ancestor = traverser.get_next_node()
    if traverser.is_in_subgraph(node_bitflag): 
         continue
    if traverser.has_unfusable_ancestors(node_bitflag, is_ancestor):
        continue
    elif traverser.has_unfusible_clients(node_bitflag, is_ancestor):
        continue
    

And so on. These methods can have docstrings corresponding to your (very very excellent) comments as documentation if someone wants to drill down into what's actually happening, but the FusionOptimizer itself will remain readable at a high level.

return [x], [out]

@staticmethod
def diamond_graph(n):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case is commented out below, just remove it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, it was just for debugging

@ricardoV94
Copy link
Member Author

I think one thing that would help would be to factor out things into some helpers.

I disagree because:

  1. Python methods are incredibly slow and this is a hot loop
  2. Functionality is not meant to be reused elsewhere, it's an implementation detail
  3. Comments can achieve the same and you don't have to go somewhere else to read it.

@ricardoV94
Copy link
Member Author

But I also think it's highly non-obvious what's going on, and I think it introduces a future maintenance burden and the possibility for subtle bugs.

Agree, but all that applies even more to the old code. Try and read how it looked like before.

@jessegrabowski
Copy link
Member

Agree, but all that applies even more to the old code. Try and read how it looked like before.

I 100% agree that this PR improves on the status quo. I frankly don't even know what you had to do to get a deep enough understanding of the old code to do this refactor.

If you want to spend time implementing and timing something, I would rather it be a version of this algorithm that pulls out logical steps into functions/methods and makes the hot loop as readable as possible. This would open up the possibility to test each component subroutine individually, and make it clear to future readers exactly what is going on.

If it ends up being unacceptably slow, fine I'll back off. But might as well have a conversation about what our compute budget is for readability. If it's zero, fine. But I think some of these routines were essentially abandoned by the devs because they're so dense and arcane. Having something that's 1% slower but which offers hope for someone to come along and understand it in the future would be worth it imo.

@ricardoV94
Copy link
Member Author

I frankly don't even know what you had to do to get a deep enough understanding of the old code to do this refactor.

I wrote the previous crappy code :D, that's probably how

@ricardoV94
Copy link
Member Author

ricardoV94 commented Sep 27, 2025

But might as well have a conversation about what our compute budget is for readability. If it's zero, fine.

I want us to have compile times similar to jax (ignoring when it goes crazy with constant folding).

If I compile the graph from deep_small_kernels(n=20) on (one of my machines, not sure if the same as the one with timings above) it was taking 150ms, JAX would take way less, maybe 20ms.

The FusionOptimizer takes 40% of the compile time before this PR and 10% after. Of what remains it's mostly unrelated to the rewrite and just the cost of doing fgraph replacements and creating composite nodes. With this PR and previous ones and some more changes in #1607 I can get the function to compile in ~50-40ms.

Compile times are critical to make PyTensor attractive. It was a reason why preliz decided to just use numba directly, and why they are now willing to reconsider. It's a big part of why the soccer modelling library switched from pymc to stan (they also complained about import times which we improved a lot and c backend installation ofc).

Now that example is not realistic but it's a worst case scenario (as well as the pre-existing bench) for the rewrite. Fusion is our most complex rewrite, after the scan rewrites, certainly the most commonly triggered of the complex rewrites. It is absolutely necessary to generate fast C/Numba code (it's not used in jax, because they do fusion themselves).

@jessegrabowski
Copy link
Member

Fine, I'm convinced on speed. I've also been frustrated with pytensor compile times, so I definitely get it.

@ricardoV94 ricardoV94 force-pushed the faster_fusion_optimizer_based_on_edges branch from eb010b7 to e8095ef Compare September 29, 2025 16:33
@ricardoV94 ricardoV94 force-pushed the faster_fusion_optimizer_based_on_edges branch 2 times, most recently from 7bfd4f9 to e0d0bd5 Compare September 29, 2025 16:53
@ricardoV94
Copy link
Member Author

@jessegrabowski I added more description, removed the print statements (changed my mind on that), and addressed your other comments.

I added a TestFusion.test_expansion_order that would fail if we didn't expand in the right order in the current algorithm. I tested it fails sometimes if I randomize the pop from the queue instead of using the intended ordering

The change in number of fused kernels has to do with the order of iteration, and could be replicated in the old approach by iterating in topological order. It was an accident that it happen to visit in an order where it connected two branches, instead of keeping them separate. The underlying limitation already existed and is described in pymc-devs#249
@ricardoV94 ricardoV94 force-pushed the faster_fusion_optimizer_based_on_edges branch from e0d0bd5 to e5e58b2 Compare September 29, 2025 16:58
@ricardoV94 ricardoV94 merged commit 46f8227 into pymc-devs:main Sep 30, 2025
63 of 64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants