Skip to content

vmap over composite out-of-place ops whose in-place variant is non-composite #260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
zou3519 opened this issue Nov 12, 2021 · 2 comments
Assignees
Labels
actionable It is clear what should be done for this issue

Comments

@zou3519
Copy link
Contributor

zou3519 commented Nov 12, 2021

The following operations:

  • index_add
  • index_copy
  • index_fill
  • masked_fill
  • masked_scatter

have this quirk where they are composite but their in-place variant has an autograd formula. For example,

Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Scalar& source) {
  return self.clone(at::MemoryFormat::Preserve).index_fill_(dim, index, source);
}

We have trouble doing vmap(grad(foo)) where foo includes one of the above operations. This is because Autograd ends up decomposing e.g. index_fill into tensor.clone().index_fill_(...) and the in-place operation is not vmap-compatible.

I've brainstormed two ways to solve this. There are tradeoffs for each and I'm wondering if someone else has a preference.

Approach 1: DECOMPOSE_FUNCTIONAL

(cc @bdhirsh)

We use the functionalization pass to functionalize index_fill. Unfortunately this results in the following code:

self.clone(...).index_fill(dim, index, source)

which results in an unnecessary clone() which is not good for performance. IDK if we want to make the functionalization pass smarter in the future, this sounds complicated.

Approach 2: Add backward formulas for index_fill and all of the operations above (aka, turn them into primitive operations).

(cc @albanD)

This means that both index_fill and index_fill_ get backward formulas (Could we get away with only giving index_fill a backward formula?). This is a pretty simple solution, the tradeoff is that we need to duplicate the formulas and we are setting the precedence that "operations that have an out-of-place variant must have a backward formula defined on the out-of-place variant".

Discussion

I prefer Approach 2 for its simplicity. To address the code duplication we can put the formulas into helper functions. Thoughts?

@bdhirsh
Copy link
Contributor

bdhirsh commented Nov 12, 2021

The discussion here is probably relevant: https://github.com/pytorch/pytorch/pull/65993/files#r729912376

index_add will stop being a CompositeImplicitAutograd kernel after the above structured porting PR lands. I think the same will probably happen to the other ops on the list, once they're all eventually ported to structured.

I had the same question about index_add and index_add_ both getting backward formulas in that PR (currently in the PR, they do). @albanD I think the answer is that we still want a formula directly for index_add_ due to perf, right? Perf is only affected in cases where a user calls the inplace index_add_ op directly and uses it with autograd.

So my take (definitely open to other opinions, though), is that all of the above ops will stop having math-composite kernels at some point, which should fix this problem for functorch (once they're all structured). Maybe we should prioritize porting them if they're important to functorch.

@zou3519
Copy link
Contributor Author

zou3519 commented Nov 15, 2021

Thanks for the link, @bdhirsh! It's very relevant.

It sounds like the solution is to have two sets of backward formulas, one for the out-of-place version (index_add) and one for the in-place version (index_add_) for perf reasons. Furthermore, it looks like that is also a prereq for making the operation structured.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable It is clear what should be done for this issue
Projects
None yet
Development

No branches or pull requests

3 participants