-
Notifications
You must be signed in to change notification settings - Fork 142
Speedup FusionOptimizer #1615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speedup FusionOptimizer #1615
Conversation
19748ac
to
97cef0b
Compare
Codecov Report❌ Patch coverage is
❌ Your patch status has failed because the patch coverage (97.04%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1615 +/- ##
=======================================
Coverage 81.64% 81.64%
=======================================
Files 231 231
Lines 52997 52952 -45
Branches 9395 9388 -7
=======================================
- Hits 43267 43235 -32
+ Misses 7282 7273 -9
+ Partials 2448 2444 -4
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes the FusionOptimizer to improve compilation performance by a factor of 3-4x on benchmarked graphs. The optimization focuses on reducing redundant operations and using more efficient data structures for graph analysis.
- Implemented bitset-based ancestor dependency tracking for faster subgraph convexity checks
- Eliminated redundant graph cloning and toposort computations during fusion analysis
- Streamlined the fusion algorithm to avoid backtracking by using a more direct approach
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
File | Description |
---|---|
pytensor/tensor/rewriting/elemwise.py | Complete rewrite of FusionOptimizer logic with bitset-based dependency tracking and elimination of redundant operations |
pytensor/scalar/basic.py | Performance optimizations for scalar type creation, graph cleanup, and C code validation |
tests/tensor/rewriting/test_elemwise.py | Updated benchmarks with new test cases and expected fusion counts |
tests/test_printing.py | Updated expected test output reflecting changes in Composite operation ordering |
pytensor/tensor/conv/abstract_conv.py | Removed type checking comment for scipy import |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
986ba6f
to
eb010b7
Compare
Tests are passing and conflicts sorted |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few small changes in the review so far.
I also want to have a bigger discussion about the bitset algorithm being used in the FusionOptimizer. I think it's obviously better: about 100 lines of code shorter, and faster. But I also think it's highly non-obvious what's going on, and I think it introduces a future maintenance burden and the possibility for subtle bugs.
I think one thing that would help would be to factor out things into some helpers. I'm thinking about something like a BitSet
class, and something like a BitSetTraverser
class, so that the code inside the fusion optimizer becomes more readable, like:
while traverser.queue:
node_bitflag, node, is_ancestor = traverser.get_next_node()
if traverser.is_in_subgraph(node_bitflag):
continue
if traverser.has_unfusable_ancestors(node_bitflag, is_ancestor):
continue
elif traverser.has_unfusible_clients(node_bitflag, is_ancestor):
continue
And so on. These methods can have docstrings corresponding to your (very very excellent) comments as documentation if someone wants to drill down into what's actually happening, but the FusionOptimizer itself will remain readable at a high level.
return [x], [out] | ||
|
||
@staticmethod | ||
def diamond_graph(n): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test case is commented out below, just remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, it was just for debugging
I disagree because:
|
Agree, but all that applies even more to the old code. Try and read how it looked like before. |
I 100% agree that this PR improves on the status quo. I frankly don't even know what you had to do to get a deep enough understanding of the old code to do this refactor. If you want to spend time implementing and timing something, I would rather it be a version of this algorithm that pulls out logical steps into functions/methods and makes the hot loop as readable as possible. This would open up the possibility to test each component subroutine individually, and make it clear to future readers exactly what is going on. If it ends up being unacceptably slow, fine I'll back off. But might as well have a conversation about what our compute budget is for readability. If it's zero, fine. But I think some of these routines were essentially abandoned by the devs because they're so dense and arcane. Having something that's 1% slower but which offers hope for someone to come along and understand it in the future would be worth it imo. |
I wrote the previous crappy code :D, that's probably how |
I want us to have compile times similar to jax (ignoring when it goes crazy with constant folding). If I compile the graph from The FusionOptimizer takes 40% of the compile time before this PR and 10% after. Of what remains it's mostly unrelated to the rewrite and just the cost of doing fgraph replacements and creating composite nodes. With this PR and previous ones and some more changes in #1607 I can get the function to compile in ~50-40ms. Compile times are critical to make PyTensor attractive. It was a reason why preliz decided to just use numba directly, and why they are now willing to reconsider. It's a big part of why the soccer modelling library switched from pymc to stan (they also complained about import times which we improved a lot and c backend installation ofc). Now that example is not realistic but it's a worst case scenario (as well as the pre-existing bench) for the rewrite. Fusion is our most complex rewrite, after the scan rewrites, certainly the most commonly triggered of the complex rewrites. It is absolutely necessary to generate fast C/Numba code (it's not used in jax, because they do fusion themselves). |
Fine, I'm convinced on speed. I've also been frustrated with pytensor compile times, so I definitely get it. |
eb010b7
to
e8095ef
Compare
Not using `__call__` avoids the test_value computation
It's not really needed as we never expand on the new nodes
7bfd4f9
to
e0d0bd5
Compare
@jessegrabowski I added more description, removed the print statements (changed my mind on that), and addressed your other comments. I added a |
The change in number of fused kernels has to do with the order of iteration, and could be replicated in the old approach by iterating in topological order. It was an accident that it happen to visit in an order where it connected two branches, instead of keeping them separate. The underlying limitation already existed and is described in pymc-devs#249
e0d0bd5
to
e5e58b2
Compare
FusionOptimizer can be one of the slower rewrites during compilation. This PR speedups it up by a factor of 4-3x in the benchmarked graphs.
Each commit provides a substantial speedup (except maybe for 2-> 3).
The main speedups come from:
The logic for finding valid fused kernels is also more clear now imo, avoiding the need for backtracking.
Benchmark per commit