Skip to content

Extend mma dimension and layout checking to support strided batched matmul and tensor contractions #1761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 27, 2022

Conversation

shmsong
Copy link

@shmsong shmsong commented Jun 14, 2022

The mma interface was initially simplified to only support matmuls. The current interface re-defines the axes matching analysis using root domain mapping to support more flexible mapping of M,N,K dimensions. This will make it possible to implement strided batched matmul with mma ops (see test case FusionAmpereStridedBatchedMatmulTN), and quite a few variants of tensor contractions should also fall into this generalization.

TODO:

This is still only a further step into full expressiveness of mma interface. There is another simplification that mma only packs the innermost dimension of each of the M, N, and K axes. This is currrently not the highest priority but would be further extended in a follow up PR to support optimizing tensor contraction kernels.

Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some more comments could be helpful. Curious how far this can take us today when we think about fully generic tensor contractions. It seems we couldn't have a non-used broadcast axis, or a trivial reduction in the fusion. However, it seems if we correctly use rfactor, we could probably have an extra reduction dimension, we'd just have to rfactor it. Is that true?


// Merge the outer dims:
tv2->merge(0);
tv2->merge(0);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a preference to put the batch dim into the M dim? I guess I imagined we'd just leave the batched dim out of the scheduling more or less.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything interesting following this, or is it now just a standard matmul schedule?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a preference to put the batch dim into the M dim? I guess I imagined we'd just leave the batched dim out of the scheduling more or less.

No preference at the moment, and yes all this step currently does is merge the batched dims and attach it to a block dim.

The actual difference will be layout dependent though, which we could revisit when we go for tensor contraction perf.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything interesting following this, or is it now just a standard matmul schedule?

I'm just using a standard matmul schedule for the inner loops now. Most of the standard matmul optimization should just apply here.

is_broadcast_in_a && !is_broadcast_in_b && !is_reduction_id;
break;

default:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we ignore dimensions that are broadcasted in both a and b?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes in this check. This utility only grabs mma_dimensions, which are the dimensions that the mma instruction can and will realize.

In that case the broadcasted dimension isn't concretized in this mma op so it's not a mma_dimension, it's kind of a "batching" dimension.

MmaDimension dimension) {
// Build a fusion-level root domain map
// so we can use the mma swizzles on non-immediate tensor operands, for
// example loadstore staging ops.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uncertain what you mean here about being able to use the mma swizzles on non-immediate tensor operands.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for staging and prolog fusion, in the pure matmul compute, the operand path from smem to mma we have:
ldmatrix -> serial broadcast -> mma

with prolog fusion we could also have:
ldmatrix -> unary op -> serial broadcast -> mma etc.

The mma swizzle information is on the mma op while the operand swizzle we want to apply at ldmatrix op, so the connection is done through root domain mapping, in this case.

// a lot of mma ops in a fusion this could be lower priority.
// First it'd be nice not having to build root map every time this function
// is called. That'd require some explicit boundary where we "lock" the
// compute in the fusion so the root map stays valid.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be an appropriate optimization to build into scheduling with the caching system. Other than that we should just rebulid it, and for now it seems reasonable to do so. If we need to do this during heuristics checking, then makes sense to cache sooner than later. However, until we have a scheduler that can handle matmul's, it seems moot.

// compute in the fusion so the root map stays valid.
// Second it'd reduce complexity of the below matching by an order if we have
// something similar to "disjointSetOf" in idGraph, for just the root domains
// at scheduler composing time.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the problem with literally using id graph? We use compute at maps in scheduling and that seems reasonable to me.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these utilities are used in scheduling time. RootDomainMap should be completely defined since the compute is defined. But computeAt may still change after this stage so not sure if we'd want to be building an IdGraph at this stage. Also seems to be quite compute heavy to do in sync with scheduler primitives.

auto mma_root_dimensions = getMmaDomains(mma, dimension);
auto mma_accumulator_tv = mma->out()->as<TensorView>();

std::vector<IterDomain*> result;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you comment this double loop a bit, that you're just checking if any of the returned dimensions from the getMmaDomain function matches any dimension in the tv of interest, and if so is just accumulating the dims into the result.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You give good context in the comment above the function, but don't really specify how you're doing this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's just accumulating matching root id's to the result vector. Added comments here. Thanks.

"MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain");
}

void validateResultInnerMN(TensorView* tv, int m, int n) {
void validateMmaRootInnerMN(TensorView* tv, MmaOptions options, int m, int n) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you comment a bit more of what you're validating here. I don't recall when you'd be validating here. Is this just validating after scheduling the formats are correctly in the format of the hardware mma operation?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is right based on the comment of canValidateIsInnerDim

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment. Yes this is validating that the innermost axes are scheduled according to the mma instruction size and mma swizzler will further apply the mma swizzle patterns to it.

@shmsong
Copy link
Author

shmsong commented Jun 27, 2022

LGTM, some more comments could be helpful. Curious how far this can take us today when we think about fully generic tensor contractions. It seems we couldn't have a non-used broadcast axis, or a trivial reduction in the fusion. However, it seems if we correctly use rfactor, we could probably have an extra reduction dimension, we'd just have to rfactor it. Is that true?

The current handling of mma should be quite generic and should be able to generate most contractions that can be realized with matmuls, with possibly some permutations required when multiple dimensions need to be reduced.

The main limitation is that it might not have covered all the optimization strategies we want to do when scheduling tensor contraction kernels. We'd need to further generalize the mma matching pattern to support not just packing the innermost dimensions of the m,n,k axes when we use mma.

On the extra broadcast/reduction dimensions:

  • any broadcast that is concretized at the mma op is either a M or a N dimension, borrowing the matmul terminology.

  • any broadcast that's not concretized at the mma op could be treated as a batch dimension and left out after the mma.

  • any reduction that's concrete and mapped across the operands, is a K dimension and mma handles them.

  • any reduction that isn't concretely matched coming from the operand would be a prolog reduction that mma wouldn't be able to handle any ways and that'd be a separate reduction op.

  • further reductions on the M and N dimensions are epilog reductions, that should be a separate reduction op and I believe binary tensor contractions wouldn't directly involve that.

We should be able to handle all above and I think this should be all possible cases plus minus trivial reductions and broadcasts that we'd need to revisit when we fully generalize.

@shmsong shmsong changed the title WIP: Extend mma dimension and layout checking to support strided batched matmul and tensor contractions Extend mma dimension and layout checking to support strided batched matmul and tensor contractions Jun 27, 2022
@shmsong shmsong merged commit ecc7a87 into devel Jun 27, 2022
@shmsong shmsong deleted the mma_stride0 branch June 27, 2022 20:49
naoyam added a commit that referenced this pull request Jul 11, 2022
* Refactor TransormPropagator to allow specifying a position and propagating to part of the DAG (#1775)

`MaxInfoPropagator` is renamed to `MaxInfoSpanningTree`, it now only does path-finding, and the propagation is in a separate class `MaxInfoSpanningTree::Propagator`. Same for `MaxRootDomainInfoPropagator`.

`MaxInfoSpanningTree` and `MaxRootDomainInfoSpanningTree`  now allow specifying a selector, which controls which subgraph should be included in path-finding.

`MaxRootDomainInfoSpanningTree` also gets a few new constructors for convenience to use.

`TransormPropagator` is now a subclass of `MaxInfoSpanningTree::Propagator`, so the way to use it has changed.

Now `MaxInfoSpanningTree` and `MaxRootDomainInfoSpanningTree` will store the path after generation so that the same path can be traversed multiple times. This will be useful to support use cases like new `computeAt`. Pseudo-code:
```C++
void TensorView::computeAt(TensorView tv, int pos) {
  auto ComputeAtSubgraphSelector selector(this, tv);
  MaxRootDomainInfoSpanningTree path(tv, pos, &selector);
  TransformPropagator propagator(tv, pos);
  path.traverse(&propagator);
  ComputeAtPosPropagator ca_propagator(tv, pos);
  path.traverse(&ca_propagator);
}
```

* Revert scheduling changes. Cleanup only.

* Start drafting grid persistent kernels.

* Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)

Co-authored-by: Christian Sarofeen <[email protected]>

* Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)

* Fix div(Val, TensorView) (#1778)

* Fix div(scalar, tensor)

* lintrunner: clang-format

* Adding sibling path for MaxInfoSpanningTree (#1776)

The sibling path is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector. For example, when the producer of a Welford is excluded from the propagation section. See test `FusionTransformPropagateSelectorSibling_CUDA` for a detailed example. Besides, since we know that siblings should be transformed exactly the same, the sibling path is a perfect next hop for preserving information.

If you want a spanning tree without a sibling path, you can override `allowSibling` as `return false` in your selector;

* Save.

* Disable register reuse across serial broadcast ops (#1787)

Disable memory aliasing for inner sharing across serial broadcast.

* Fix isIntegralType error msg (#1789)

* Transform propagator skip replay when possible (#1782)

This comment in the code describes what this PR is doing:

```C++
  // Note: [Using multiple TransformPropagators]
  // There are cases that we use multiple TransformPropagators along different
  // spanning trees with different references in the same fusion. Some of these
  // spanning trees could overlap. In cases when there are overlapping nodes,
  // TransformPropagator needs to respect the replay of others, because the
  // current TransformPropagator might not contain the most amount of
  // information on how to do the correct transformation. The logic below tells
  // TransformPropagator to skip the replay when not necessary.
```

* Output allocate patch (#1790)

Caching strides along with sizes. This is to support current expand, which introduces non-contiguous output tensor

* Add SpanningTreePrinter (#1786)

* New compute at interface (#1743)

Rewrite of the compute at pass to rely on the new propagation mechanisms.

* Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)

* Some further cleanup for the new computeAt interface (#1793)

Revert MaxProducerPosUpdater to old algo.

* Use TransformPropagatorWithCheck in many tests (#1795)

* validateDomain in TransformPropagator (#1796)

* InlinePropagator please don't replay (#1797)

This PR makes `InlinePropagator` just set compute-at positions. It will not replay any tensor. If you want to replay, please use `TransformPropagator` and friends to do so.

Currently, `InlinePropagator` is already asserting no replay for standard and best effort compute at. So this PR is mostly about making most inlined compute at works as well.

This PR also does a lot of cleanups to remove the word "replay" from comments and variable and function names from `InlinePropagator`.

I also cleaned up `recordReplayedPos` and `retrieveReplayedPos`, now the logic is much easier to understand.

* Coding style cleanups (#1798)

Per offline discussion with @csarofeen, this PR does many renaming for better coding style: For all propagation-related things, I am now using the names `P2C` and `C2P` instead of `CasP` and `PasC`. Because "A as B" somewhat implies we want to replay A the same as B, but "B to A" sounds more general and is a better word for this case. Also, I modified the order of function arguments to match the order in its name. For example `PasC` should have `(producer, consumer)` or `(to, from)`, but not `(consumer, producer)` or `(from, to)`, and `C2P` should have `(consumer, producer)` or `(from, to)`, but not `(producer, consumer)` or `(to, from)`.

* Add parsing support for `_to_copy` to handle AMP casts. (#1756)

1. Add support for _to_copy() to support AMP casts.
2. refactored cast, accept none for dtype
3. python tests

Co-authored-by: jjsjann123 <[email protected]>

* MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)

* Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)

Co-authored-by: Christian Sarofeen <[email protected]>

* More cleanup on InlinePropagator (#1800)

I just realized that `InlinePropagator` can be further simplified because it no longer replays.

Since `InlinePropagator` is no longer doing replay, it is more like a "for each" problem rather than a propagation problem:

For each tensor `tv`, if we already know what is the max position of `tv` that is mapped to the reference tensor's selected outer dimensions(stored in `mapped_reference_pos_` in the code), setting the CA position is a very local operation, and is as simple as checking `tv` itself and all its consumers to determine the inline position.

`InlinePropagator` is not completely a "for each" problem only because the computation of `mapped_reference_pos_` is a propagation problem.

This cleanup reorganizes the code of `InlinePropagator` so it is clear that `InlinePropagator` is nothing but a two-step process:
Step 1: Do a propagation to find the `mapped_reference_pos_` for all tensors.
Step 2: For each tensor, check itself and its consumers to determine the CA position.

Conceptually, I would like to split step 1 with step 2. Because this split makes these concepts decoupled. Especially, this PR makes `mapped_reference_pos_` only contain info about the reference tensor, and is independent of the CA position (Currently, this is not true for best effort and most inlined computeAt without this PR). Now, in my view, `InlinePropagator` is conceptually very simple and easy to understand.

In terms of implementation, step 1 and step 2 can be interleaved, because when we don't need to know the `mapped_reference_pos_` for `tv`'s consumer in order to compute the CA position of `tv`. So a one-pass traverse could do both step 1 and step 2 altogether.

* Temporarily disable test requring large shared memory. (#1802)

* Grouping grid allreduces across iterations (#1755)

* Extend the grouped grid reduction kernel

The kernel itself should work with an arbitrary number of inputs, but
the underlying data structure, Tuple, still explicitly needs to be
specialized for the number of values, which is currently limited to 8.

* Check siblings in getMaxPosAll (#1805)

* remove dead indexing code (#1806)

* Broadcast in dim with expand (#1794)

Fixes #1788

Added expand in broadcast_in_dim to support expanding to concrete size. Note that we are not supporting dynamic shape for concrete size at this moment.

* spam nvrtc options (#1783)

TORCH_WARN on nvrtc debug option impacting performance.

Co-authored-by: Gao, Xiang <[email protected]>
Co-authored-by: S. Song <[email protected]>
Co-authored-by: Ivan Yashchuk <[email protected]>
Co-authored-by: Sergey Lebedev <[email protected]>
Co-authored-by: jjsjann123 <[email protected]>
Co-authored-by: Kevin Stephano <[email protected]>
Co-authored-by: Naoya Maruyama <[email protected]>
shmsong pushed a commit to shmsong/pytorch that referenced this pull request Jul 24, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (csarofeen#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (csarofeen#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (csarofeen#1811)
03180aa improve broadcast resolution (csarofeen#1792)
bee6c69 bug fix (csarofeen#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (csarofeen#1812)
de6b7ca Fix negative position in InlinePropagator (csarofeen#1813)
10a996c Remove redundant check in schedulePointwise (csarofeen#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (csarofeen#1441)
3ed8330 Kernel args patch to show zero_init buffer (csarofeen#1809)
037a75a Dropout prob extremal patch (csarofeen#1804)
282c429 spam nvrtc options (csarofeen#1783)
3ba6a5f Broadcast in dim with expand (csarofeen#1794)
fd4be12 remove dead indexing code (csarofeen#1806)
fa4e6a4 Check siblings in getMaxPosAll (csarofeen#1805)
025c840 Grouping grid allreduces across iterations (csarofeen#1755)
37c579e Temporarily disable test requring large shared memory. (csarofeen#1802)
5f375d0 More cleanup on InlinePropagator (csarofeen#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (csarofeen#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (csarofeen#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (csarofeen#1756)
ef04f6c Coding style cleanups (csarofeen#1798)
38c7f3c InlinePropagator please don't replay (csarofeen#1797)
3f2c263 validateDomain in TransformPropagator (csarofeen#1796)
c077085 Use TransformPropagatorWithCheck in many tests (csarofeen#1795)
d0d0908 Some further cleanup for the new computeAt interface (csarofeen#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (csarofeen#1791)
28cbaf9 New compute at interface (csarofeen#1743)
635ebfc Add SpanningTreePrinter (csarofeen#1786)
59f3c32 Output allocate patch (csarofeen#1790)
fe93bf5 Transform propagator skip replay when possible (csarofeen#1782)
ebf23a5 Fix isIntegralType error msg (csarofeen#1789)
0c82ecf Disable register reuse across serial broadcast ops (csarofeen#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (csarofeen#1776)
86f46aa Fix div(Val, TensorView) (csarofeen#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (csarofeen#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (csarofeen#1761)
```

[ghstack-poisoned]
shmsong pushed a commit to shmsong/pytorch that referenced this pull request Jul 24, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (csarofeen#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (csarofeen#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (csarofeen#1811)
03180aa improve broadcast resolution (csarofeen#1792)
bee6c69 bug fix (csarofeen#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (csarofeen#1812)
de6b7ca Fix negative position in InlinePropagator (csarofeen#1813)
10a996c Remove redundant check in schedulePointwise (csarofeen#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (csarofeen#1441)
3ed8330 Kernel args patch to show zero_init buffer (csarofeen#1809)
037a75a Dropout prob extremal patch (csarofeen#1804)
282c429 spam nvrtc options (csarofeen#1783)
3ba6a5f Broadcast in dim with expand (csarofeen#1794)
fd4be12 remove dead indexing code (csarofeen#1806)
fa4e6a4 Check siblings in getMaxPosAll (csarofeen#1805)
025c840 Grouping grid allreduces across iterations (csarofeen#1755)
37c579e Temporarily disable test requring large shared memory. (csarofeen#1802)
5f375d0 More cleanup on InlinePropagator (csarofeen#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (csarofeen#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (csarofeen#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (csarofeen#1756)
ef04f6c Coding style cleanups (csarofeen#1798)
38c7f3c InlinePropagator please don't replay (csarofeen#1797)
3f2c263 validateDomain in TransformPropagator (csarofeen#1796)
c077085 Use TransformPropagatorWithCheck in many tests (csarofeen#1795)
d0d0908 Some further cleanup for the new computeAt interface (csarofeen#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (csarofeen#1791)
28cbaf9 New compute at interface (csarofeen#1743)
635ebfc Add SpanningTreePrinter (csarofeen#1786)
59f3c32 Output allocate patch (csarofeen#1790)
fe93bf5 Transform propagator skip replay when possible (csarofeen#1782)
ebf23a5 Fix isIntegralType error msg (csarofeen#1789)
0c82ecf Disable register reuse across serial broadcast ops (csarofeen#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (csarofeen#1776)
86f46aa Fix div(Val, TensorView) (csarofeen#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (csarofeen#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (csarofeen#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
csarofeen pushed a commit that referenced this pull request Aug 4, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)
Pull Request resolved: pytorch#81861
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants