You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We have trouble doing vmap(grad(foo)) where foo includes one of the above operations. This is because Autograd ends up decomposing e.g. index_fill into tensor.clone().index_fill_(...) and the in-place operation is not vmap-compatible.
I've brainstormed two ways to solve this. There are tradeoffs for each and I'm wondering if someone else has a preference.
We use the functionalization pass to functionalize index_fill. Unfortunately this results in the following code:
self.clone(...).index_fill(dim, index, source)
which results in an unnecessary clone() which is not good for performance. IDK if we want to make the functionalization pass smarter in the future, this sounds complicated.
Approach 2: Add backward formulas for index_fill and all of the operations above (aka, turn them into primitive operations).
This means that both index_fill and index_fill_ get backward formulas (Could we get away with only giving index_fill a backward formula?). This is a pretty simple solution, the tradeoff is that we need to duplicate the formulas and we are setting the precedence that "operations that have an out-of-place variant must have a backward formula defined on the out-of-place variant".
Discussion
I prefer Approach 2 for its simplicity. To address the code duplication we can put the formulas into helper functions. Thoughts?
The text was updated successfully, but these errors were encountered:
index_add will stop being a CompositeImplicitAutograd kernel after the above structured porting PR lands. I think the same will probably happen to the other ops on the list, once they're all eventually ported to structured.
I had the same question about index_add and index_add_ both getting backward formulas in that PR (currently in the PR, they do). @albanD I think the answer is that we still want a formula directly for index_add_ due to perf, right? Perf is only affected in cases where a user calls the inplace index_add_ op directly and uses it with autograd.
So my take (definitely open to other opinions, though), is that all of the above ops will stop having math-composite kernels at some point, which should fix this problem for functorch (once they're all structured). Maybe we should prioritize porting them if they're important to functorch.
Thanks for the link, @bdhirsh! It's very relevant.
It sounds like the solution is to have two sets of backward formulas, one for the out-of-place version (index_add) and one for the in-place version (index_add_) for perf reasons. Furthermore, it looks like that is also a prereq for making the operation structured.
The following operations:
have this quirk where they are composite but their in-place variant has an autograd formula. For example,
We have trouble doing
vmap(grad(foo))
wherefoo
includes one of the above operations. This is because Autograd ends up decomposing e.g.index_fill
intotensor.clone().index_fill_(...)
and the in-place operation is not vmap-compatible.I've brainstormed two ways to solve this. There are tradeoffs for each and I'm wondering if someone else has a preference.
Approach 1: DECOMPOSE_FUNCTIONAL
(cc @bdhirsh)
We use the functionalization pass to functionalize index_fill. Unfortunately this results in the following code:
which results in an unnecessary clone() which is not good for performance. IDK if we want to make the functionalization pass smarter in the future, this sounds complicated.
Approach 2: Add backward formulas for
index_fill
and all of the operations above (aka, turn them into primitive operations).(cc @albanD)
This means that both
index_fill
andindex_fill_
get backward formulas (Could we get away with only giving index_fill a backward formula?). This is a pretty simple solution, the tradeoff is that we need to duplicate the formulas and we are setting the precedence that "operations that have an out-of-place variant must have a backward formula defined on the out-of-place variant".Discussion
I prefer Approach 2 for its simplicity. To address the code duplication we can put the formulas into helper functions. Thoughts?
The text was updated successfully, but these errors were encountered: