Skip to content

Improve matmul instruction scheduling with loop rotation #2488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
6433b65
Loop rotation WIP
zasdfgbnm Feb 17, 2023
c8c40db
save
zasdfgbnm Feb 18, 2023
9c544d8
save
zasdfgbnm Feb 18, 2023
62fc4de
use saved ptr
zasdfgbnm Feb 18, 2023
ee0af68
lower predicate
zasdfgbnm Feb 18, 2023
c48f27b
more fixes
zasdfgbnm Feb 18, 2023
21787cf
move param to fusion
zasdfgbnm Feb 19, 2023
cdca970
indexing and predicate - step 1
zasdfgbnm Feb 19, 2023
ed2f9f1
cleanup
zasdfgbnm Feb 19, 2023
301c7c5
working
zasdfgbnm Feb 19, 2023
da6a426
Merge branch 'devel' of github.com:csarofeen/pytorch into loop-rotation
zasdfgbnm Feb 19, 2023
f868b7a
misc improvements
zasdfgbnm Feb 19, 2023
e445a0f
prepare matmul schdule for loop rotation
zasdfgbnm Feb 19, 2023
caafc29
save
zasdfgbnm Feb 19, 2023
5e87b2f
save
zasdfgbnm Feb 19, 2023
0ffdeb6
save
zasdfgbnm Feb 19, 2023
645d4bd
save
zasdfgbnm Feb 19, 2023
b56ee20
setAssertOutOfBound
zasdfgbnm Feb 19, 2023
3b9531c
fix
zasdfgbnm Feb 19, 2023
d09a2a0
sass test
zasdfgbnm Feb 19, 2023
76c5326
double buffer fixes
zasdfgbnm Feb 20, 2023
d25e941
fix
zasdfgbnm Feb 20, 2023
4581cb5
save
zasdfgbnm Feb 20, 2023
7e91718
fix
zasdfgbnm Feb 20, 2023
b92cbbf
fixes
zasdfgbnm Feb 20, 2023
4c5ff2b
rename
zasdfgbnm Feb 20, 2023
51d7311
Do not assert not used
zasdfgbnm Feb 20, 2023
325abad
fix
zasdfgbnm Feb 20, 2023
676b791
remove predicates
zasdfgbnm Feb 20, 2023
8eff58c
save
zasdfgbnm Feb 20, 2023
044889c
save
zasdfgbnm Feb 20, 2023
a706ee4
cleanup
zasdfgbnm Feb 20, 2023
e7eadc3
Merge branch 'devel' of github.com:csarofeen/pytorch into loop-rotation
zasdfgbnm Feb 21, 2023
26d2509
Merge branch 'devel' of github.com:csarofeen/pytorch into loop-rotation
zasdfgbnm Mar 1, 2023
3ea82fe
save
zasdfgbnm Mar 1, 2023
cbb60a2
fix
zasdfgbnm Mar 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions third_party/nvfuser/csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,12 @@ void scheduleMatmul(
acr->axis(-1)->parallelize(ParallelType::Vectorize);
bcr->axis(-1)->parallelize(ParallelType::Vectorize);

// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Kwo Mwo Nwo Mw Nw (Mi Ni Ki)]
cc->axis(0)->parallelize(ParallelType::BIDx);
cc->axis(1)->parallelize(ParallelType::BIDy);
cc->axis(3)->parallelize(ParallelType::TIDz);
cc->axis(4)->parallelize(ParallelType::TIDy);
cc->axis(4)->parallelize(ParallelType::TIDz);
cc->axis(5)->parallelize(ParallelType::TIDy);

// Propagate mma output swizzle and parallelization down the DAG
if (params.double_buffer_options.double_buffer_smem_write) {
Expand Down Expand Up @@ -318,6 +318,11 @@ void scheduleMatmul(
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());

if (params.double_buffer_options.double_buffer_smem_read &&
params.double_buffer_options.double_buffer_smem_write) {
scheduler_utils::rotateLoop(cc, 2, {acr, bcr});
}
}

} // namespace nvfuser
3 changes: 3 additions & 0 deletions third_party/nvfuser/csrc/scheduler/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class MatmulParam {
int smem_double_buffer_stage = 2;
};

//! Whether to rotate the ldmatrix out of the main loop
bool rotate_ldmatrix_out_of_main_loop = true;

//! (Ampere+) Use cp.async to load operands.
bool async_gmem_load_operands = false;

Expand Down
14 changes: 7 additions & 7 deletions third_party/nvfuser/csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1571,12 +1571,12 @@ void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) {
tv->split(-2, instruction_tile.n);
tv->split(-1, instruction_tile.k);

// -8 -7 -6 -5 -4 -3 -2 -1
// [Mwo Mw Mi Nwo Nw Ni Ko Ki]
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mwo Mw Mi Nwo Nw Ni Kwo Ki]

tv->reorder({{-7, -5}, {-6, -3}, {-5, -7}, {-3, -2}, {-2, -6}});
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mwo Nwo Ko Mw Nw Mi Ni Ki]
tv->reorder({{-7, -5}, {-6, -3}, {-5, -6}, {-3, -2}, {-2, -8}, {-8, -7}});
// -8 -7 -6 -5 -4 -3 -2 -1
// [Kwo Mwo Nwo Mw Nw Mi Ni Ki]
} else {
// Split K over warp case:
// Main difference is that an additional
Expand All @@ -1589,8 +1589,8 @@ void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) {
tv->split(-2, warp_tile.n);
tv->split(-1, warp_tile.k);

// -6 -5 -4 -3 -2 -1
// [Mwo Mw Nwo Nw K, Kw]
// -6 -5 -4 -3 -2 -1
// [Mwo Mw Nwo Nw Kwo Kw]
tv->split(-5, instruction_tile.m);
tv->split(-3, instruction_tile.n);
tv->split(-1, instruction_tile.k);
Expand Down
3 changes: 2 additions & 1 deletion third_party/nvfuser/runtime/tensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ struct Tensor {
for (int i = 0; i < N; i++) {
max_ind += (size[i] - 1) * stride[i];
}
assert(ind >= 0 && ind <= max_ind);
assert(ind >= 0);
assert(ind <= max_ind);
Comment on lines +10 to +11
Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Mar 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR, but asserting different conditions separately provides a better error message. (The line number in the error message will tell me which is violated).

#endif
return data[ind];
};
Expand Down
36 changes: 29 additions & 7 deletions third_party/nvfuser/test/test_gpu_matmul_sass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,27 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSASSModifiersCheck_CUDA) {
using T = std::decay_t<decltype(i)>;
if constexpr (std::is_same_v<sass::Instruction, T>) {
if (i.opCode() == "LDGSTS") {
const std::vector<std::string> expect = {"E", "BYPASS", "128"};
const std::vector<std::string> expect = {
"E", "BYPASS", "LTC128B", "128"};
TORCH_CHECK(
i.modifiers() == expect,
"Modifiers for LDGSTS has changed. "
"Please manually check if the new modifiers makes sense and update this test.");
"Please manually check if the new modifiers makes sense and update this test. "
"Expect: ",
expect,
" Get: ",
i.modifiers());
found_LDGSTS = true;
} else if (i.opCode() == "LDGDEPBAR") {
const std::vector<std::string> expect;
TORCH_CHECK(
i.modifiers() == expect,
"Modifiers for LDGDEPBAR has changed. "
"Please manually check if the new modifiers makes sense and update this test.");
"Please manually check if the new modifiers makes sense and update this test. "
"Expect: ",
expect,
" Get: ",
i.modifiers());
found_LDGDEPBAR = true;
} else if (i.opCode() == "LDSM") {
const std::vector<std::string> expect1 = {"16", "M88", "2"};
Expand All @@ -185,21 +194,34 @@ TEST_F(NVFuserTest, FusionAmpereMatmulSASSModifiersCheck_CUDA) {
TORCH_CHECK(
i.modifiers() == expect,
"Modifiers for HMMA has changed. "
"Please manually check if the new modifiers makes sense and update this test.");
"Please manually check if the new modifiers makes sense and update this test. "
"Expect: ",
expect,
" Get: ",
i.modifiers());
found_HMMA = true;
} else if (i.opCode() == "BAR") {
const std::vector<std::string> expect = {"SYNC"};
const std::vector<std::string> expect = {
"SYNC", "DEFER_BLOCKING"};
TORCH_CHECK(
i.modifiers() == expect,
"Modifiers for BAR has changed. "
"Please manually check if the new modifiers makes sense and update this test.");
"Please manually check if the new modifiers makes sense and update this test. "
"Expect: ",
expect,
" Get: ",
i.modifiers());
found_BAR = true;
} else if (i.opCode() == "DEPBAR") {
const std::vector<std::string> expect = {"LE"};
TORCH_CHECK(
i.modifiers() == expect,
"Modifiers for DEPBAR has changed. "
"Please manually check if the new modifiers makes sense and update this test.");
"Please manually check if the new modifiers makes sense and update this test. "
"Expect: ",
expect,
" Get: ",
i.modifiers());
found_DEPBAR = true;
}
}
Expand Down
92 changes: 48 additions & 44 deletions third_party/nvfuser/test/test_gpu_tensorcore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) {
scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
tv4, gemm_tile2);
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
// [Mo No Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki]
tv3cr->computeAt(tv4c, -4);
tv2cr->computeAt(tv4c, -4);

Expand Down Expand Up @@ -1023,8 +1023,8 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) {
// 0 1 2 3 4 5 6 7
// [Mo No Mwo Nwo Mw Nw (Mi Ni)]
// Gemm 1
tv3c->axis(3)->parallelize(ParallelType::TIDz);
tv3c->axis(4)->parallelize(ParallelType::TIDy);
tv3c->axis(4)->parallelize(ParallelType::TIDz);
tv3c->axis(5)->parallelize(ParallelType::TIDy);

tv3->computeAt(tv3cw, -2);
tv3cw->axis(2)->parallelize(ParallelType::TIDz);
Expand All @@ -1033,8 +1033,8 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) {
// Gemm 2
tv4->axis(2)->parallelize(ParallelType::TIDz);
tv4->axis(3)->parallelize(ParallelType::TIDy);
tv4c->axis(3)->parallelize(ParallelType::TIDz);
tv4c->axis(4)->parallelize(ParallelType::TIDy);
tv4c->axis(4)->parallelize(ParallelType::TIDz);
tv4c->axis(5)->parallelize(ParallelType::TIDy);

tv4->axis(0)->parallelize(ParallelType::BIDx);
tv4->axis(1)->parallelize(ParallelType::BIDy);
Expand Down Expand Up @@ -1211,8 +1211,8 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) {
scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv4c, gemm_tile);
scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
tv4, gemm_tile);
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki]
tv3cr->computeAt(tv4c, -4);
tv2cr->computeAt(tv4c, -4);

Expand Down Expand Up @@ -1388,8 +1388,8 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) {
// 0 1 2 3 4 5 6 7
// [Mo No Mwo Nwo Mw Nw (Mi Ni)]
// Gemm 1
tv3c->axis(3)->parallelize(ParallelType::TIDz);
tv3c->axis(4)->parallelize(ParallelType::TIDy);
tv3c->axis(4)->parallelize(ParallelType::TIDz);
tv3c->axis(5)->parallelize(ParallelType::TIDy);
tv3->axis(2)->parallelize(ParallelType::TIDz);
tv3->axis(3)->parallelize(ParallelType::TIDy);

Expand Down Expand Up @@ -1421,8 +1421,8 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) {
// Gemm 2
tv4->axis(2)->parallelize(ParallelType::TIDz);
tv4->axis(3)->parallelize(ParallelType::TIDy);
tv4c->axis(3)->parallelize(ParallelType::TIDz);
tv4c->axis(4)->parallelize(ParallelType::TIDy);
tv4c->axis(4)->parallelize(ParallelType::TIDz);
tv4c->axis(5)->parallelize(ParallelType::TIDy);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
auto t0 = at::randn({M1, K1}, options);
Expand Down Expand Up @@ -1789,8 +1789,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNcpAsync_CUDA) {
scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile);
scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
tv2, gemm_tile);
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki]
tv0cr->computeAt(tv2c, -4);
tv1cr->computeAt(tv2c, -4);

Expand Down Expand Up @@ -1827,10 +1827,10 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNcpAsync_CUDA) {
mma_builder.operand(MmaOptions::Operand::Accumulator).build());

// Parallelize
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
tv2c->axis(3)->parallelize(ParallelType::TIDz);
tv2c->axis(4)->parallelize(ParallelType::TIDy);
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Kwo Mwo Nwo Mw Nw (Mi Ni Ki)]
tv2c->axis(4)->parallelize(ParallelType::TIDz);
tv2c->axis(5)->parallelize(ParallelType::TIDy);

// Parallelize
// 0 1 2 3 4 5 6 7
Expand Down Expand Up @@ -1949,8 +1949,8 @@ TEST_F(NVFuserTest, FusionAmpereStridedBatchedMatmulTN_CUDA) {
scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile);
scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
tv2, gemm_tile);
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki]
tv0cr->computeAt(tv2c, -4);
tv1cr->computeAt(tv2c, -4);

Expand Down Expand Up @@ -1994,10 +1994,10 @@ TEST_F(NVFuserTest, FusionAmpereStridedBatchedMatmulTN_CUDA) {
mma_builder.operand(MmaOptions::Operand::Accumulator).build());

// Parallelize
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
tv2c->axis(3)->parallelize(ParallelType::TIDz);
tv2c->axis(4)->parallelize(ParallelType::TIDy);
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Kwo Mwo Nwo Mw Nw (Mi Ni Ki)]
tv2c->axis(4)->parallelize(ParallelType::TIDz);
tv2c->axis(5)->parallelize(ParallelType::TIDy);

// Parallelize
// 0 1 2 3 4 5 6 7
Expand Down Expand Up @@ -2116,8 +2116,8 @@ TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) {
scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile);
scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
tv2, gemm_tile);
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki]
tv0cr->computeAt(tv2c, -4);
tv1cr->computeAt(tv2c, -4);

Expand Down Expand Up @@ -2166,10 +2166,10 @@ TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) {
tv0_reshape->computeAt(tv0cw, -2);

// Parallelize
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
tv2c->axis(3)->parallelize(ParallelType::TIDz);
tv2c->axis(4)->parallelize(ParallelType::TIDy);
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Kwo Mwo Nwo Mw Nw (Mi Ni Ki)]
tv2c->axis(4)->parallelize(ParallelType::TIDz);
tv2c->axis(5)->parallelize(ParallelType::TIDy);

// Parallelize
// 0 1 2 3 4 5 6 7
Expand Down Expand Up @@ -2199,7 +2199,7 @@ TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) {
}

// Initial test case for in-CTA split K with VoltaMMA
TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossWarp_CUDA) {
TEST_F(NVFuserTest, FusionVoltaMatmulTNCrossWarp_CUDA) {
NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0);

Fusion fusion;
Expand Down Expand Up @@ -2361,7 +2361,7 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossWarp_CUDA) {
}

// Initial test case for cross-CTA split K with VoltaMMA
TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossCTA_CUDA) {
TEST_F(NVFuserTest, FusionVoltaMatmulTNCrossCTA_CUDA) {
NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0);

Fusion fusion;
Expand Down Expand Up @@ -2436,7 +2436,9 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossCTA_CUDA) {
// Make warp tile:
// -------------------------------------------------------------------------
scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile);
auto tv2c_rf = tv2c->rFactor({-9, -6, -1});
// -9 -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No K2CTA Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki]
auto tv2c_rf = tv2c->rFactor({-9, -8, -1});

// tv2c_rf is the actual output of the mma op after
// Rfactoring.
Expand All @@ -2445,8 +2447,8 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossCTA_CUDA) {
scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
tv2, gemm_tile);

// -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No K2CTA Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki]
tv0cr->computeAt(tv2c_rf, -4);
tv1cr->computeAt(tv2c_rf, -4);

Expand Down Expand Up @@ -2496,14 +2498,16 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossCTA_CUDA) {
tv0cr->axis(-1)->parallelize(ParallelType::Vectorize);
tv1cr->axis(-1)->parallelize(ParallelType::Vectorize);
// Parallelize
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
// 0 1 2 3 4 5 6 7 8 9 10 11
// [Mo No K2CTA Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki]
tv2c_rf->axis(0)->parallelize(ParallelType::BIDx);
tv2c_rf->axis(1)->parallelize(ParallelType::BIDy);
tv2c_rf->axis(2)->parallelize(ParallelType::BIDz);
tv2c_rf->axis(4)->parallelize(ParallelType::TIDz);
tv2c_rf->axis(5)->parallelize(ParallelType::TIDy);
tv2c_rf->axis(5)->parallelize(ParallelType::TIDz);
tv2c_rf->axis(6)->parallelize(ParallelType::TIDy);

// 0 1 2 3 4 5 6 7 8
// [Mo No K2CTA Mwo Nwo Mw Nw Mi Ni]
tv2c->axis(0)->parallelize(ParallelType::BIDx);
tv2c->axis(1)->parallelize(ParallelType::BIDy);
tv2c->axis(2)->parallelize(ParallelType::BIDz);
Expand Down Expand Up @@ -2605,8 +2609,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) {
scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile);
scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction(
tv2, gemm_tile);
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki]
// -8 -7 -6 -5 -4 -3 -2 -1
// [Mo No Ko Kwo Mwo Nwo Mw Nw Mi Ni Ki]
tv0cr->computeAt(tv2c, -4);
tv1cr->computeAt(tv2c, -4);

Expand Down Expand Up @@ -2675,10 +2679,10 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) {
mma_builder.operand(MmaOptions::Operand::Accumulator).build());

// Parallelize
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
tv2c->axis(3)->parallelize(ParallelType::TIDz);
tv2c->axis(4)->parallelize(ParallelType::TIDy);
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Kwo Mwo Nwo Mw Nw (Mi Ni Ki)]
tv2c->axis(4)->parallelize(ParallelType::TIDz);
tv2c->axis(5)->parallelize(ParallelType::TIDy);

// Parallelize
// 0 1 2 3 4 5 6 7
Expand Down